将训练模型信息保存到数据库
This commit is contained in:
parent
3aaddcd658
commit
a02bc11921
@ -35,48 +35,49 @@
|
|||||||
</el-form>
|
</el-form>
|
||||||
|
|
||||||
<el-table :data="models" stripe v-loading="loading">
|
<el-table :data="models" stripe v-loading="loading">
|
||||||
<el-table-column prop="model_id" label="模型ID" min-width="150">
|
<el-table-column prop="display_name" label="模型名称" min-width="180">
|
||||||
|
<template #default="{ row }">
|
||||||
|
<el-tooltip :content="`UID: ${row.model_uid}`" placement="top">
|
||||||
|
<span>{{ row.display_name }}</span>
|
||||||
|
</el-tooltip>
|
||||||
|
</template>
|
||||||
|
</el-table-column>
|
||||||
|
<el-table-column prop="model_type" label="模型类型" min-width="120">
|
||||||
<template #default="{ row }">
|
<template #default="{ row }">
|
||||||
<el-tag size="small" :type="getModelTagType(row.model_type)">
|
<el-tag size="small" :type="getModelTagType(row.model_type)">
|
||||||
{{ row.model_type }}
|
{{ row.model_type }}
|
||||||
</el-tag>
|
</el-tag>
|
||||||
<span style="margin-left: 8px">{{ row.product_id }}</span>
|
|
||||||
</template>
|
</template>
|
||||||
</el-table-column>
|
</el-table-column>
|
||||||
<el-table-column prop="product_name" label="产品名称" min-width="120"></el-table-column>
|
<el-table-column prop="training_mode" label="训练模式" min-width="120"></el-table-column>
|
||||||
<el-table-column prop="created_at" label="创建时间" min-width="160">
|
<el-table-column prop="created_at" label="创建时间" min-width="160">
|
||||||
<template #default="{ row }">
|
<template #default="{ row }">
|
||||||
{{ formatDateTime(row.created_at) }}
|
{{ formatDateTime(row.created_at) }}
|
||||||
</template>
|
</template>
|
||||||
</el-table-column>
|
</el-table-column>
|
||||||
<el-table-column label="评估指标" min-width="200">
|
<el-table-column label="评估指标" min-width="220">
|
||||||
<template #default="{ row }">
|
<template #default="{ row }">
|
||||||
<div class="metrics-display">
|
<div class="metrics-display" v-if="row.performance_metrics && Object.keys(row.performance_metrics).length > 0">
|
||||||
<el-tooltip effect="dark" content="R平方值" placement="top">
|
<el-tooltip effect="dark" content="R平方值" placement="top">
|
||||||
<div class="metric-item">
|
<div class="metric-item">
|
||||||
<span class="metric-label">R²:</span>
|
<span class="metric-label">R²:</span>
|
||||||
<span class="metric-value">{{ row.metrics?.R2?.toFixed(4) || 'N/A' }}</span>
|
<span class="metric-value">{{ row.performance_metrics.R2?.toFixed(4) || 'N/A' }}</span>
|
||||||
</div>
|
</div>
|
||||||
</el-tooltip>
|
</el-tooltip>
|
||||||
<el-tooltip effect="dark" content="均方根误差" placement="top">
|
<el-tooltip effect="dark" content="均方根误差" placement="top">
|
||||||
<div class="metric-item">
|
<div class="metric-item">
|
||||||
<span class="metric-label">RMSE:</span>
|
<span class="metric-label">RMSE:</span>
|
||||||
<span class="metric-value">{{ row.metrics?.RMSE?.toFixed(4) || 'N/A' }}</span>
|
<span class="metric-value">{{ row.performance_metrics.RMSE?.toFixed(4) || 'N/A' }}</span>
|
||||||
</div>
|
</div>
|
||||||
</el-tooltip>
|
</el-tooltip>
|
||||||
<el-tooltip effect="dark" content="平均绝对误差" placement="top">
|
<el-tooltip effect="dark" content="平均绝对误差" placement="top">
|
||||||
<div class="metric-item">
|
<div class="metric-item">
|
||||||
<span class="metric-label">MAE:</span>
|
<span class="metric-label">MAE:</span>
|
||||||
<span class="metric-value">{{ row.metrics?.MAE?.toFixed(4) || 'N/A' }}</span>
|
<span class="metric-value">{{ row.performance_metrics.MAE?.toFixed(4) || 'N/A' }}</span>
|
||||||
</div>
|
|
||||||
</el-tooltip>
|
|
||||||
<el-tooltip effect="dark" content="平均绝对百分比误差" placement="top">
|
|
||||||
<div class="metric-item">
|
|
||||||
<span class="metric-label">MAPE:</span>
|
|
||||||
<span class="metric-value">{{ row.metrics?.MAPE?.toFixed(2) || 'N/A' }}%</span>
|
|
||||||
</div>
|
</div>
|
||||||
</el-tooltip>
|
</el-tooltip>
|
||||||
</div>
|
</div>
|
||||||
|
<span v-else>无</span>
|
||||||
</template>
|
</template>
|
||||||
</el-table-column>
|
</el-table-column>
|
||||||
<el-table-column label="操作" width="280" fixed="right">
|
<el-table-column label="操作" width="280" fixed="right">
|
||||||
@ -254,7 +255,7 @@ const fetchModels = async () => {
|
|||||||
const response = await axios.get('/api/models', { params })
|
const response = await axios.get('/api/models', { params })
|
||||||
if (response.data.status === 'success') {
|
if (response.data.status === 'success') {
|
||||||
models.value = response.data.data.map(model => {
|
models.value = response.data.data.map(model => {
|
||||||
model.metrics = normalizeMetricsKeys(model.metrics);
|
model.performance_metrics = normalizeMetricsKeys(model.performance_metrics);
|
||||||
return model;
|
return model;
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -291,7 +292,23 @@ const viewDetails = async (model) => {
|
|||||||
detailsDialogVisible.value = true;
|
detailsDialogVisible.value = true;
|
||||||
selectedModelDetails.value = null; // 重置
|
selectedModelDetails.value = null; // 重置
|
||||||
try {
|
try {
|
||||||
const response = await axios.get(`/api/models/${model.model_type}/${model.product_id}/details`);
|
// 新逻辑: 直接使用行数据,因为列表API已返回足够信息
|
||||||
|
const details = {
|
||||||
|
model_info: {
|
||||||
|
model_id: model.model_uid,
|
||||||
|
model_type: model.model_type,
|
||||||
|
product_name: model.display_name,
|
||||||
|
created_at: model.created_at,
|
||||||
|
},
|
||||||
|
training_metrics: model.performance_metrics,
|
||||||
|
chart_data: {
|
||||||
|
loss_chart: model.artifacts?.loss_curve_data || { epochs: [], train_loss: [], test_loss: [] }
|
||||||
|
}
|
||||||
|
};
|
||||||
|
selectedModelDetails.value = details;
|
||||||
|
nextTick(() => {
|
||||||
|
initLossChart();
|
||||||
|
});
|
||||||
if (response.data.status === 'success') {
|
if (response.data.status === 'success') {
|
||||||
const details = response.data.data;
|
const details = response.data.data;
|
||||||
details.training_metrics = normalizeMetricsKeys(details.training_metrics);
|
details.training_metrics = normalizeMetricsKeys(details.training_metrics);
|
||||||
@ -311,7 +328,7 @@ const viewDetails = async (model) => {
|
|||||||
|
|
||||||
const deleteModel = async (model) => {
|
const deleteModel = async (model) => {
|
||||||
try {
|
try {
|
||||||
const response = await axios.delete(`/api/models/${model.model_id}`)
|
const response = await axios.delete(`/api/models/${model.model_uid}`)
|
||||||
if (response.data.status === 'success') {
|
if (response.data.status === 'success') {
|
||||||
ElMessage.success('模型删除成功')
|
ElMessage.success('模型删除成功')
|
||||||
fetchModels()
|
fetchModels()
|
||||||
@ -325,7 +342,7 @@ const deleteModel = async (model) => {
|
|||||||
|
|
||||||
const exportModel = async (model) => {
|
const exportModel = async (model) => {
|
||||||
try {
|
try {
|
||||||
const response = await axios.get(`/api/models/${model.model_id}/export`, { responseType: 'blob' })
|
const response = await axios.get(`/api/models/${model.model_uid}/export`, { responseType: 'blob' })
|
||||||
const url = window.URL.createObjectURL(new Blob([response.data]))
|
const url = window.URL.createObjectURL(new Blob([response.data]))
|
||||||
const link = document.createElement('a')
|
const link = document.createElement('a')
|
||||||
link.href = url
|
link.href = url
|
||||||
|
@ -96,17 +96,18 @@
|
|||||||
<el-col :span="6">
|
<el-col :span="6">
|
||||||
<el-form-item label="模型版本">
|
<el-form-item label="模型版本">
|
||||||
<el-select
|
<el-select
|
||||||
v-model="form.version"
|
v-model="form.model_uid"
|
||||||
placeholder="选择版本"
|
placeholder="选择一个具体的模型"
|
||||||
style="width: 100%"
|
style="width: 100%"
|
||||||
:disabled="!availableVersions.length"
|
:disabled="!availableVersions.length"
|
||||||
:loading="versionsLoading"
|
:loading="versionsLoading"
|
||||||
|
value-key="model_uid"
|
||||||
>
|
>
|
||||||
<el-option
|
<el-option
|
||||||
v-for="version in availableVersions"
|
v-for="model in availableVersions"
|
||||||
:key="version"
|
:key="model.model_uid"
|
||||||
:label="version"
|
:label="`${model.display_name} (v${model.version})`"
|
||||||
:value="version"
|
:value="model.model_uid"
|
||||||
/>
|
/>
|
||||||
</el-select>
|
</el-select>
|
||||||
<div class="version-info">
|
<div class="version-info">
|
||||||
@ -176,16 +177,6 @@
|
|||||||
开始预测
|
开始预测
|
||||||
</el-button>
|
</el-button>
|
||||||
|
|
||||||
<el-button
|
|
||||||
v-if="predictionResult"
|
|
||||||
type="success"
|
|
||||||
size="large"
|
|
||||||
@click="savePrediction"
|
|
||||||
:loading="saving"
|
|
||||||
>
|
|
||||||
<el-icon><Download /></el-icon>
|
|
||||||
保存结果
|
|
||||||
</el-button>
|
|
||||||
</div>
|
</div>
|
||||||
</el-card>
|
</el-card>
|
||||||
|
|
||||||
@ -274,7 +265,7 @@ const form = reactive({
|
|||||||
product_id: '',
|
product_id: '',
|
||||||
store_id: '',
|
store_id: '',
|
||||||
model_type: '',
|
model_type: '',
|
||||||
version: '',
|
model_uid: '',
|
||||||
future_days: 7,
|
future_days: 7,
|
||||||
start_date: '',
|
start_date: '',
|
||||||
analyze_result: true
|
analyze_result: true
|
||||||
@ -282,18 +273,8 @@ const form = reactive({
|
|||||||
|
|
||||||
// 计算属性
|
// 计算属性
|
||||||
const canPredict = computed(() => {
|
const canPredict = computed(() => {
|
||||||
const baseCheck = form.training_mode && form.model_type && form.version
|
return !!form.model_uid;
|
||||||
|
});
|
||||||
if (form.training_mode === 'product') {
|
|
||||||
return baseCheck && form.product_id
|
|
||||||
} else if (form.training_mode === 'store') {
|
|
||||||
return baseCheck && form.store_id
|
|
||||||
} else if (form.training_mode === 'global') {
|
|
||||||
return baseCheck
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
})
|
|
||||||
|
|
||||||
const predictionTableData = computed(() => {
|
const predictionTableData = computed(() => {
|
||||||
if (!predictionResult.value || !predictionResult.value.predictions) return []
|
if (!predictionResult.value || !predictionResult.value.predictions) return []
|
||||||
@ -329,24 +310,21 @@ const fetchAvailableVersions = async () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
versionsLoading.value = true
|
versionsLoading.value = true;
|
||||||
let url = ''
|
const params = {
|
||||||
|
training_mode: form.training_mode,
|
||||||
|
model_type: form.model_type,
|
||||||
|
product_id: form.training_mode === 'product' ? form.product_id : undefined,
|
||||||
|
store_id: form.training_mode === 'store' ? form.store_id : undefined,
|
||||||
|
page_size: 100 // 获取足够多的模型以供选择
|
||||||
|
};
|
||||||
|
|
||||||
if (form.training_mode === 'product' && form.product_id) {
|
const response = await axios.get('/api/models', { params });
|
||||||
url = `/api/models/${form.product_id}/${form.model_type}/versions`
|
|
||||||
} else if (form.training_mode === 'store' && form.store_id) {
|
|
||||||
url = `/api/models/store/${form.store_id}/${form.model_type}/versions`
|
|
||||||
} else if (form.training_mode === 'global') {
|
|
||||||
url = `/api/models/global/${form.model_type}/versions`
|
|
||||||
}
|
|
||||||
|
|
||||||
if (url) {
|
|
||||||
const response = await axios.get(url)
|
|
||||||
if (response.data.status === 'success') {
|
if (response.data.status === 'success') {
|
||||||
availableVersions.value = response.data.data.versions || []
|
availableVersions.value = response.data.data || [];
|
||||||
if (response.data.data.latest_version) {
|
if (availableVersions.value.length > 0) {
|
||||||
form.version = response.data.data.latest_version
|
// 默认选中第一个模型
|
||||||
}
|
form.model_uid = availableVersions.value[0].model_uid;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
@ -361,22 +339,22 @@ const handleTrainingModeChange = () => {
|
|||||||
form.product_id = ''
|
form.product_id = ''
|
||||||
form.store_id = ''
|
form.store_id = ''
|
||||||
form.model_type = ''
|
form.model_type = ''
|
||||||
form.version = ''
|
form.model_uid = ''
|
||||||
availableVersions.value = []
|
availableVersions.value = []
|
||||||
}
|
}
|
||||||
|
|
||||||
const handleProductChange = () => {
|
const handleProductChange = () => {
|
||||||
form.version = ''
|
form.model_uid = ''
|
||||||
fetchAvailableVersions()
|
fetchAvailableVersions()
|
||||||
}
|
}
|
||||||
|
|
||||||
const handleStoreChange = () => {
|
const handleStoreChange = () => {
|
||||||
form.version = ''
|
form.model_uid = ''
|
||||||
fetchAvailableVersions()
|
fetchAvailableVersions()
|
||||||
}
|
}
|
||||||
|
|
||||||
const handleModelTypeChange = () => {
|
const handleModelTypeChange = () => {
|
||||||
form.version = ''
|
form.model_uid = ''
|
||||||
fetchAvailableVersions()
|
fetchAvailableVersions()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -385,21 +363,14 @@ const startPrediction = async () => {
|
|||||||
predicting.value = true
|
predicting.value = true
|
||||||
|
|
||||||
const payload = {
|
const payload = {
|
||||||
model_type: form.model_type,
|
model_uid: form.model_uid,
|
||||||
version: form.version,
|
|
||||||
future_days: form.future_days,
|
future_days: form.future_days,
|
||||||
start_date: form.start_date,
|
start_date: form.start_date,
|
||||||
analyze_result: form.analyze_result
|
include_visualization: form.analyze_result, // 后端现在叫 include_visualization
|
||||||
}
|
history_lookback_days: 30
|
||||||
|
};
|
||||||
|
|
||||||
// 根据训练模式添加相应参数
|
const response = await axios.post('/api/prediction', payload);
|
||||||
if (form.training_mode === 'product') {
|
|
||||||
payload.product_id = form.product_id
|
|
||||||
} else if (form.training_mode === 'store') {
|
|
||||||
payload.store_id = form.store_id
|
|
||||||
}
|
|
||||||
|
|
||||||
const response = await axios.post('/api/predict', payload)
|
|
||||||
|
|
||||||
if (response.data.status === 'success') {
|
if (response.data.status === 'success') {
|
||||||
predictionResult.value = response.data.data
|
predictionResult.value = response.data.data
|
||||||
@ -470,29 +441,6 @@ const renderChart = () => {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
const savePrediction = async () => {
|
|
||||||
try {
|
|
||||||
saving.value = true
|
|
||||||
|
|
||||||
const saveData = {
|
|
||||||
...predictionResult.value,
|
|
||||||
training_mode: form.training_mode,
|
|
||||||
parameters: { ...form }
|
|
||||||
}
|
|
||||||
|
|
||||||
const response = await axios.post('/api/predictions/save', saveData)
|
|
||||||
|
|
||||||
if (response.data.status === 'success') {
|
|
||||||
ElMessage.success('预测结果已保存')
|
|
||||||
} else {
|
|
||||||
ElMessage.error('保存失败')
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
ElMessage.error('保存请求失败')
|
|
||||||
} finally {
|
|
||||||
saving.value = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 辅助方法
|
// 辅助方法
|
||||||
const getPredictionInfoText = () => {
|
const getPredictionInfoText = () => {
|
||||||
|
@ -129,33 +129,20 @@ const pagination = reactive({
|
|||||||
pageSize: 8
|
pageSize: 8
|
||||||
})
|
})
|
||||||
|
|
||||||
const storeNameMap = computed(() => {
|
|
||||||
return stores.value.reduce((acc, store) => {
|
|
||||||
acc[store.store_id] = store.store_name
|
|
||||||
return acc
|
|
||||||
}, {})
|
|
||||||
})
|
|
||||||
|
|
||||||
const modelsWithNames = computed(() => {
|
|
||||||
return modelList.value.map(model => ({
|
|
||||||
...model,
|
|
||||||
store_name: storeNameMap.value[model.store_id] || model.store_id
|
|
||||||
}))
|
|
||||||
})
|
|
||||||
|
|
||||||
const filteredModelList = computed(() => {
|
const filteredModelList = computed(() => {
|
||||||
return modelsWithNames.value.filter(model => {
|
return modelList.value.filter(model => {
|
||||||
const storeMatch = !filters.store_id || model.store_id === filters.store_id
|
// 简化逻辑:直接使用后端返回的数据进行筛选
|
||||||
const modelTypeMatch = !filters.model_type || model.model_type === filters.model_type.id
|
const storeMatch = !filters.store_id || (model.training_scope?.store?.id === filters.store_id);
|
||||||
return storeMatch && modelTypeMatch
|
const modelTypeMatch = !filters.model_type || model.model_type === filters.model_type.id;
|
||||||
})
|
return storeMatch && modelTypeMatch;
|
||||||
})
|
});
|
||||||
|
});
|
||||||
|
|
||||||
const paginatedModelList = computed(() => {
|
const paginatedModelList = computed(() => {
|
||||||
const start = (pagination.currentPage - 1) * pagination.pageSize
|
const start = (pagination.currentPage - 1) * pagination.pageSize;
|
||||||
const end = start + pagination.pageSize
|
const end = start + pagination.pageSize;
|
||||||
return filteredModelList.value.slice(start, end)
|
return filteredModelList.value.slice(start, end);
|
||||||
})
|
});
|
||||||
|
|
||||||
const handlePageChange = (page) => {
|
const handlePageChange = (page) => {
|
||||||
pagination.currentPage = page
|
pagination.currentPage = page
|
||||||
|
@ -604,14 +604,45 @@ const startTraining = async () => {
|
|||||||
? "/api/training/retrain"
|
? "/api/training/retrain"
|
||||||
: "/api/training";
|
: "/api/training";
|
||||||
|
|
||||||
|
// 构造 training_scope 对象
|
||||||
|
const training_scope = {
|
||||||
|
stores: "all",
|
||||||
|
products: "all"
|
||||||
|
};
|
||||||
|
|
||||||
|
if (form.training_scope === 'selected_stores' || form.training_scope === 'custom') {
|
||||||
|
if (form.store_ids.length > 0) {
|
||||||
|
training_scope.stores = form.store_ids.map(sid => {
|
||||||
|
const store = stores.value.find(s => s.store_id === sid);
|
||||||
|
return { id: sid, name: store ? store.store_name : sid };
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
// 如果没有选择店铺,则默认为 all
|
||||||
|
training_scope.stores = "all";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (form.training_scope === 'selected_products' || form.training_scope === 'custom') {
|
||||||
|
if (form.product_ids.length > 0) {
|
||||||
|
training_scope.products = form.product_ids.map(pid => {
|
||||||
|
const product = products.value.find(p => p.product_id === pid);
|
||||||
|
return { id: pid, name: product ? product.product_name : pid };
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
// 如果没有选择药品,则默认为 all
|
||||||
|
training_scope.products = "all";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const payload = {
|
const payload = {
|
||||||
model_type: form.model_type,
|
model_type: form.model_type,
|
||||||
epochs: form.epochs,
|
epochs: form.epochs,
|
||||||
training_mode: 'global', // 标识这是全局训练模式
|
training_mode: 'global', // 标识这是全局训练模式
|
||||||
training_scope: form.training_scope,
|
training_scope: training_scope, // 使用新构建的对象
|
||||||
aggregation_method: form.aggregation_method
|
aggregation_method: form.aggregation_method
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// 为了向后兼容或调试,仍然可以发送原始ID列表
|
||||||
if (form.store_ids.length > 0) {
|
if (form.store_ids.length > 0) {
|
||||||
payload.store_ids = form.store_ids;
|
payload.store_ids = form.store_ids;
|
||||||
}
|
}
|
||||||
|
@ -541,12 +541,31 @@ const startTraining = async () => {
|
|||||||
? "/api/training/retrain"
|
? "/api/training/retrain"
|
||||||
: "/api/training";
|
: "/api/training";
|
||||||
|
|
||||||
|
// 构造 training_scope 对象
|
||||||
|
const selectedProduct = products.value.find(p => p.product_id === form.product_id);
|
||||||
|
const training_scope = {
|
||||||
|
product: {
|
||||||
|
id: form.product_id,
|
||||||
|
name: selectedProduct ? selectedProduct.product_name : form.product_id
|
||||||
|
},
|
||||||
|
stores: "all"
|
||||||
|
};
|
||||||
|
|
||||||
|
if (form.data_scope === 'specific' && form.store_id) {
|
||||||
|
const selectedStore = stores.value.find(s => s.store_id === form.store_id);
|
||||||
|
training_scope.stores = [{
|
||||||
|
id: form.store_id,
|
||||||
|
name: selectedStore ? selectedStore.store_name : form.store_id
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
const payload = {
|
const payload = {
|
||||||
product_id: form.product_id,
|
product_id: form.product_id,
|
||||||
store_id: form.data_scope === 'global' ? null : form.store_id,
|
store_id: form.data_scope === 'global' ? null : form.store_id,
|
||||||
model_type: form.model_type,
|
model_type: form.model_type,
|
||||||
epochs: form.epochs,
|
epochs: form.epochs,
|
||||||
training_mode: 'product' // 标识这是药品训练模式
|
training_mode: 'product', // 标识这是药品训练模式
|
||||||
|
training_scope: training_scope // 添加符合新后端要求的 training_scope
|
||||||
};
|
};
|
||||||
|
|
||||||
if (form.training_type === "retrain") {
|
if (form.training_type === "retrain") {
|
||||||
|
@ -588,12 +588,33 @@ const startTraining = async () => {
|
|||||||
? "/api/training/retrain"
|
? "/api/training/retrain"
|
||||||
: "/api/training";
|
: "/api/training";
|
||||||
|
|
||||||
|
// 构造 training_scope 对象
|
||||||
|
const selectedStore = stores.value.find(s => s.store_id === form.store_id);
|
||||||
|
const training_scope = {
|
||||||
|
store: {
|
||||||
|
id: form.store_id,
|
||||||
|
name: selectedStore ? selectedStore.store_name : form.store_id
|
||||||
|
},
|
||||||
|
products: "all"
|
||||||
|
};
|
||||||
|
|
||||||
|
if (form.product_scope === 'specific' && form.product_ids.length > 0) {
|
||||||
|
training_scope.products = form.product_ids.map(pid => {
|
||||||
|
const product = storeProducts.value.find(p => p.product_id === pid);
|
||||||
|
return {
|
||||||
|
id: pid,
|
||||||
|
name: product ? product.product_name : pid
|
||||||
|
};
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
const payload = {
|
const payload = {
|
||||||
store_id: form.store_id,
|
store_id: form.store_id,
|
||||||
model_type: form.model_type,
|
model_type: form.model_type,
|
||||||
epochs: form.epochs,
|
epochs: form.epochs,
|
||||||
training_mode: 'store', // 标识这是店铺训练模式
|
training_mode: 'store', // 标识这是店铺训练模式
|
||||||
product_scope: form.product_scope
|
product_scope: form.product_scope,
|
||||||
|
training_scope: training_scope // 添加符合新后端要求的 training_scope
|
||||||
};
|
};
|
||||||
|
|
||||||
if (form.product_scope === 'specific') {
|
if (form.product_scope === 'specific') {
|
||||||
|
Binary file not shown.
645
server/api.py
645
server/api.py
@ -65,6 +65,7 @@ from core.config import (
|
|||||||
from utils.multi_store_data_utils import (
|
from utils.multi_store_data_utils import (
|
||||||
get_available_stores, get_available_products, get_sales_statistics
|
get_available_stores, get_available_products, get_sales_statistics
|
||||||
)
|
)
|
||||||
|
from utils.database_utils import query_models_from_db, find_model_by_uid, save_prediction_to_db
|
||||||
|
|
||||||
# 导入数据库初始化工具
|
# 导入数据库初始化工具
|
||||||
from init_multi_store_db import get_db_connection
|
from init_multi_store_db import get_db_connection
|
||||||
@ -119,60 +120,68 @@ def init_db():
|
|||||||
conn = sqlite3.connect('prediction_history.db')
|
conn = sqlite3.connect('prediction_history.db')
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
# 创建预测历史表
|
# 废弃旧的 model_versions 表
|
||||||
cursor.execute('''
|
cursor.execute('DROP TABLE IF EXISTS model_versions')
|
||||||
CREATE TABLE IF NOT EXISTS prediction_history (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
prediction_id TEXT UNIQUE NOT NULL,
|
|
||||||
product_id TEXT NOT NULL,
|
|
||||||
product_name TEXT NOT NULL,
|
|
||||||
model_type TEXT NOT NULL,
|
|
||||||
model_id TEXT NOT NULL,
|
|
||||||
start_date TEXT,
|
|
||||||
future_days INTEGER,
|
|
||||||
created_at TEXT NOT NULL,
|
|
||||||
predictions_data TEXT,
|
|
||||||
metrics TEXT,
|
|
||||||
chart_data TEXT,
|
|
||||||
analysis TEXT,
|
|
||||||
file_path TEXT
|
|
||||||
)
|
|
||||||
''')
|
|
||||||
|
|
||||||
# 创建模型版本表
|
# 创建新的 models 表
|
||||||
cursor.execute('''
|
cursor.execute('''
|
||||||
CREATE TABLE IF NOT EXISTS model_versions (
|
CREATE TABLE IF NOT EXISTS models (
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
product_id TEXT NOT NULL,
|
model_uid TEXT UNIQUE NOT NULL,
|
||||||
|
display_name TEXT,
|
||||||
model_type TEXT NOT NULL,
|
model_type TEXT NOT NULL,
|
||||||
|
training_mode TEXT NOT NULL,
|
||||||
|
training_scope TEXT,
|
||||||
|
parent_model_id INTEGER,
|
||||||
version TEXT NOT NULL,
|
version TEXT NOT NULL,
|
||||||
file_path TEXT NOT NULL,
|
status TEXT DEFAULT 'active',
|
||||||
created_at TEXT NOT NULL,
|
training_params TEXT,
|
||||||
metrics TEXT,
|
performance_metrics TEXT,
|
||||||
is_active INTEGER DEFAULT 1,
|
artifacts TEXT,
|
||||||
UNIQUE(product_id, model_type, version)
|
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
FOREIGN KEY (parent_model_id) REFERENCES models (id)
|
||||||
)
|
)
|
||||||
''')
|
''')
|
||||||
|
|
||||||
# 创建索引以提高查询性能
|
# 更新 prediction_history 表
|
||||||
cursor.execute('''
|
# 为了平滑过渡,我们先检查表是否存在,然后尝试添加新列
|
||||||
CREATE INDEX IF NOT EXISTS idx_prediction_product_model
|
cursor.execute("PRAGMA table_info(prediction_history)")
|
||||||
ON prediction_history(product_id, model_type)
|
columns = [row[1] for row in cursor.fetchall()]
|
||||||
''')
|
|
||||||
|
|
||||||
|
if 'prediction_uid' not in columns:
|
||||||
|
# 如果表结构很旧,重建它
|
||||||
|
cursor.execute('DROP TABLE IF EXISTS prediction_history')
|
||||||
cursor.execute('''
|
cursor.execute('''
|
||||||
CREATE INDEX IF NOT EXISTS idx_model_versions_product_type
|
CREATE TABLE prediction_history (
|
||||||
ON model_versions(product_id, model_type)
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
prediction_uid TEXT UNIQUE NOT NULL,
|
||||||
|
model_id TEXT NOT NULL,
|
||||||
|
model_type TEXT,
|
||||||
|
product_name TEXT,
|
||||||
|
prediction_scope TEXT,
|
||||||
|
prediction_params TEXT,
|
||||||
|
metrics TEXT,
|
||||||
|
result_file_path TEXT,
|
||||||
|
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||||
|
)
|
||||||
''')
|
''')
|
||||||
|
else:
|
||||||
|
# 否则,只添加缺失的列
|
||||||
|
if 'prediction_scope' not in columns:
|
||||||
|
cursor.execute('ALTER TABLE prediction_history ADD COLUMN prediction_scope TEXT')
|
||||||
|
if 'result_file_path' not in columns:
|
||||||
|
cursor.execute('ALTER TABLE prediction_history ADD COLUMN result_file_path TEXT')
|
||||||
|
# 确保 model_id 字段存在且类型正确
|
||||||
|
# 在SQLite中修改列类型比较复杂,通常建议重建。此处简化处理。
|
||||||
|
|
||||||
cursor.execute('''
|
# 创建索引
|
||||||
CREATE INDEX IF NOT EXISTS idx_model_versions_active
|
cursor.execute('CREATE INDEX IF NOT EXISTS idx_models_uid ON models(model_uid)')
|
||||||
ON model_versions(product_id, model_type, is_active)
|
cursor.execute('CREATE INDEX IF NOT EXISTS idx_models_type_mode ON models(model_type, training_mode)')
|
||||||
''')
|
cursor.execute('CREATE INDEX IF NOT EXISTS idx_prediction_history_model_id ON prediction_history(model_id)')
|
||||||
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
print("数据库初始化完成,包含模型版本管理表")
|
print("数据库初始化完成,已更新为新的 `models` 和 `prediction_history` 表结构。")
|
||||||
|
|
||||||
# 自定义JSON编码器来处理Pandas的Timestamp和NumPy类型
|
# 自定义JSON编码器来处理Pandas的Timestamp和NumPy类型
|
||||||
class CustomJSONEncoder(json.JSONEncoder):
|
class CustomJSONEncoder(json.JSONEncoder):
|
||||||
@ -219,6 +228,8 @@ def convert_numpy_types_for_json(obj):
|
|||||||
return obj
|
return obj
|
||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
|
# 解决jsonify中文显示为unicode的问题
|
||||||
|
app.config['JSON_AS_ASCII'] = False
|
||||||
# 设置自定义JSON编码器
|
# 设置自定义JSON编码器
|
||||||
app.json_encoder = CustomJSONEncoder
|
app.json_encoder = CustomJSONEncoder
|
||||||
app.config['SECRET_KEY'] = 'pharmacy_prediction_secret_key'
|
app.config['SECRET_KEY'] = 'pharmacy_prediction_secret_key'
|
||||||
@ -874,7 +885,7 @@ def start_training():
|
|||||||
product_ids = data.get('product_ids', [])
|
product_ids = data.get('product_ids', [])
|
||||||
store_ids = data.get('store_ids', [])
|
store_ids = data.get('store_ids', [])
|
||||||
product_scope = data.get('product_scope', 'all')
|
product_scope = data.get('product_scope', 'all')
|
||||||
training_scope = data.get('training_scope', 'all_stores_all_products')
|
# training_scope = data.get('training_scope', 'all_stores_all_products') # 已废弃,由下面的 training_scope_obj 替代
|
||||||
aggregation_method = data.get('aggregation_method', 'sum')
|
aggregation_method = data.get('aggregation_method', 'sum')
|
||||||
|
|
||||||
if not model_type:
|
if not model_type:
|
||||||
@ -896,12 +907,18 @@ def start_training():
|
|||||||
|
|
||||||
# 使用新的训练进程管理器提交任务
|
# 使用新的训练进程管理器提交任务
|
||||||
try:
|
try:
|
||||||
|
# 修正: 直接从请求中获取完整的 training_scope 对象
|
||||||
|
training_scope_obj = data.get('training_scope')
|
||||||
|
if not training_scope_obj or not isinstance(training_scope_obj, dict):
|
||||||
|
return jsonify({'error': '请求体中缺少有效(必须是JSON对象)的 training_scope 参数'}), 400
|
||||||
|
|
||||||
task_id = training_manager.submit_task(
|
task_id = training_manager.submit_task(
|
||||||
product_id=product_id or "unknown",
|
product_id=product_id or "unknown",
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
training_mode=training_mode,
|
training_mode=training_mode,
|
||||||
store_id=store_id,
|
store_id=store_id,
|
||||||
epochs=epochs
|
epochs=epochs,
|
||||||
|
training_scope=training_scope_obj
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"🚀 训练任务已提交到进程管理器: {task_id[:8]}")
|
logger.info(f"🚀 训练任务已提交到进程管理器: {task_id[:8]}")
|
||||||
@ -1273,8 +1290,8 @@ def get_training_status(task_id):
|
|||||||
@app.route('/api/prediction', methods=['POST'])
|
@app.route('/api/prediction', methods=['POST'])
|
||||||
@swag_from({
|
@swag_from({
|
||||||
'tags': ['模型预测'],
|
'tags': ['模型预测'],
|
||||||
'summary': '使用模型进行预测',
|
'summary': '使用模型进行预测 (v2)',
|
||||||
'description': '使用指定模型预测未来销售数据',
|
'description': '使用指定模型UID预测未来销售数据',
|
||||||
'parameters': [
|
'parameters': [
|
||||||
{
|
{
|
||||||
'name': 'body',
|
'name': 'body',
|
||||||
@ -1283,165 +1300,61 @@ def get_training_status(task_id):
|
|||||||
'schema': {
|
'schema': {
|
||||||
'type': 'object',
|
'type': 'object',
|
||||||
'properties': {
|
'properties': {
|
||||||
'product_id': {'type': 'string'},
|
'model_uid': {'type': 'string', 'description': '要用于预测的模型的唯一ID'},
|
||||||
'model_type': {'type': 'string', 'enum': ['mlstm', 'transformer', 'kan', 'optimized_kan', 'tcn']},
|
'future_days': {'type': 'integer', 'default': 7},
|
||||||
'store_id': {'type': 'string', 'description': '店铺ID,如 S001。为空时使用全局模型'},
|
'start_date': {'type': 'string', 'description': '预测起始日期 (YYYY-MM-DD)'},
|
||||||
'version': {'type': 'string'},
|
'include_visualization': {'type': 'boolean', 'default': True},
|
||||||
'future_days': {'type': 'integer'},
|
'history_lookback_days': {'type': 'integer', 'default': 30}
|
||||||
'include_visualization': {'type': 'boolean'},
|
|
||||||
'start_date': {'type': 'string', 'description': '预测起始日期,格式为YYYY-MM-DD'}
|
|
||||||
},
|
},
|
||||||
'required': ['product_id', 'model_type']
|
'required': ['model_uid']
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
'responses': {
|
'responses': {
|
||||||
200: {
|
200: {'description': '预测成功'},
|
||||||
'description': '预测成功',
|
400: {'description': '请求错误'},
|
||||||
'schema': {
|
404: {'description': '模型不存在'},
|
||||||
'type': 'object',
|
500: {'description': '服务器内部错误'}
|
||||||
'properties': {
|
|
||||||
'status': {'type': 'string'},
|
|
||||||
'data': {
|
|
||||||
'type': 'object',
|
|
||||||
'properties': {
|
|
||||||
'product_id': {'type': 'string'},
|
|
||||||
'product_name': {'type': 'string'},
|
|
||||||
'model_type': {'type': 'string'},
|
|
||||||
'predictions': {'type': 'array'},
|
|
||||||
'visualization': {'type': 'string'}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
400: {
|
|
||||||
'description': '请求错误,缺少必要参数或参数格式错误'
|
|
||||||
},
|
|
||||||
404: {
|
|
||||||
'description': '产品或模型不存在'
|
|
||||||
},
|
|
||||||
500: {
|
|
||||||
'description': '服务器内部错误'
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
def predict():
|
def predict():
|
||||||
"""
|
"""
|
||||||
使用指定的模型进行预测
|
使用指定的模型进行预测 (v2 - 基于数据库)
|
||||||
---
|
|
||||||
tags:
|
|
||||||
- 模型预测
|
|
||||||
parameters:
|
|
||||||
- in: body
|
|
||||||
name: body
|
|
||||||
schema:
|
|
||||||
type: object
|
|
||||||
required:
|
|
||||||
- product_id
|
|
||||||
- model_type
|
|
||||||
properties:
|
|
||||||
product_id:
|
|
||||||
type: string
|
|
||||||
description: 产品ID
|
|
||||||
model_type:
|
|
||||||
type: string
|
|
||||||
description: "模型类型 (mlstm, kan, transformer)"
|
|
||||||
version:
|
|
||||||
type: string
|
|
||||||
description: "模型版本 (v1, v2, v3 等),如果不指定则使用最新版本"
|
|
||||||
future_days:
|
|
||||||
type: integer
|
|
||||||
description: 预测未来天数
|
|
||||||
default: 7
|
|
||||||
start_date:
|
|
||||||
type: string
|
|
||||||
description: 预测起始日期,格式为YYYY-MM-DD
|
|
||||||
default: ''
|
|
||||||
responses:
|
|
||||||
200:
|
|
||||||
description: 预测成功
|
|
||||||
400:
|
|
||||||
description: 请求参数错误
|
|
||||||
404:
|
|
||||||
description: 模型文件未找到
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
data = request.json
|
data = request.json
|
||||||
model_type = data.get('model_type')
|
model_uid = data.get('model_uid')
|
||||||
version = data.get('version')
|
if not model_uid:
|
||||||
|
return jsonify({"status": "error", "message": "缺少 'model_uid' 参数"}), 400
|
||||||
|
|
||||||
|
# 从数据库查找模型记录
|
||||||
|
model_record = find_model_by_uid(model_uid)
|
||||||
|
if not model_record:
|
||||||
|
return jsonify({"status": "error", "message": f"模型UID '{model_uid}' 不存在"}), 404
|
||||||
|
|
||||||
|
# 解析必要的模型元数据
|
||||||
|
model_type = model_record.get('model_type')
|
||||||
|
training_mode = model_record.get('training_mode')
|
||||||
|
version = model_record.get('version')
|
||||||
|
|
||||||
|
# 解析 artifacts 找到模型文件路径
|
||||||
|
artifacts = json.loads(model_record.get('artifacts', '{}'))
|
||||||
|
model_file_path = artifacts.get('best_model') or artifacts.get('versioned_model')
|
||||||
|
if not model_file_path or not os.path.exists(model_file_path):
|
||||||
|
return jsonify({"status": "error", "message": "找不到模型文件或文件路径无效"}), 404
|
||||||
|
|
||||||
|
# 解析 training_scope 获取 product_id 或 store_id
|
||||||
|
training_scope = json.loads(model_record.get('training_scope', '{}'))
|
||||||
|
product_id = training_scope.get('product', {}).get('id')
|
||||||
|
store_id = training_scope.get('store', {}).get('id')
|
||||||
|
|
||||||
|
# 获取其他预测参数
|
||||||
future_days = int(data.get('future_days', 7))
|
future_days = int(data.get('future_days', 7))
|
||||||
start_date = data.get('start_date', '')
|
start_date = data.get('start_date', '')
|
||||||
include_visualization = data.get('include_visualization', False)
|
include_visualization = data.get('include_visualization', True)
|
||||||
history_lookback_days = int(data.get('history_lookback_days', 30)) # 新增参数
|
history_lookback_days = int(data.get('history_lookback_days', 30))
|
||||||
|
|
||||||
# 确定训练模式和标识符
|
# 调用核心预测函数
|
||||||
training_mode = data.get('training_mode', 'product')
|
|
||||||
product_id = data.get('product_id')
|
|
||||||
store_id = data.get('store_id')
|
|
||||||
|
|
||||||
# v2版:根据训练模式和ID构建模型标识符
|
|
||||||
aggregation_method = data.get('aggregation_method', 'sum') # 全局模式需要
|
|
||||||
if training_mode == 'global':
|
|
||||||
model_identifier = f"global_{aggregation_method}"
|
|
||||||
product_name = f"全局聚合数据 ({aggregation_method})"
|
|
||||||
elif training_mode == 'store':
|
|
||||||
if not store_id:
|
|
||||||
return jsonify({"status": "error", "error": "店铺模式需要 store_id"}), 400
|
|
||||||
model_identifier = f"store_{store_id}"
|
|
||||||
product_name = f"店铺 {store_id} 整体"
|
|
||||||
else: # 默认为 'product' 模式
|
|
||||||
if not product_id:
|
|
||||||
return jsonify({"status": "error", "error": "药品模式需要 product_id"}), 400
|
|
||||||
model_identifier = product_id
|
|
||||||
product_name = get_product_name(product_id) or product_id
|
|
||||||
|
|
||||||
print(f"API接收到预测请求: mode={training_mode}, model_identifier='{model_identifier}', model_type='{model_type}', version='{version}'")
|
|
||||||
|
|
||||||
if not model_type:
|
|
||||||
return jsonify({"status": "error", "error": "model_type 是必需的"}), 400
|
|
||||||
|
|
||||||
# 获取模型版本
|
|
||||||
if not version:
|
|
||||||
version = get_latest_model_version(model_identifier, model_type)
|
|
||||||
|
|
||||||
if not version:
|
|
||||||
return jsonify({"status": "error", "error": f"未找到标识符为 {model_identifier} 的 {model_type} 类型模型"}), 404
|
|
||||||
|
|
||||||
# v2版:使用 ModelManager 查找模型文件,不再使用旧的 get_model_file_path
|
|
||||||
from utils.model_manager import model_manager
|
|
||||||
|
|
||||||
# 智能修正 training_mode (兼容前端可能发送的错误模式)
|
|
||||||
if model_identifier.startswith('store_'):
|
|
||||||
training_mode = 'store'
|
|
||||||
store_id = model_identifier.split('_')[1]
|
|
||||||
elif model_identifier.startswith('global_'):
|
|
||||||
training_mode = 'global'
|
|
||||||
|
|
||||||
# 使用 model_manager 查找模型
|
|
||||||
models_result = model_manager.list_models(
|
|
||||||
model_type=model_type,
|
|
||||||
store_id=store_id if training_mode == 'store' else None,
|
|
||||||
product_id=product_id if training_mode == 'product' else None,
|
|
||||||
training_mode=training_mode
|
|
||||||
)
|
|
||||||
|
|
||||||
found_model = None
|
|
||||||
for model in models_result.get('models', []):
|
|
||||||
if model.get('version') == version:
|
|
||||||
found_model = model
|
|
||||||
break
|
|
||||||
|
|
||||||
if not found_model or not found_model.get('file_path'):
|
|
||||||
error_msg = f"在系统中未找到匹配的模型: mode={training_mode}, type={model_type}, id='{model_identifier}', version={version}"
|
|
||||||
print(error_msg)
|
|
||||||
return jsonify({"status": "error", "error": error_msg}), 404
|
|
||||||
|
|
||||||
model_file_path = found_model['file_path']
|
|
||||||
|
|
||||||
model_id = f"{model_identifier}_{model_type}_{version}"
|
|
||||||
|
|
||||||
# v3版:直接调用核心预测函数
|
|
||||||
prediction_result = load_model_and_predict(
|
prediction_result = load_model_and_predict(
|
||||||
model_path=model_file_path,
|
model_path=model_file_path,
|
||||||
product_id=product_id,
|
product_id=product_id,
|
||||||
@ -1456,49 +1369,36 @@ def predict():
|
|||||||
)
|
)
|
||||||
|
|
||||||
if prediction_result is None:
|
if prediction_result is None:
|
||||||
return jsonify({"status": "error", "error": "预测失败,核心预测器返回None"}), 500
|
return jsonify({"status": "error", "message": "预测失败,核心预测器返回None"}), 500
|
||||||
|
|
||||||
# 核心函数已处理好所有数据格式,此处直接构建最终响应
|
# 保存预测记录到数据库
|
||||||
response_data = {
|
prediction_uid = str(uuid.uuid4())
|
||||||
'status': 'success',
|
result_file_path = os.path.join('saved_predictions', f'prediction_{prediction_uid}.json')
|
||||||
'data': prediction_result, # 包含所有信息的完整结果
|
os.makedirs(os.path.dirname(result_file_path), exist_ok=True)
|
||||||
'history_data': prediction_result.get('history_data', []),
|
|
||||||
'prediction_data': prediction_result.get('prediction_data', [])
|
with open(result_file_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(prediction_result, f, ensure_ascii=False, cls=CustomJSONEncoder)
|
||||||
|
|
||||||
|
db_payload = {
|
||||||
|
"prediction_uid": prediction_uid,
|
||||||
|
"model_id": model_uid,
|
||||||
|
"model_type": model_type,
|
||||||
|
"product_name": model_record.get('display_name'),
|
||||||
|
"prediction_scope": {"product_id": product_id, "store_id": store_id},
|
||||||
|
"prediction_params": {"future_days": future_days, "start_date": start_date},
|
||||||
|
"metrics": prediction_result.get('analysis', {}).get('metrics', {}),
|
||||||
|
"result_file_path": result_file_path
|
||||||
}
|
}
|
||||||
|
save_prediction_to_db(db_payload)
|
||||||
|
|
||||||
# 调试日志
|
return jsonify({
|
||||||
print("=== 预测API响应数据结构 (v3) ===")
|
'status': 'success',
|
||||||
print(f"history_data 长度: {len(response_data['history_data'])}")
|
'data': prediction_result
|
||||||
print(f"prediction_data 长度: {len(response_data['prediction_data'])}")
|
})
|
||||||
print("================================")
|
|
||||||
|
|
||||||
# 重新加入保存预测结果的逻辑
|
|
||||||
try:
|
|
||||||
model_id_to_save = f"{model_identifier}_{model_type}_{version}"
|
|
||||||
product_name_to_save = prediction_result.get('product_name', product_id or store_id or 'global')
|
|
||||||
|
|
||||||
# 调用辅助函数保存结果
|
|
||||||
save_prediction_result(
|
|
||||||
prediction_result=prediction_result.copy(),
|
|
||||||
product_id=product_id or store_id or 'global',
|
|
||||||
product_name=product_name_to_save,
|
|
||||||
model_type=model_type,
|
|
||||||
model_id=model_id_to_save,
|
|
||||||
start_date=start_date,
|
|
||||||
future_days=future_days
|
|
||||||
)
|
|
||||||
print(f"✅ 预测结果已成功保存到历史记录。")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"⚠️ 警告: 保存预测结果到历史记录失败: {str(e)}")
|
logger.error(f"预测失败: {e}\n{traceback.format_exc()}")
|
||||||
traceback.print_exc()
|
return jsonify({"status": "error", "message": str(e)}), 500
|
||||||
# 不应阻止向用户返回结果,因此只打印警告
|
|
||||||
|
|
||||||
return jsonify(response_data)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"预测失败: {str(e)}")
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
return jsonify({"status": "error", "error": str(e)}), 500
|
|
||||||
|
|
||||||
@app.route('/api/prediction/compare', methods=['POST'])
|
@app.route('/api/prediction/compare', methods=['POST'])
|
||||||
@swag_from({
|
@swag_from({
|
||||||
@ -2042,154 +1942,87 @@ def delete_prediction(prediction_id):
|
|||||||
})
|
})
|
||||||
def list_models():
|
def list_models():
|
||||||
"""
|
"""
|
||||||
列出所有可用的模型 - 使用统一模型管理器
|
列出所有可用的模型 - v2版,从数据库查询
|
||||||
---
|
|
||||||
tags:
|
|
||||||
- 模型管理
|
|
||||||
parameters:
|
|
||||||
- name: product_id
|
|
||||||
in: query
|
|
||||||
type: string
|
|
||||||
required: false
|
|
||||||
description: 按产品ID筛选
|
|
||||||
- name: model_type
|
|
||||||
in: query
|
|
||||||
type: string
|
|
||||||
required: false
|
|
||||||
description: "按模型类型筛选 (mlstm, kan, transformer, tcn)"
|
|
||||||
- name: store_id
|
|
||||||
in: query
|
|
||||||
type: string
|
|
||||||
required: false
|
|
||||||
description: 按店铺ID筛选
|
|
||||||
- name: training_mode
|
|
||||||
in: query
|
|
||||||
type: string
|
|
||||||
required: false
|
|
||||||
description: "按训练模式筛选 (product, store, global)"
|
|
||||||
responses:
|
|
||||||
200:
|
|
||||||
description: 模型列表
|
|
||||||
schema:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
status:
|
|
||||||
type: string
|
|
||||||
example: success
|
|
||||||
data:
|
|
||||||
type: array
|
|
||||||
items:
|
|
||||||
type: object
|
|
||||||
properties:
|
|
||||||
model_id:
|
|
||||||
type: string
|
|
||||||
product_id:
|
|
||||||
type: string
|
|
||||||
product_name:
|
|
||||||
type: string
|
|
||||||
model_type:
|
|
||||||
type: string
|
|
||||||
training_mode:
|
|
||||||
type: string
|
|
||||||
store_id:
|
|
||||||
type: string
|
|
||||||
version:
|
|
||||||
type: string
|
|
||||||
created_at:
|
|
||||||
type: string
|
|
||||||
metrics:
|
|
||||||
type: object
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from utils.model_manager import ModelManager
|
|
||||||
|
|
||||||
# 创建新的ModelManager实例以避免缓存问题
|
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
project_root = os.path.dirname(current_dir) # 向上一级到项目根目录
|
|
||||||
model_dir = os.path.join(project_root, 'saved_models')
|
|
||||||
model_manager = ModelManager(model_dir)
|
|
||||||
|
|
||||||
logger.info(f"[API] 获取模型列表请求")
|
|
||||||
logger.info(f"[API] 模型管理器目录: {model_manager.model_dir}")
|
|
||||||
logger.info(f"[API] 目录存在: {os.path.exists(model_manager.model_dir)}")
|
|
||||||
|
|
||||||
# 获取查询参数
|
# 获取查询参数
|
||||||
product_id_filter = request.args.get('product_id')
|
filters = {
|
||||||
model_type_filter = request.args.get('model_type')
|
'product_id': request.args.get('product_id'),
|
||||||
store_id_filter = request.args.get('store_id')
|
'model_type': request.args.get('model_type'),
|
||||||
training_mode_filter = request.args.get('training_mode')
|
'store_id': request.args.get('store_id'),
|
||||||
|
'training_mode': request.args.get('training_mode')
|
||||||
|
}
|
||||||
|
# 移除值为None的过滤器
|
||||||
|
filters = {k: v for k, v in filters.items() if v is not None}
|
||||||
|
|
||||||
# 获取分页参数
|
# 获取分页参数
|
||||||
page = request.args.get('page', type=int)
|
page = request.args.get('page', 1, type=int)
|
||||||
page_size = request.args.get('page_size', type=int, default=10)
|
page_size = request.args.get('page_size', 10, type=int)
|
||||||
|
|
||||||
logger.info(f"[API] 分页参数: page={page}, page_size={page_size}")
|
# 调用新的数据库查询函数
|
||||||
|
result = query_models_from_db(filters=filters, page=page, page_size=page_size)
|
||||||
|
|
||||||
# 使用模型管理器获取模型列表
|
models = result.get('models', [])
|
||||||
result = model_manager.list_models(
|
pagination = result.get('pagination', {})
|
||||||
product_id=product_id_filter,
|
|
||||||
model_type=model_type_filter,
|
|
||||||
store_id=store_id_filter,
|
|
||||||
training_mode=training_mode_filter,
|
|
||||||
page=page,
|
|
||||||
page_size=page_size
|
|
||||||
)
|
|
||||||
|
|
||||||
models = result['models']
|
# 格式化响应数据,解析JSON字段
|
||||||
pagination = result['pagination']
|
|
||||||
|
|
||||||
# 格式化响应数据
|
|
||||||
formatted_models = []
|
formatted_models = []
|
||||||
for model in models:
|
for model in models:
|
||||||
# 生成唯一且有意义的model_id
|
try:
|
||||||
model_id = model.get('filename', '').replace('.pth', '')
|
model['training_scope'] = json.loads(model['training_scope']) if model.get('training_scope') else {}
|
||||||
if not model_id:
|
model['training_params'] = json.loads(model['training_params']) if model.get('training_params') else {}
|
||||||
# 备用方案:基于模型信息生成ID
|
model['performance_metrics'] = json.loads(model['performance_metrics']) if model.get('performance_metrics') else {}
|
||||||
product_id = model.get('product_id', 'unknown')
|
model['artifacts'] = json.loads(model['artifacts']) if model.get('artifacts') else {}
|
||||||
model_type = model.get('model_type', 'unknown')
|
# =================================================================
|
||||||
version = model.get('version', 'v1')
|
# 核心修复 v4:精确提供前端所需的 product_name 和 store_name 字段
|
||||||
training_mode = model.get('training_mode', 'product')
|
# =================================================================
|
||||||
store_id = model.get('store_id')
|
scope = model.get('training_scope', {})
|
||||||
|
mode = model.get('training_mode')
|
||||||
|
|
||||||
if training_mode == 'store' and store_id:
|
# 初始化所有可能的名称字段
|
||||||
model_id = f"{model_type}_store_{store_id}_{product_id}_{version}"
|
model['product_name'] = None
|
||||||
elif training_mode == 'global':
|
model['store_name'] = None
|
||||||
aggregation_method = model.get('aggregation_method', 'mean')
|
model['display_name'] = model.get('display_name') # 优先使用数据库值
|
||||||
model_id = f"{model_type}_global_{product_id}_{aggregation_method}_{version}"
|
|
||||||
else:
|
|
||||||
model_id = f"{model_type}_product_{product_id}_{version}"
|
|
||||||
|
|
||||||
formatted_model = {
|
if isinstance(scope, dict):
|
||||||
'model_id': model_id,
|
product_info = scope.get('product')
|
||||||
'filename': model.get('filename', ''),
|
store_info = scope.get('store')
|
||||||
'product_id': model.get('product_id', ''),
|
|
||||||
'product_name': model.get('product_name', model.get('product_id', '')),
|
|
||||||
'model_type': model.get('model_type', ''),
|
|
||||||
'training_mode': model.get('training_mode', 'product'),
|
|
||||||
'store_id': model.get('store_id'),
|
|
||||||
'aggregation_method': model.get('aggregation_method'),
|
|
||||||
'version': model.get('version', 'v1'),
|
|
||||||
'created_at': model.get('created_at', model.get('modified_at', '')),
|
|
||||||
'file_size': model.get('file_size', 0),
|
|
||||||
'metrics': model.get('metrics', {}),
|
|
||||||
'config': model.get('config', {})
|
|
||||||
}
|
|
||||||
formatted_models.append(formatted_model)
|
|
||||||
|
|
||||||
logger.info(f"[API] 成功获取 {len(formatted_models)} 个模型")
|
if mode == 'product' and isinstance(product_info, dict):
|
||||||
for i, model in enumerate(formatted_models):
|
model['product_name'] = product_info.get('name')
|
||||||
logger.info(f"[API] 模型 {i+1}: id='{model.get('model_id', 'EMPTY')}', filename='{model.get('filename', 'MISSING')}'")
|
model['display_name'] = model['product_name']
|
||||||
|
|
||||||
# Manually convert numpy types to prevent JSON serialization errors
|
elif mode == 'store' and isinstance(store_info, dict):
|
||||||
processed_models = convert_numpy_types_for_json(formatted_models)
|
model['store_name'] = store_info.get('name')
|
||||||
|
model['display_name'] = model['store_name'] # 默认显示店铺名
|
||||||
|
|
||||||
|
elif mode == 'global':
|
||||||
|
model['display_name'] = "全局模型"
|
||||||
|
|
||||||
|
# 提供最终的后备方案
|
||||||
|
if not model.get('display_name'):
|
||||||
|
model['display_name'] = "信息不完整"
|
||||||
|
# =================================================================
|
||||||
|
|
||||||
|
formatted_models.append(model)
|
||||||
|
except (json.JSONDecodeError, TypeError) as e:
|
||||||
|
logger.error(f"解析模型JSON数据失败 (model_uid: {model.get('model_uid')}): {e}")
|
||||||
|
# 即使解析失败,也添加部分数据
|
||||||
|
model['training_scope'] = {"error": "invalid json"}
|
||||||
|
model['performance_metrics'] = {"error": "invalid json"}
|
||||||
|
model['artifacts'] = {"error": "invalid json"}
|
||||||
|
model['display_name'] = "元数据解析失败"
|
||||||
|
model['product_name'] = "N/A"
|
||||||
|
model['store_name'] = "N/A"
|
||||||
|
formatted_models.append(model)
|
||||||
|
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"data": processed_models,
|
"data": formatted_models,
|
||||||
"pagination": pagination
|
"pagination": pagination
|
||||||
})
|
})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"获取模型列表失败: {e}")
|
logger.error(f"获取模型列表失败: {e}\n{traceback.format_exc()}")
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": f"获取模型列表失败: {str(e)}",
|
"message": f"获取模型列表失败: {str(e)}",
|
||||||
@ -3061,103 +2894,6 @@ def analyze_prediction(prediction_result):
|
|||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 保存预测结果的辅助函数
|
|
||||||
def save_prediction_result(prediction_result, product_id, product_name, model_type, model_id, start_date, future_days):
|
|
||||||
"""
|
|
||||||
保存预测结果到文件和数据库
|
|
||||||
|
|
||||||
返回:
|
|
||||||
(prediction_id, file_path) - 预测ID和文件路径
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 生成唯一的预测ID
|
|
||||||
prediction_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
# 确保目录存在
|
|
||||||
os.makedirs('static/predictions', exist_ok=True)
|
|
||||||
|
|
||||||
# 限制数据量
|
|
||||||
if 'history_data' in prediction_result and isinstance(prediction_result['history_data'], list):
|
|
||||||
history_data = prediction_result['history_data']
|
|
||||||
if len(history_data) > 30:
|
|
||||||
print(f"保存时历史数据超过30天,进行裁剪,原始数量: {len(history_data)}")
|
|
||||||
prediction_result['history_data'] = history_data[-30:] # 只保留最近30天
|
|
||||||
|
|
||||||
if 'prediction_data' in prediction_result and isinstance(prediction_result['prediction_data'], list):
|
|
||||||
prediction_data = prediction_result['prediction_data']
|
|
||||||
if len(prediction_data) > 7:
|
|
||||||
print(f"保存时预测数据超过7天,进行裁剪,原始数量: {len(prediction_data)}")
|
|
||||||
prediction_result['prediction_data'] = prediction_data[:7] # 只保留前7天
|
|
||||||
|
|
||||||
# 处理预测结果中可能存在的NumPy类型
|
|
||||||
def convert_numpy_types(obj):
|
|
||||||
if isinstance(obj, dict):
|
|
||||||
return {k: convert_numpy_types(v) for k, v in obj.items()}
|
|
||||||
elif isinstance(obj, list):
|
|
||||||
return [convert_numpy_types(item) for item in obj]
|
|
||||||
elif isinstance(obj, pd.DataFrame):
|
|
||||||
return obj.to_dict(orient='records')
|
|
||||||
elif isinstance(obj, pd.Series):
|
|
||||||
return obj.to_dict()
|
|
||||||
elif isinstance(obj, np.generic):
|
|
||||||
return obj.item() # 将NumPy标量转换为Python原生类型
|
|
||||||
elif isinstance(obj, np.ndarray):
|
|
||||||
return obj.tolist()
|
|
||||||
elif pd.isna(obj):
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
return obj
|
|
||||||
|
|
||||||
# 转换整个预测结果对象
|
|
||||||
prediction_result = convert_numpy_types(prediction_result)
|
|
||||||
|
|
||||||
# 保存预测结果到JSON文件
|
|
||||||
file_name = f"prediction_{prediction_id}.json"
|
|
||||||
file_path = os.path.join('static/predictions', file_name)
|
|
||||||
|
|
||||||
with open(file_path, 'w', encoding='utf-8') as f:
|
|
||||||
json.dump(prediction_result, f, ensure_ascii=False, cls=CustomJSONEncoder)
|
|
||||||
|
|
||||||
# 将预测记录保存到数据库
|
|
||||||
try:
|
|
||||||
conn = get_db_connection()
|
|
||||||
cursor = conn.cursor()
|
|
||||||
|
|
||||||
# 提取并序列化需要存储的数据
|
|
||||||
predictions_data_json = json.dumps(prediction_result.get('prediction_data', []), cls=CustomJSONEncoder)
|
|
||||||
|
|
||||||
# 从分析结果中获取指标,如果分析结果不存在,则使用空字典
|
|
||||||
analysis_data = prediction_result.get('analysis', {})
|
|
||||||
metrics_data = analysis_data.get('metrics', {}) if isinstance(analysis_data, dict) else {}
|
|
||||||
metrics_json = json.dumps(metrics_data, cls=CustomJSONEncoder)
|
|
||||||
|
|
||||||
chart_data_json = json.dumps(prediction_result.get('chart_data', {}), cls=CustomJSONEncoder)
|
|
||||||
analysis_json = json.dumps(analysis_data, cls=CustomJSONEncoder)
|
|
||||||
|
|
||||||
cursor.execute('''
|
|
||||||
INSERT INTO prediction_history (
|
|
||||||
prediction_id, product_id, product_name, model_type, model_id,
|
|
||||||
start_date, future_days, created_at, file_path,
|
|
||||||
predictions_data, metrics, chart_data, analysis
|
|
||||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
||||||
''', (
|
|
||||||
prediction_id, product_id, product_name, model_type, model_id,
|
|
||||||
start_date if start_date else datetime.now().strftime('%Y-%m-%d'),
|
|
||||||
future_days, datetime.now().isoformat(), file_path,
|
|
||||||
predictions_data_json, metrics_json, chart_data_json, analysis_json
|
|
||||||
))
|
|
||||||
|
|
||||||
conn.commit()
|
|
||||||
conn.close()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"保存预测记录到数据库失败: {str(e)}")
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
return prediction_id, file_path
|
|
||||||
except Exception as e:
|
|
||||||
print(f"保存预测结果失败: {str(e)}")
|
|
||||||
traceback.print_exc()
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
# 添加模型性能分析接口
|
# 添加模型性能分析接口
|
||||||
@app.route('/api/models/analyze-metrics', methods=['POST'])
|
@app.route('/api/models/analyze-metrics', methods=['POST'])
|
||||||
@ -4391,13 +4127,10 @@ def test_models_fix():
|
|||||||
from utils.model_manager import ModelManager
|
from utils.model_manager import ModelManager
|
||||||
import os
|
import os
|
||||||
|
|
||||||
# 强制创建新的ModelManager实例
|
# 修正: 直接使用默认的相对路径
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
manager = ModelManager()
|
||||||
project_root = os.path.dirname(current_dir)
|
|
||||||
model_dir = os.path.join(project_root, 'saved_models')
|
|
||||||
manager = ModelManager(model_dir)
|
|
||||||
|
|
||||||
models = manager.list_models()
|
models = manager.list_models()['models']
|
||||||
|
|
||||||
# 简化的响应格式
|
# 简化的响应格式
|
||||||
test_result = {
|
test_result = {
|
||||||
|
@ -2178,11 +2178,8 @@ def list_models():
|
|||||||
try:
|
try:
|
||||||
from utils.model_manager import ModelManager
|
from utils.model_manager import ModelManager
|
||||||
|
|
||||||
# 创建新的ModelManager实例以避免缓存问题
|
# 修正: 直接使用默认的相对路径
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
model_manager = ModelManager()
|
||||||
project_root = os.path.dirname(current_dir) # 向上一级到项目根目录
|
|
||||||
model_dir = os.path.join(project_root, 'saved_models')
|
|
||||||
model_manager = ModelManager(model_dir)
|
|
||||||
|
|
||||||
logger.info(f"[API] 获取模型列表请求")
|
logger.info(f"[API] 获取模型列表请求")
|
||||||
logger.info(f"[API] 模型管理器目录: {model_manager.model_dir}")
|
logger.info(f"[API] 模型管理器目录: {model_manager.model_dir}")
|
||||||
@ -4465,13 +4462,10 @@ def test_models_fix():
|
|||||||
from utils.model_manager import ModelManager
|
from utils.model_manager import ModelManager
|
||||||
import os
|
import os
|
||||||
|
|
||||||
# 强制创建新的ModelManager实例
|
# 修正: 直接使用默认的相对路径
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
manager = ModelManager()
|
||||||
project_root = os.path.dirname(current_dir)
|
|
||||||
model_dir = os.path.join(project_root, 'saved_models')
|
|
||||||
manager = ModelManager(model_dir)
|
|
||||||
|
|
||||||
models = manager.list_models()
|
models = manager.list_models()['models']
|
||||||
|
|
||||||
# 简化的响应格式
|
# 简化的响应格式
|
||||||
test_result = {
|
test_result = {
|
||||||
|
@ -10,13 +10,6 @@ import os
|
|||||||
import re
|
import re
|
||||||
import glob
|
import glob
|
||||||
|
|
||||||
# 项目根目录
|
|
||||||
# __file__ 是当前文件 (config.py) 的路径
|
|
||||||
# os.path.dirname(__file__) 是 server/core
|
|
||||||
# os.path.join(..., '..') 是 server
|
|
||||||
# os.path.join(..., '..', '..') 是项目根目录
|
|
||||||
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
|
|
||||||
|
|
||||||
# 解决画图中文显示问题
|
# 解决画图中文显示问题
|
||||||
plt.rcParams['font.sans-serif'] = ['SimHei']
|
plt.rcParams['font.sans-serif'] = ['SimHei']
|
||||||
plt.rcParams['axes.unicode_minus'] = False
|
plt.rcParams['axes.unicode_minus'] = False
|
||||||
@ -34,8 +27,9 @@ DEVICE = get_device()
|
|||||||
|
|
||||||
# 数据相关配置
|
# 数据相关配置
|
||||||
# 使用 os.path.join 构造跨平台的路径
|
# 使用 os.path.join 构造跨平台的路径
|
||||||
DEFAULT_DATA_PATH = os.path.join(PROJECT_ROOT, 'data', 'timeseries_training_data_sample_10s50p.parquet')
|
# 修正: 改为相对路径
|
||||||
DEFAULT_MODEL_DIR = os.path.join(PROJECT_ROOT, 'saved_models')
|
DEFAULT_DATA_PATH = os.path.join('data', 'timeseries_training_data_sample_10s50p.parquet')
|
||||||
|
DEFAULT_MODEL_DIR = 'saved_models'
|
||||||
DEFAULT_FEATURES = ['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
DEFAULT_FEATURES = ['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||||
|
|
||||||
# 时间序列参数
|
# 时间序列参数
|
||||||
|
@ -225,11 +225,13 @@ class PharmacyPredictor:
|
|||||||
result = trainer_function(**valid_args)
|
result = trainer_function(**valid_args)
|
||||||
|
|
||||||
# 根据返回值的数量解析metrics
|
# 根据返回值的数量解析metrics
|
||||||
if isinstance(result, tuple) and len(result) >= 2:
|
if isinstance(result, tuple) and len(result) >= 3:
|
||||||
metrics = result[1] # 通常第二个返回值是metrics
|
metrics = result[1] # 第二个返回值是metrics
|
||||||
|
artifacts = result[2] # 第三个返回值是artifacts
|
||||||
else:
|
else:
|
||||||
log_message(f"⚠️ 训练器返回格式未知,无法直接提取metrics: {type(result)}", 'warning')
|
log_message(f"⚠️ 训练器返回格式未知,无法直接提取metrics和artifacts: {type(result)}", 'warning')
|
||||||
metrics = None
|
metrics = None
|
||||||
|
artifacts = None
|
||||||
|
|
||||||
|
|
||||||
# 检查和打印返回的metrics
|
# 检查和打印返回的metrics
|
||||||
@ -249,7 +251,7 @@ class PharmacyPredictor:
|
|||||||
else:
|
else:
|
||||||
log_message(f"⚠️ metrics为空或None", 'warning')
|
log_message(f"⚠️ metrics为空或None", 'warning')
|
||||||
|
|
||||||
return metrics
|
return metrics, artifacts
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log_message(f"模型训练失败: {e}", 'error')
|
log_message(f"模型训练失败: {e}", 'error')
|
||||||
|
@ -65,9 +65,21 @@ def train_with_cnn_bilstm_attention(product_id, model_identifier, product_df, st
|
|||||||
|
|
||||||
# --- 3. 训练循环与早停 ---
|
# --- 3. 训练循环与早停 ---
|
||||||
print("开始训练 CNN-BiLSTM-Attention 模型 (含早停)...")
|
print("开始训练 CNN-BiLSTM-Attention 模型 (含早停)...")
|
||||||
|
|
||||||
|
# 版本锁定:在训练开始前确定本次训练的版本号
|
||||||
|
current_version = model_manager.peek_next_version(
|
||||||
|
model_type='cnn_bilstm_attention',
|
||||||
|
product_id=product_id,
|
||||||
|
store_id=store_id,
|
||||||
|
training_mode=training_mode,
|
||||||
|
aggregation_method=aggregation_method
|
||||||
|
)
|
||||||
|
print(f"🔒 本次训练版本锁定为: {current_version}")
|
||||||
|
|
||||||
loss_history = {'train': [], 'val': []}
|
loss_history = {'train': [], 'val': []}
|
||||||
best_val_loss = float('inf')
|
best_val_loss = float('inf')
|
||||||
best_model_state = None
|
best_model_state = None
|
||||||
|
best_model_path = None # 用于存储最佳模型的路径
|
||||||
patience = kwargs.get('patience', 15)
|
patience = kwargs.get('patience', 15)
|
||||||
patience_counter = 0
|
patience_counter = 0
|
||||||
|
|
||||||
@ -99,6 +111,31 @@ def train_with_cnn_bilstm_attention(product_id, model_identifier, product_df, st
|
|||||||
best_model_state = copy.deepcopy(model.state_dict())
|
best_model_state = copy.deepcopy(model.state_dict())
|
||||||
patience_counter = 0
|
patience_counter = 0
|
||||||
print(f"✨ 新的最佳模型! Epoch: {epoch+1}, Val Loss: {best_val_loss:.4f}")
|
print(f"✨ 新的最佳模型! Epoch: {epoch+1}, Val Loss: {best_val_loss:.4f}")
|
||||||
|
|
||||||
|
# 立即保存最佳模型
|
||||||
|
best_model_data = {
|
||||||
|
'model_state_dict': best_model_state,
|
||||||
|
'scaler_X': scaler_X,
|
||||||
|
'scaler_y': scaler_y,
|
||||||
|
'config': {
|
||||||
|
'model_type': 'cnn_bilstm_attention',
|
||||||
|
'input_dim': input_dim,
|
||||||
|
'output_dim': forecast_horizon,
|
||||||
|
'sequence_length': sequence_length,
|
||||||
|
'features': features
|
||||||
|
},
|
||||||
|
'epoch': epoch + 1
|
||||||
|
}
|
||||||
|
best_model_path, _ = model_manager.save_model(
|
||||||
|
model_data=best_model_data,
|
||||||
|
product_id=product_id,
|
||||||
|
model_type='cnn_bilstm_attention',
|
||||||
|
store_id=store_id,
|
||||||
|
training_mode=training_mode,
|
||||||
|
aggregation_method=aggregation_method,
|
||||||
|
product_name=product_name,
|
||||||
|
version=f"{current_version}_best"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
patience_counter += 1
|
patience_counter += 1
|
||||||
if patience_counter >= patience:
|
if patience_counter >= patience:
|
||||||
@ -128,17 +165,32 @@ def train_with_cnn_bilstm_attention(product_id, model_identifier, product_df, st
|
|||||||
print("\n最佳模型评估指标:")
|
print("\n最佳模型评估指标:")
|
||||||
print(f"MSE: {metrics['mse']:.4f}, RMSE: {metrics['rmse']:.4f}, MAE: {metrics['mae']:.4f}, R²: {metrics['r2']:.4f}, MAPE: {metrics['mape']:.2f}%")
|
print(f"MSE: {metrics['mse']:.4f}, RMSE: {metrics['rmse']:.4f}, MAE: {metrics['mae']:.4f}, R²: {metrics['r2']:.4f}, MAPE: {metrics['mape']:.2f}%")
|
||||||
|
|
||||||
# 绘制损失曲线
|
# --- 5. 保存工件 ---
|
||||||
|
|
||||||
|
# 准备 scope 和 identifier 以生成标准化的文件名
|
||||||
|
scope = training_mode
|
||||||
|
if scope == 'product':
|
||||||
|
identifier = product_id
|
||||||
|
elif scope == 'store':
|
||||||
|
identifier = store_id
|
||||||
|
elif scope == 'global':
|
||||||
|
identifier = aggregation_method
|
||||||
|
else:
|
||||||
|
identifier = product_name # 后备方案
|
||||||
|
|
||||||
|
# 绘制带有版本号的损失曲线图
|
||||||
loss_curve_path = plot_loss_curve(
|
loss_curve_path = plot_loss_curve(
|
||||||
loss_history['train'],
|
train_losses=loss_history['train'],
|
||||||
loss_history['val'],
|
val_losses=loss_history['val'],
|
||||||
product_name,
|
model_type='cnn_bilstm_attention',
|
||||||
'cnn_bilstm_attention',
|
scope=scope,
|
||||||
|
identifier=identifier,
|
||||||
|
version=current_version, # 使用锁定的版本
|
||||||
model_dir=model_dir
|
model_dir=model_dir
|
||||||
)
|
)
|
||||||
print(f"📈 损失曲线已保存到: {loss_curve_path}")
|
print(f"📈 带版本号的损失曲线已保存: {loss_curve_path}")
|
||||||
|
|
||||||
# --- 5. 模型保存 ---
|
# 准备要保存的最终模型数据
|
||||||
model_data = {
|
model_data = {
|
||||||
'model_state_dict': best_model_state, # 保存最佳模型的状态
|
'model_state_dict': best_model_state, # 保存最佳模型的状态
|
||||||
'scaler_X': scaler_X,
|
'scaler_X': scaler_X,
|
||||||
@ -151,11 +203,11 @@ def train_with_cnn_bilstm_attention(product_id, model_identifier, product_df, st
|
|||||||
'features': features
|
'features': features
|
||||||
},
|
},
|
||||||
'metrics': metrics,
|
'metrics': metrics,
|
||||||
'loss_history': loss_history, # 保存损失历史
|
'loss_history': loss_history,
|
||||||
'loss_curve_path': loss_curve_path # 添加损失图路径
|
'loss_curve_path': loss_curve_path # 直接包含路径
|
||||||
}
|
}
|
||||||
|
|
||||||
# 保存最终版本模型
|
# 使用模型管理器保存最终模型
|
||||||
final_model_path, final_version = model_manager.save_model(
|
final_model_path, final_version = model_manager.save_model(
|
||||||
model_data=model_data,
|
model_data=model_data,
|
||||||
product_id=product_id,
|
product_id=product_id,
|
||||||
@ -163,24 +215,20 @@ def train_with_cnn_bilstm_attention(product_id, model_identifier, product_df, st
|
|||||||
store_id=store_id,
|
store_id=store_id,
|
||||||
training_mode=training_mode,
|
training_mode=training_mode,
|
||||||
aggregation_method=aggregation_method,
|
aggregation_method=aggregation_method,
|
||||||
product_name=product_name
|
product_name=product_name,
|
||||||
|
version=current_version # 使用锁定的版本号
|
||||||
)
|
)
|
||||||
print(f"✅ CNN-BiLSTM-Attention 最终模型已保存,版本: {final_version}")
|
print(f"✅ CNN-BiLSTM-Attention 最终模型已保存,版本: {final_version}")
|
||||||
|
|
||||||
# 保存最佳版本模型
|
# 组装返回的工件
|
||||||
best_model_path, best_version = model_manager.save_model(
|
artifacts = {
|
||||||
model_data=model_data,
|
"versioned_model": final_model_path,
|
||||||
product_id=product_id,
|
"loss_curve_plot": loss_curve_path,
|
||||||
model_type='cnn_bilstm_attention',
|
"best_model": best_model_path,
|
||||||
store_id=store_id,
|
"version": final_version
|
||||||
training_mode=training_mode,
|
}
|
||||||
aggregation_method=aggregation_method,
|
|
||||||
product_name=product_name,
|
|
||||||
version='best' # 明确指定版本为 'best'
|
|
||||||
)
|
|
||||||
print(f"✅ CNN-BiLSTM-Attention 最佳模型已保存,版本: {best_version}")
|
|
||||||
|
|
||||||
return model, metrics, final_version, final_model_path
|
return model, metrics, artifacts
|
||||||
|
|
||||||
# --- 关键步骤: 将训练器注册到系统中 ---
|
# --- 关键步骤: 将训练器注册到系统中 ---
|
||||||
register_trainer('cnn_bilstm_attention', train_with_cnn_bilstm_attention)
|
register_trainer('cnn_bilstm_attention', train_with_cnn_bilstm_attention)
|
@ -165,10 +165,22 @@ def train_product_model_with_kan(product_id, model_identifier, product_df=None,
|
|||||||
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||||
|
|
||||||
# 训练模型
|
# 训练模型
|
||||||
|
from utils.model_manager import model_manager
|
||||||
|
model_type_name = 'optimized_kan' if use_optimized else 'kan'
|
||||||
|
current_version = model_manager.peek_next_version(
|
||||||
|
model_type=model_type_name,
|
||||||
|
product_id=model_identifier,
|
||||||
|
store_id=store_id,
|
||||||
|
training_mode=training_mode,
|
||||||
|
aggregation_method=aggregation_method
|
||||||
|
)
|
||||||
|
print(f"🔒 本次训练版本锁定为: {current_version}")
|
||||||
|
|
||||||
train_losses = []
|
train_losses = []
|
||||||
test_losses = []
|
test_losses = []
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
best_loss = float('inf')
|
best_loss = float('inf')
|
||||||
|
best_model_path = None
|
||||||
|
|
||||||
for epoch in range(epochs):
|
for epoch in range(epochs):
|
||||||
model.train()
|
model.train()
|
||||||
@ -253,7 +265,7 @@ def train_product_model_with_kan(product_id, model_identifier, product_df=None,
|
|||||||
|
|
||||||
# 使用模型管理器保存 'best' 版本
|
# 使用模型管理器保存 'best' 版本
|
||||||
from utils.model_manager import model_manager
|
from utils.model_manager import model_manager
|
||||||
model_manager.save_model(
|
best_model_path, _ = model_manager.save_model(
|
||||||
model_data=best_model_data,
|
model_data=best_model_data,
|
||||||
product_id=model_identifier, # 修正:使用唯一的标识符
|
product_id=model_identifier, # 修正:使用唯一的标识符
|
||||||
model_type=model_type_name,
|
model_type=model_type_name,
|
||||||
@ -261,7 +273,7 @@ def train_product_model_with_kan(product_id, model_identifier, product_df=None,
|
|||||||
training_mode=training_mode,
|
training_mode=training_mode,
|
||||||
aggregation_method=aggregation_method,
|
aggregation_method=aggregation_method,
|
||||||
product_name=product_name,
|
product_name=product_name,
|
||||||
version='best' # 显式覆盖版本为'best'
|
version=f"{current_version}_best"
|
||||||
)
|
)
|
||||||
|
|
||||||
if (epoch + 1) % 10 == 0:
|
if (epoch + 1) % 10 == 0:
|
||||||
@ -271,15 +283,6 @@ def train_product_model_with_kan(product_id, model_identifier, product_df=None,
|
|||||||
training_time = time.time() - start_time
|
training_time = time.time() - start_time
|
||||||
|
|
||||||
# 绘制损失曲线并保存到模型目录
|
# 绘制损失曲线并保存到模型目录
|
||||||
model_name = 'optimized_kan' if use_optimized else 'kan'
|
|
||||||
loss_curve_path = plot_loss_curve(
|
|
||||||
train_losses,
|
|
||||||
test_losses,
|
|
||||||
product_name,
|
|
||||||
model_type,
|
|
||||||
model_dir=model_dir
|
|
||||||
)
|
|
||||||
print(f"损失曲线已保存到: {loss_curve_path}")
|
|
||||||
|
|
||||||
# 评估模型
|
# 评估模型
|
||||||
model.eval()
|
model.eval()
|
||||||
@ -307,11 +310,33 @@ def train_product_model_with_kan(product_id, model_identifier, product_df=None,
|
|||||||
print(f"MAPE: {metrics['mape']:.2f}%")
|
print(f"MAPE: {metrics['mape']:.2f}%")
|
||||||
print(f"训练时间: {training_time:.2f}秒")
|
print(f"训练时间: {training_time:.2f}秒")
|
||||||
|
|
||||||
# 使用统一模型管理器保存模型
|
# --- 5. 保存工件 ---
|
||||||
from utils.model_manager import model_manager
|
|
||||||
|
|
||||||
model_type_name = 'optimized_kan' if use_optimized else 'kan'
|
model_type_name = 'optimized_kan' if use_optimized else 'kan'
|
||||||
|
|
||||||
|
# 准备 scope 和 identifier 以生成标准化的文件名
|
||||||
|
scope = training_mode
|
||||||
|
if scope == 'product':
|
||||||
|
identifier = model_identifier
|
||||||
|
elif scope == 'store':
|
||||||
|
identifier = store_id
|
||||||
|
elif scope == 'global':
|
||||||
|
identifier = aggregation_method
|
||||||
|
else:
|
||||||
|
identifier = product_name # 后备方案
|
||||||
|
|
||||||
|
# 绘制带有版本号的损失曲线图
|
||||||
|
loss_curve_path = plot_loss_curve(
|
||||||
|
train_losses=train_losses,
|
||||||
|
val_losses=test_losses,
|
||||||
|
model_type=model_type_name,
|
||||||
|
scope=scope,
|
||||||
|
identifier=identifier,
|
||||||
|
version=current_version, # 使用锁定的版本
|
||||||
|
model_dir=model_dir
|
||||||
|
)
|
||||||
|
print(f"📈 带版本号的损失曲线已保存: {loss_curve_path}")
|
||||||
|
|
||||||
|
# 准备要保存的最终模型数据
|
||||||
model_data = {
|
model_data = {
|
||||||
'model_state_dict': model.state_dict(),
|
'model_state_dict': model.state_dict(),
|
||||||
'scaler_X': scaler_X,
|
'scaler_X': scaler_X,
|
||||||
@ -332,24 +357,32 @@ def train_product_model_with_kan(product_id, model_identifier, product_df=None,
|
|||||||
'test': test_losses,
|
'test': test_losses,
|
||||||
'epochs': list(range(1, epochs + 1))
|
'epochs': list(range(1, epochs + 1))
|
||||||
},
|
},
|
||||||
'loss_curve_path': loss_curve_path
|
'loss_curve_path': loss_curve_path # 直接包含路径
|
||||||
}
|
}
|
||||||
|
|
||||||
# 保存最终模型,让 model_manager 自动处理版本号
|
# 使用模型管理器保存最终模型
|
||||||
|
from utils.model_manager import model_manager
|
||||||
final_model_path, final_version = model_manager.save_model(
|
final_model_path, final_version = model_manager.save_model(
|
||||||
model_data=model_data,
|
model_data=model_data,
|
||||||
product_id=model_identifier, # 修正:使用唯一的标识符
|
product_id=model_identifier,
|
||||||
model_type=model_type_name,
|
model_type=model_type_name,
|
||||||
store_id=store_id,
|
store_id=store_id,
|
||||||
training_mode=training_mode,
|
training_mode=training_mode,
|
||||||
aggregation_method=aggregation_method,
|
aggregation_method=aggregation_method,
|
||||||
product_name=product_name
|
product_name=product_name,
|
||||||
# 注意:此处不传递version参数,由管理器自动生成
|
version=current_version # 使用锁定的版本
|
||||||
)
|
)
|
||||||
|
print(f"✅ {model_type_name} 最终模型已保存,版本: {final_version}")
|
||||||
|
|
||||||
print(f"最终模型已保存,版本: {final_version}, 路径: {final_model_path}")
|
# 组装返回的工件
|
||||||
|
artifacts = {
|
||||||
|
"versioned_model": final_model_path,
|
||||||
|
"loss_curve_plot": loss_curve_path,
|
||||||
|
"best_model": best_model_path,
|
||||||
|
"version": final_version
|
||||||
|
}
|
||||||
|
|
||||||
return model, metrics
|
return model, metrics, artifacts
|
||||||
|
|
||||||
# --- 将此训练器注册到系统中 ---
|
# --- 将此训练器注册到系统中 ---
|
||||||
from models.model_registry import register_trainer
|
from models.model_registry import register_trainer
|
||||||
|
@ -255,6 +255,16 @@ def train_product_model_with_mlstm(
|
|||||||
emit_progress("数据预处理完成,开始模型训练...", progress=10)
|
emit_progress("数据预处理完成,开始模型训练...", progress=10)
|
||||||
|
|
||||||
# 训练模型
|
# 训练模型
|
||||||
|
# 版本锁定
|
||||||
|
current_version = model_manager.peek_next_version(
|
||||||
|
model_type='mlstm',
|
||||||
|
product_id=model_identifier,
|
||||||
|
store_id=store_id,
|
||||||
|
training_mode=training_mode,
|
||||||
|
aggregation_method=aggregation_method
|
||||||
|
)
|
||||||
|
print(f"🔒 本次训练版本锁定为: {current_version}")
|
||||||
|
|
||||||
train_losses = []
|
train_losses = []
|
||||||
test_losses = []
|
test_losses = []
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@ -263,6 +273,7 @@ def train_product_model_with_mlstm(
|
|||||||
checkpoint_interval = max(1, epochs // 10) # 每10%进度保存一次,最少每1个epoch
|
checkpoint_interval = max(1, epochs // 10) # 每10%进度保存一次,最少每1个epoch
|
||||||
best_loss = float('inf')
|
best_loss = float('inf')
|
||||||
epochs_no_improve = 0
|
epochs_no_improve = 0
|
||||||
|
best_model_path = None
|
||||||
|
|
||||||
emit_progress(f"开始训练 - 总epoch: {epochs}, 检查点间隔: {checkpoint_interval}, 耐心值: {patience}")
|
emit_progress(f"开始训练 - 总epoch: {epochs}, 检查点间隔: {checkpoint_interval}, 耐心值: {patience}")
|
||||||
|
|
||||||
@ -365,7 +376,7 @@ def train_product_model_with_mlstm(
|
|||||||
# 如果是最佳模型,额外保存一份
|
# 如果是最佳模型,额外保存一份
|
||||||
if test_loss < best_loss:
|
if test_loss < best_loss:
|
||||||
best_loss = test_loss
|
best_loss = test_loss
|
||||||
model_manager.save_model(
|
best_model_path, _ = model_manager.save_model(
|
||||||
model_data=checkpoint_data,
|
model_data=checkpoint_data,
|
||||||
product_id=model_identifier, # 修正:使用唯一的标识符
|
product_id=model_identifier, # 修正:使用唯一的标识符
|
||||||
model_type='mlstm',
|
model_type='mlstm',
|
||||||
@ -373,7 +384,7 @@ def train_product_model_with_mlstm(
|
|||||||
training_mode=training_mode,
|
training_mode=training_mode,
|
||||||
aggregation_method=aggregation_method,
|
aggregation_method=aggregation_method,
|
||||||
product_name=product_name,
|
product_name=product_name,
|
||||||
version='best'
|
version=f"{current_version}_best"
|
||||||
)
|
)
|
||||||
emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})")
|
emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})")
|
||||||
epochs_no_improve = 0
|
epochs_no_improve = 0
|
||||||
@ -391,37 +402,6 @@ def train_product_model_with_mlstm(
|
|||||||
# 计算训练时间
|
# 计算训练时间
|
||||||
training_time = time.time() - start_time
|
training_time = time.time() - start_time
|
||||||
|
|
||||||
emit_progress("生成损失曲线...", progress=95)
|
|
||||||
|
|
||||||
# 确定模型保存目录(支持多店铺)
|
|
||||||
if store_id:
|
|
||||||
# 为特定店铺创建子目录
|
|
||||||
store_model_dir = os.path.join(model_dir, 'mlstm', store_id)
|
|
||||||
os.makedirs(store_model_dir, exist_ok=True)
|
|
||||||
loss_curve_filename = f"{product_id}_mlstm_{version}_loss_curve.png"
|
|
||||||
loss_curve_path = os.path.join(store_model_dir, loss_curve_filename)
|
|
||||||
else:
|
|
||||||
# 全局模型保存在global目录
|
|
||||||
global_model_dir = os.path.join(model_dir, 'mlstm', 'global')
|
|
||||||
os.makedirs(global_model_dir, exist_ok=True)
|
|
||||||
loss_curve_filename = f"{product_id}_mlstm_{version}_global_loss_curve.png"
|
|
||||||
loss_curve_path = os.path.join(global_model_dir, loss_curve_filename)
|
|
||||||
|
|
||||||
# 绘制损失曲线并保存到模型目录
|
|
||||||
plt.figure(figsize=(10, 6))
|
|
||||||
plt.plot(train_losses, label='Training Loss')
|
|
||||||
plt.plot(test_losses, label='Test Loss')
|
|
||||||
title_suffix = f" - {training_scope}" if store_id else " - 全局模型"
|
|
||||||
plt.title(f'mLSTM 模型训练损失曲线 - {product_name} ({version}){title_suffix}')
|
|
||||||
plt.xlabel('Epoch')
|
|
||||||
plt.ylabel('Loss')
|
|
||||||
plt.legend()
|
|
||||||
plt.grid(True)
|
|
||||||
plt.savefig(loss_curve_path, dpi=300, bbox_inches='tight')
|
|
||||||
plt.close()
|
|
||||||
|
|
||||||
print(f"损失曲线已保存到: {loss_curve_path}")
|
|
||||||
|
|
||||||
emit_progress("模型评估中...", progress=98)
|
emit_progress("模型评估中...", progress=98)
|
||||||
|
|
||||||
# 评估模型
|
# 评估模型
|
||||||
@ -475,7 +455,7 @@ def train_product_model_with_mlstm(
|
|||||||
'model_type': 'mlstm'
|
'model_type': 'mlstm'
|
||||||
},
|
},
|
||||||
'metrics': metrics,
|
'metrics': metrics,
|
||||||
'loss_curve_path': loss_curve_path,
|
'metrics': metrics,
|
||||||
'training_info': {
|
'training_info': {
|
||||||
'product_id': product_id,
|
'product_id': product_id,
|
||||||
'product_name': product_name,
|
'product_name': product_name,
|
||||||
@ -488,6 +468,31 @@ def train_product_model_with_mlstm(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# 准备 scope 和 identifier 以生成标准化的文件名
|
||||||
|
scope = training_mode
|
||||||
|
if scope == 'product':
|
||||||
|
identifier = model_identifier
|
||||||
|
elif scope == 'store':
|
||||||
|
identifier = store_id
|
||||||
|
elif scope == 'global':
|
||||||
|
identifier = aggregation_method
|
||||||
|
else:
|
||||||
|
identifier = product_name # 后备方案
|
||||||
|
|
||||||
|
# 绘制带有版本号的损失曲线图
|
||||||
|
loss_curve_path = plot_loss_curve(
|
||||||
|
train_losses=train_losses,
|
||||||
|
val_losses=test_losses,
|
||||||
|
model_type='mlstm',
|
||||||
|
scope=scope,
|
||||||
|
identifier=identifier,
|
||||||
|
version=current_version, # 使用锁定的版本
|
||||||
|
model_dir=model_dir
|
||||||
|
)
|
||||||
|
print(f"📈 带版本号的损失曲线已保存: {loss_curve_path}")
|
||||||
|
|
||||||
|
# 更新模型数据中的损失图路径
|
||||||
|
final_model_data['loss_curve_path'] = loss_curve_path
|
||||||
# 保存最终模型,让 model_manager 自动处理版本号
|
# 保存最终模型,让 model_manager 自动处理版本号
|
||||||
final_model_path, final_version = model_manager.save_model(
|
final_model_path, final_version = model_manager.save_model(
|
||||||
model_data=final_model_data,
|
model_data=final_model_data,
|
||||||
@ -496,7 +501,8 @@ def train_product_model_with_mlstm(
|
|||||||
store_id=store_id,
|
store_id=store_id,
|
||||||
training_mode=training_mode,
|
training_mode=training_mode,
|
||||||
aggregation_method=aggregation_method,
|
aggregation_method=aggregation_method,
|
||||||
product_name=product_name
|
product_name=product_name,
|
||||||
|
version=current_version
|
||||||
)
|
)
|
||||||
|
|
||||||
# 发送训练完成消息
|
# 发送训练完成消息
|
||||||
@ -514,7 +520,16 @@ def train_product_model_with_mlstm(
|
|||||||
|
|
||||||
emit_progress(f"✅ mLSTM模型训练完成!版本 {final_version} 已保存", progress=100, metrics=final_metrics)
|
emit_progress(f"✅ mLSTM模型训练完成!版本 {final_version} 已保存", progress=100, metrics=final_metrics)
|
||||||
|
|
||||||
return model, metrics, epochs, final_model_path
|
# 组装 artifacts 字典
|
||||||
|
artifacts = {
|
||||||
|
"versioned_model": final_model_path,
|
||||||
|
"loss_curve_plot": loss_curve_path,
|
||||||
|
# 假设 best model 的路径可以从 model_manager 获取或推断
|
||||||
|
"best_model": best_model_path,
|
||||||
|
"version": final_version
|
||||||
|
}
|
||||||
|
|
||||||
|
return model, metrics, artifacts
|
||||||
|
|
||||||
# --- 将此训练器注册到系统中 ---
|
# --- 将此训练器注册到系统中 ---
|
||||||
from models.model_registry import register_trainer
|
from models.model_registry import register_trainer
|
||||||
|
@ -164,8 +164,19 @@ def train_product_model_with_tcn(
|
|||||||
test_losses = []
|
test_losses = []
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
# 版本锁定
|
||||||
|
current_version = model_manager.peek_next_version(
|
||||||
|
model_type='tcn',
|
||||||
|
product_id=model_identifier,
|
||||||
|
store_id=store_id,
|
||||||
|
training_mode=training_mode,
|
||||||
|
aggregation_method=aggregation_method
|
||||||
|
)
|
||||||
|
print(f"🔒 本次训练版本锁定为: {current_version}")
|
||||||
|
|
||||||
checkpoint_interval = max(1, epochs // 10)
|
checkpoint_interval = max(1, epochs // 10)
|
||||||
best_loss = float('inf')
|
best_loss = float('inf')
|
||||||
|
best_model_path = None
|
||||||
|
|
||||||
progress_manager.set_stage("model_training", 0)
|
progress_manager.set_stage("model_training", 0)
|
||||||
emit_progress(f"开始训练 - 总epoch: {epochs}, 检查点间隔: {checkpoint_interval}")
|
emit_progress(f"开始训练 - 总epoch: {epochs}, 检查点间隔: {checkpoint_interval}")
|
||||||
@ -269,7 +280,7 @@ def train_product_model_with_tcn(
|
|||||||
|
|
||||||
if test_loss < best_loss:
|
if test_loss < best_loss:
|
||||||
best_loss = test_loss
|
best_loss = test_loss
|
||||||
model_manager.save_model(
|
best_model_path, _ = model_manager.save_model(
|
||||||
model_data=checkpoint_data,
|
model_data=checkpoint_data,
|
||||||
product_id=model_identifier, # 修正:使用唯一的标识符
|
product_id=model_identifier, # 修正:使用唯一的标识符
|
||||||
model_type='tcn',
|
model_type='tcn',
|
||||||
@ -277,7 +288,7 @@ def train_product_model_with_tcn(
|
|||||||
training_mode=training_mode,
|
training_mode=training_mode,
|
||||||
aggregation_method=aggregation_method,
|
aggregation_method=aggregation_method,
|
||||||
product_name=product_name,
|
product_name=product_name,
|
||||||
version='best'
|
version=f"{current_version}_best"
|
||||||
)
|
)
|
||||||
emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})")
|
emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})")
|
||||||
|
|
||||||
@ -289,14 +300,6 @@ def train_product_model_with_tcn(
|
|||||||
progress_manager.set_stage("model_saving", 0)
|
progress_manager.set_stage("model_saving", 0)
|
||||||
emit_progress("训练完成,正在保存模型...")
|
emit_progress("训练完成,正在保存模型...")
|
||||||
|
|
||||||
loss_curve_path = plot_loss_curve(
|
|
||||||
train_losses,
|
|
||||||
test_losses,
|
|
||||||
product_name,
|
|
||||||
'TCN',
|
|
||||||
model_dir=model_dir
|
|
||||||
)
|
|
||||||
print(f"损失曲线已保存到: {loss_curve_path}")
|
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@ -340,7 +343,7 @@ def train_product_model_with_tcn(
|
|||||||
'model_type': 'tcn'
|
'model_type': 'tcn'
|
||||||
},
|
},
|
||||||
'metrics': metrics,
|
'metrics': metrics,
|
||||||
'loss_curve_path': loss_curve_path,
|
'metrics': metrics,
|
||||||
'training_info': {
|
'training_info': {
|
||||||
'product_id': product_id,
|
'product_id': product_id,
|
||||||
'product_name': product_name,
|
'product_name': product_name,
|
||||||
@ -361,7 +364,8 @@ def train_product_model_with_tcn(
|
|||||||
store_id=store_id,
|
store_id=store_id,
|
||||||
training_mode=training_mode,
|
training_mode=training_mode,
|
||||||
aggregation_method=aggregation_method,
|
aggregation_method=aggregation_method,
|
||||||
product_name=product_name
|
product_name=product_name,
|
||||||
|
version=current_version
|
||||||
)
|
)
|
||||||
|
|
||||||
progress_manager.set_stage("model_saving", 100)
|
progress_manager.set_stage("model_saving", 100)
|
||||||
@ -379,7 +383,40 @@ def train_product_model_with_tcn(
|
|||||||
|
|
||||||
emit_progress(f"模型训练完成!版本 {final_version} 已保存", progress=100, metrics=final_metrics)
|
emit_progress(f"模型训练完成!版本 {final_version} 已保存", progress=100, metrics=final_metrics)
|
||||||
|
|
||||||
return model, metrics, epochs, final_model_path
|
# 准备 scope 和 identifier 以生成标准化的文件名
|
||||||
|
scope = training_mode
|
||||||
|
if scope == 'product':
|
||||||
|
identifier = model_identifier
|
||||||
|
elif scope == 'store':
|
||||||
|
identifier = store_id
|
||||||
|
elif scope == 'global':
|
||||||
|
identifier = aggregation_method
|
||||||
|
else:
|
||||||
|
identifier = product_name # 后备方案
|
||||||
|
|
||||||
|
# 绘制带有版本号的损失曲线图
|
||||||
|
loss_curve_path = plot_loss_curve(
|
||||||
|
train_losses=train_losses,
|
||||||
|
val_losses=test_losses,
|
||||||
|
model_type='tcn',
|
||||||
|
scope=scope,
|
||||||
|
identifier=identifier,
|
||||||
|
version=current_version, # 使用锁定的版本
|
||||||
|
model_dir=model_dir
|
||||||
|
)
|
||||||
|
print(f"📈 带版本号的损失曲线已保存: {loss_curve_path}")
|
||||||
|
|
||||||
|
# 更新模型数据中的损失图路径
|
||||||
|
final_model_data['loss_curve_path'] = loss_curve_path
|
||||||
|
|
||||||
|
artifacts = {
|
||||||
|
"versioned_model": final_model_path,
|
||||||
|
"loss_curve_plot": loss_curve_path,
|
||||||
|
"best_model": best_model_path,
|
||||||
|
"version": final_version
|
||||||
|
}
|
||||||
|
|
||||||
|
return model, metrics, artifacts
|
||||||
|
|
||||||
# --- 将此训练器注册到系统中 ---
|
# --- 将此训练器注册到系统中 ---
|
||||||
from models.model_registry import register_trainer
|
from models.model_registry import register_trainer
|
||||||
|
@ -187,9 +187,20 @@ def train_product_model_with_transformer(
|
|||||||
test_losses = []
|
test_losses = []
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
# 版本锁定
|
||||||
|
current_version = model_manager.peek_next_version(
|
||||||
|
model_type='transformer',
|
||||||
|
product_id=model_identifier,
|
||||||
|
store_id=store_id,
|
||||||
|
training_mode=training_mode,
|
||||||
|
aggregation_method=aggregation_method
|
||||||
|
)
|
||||||
|
print(f"🔒 本次训练版本锁定为: {current_version}")
|
||||||
|
|
||||||
checkpoint_interval = max(1, epochs // 10)
|
checkpoint_interval = max(1, epochs // 10)
|
||||||
best_loss = float('inf')
|
best_loss = float('inf')
|
||||||
epochs_no_improve = 0
|
epochs_no_improve = 0
|
||||||
|
best_model_path = None
|
||||||
|
|
||||||
progress_manager.set_stage("model_training", 0)
|
progress_manager.set_stage("model_training", 0)
|
||||||
emit_progress(f"开始训练 - 总epoch: {epochs}, 检查点间隔: {checkpoint_interval}, 耐心值: {patience}")
|
emit_progress(f"开始训练 - 总epoch: {epochs}, 检查点间隔: {checkpoint_interval}, 耐心值: {patience}")
|
||||||
@ -289,7 +300,8 @@ def train_product_model_with_transformer(
|
|||||||
|
|
||||||
if test_loss < best_loss:
|
if test_loss < best_loss:
|
||||||
best_loss = test_loss
|
best_loss = test_loss
|
||||||
model_manager.save_model(
|
# 修正: 保存最佳模型路径
|
||||||
|
best_model_path, _ = model_manager.save_model(
|
||||||
model_data=checkpoint_data,
|
model_data=checkpoint_data,
|
||||||
product_id=model_identifier, # 修正:使用唯一的标识符
|
product_id=model_identifier, # 修正:使用唯一的标识符
|
||||||
model_type='transformer',
|
model_type='transformer',
|
||||||
@ -297,7 +309,7 @@ def train_product_model_with_transformer(
|
|||||||
training_mode=training_mode,
|
training_mode=training_mode,
|
||||||
aggregation_method=aggregation_method,
|
aggregation_method=aggregation_method,
|
||||||
product_name=product_name,
|
product_name=product_name,
|
||||||
version='best'
|
version=f"{current_version}_best"
|
||||||
)
|
)
|
||||||
emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})")
|
emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})")
|
||||||
epochs_no_improve = 0
|
epochs_no_improve = 0
|
||||||
@ -316,14 +328,6 @@ def train_product_model_with_transformer(
|
|||||||
progress_manager.set_stage("model_saving", 0)
|
progress_manager.set_stage("model_saving", 0)
|
||||||
emit_progress("训练完成,正在保存模型...")
|
emit_progress("训练完成,正在保存模型...")
|
||||||
|
|
||||||
loss_curve_path = plot_loss_curve(
|
|
||||||
train_losses,
|
|
||||||
test_losses,
|
|
||||||
product_name,
|
|
||||||
'Transformer',
|
|
||||||
model_dir=model_dir
|
|
||||||
)
|
|
||||||
print(f"📈 损失曲线已保存到: {loss_curve_path}", flush=True)
|
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@ -366,7 +370,7 @@ def train_product_model_with_transformer(
|
|||||||
'model_type': 'transformer'
|
'model_type': 'transformer'
|
||||||
},
|
},
|
||||||
'metrics': metrics,
|
'metrics': metrics,
|
||||||
'loss_curve_path': loss_curve_path,
|
'metrics': metrics,
|
||||||
'training_info': {
|
'training_info': {
|
||||||
'product_id': product_id,
|
'product_id': product_id,
|
||||||
'product_name': product_name,
|
'product_name': product_name,
|
||||||
@ -387,7 +391,8 @@ def train_product_model_with_transformer(
|
|||||||
store_id=store_id,
|
store_id=store_id,
|
||||||
training_mode=training_mode,
|
training_mode=training_mode,
|
||||||
aggregation_method=aggregation_method,
|
aggregation_method=aggregation_method,
|
||||||
product_name=product_name
|
product_name=product_name,
|
||||||
|
version=current_version
|
||||||
)
|
)
|
||||||
|
|
||||||
progress_manager.set_stage("model_saving", 100)
|
progress_manager.set_stage("model_saving", 100)
|
||||||
@ -406,7 +411,40 @@ def train_product_model_with_transformer(
|
|||||||
'version': final_version
|
'version': final_version
|
||||||
}
|
}
|
||||||
|
|
||||||
return model, final_metrics, epochs
|
# 准备 scope 和 identifier 以生成标准化的文件名
|
||||||
|
scope = training_mode
|
||||||
|
if scope == 'product':
|
||||||
|
identifier = model_identifier
|
||||||
|
elif scope == 'store':
|
||||||
|
identifier = store_id
|
||||||
|
elif scope == 'global':
|
||||||
|
identifier = aggregation_method
|
||||||
|
else:
|
||||||
|
identifier = product_name # 后备方案
|
||||||
|
|
||||||
|
# 绘制带有版本号的损失曲线图
|
||||||
|
loss_curve_path = plot_loss_curve(
|
||||||
|
train_losses=train_losses,
|
||||||
|
val_losses=test_losses,
|
||||||
|
model_type='transformer',
|
||||||
|
scope=scope,
|
||||||
|
identifier=identifier,
|
||||||
|
version=current_version, # 使用锁定的版本
|
||||||
|
model_dir=model_dir
|
||||||
|
)
|
||||||
|
print(f"📈 带版本号的损失曲线已保存: {loss_curve_path}")
|
||||||
|
|
||||||
|
# 更新模型数据中的损失图路径
|
||||||
|
final_model_data['loss_curve_path'] = loss_curve_path
|
||||||
|
|
||||||
|
artifacts = {
|
||||||
|
"versioned_model": final_model_path,
|
||||||
|
"loss_curve_plot": loss_curve_path,
|
||||||
|
"best_model": best_model_path,
|
||||||
|
"version": final_version
|
||||||
|
}
|
||||||
|
|
||||||
|
return model, final_metrics, artifacts
|
||||||
|
|
||||||
# --- 将此训练器注册到系统中 ---
|
# --- 将此训练器注册到系统中 ---
|
||||||
from models.model_registry import register_trainer
|
from models.model_registry import register_trainer
|
||||||
|
@ -76,6 +76,17 @@ def train_product_model_with_xgboost(product_id, model_identifier, product_df, s
|
|||||||
n_estimators = kwargs.get('n_estimators', 500)
|
n_estimators = kwargs.get('n_estimators', 500)
|
||||||
|
|
||||||
print("开始训练XGBoost模型 (使用核心xgb.train API)...")
|
print("开始训练XGBoost模型 (使用核心xgb.train API)...")
|
||||||
|
|
||||||
|
# 版本锁定
|
||||||
|
current_version = model_manager.peek_next_version(
|
||||||
|
model_type='xgboost',
|
||||||
|
product_id=product_id,
|
||||||
|
store_id=store_id,
|
||||||
|
training_mode=training_mode,
|
||||||
|
aggregation_method=aggregation_method
|
||||||
|
)
|
||||||
|
print(f"🔒 本次训练版本锁定为: {current_version}")
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
evals_result = {}
|
evals_result = {}
|
||||||
@ -109,16 +120,33 @@ def train_product_model_with_xgboost(product_id, model_identifier, product_df, s
|
|||||||
# 提取损失并绘制曲线
|
# 提取损失并绘制曲线
|
||||||
train_losses = evals_result['train']['rmse']
|
train_losses = evals_result['train']['rmse']
|
||||||
test_losses = evals_result['test']['rmse']
|
test_losses = evals_result['test']['rmse']
|
||||||
|
|
||||||
|
# --- 5. 保存工件 ---
|
||||||
|
|
||||||
|
# 准备 scope 和 identifier 以生成标准化的文件名
|
||||||
|
scope = training_mode
|
||||||
|
if scope == 'product':
|
||||||
|
identifier = product_id
|
||||||
|
elif scope == 'store':
|
||||||
|
identifier = store_id
|
||||||
|
elif scope == 'global':
|
||||||
|
identifier = aggregation_method
|
||||||
|
else:
|
||||||
|
identifier = product_name # 后备方案
|
||||||
|
|
||||||
|
# 绘制带有版本号的损失曲线图
|
||||||
loss_curve_path = plot_loss_curve(
|
loss_curve_path = plot_loss_curve(
|
||||||
train_losses,
|
train_losses=train_losses,
|
||||||
test_losses,
|
val_losses=test_losses,
|
||||||
product_name,
|
model_type='xgboost',
|
||||||
'xgboost',
|
scope=scope,
|
||||||
|
identifier=identifier,
|
||||||
|
version=current_version, # 使用锁定的版本
|
||||||
model_dir=model_dir
|
model_dir=model_dir
|
||||||
)
|
)
|
||||||
print(f"📈 损失曲线已保存到: {loss_curve_path}")
|
print(f"📈 带版本号的损失曲线已保存: {loss_curve_path}")
|
||||||
|
|
||||||
# --- 5. 模型保存 (借道 utils.model_manager) ---
|
# 准备要保存的最终模型数据
|
||||||
model_data = {
|
model_data = {
|
||||||
'model_state_dict': model, # 直接保存模型对象
|
'model_state_dict': model, # 直接保存模型对象
|
||||||
'scaler_X': scaler_X,
|
'scaler_X': scaler_X,
|
||||||
@ -132,7 +160,7 @@ def train_product_model_with_xgboost(product_id, model_identifier, product_df, s
|
|||||||
},
|
},
|
||||||
'metrics': metrics,
|
'metrics': metrics,
|
||||||
'loss_history': evals_result,
|
'loss_history': evals_result,
|
||||||
'loss_curve_path': loss_curve_path # 添加损失图路径
|
'loss_curve_path': loss_curve_path # 直接包含路径
|
||||||
}
|
}
|
||||||
|
|
||||||
# 保存最终版本模型
|
# 保存最终版本模型
|
||||||
@ -143,12 +171,14 @@ def train_product_model_with_xgboost(product_id, model_identifier, product_df, s
|
|||||||
store_id=store_id,
|
store_id=store_id,
|
||||||
training_mode=training_mode,
|
training_mode=training_mode,
|
||||||
aggregation_method=aggregation_method,
|
aggregation_method=aggregation_method,
|
||||||
product_name=product_name
|
product_name=product_name,
|
||||||
|
version=current_version
|
||||||
)
|
)
|
||||||
print(f"✅ XGBoost最终模型已通过统一管理器保存,版本: {final_version}, 路径: {final_model_path}")
|
print(f"✅ XGBoost最终模型已通过统一管理器保存,版本: {final_version}")
|
||||||
|
|
||||||
# 保存最佳版本模型
|
# XGBoost的 `best_model` 就是它自己,因为 `xgb.train` 内部处理了早停。
|
||||||
best_model_path, best_version = model_manager.save_model(
|
# 我们创建一个指向最终模型的 "best" 版本文件,以保持接口一致性。
|
||||||
|
best_model_path, _ = model_manager.save_model(
|
||||||
model_data=model_data,
|
model_data=model_data,
|
||||||
product_id=product_id,
|
product_id=product_id,
|
||||||
model_type='xgboost',
|
model_type='xgboost',
|
||||||
@ -156,12 +186,19 @@ def train_product_model_with_xgboost(product_id, model_identifier, product_df, s
|
|||||||
training_mode=training_mode,
|
training_mode=training_mode,
|
||||||
aggregation_method=aggregation_method,
|
aggregation_method=aggregation_method,
|
||||||
product_name=product_name,
|
product_name=product_name,
|
||||||
version='best' # 明确指定版本为 'best'
|
version=f"{current_version}_best"
|
||||||
)
|
)
|
||||||
print(f"✅ XGBoost最佳模型已通过统一管理器保存,版本: {best_version}, 路径: {best_model_path}")
|
print(f"✅ XGBoost最佳模型引用已创建,版本: {current_version}_best")
|
||||||
|
|
||||||
# 返回值遵循统一格式
|
# 组装返回的工件
|
||||||
return model, metrics, final_version, final_model_path
|
artifacts = {
|
||||||
|
"versioned_model": final_model_path,
|
||||||
|
"loss_curve_plot": loss_curve_path,
|
||||||
|
"best_model": best_model_path,
|
||||||
|
"version": final_version
|
||||||
|
}
|
||||||
|
|
||||||
|
return model, metrics, artifacts
|
||||||
|
|
||||||
# --- 将此训练器注册到系统中 ---
|
# --- 将此训练器注册到系统中 ---
|
||||||
register_trainer('xgboost', train_product_model_with_xgboost)
|
register_trainer('xgboost', train_product_model_with_xgboost)
|
162
server/utils/database_utils.py
Normal file
162
server/utils/database_utils.py
Normal file
@ -0,0 +1,162 @@
|
|||||||
|
import sqlite3
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
def get_db_connection():
|
||||||
|
"""获取数据库连接"""
|
||||||
|
conn = sqlite3.connect('prediction_history.db', check_same_thread=False)
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
return conn
|
||||||
|
|
||||||
|
def save_model_to_db(model_data: dict):
|
||||||
|
"""
|
||||||
|
将训练完成的模型元数据保存到数据库的 models 表中。
|
||||||
|
"""
|
||||||
|
conn = get_db_connection()
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 生成唯一的模型UID
|
||||||
|
model_uid = f"{model_data.get('training_mode', 'model')}_{model_data.get('model_type', 'default')}_{uuid.uuid4().hex[:8]}"
|
||||||
|
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO models (
|
||||||
|
model_uid, display_name, model_type, training_mode,
|
||||||
|
training_scope, parent_model_id, version, status,
|
||||||
|
training_params, performance_metrics, artifacts, created_at
|
||||||
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
model_uid,
|
||||||
|
model_data.get('display_name'),
|
||||||
|
model_data.get('model_type'),
|
||||||
|
model_data.get('training_mode'),
|
||||||
|
json.dumps(model_data.get('training_scope')),
|
||||||
|
model_data.get('parent_model_id'),
|
||||||
|
model_data.get('version'),
|
||||||
|
'active',
|
||||||
|
json.dumps(model_data.get('training_params')),
|
||||||
|
json.dumps(model_data.get('performance_metrics')),
|
||||||
|
json.dumps(model_data.get('artifacts')),
|
||||||
|
datetime.now().isoformat()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
print(f"✅ 模型元数据已成功保存到数据库,UID: {model_uid}")
|
||||||
|
return model_uid
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ 保存模型元数据到数据库失败: {e}")
|
||||||
|
conn.rollback()
|
||||||
|
return None
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def query_models_from_db(filters: dict, page: int = 1, page_size: int = 10):
|
||||||
|
"""
|
||||||
|
从数据库查询模型列表,支持筛选和分页。
|
||||||
|
"""
|
||||||
|
conn = get_db_connection()
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
try:
|
||||||
|
query = "SELECT * FROM models"
|
||||||
|
conditions = []
|
||||||
|
params = []
|
||||||
|
|
||||||
|
if filters.get('product_id'):
|
||||||
|
conditions.append("json_extract(training_scope, '$.product.id') = ?")
|
||||||
|
params.append(filters['product_id'])
|
||||||
|
|
||||||
|
if filters.get('model_type'):
|
||||||
|
conditions.append("model_type = ?")
|
||||||
|
params.append(filters['model_type'])
|
||||||
|
|
||||||
|
if filters.get('store_id'):
|
||||||
|
conditions.append("json_extract(training_scope, '$.store.id') = ?")
|
||||||
|
params.append(filters['store_id'])
|
||||||
|
|
||||||
|
if filters.get('training_mode'):
|
||||||
|
conditions.append("training_mode = ?")
|
||||||
|
params.append(filters['training_mode'])
|
||||||
|
|
||||||
|
if conditions:
|
||||||
|
query += " WHERE " + " AND ".join(conditions)
|
||||||
|
|
||||||
|
# 获取总数
|
||||||
|
count_query = query.replace("*", "COUNT(*)")
|
||||||
|
cursor.execute(count_query, params)
|
||||||
|
total_count = cursor.fetchone()[0]
|
||||||
|
|
||||||
|
# 添加分页
|
||||||
|
offset = (page - 1) * page_size
|
||||||
|
query += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
|
||||||
|
params.extend([page_size, offset])
|
||||||
|
|
||||||
|
cursor.execute(query, params)
|
||||||
|
models = [dict(row) for row in cursor.fetchall()]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"models": models,
|
||||||
|
"pagination": {
|
||||||
|
"total": total_count,
|
||||||
|
"page": page,
|
||||||
|
"page_size": page_size,
|
||||||
|
"total_pages": (total_count + page_size - 1) // page_size
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ 从数据库查询模型失败: {e}")
|
||||||
|
return {"models": [], "pagination": {}}
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def save_prediction_to_db(prediction_data: dict):
|
||||||
|
"""
|
||||||
|
将单次预测记录保存到数据库。
|
||||||
|
"""
|
||||||
|
conn = get_db_connection()
|
||||||
|
cursor = conn.cursor()
|
||||||
|
try:
|
||||||
|
cursor.execute('''
|
||||||
|
INSERT INTO prediction_history (
|
||||||
|
prediction_uid, model_id, model_type, product_name,
|
||||||
|
prediction_scope, prediction_params, metrics, result_file_path
|
||||||
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
''', (
|
||||||
|
prediction_data.get('prediction_uid'),
|
||||||
|
prediction_data.get('model_id'), # This should be the model_uid from the 'models' table
|
||||||
|
prediction_data.get('model_type'),
|
||||||
|
prediction_data.get('product_name'),
|
||||||
|
json.dumps(prediction_data.get('prediction_scope')),
|
||||||
|
json.dumps(prediction_data.get('prediction_params')),
|
||||||
|
json.dumps(prediction_data.get('metrics')),
|
||||||
|
prediction_data.get('result_file_path')
|
||||||
|
))
|
||||||
|
conn.commit()
|
||||||
|
print(f"✅ 预测记录 {prediction_data.get('prediction_uid')} 已保存到数据库。")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ 保存预测记录到数据库失败: {e}")
|
||||||
|
conn.rollback()
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def find_model_by_uid(model_uid: str):
|
||||||
|
"""
|
||||||
|
根据 model_uid 从数据库中查找模型。
|
||||||
|
"""
|
||||||
|
conn = get_db_connection()
|
||||||
|
cursor = conn.cursor()
|
||||||
|
try:
|
||||||
|
cursor.execute("SELECT * FROM models WHERE model_uid = ?", (model_uid,))
|
||||||
|
model_record = cursor.fetchone()
|
||||||
|
if model_record:
|
||||||
|
return dict(model_record)
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ 根据UID查找模型失败: {e}")
|
||||||
|
return None
|
||||||
|
finally:
|
||||||
|
conn.close()
|
@ -42,12 +42,26 @@ class ModelManager:
|
|||||||
|
|
||||||
max_version = 0
|
max_version = 0
|
||||||
for f in existing_files:
|
for f in existing_files:
|
||||||
match = re.search(r'_v(\d+)\.pth$', os.path.basename(f))
|
# 修正: 同时匹配 _v1.pth 和 _v1_best.pth 这样的文件名
|
||||||
|
match = re.search(r'_v(\d+)(_best)?\.pth$', os.path.basename(f))
|
||||||
if match:
|
if match:
|
||||||
max_version = max(max_version, int(match.group(1)))
|
max_version = max(max_version, int(match.group(1)))
|
||||||
|
|
||||||
return max_version + 1
|
return max_version + 1
|
||||||
|
|
||||||
|
def peek_next_version(self, model_type: str, product_id: Optional[str] = None, store_id: Optional[str] = None, training_mode: str = 'product', aggregation_method: Optional[str] = None) -> str:
|
||||||
|
"""
|
||||||
|
预览下一个版本号字符串 (e.g., 'v3'),但不进行任何文件操作。
|
||||||
|
"""
|
||||||
|
next_version_num = self._get_next_version(
|
||||||
|
model_type=model_type,
|
||||||
|
product_id=product_id,
|
||||||
|
store_id=store_id,
|
||||||
|
training_mode=training_mode,
|
||||||
|
aggregation_method=aggregation_method
|
||||||
|
)
|
||||||
|
return f"v{next_version_num}"
|
||||||
|
|
||||||
def generate_model_filename(self,
|
def generate_model_filename(self,
|
||||||
model_type: str,
|
model_type: str,
|
||||||
version: str,
|
version: str,
|
||||||
@ -92,7 +106,9 @@ class ModelManager:
|
|||||||
返回:
|
返回:
|
||||||
(模型文件路径, 使用的版本号)
|
(模型文件路径, 使用的版本号)
|
||||||
"""
|
"""
|
||||||
|
# 修正: 简化版本处理逻辑,由调用方明确提供版本字符串
|
||||||
if version is None:
|
if version is None:
|
||||||
|
# 如果未提供版本,则自动生成新版本
|
||||||
next_version_num = self._get_next_version(
|
next_version_num = self._get_next_version(
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
product_id=product_id,
|
product_id=product_id,
|
||||||
@ -102,6 +118,7 @@ class ModelManager:
|
|||||||
)
|
)
|
||||||
version_str = f"v{next_version_num}"
|
version_str = f"v{next_version_num}"
|
||||||
else:
|
else:
|
||||||
|
# 直接使用调用方提供的版本字符串 (e.g., 'v3', 'v3_best')
|
||||||
version_str = version
|
version_str = version
|
||||||
|
|
||||||
filename = self.generate_model_filename(
|
filename = self.generate_model_filename(
|
||||||
@ -212,7 +229,8 @@ class ModelManager:
|
|||||||
|
|
||||||
# 添加文件信息
|
# 添加文件信息
|
||||||
model_info['filename'] = filename
|
model_info['filename'] = filename
|
||||||
model_info['file_path'] = model_file
|
# 修正: 确保返回的是相对路径
|
||||||
|
model_info['file_path'] = os.path.join(self.model_dir, filename)
|
||||||
model_info['file_size'] = os.path.getsize(model_file)
|
model_info['file_size'] = os.path.getsize(model_file)
|
||||||
model_info['modified_at'] = datetime.fromtimestamp(
|
model_info['modified_at'] = datetime.fromtimestamp(
|
||||||
os.path.getmtime(model_file)
|
os.path.getmtime(model_file)
|
||||||
@ -343,9 +361,5 @@ class ModelManager:
|
|||||||
|
|
||||||
|
|
||||||
# 全局模型管理器实例
|
# 全局模型管理器实例
|
||||||
# 确保使用项目根目录的saved_models,而不是相对于当前工作目录
|
# 修正: 直接使用在config.py中定义的相对路径
|
||||||
import os
|
model_manager = ModelManager()
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
project_root = os.path.dirname(os.path.dirname(current_dir)) # 向上两级到项目根目录
|
|
||||||
absolute_model_dir = os.path.join(project_root, 'saved_models')
|
|
||||||
model_manager = ModelManager(absolute_model_dir)
|
|
@ -24,6 +24,7 @@ server_dir = os.path.dirname(current_dir)
|
|||||||
sys.path.append(server_dir)
|
sys.path.append(server_dir)
|
||||||
|
|
||||||
from utils.logging_config import setup_api_logging, get_training_logger, log_training_progress
|
from utils.logging_config import setup_api_logging, get_training_logger, log_training_progress
|
||||||
|
from utils.database_utils import save_model_to_db
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
def convert_numpy_types(obj):
|
def convert_numpy_types(obj):
|
||||||
@ -45,6 +46,8 @@ class TrainingTask:
|
|||||||
training_mode: str
|
training_mode: str
|
||||||
store_id: Optional[str] = None
|
store_id: Optional[str] = None
|
||||||
epochs: int = 100
|
epochs: int = 100
|
||||||
|
training_scope: Optional[Dict[str, Any]] = None # 新增
|
||||||
|
artifacts: Optional[Dict[str, Any]] = None # 新增
|
||||||
status: str = "pending" # pending, running, completed, failed
|
status: str = "pending" # pending, running, completed, failed
|
||||||
start_time: Optional[str] = None
|
start_time: Optional[str] = None
|
||||||
end_time: Optional[str] = None
|
end_time: Optional[str] = None
|
||||||
@ -138,7 +141,7 @@ class TrainingWorker:
|
|||||||
training_logger.error(f"进度回调失败: {e}")
|
training_logger.error(f"进度回调失败: {e}")
|
||||||
|
|
||||||
# 执行真正的训练,传递进度回调
|
# 执行真正的训练,传递进度回调
|
||||||
metrics = predictor.train_model(
|
metrics, artifacts = predictor.train_model(
|
||||||
product_id=task.product_id,
|
product_id=task.product_id,
|
||||||
model_type=task.model_type,
|
model_type=task.model_type,
|
||||||
epochs=task.epochs,
|
epochs=task.epochs,
|
||||||
@ -181,6 +184,7 @@ class TrainingWorker:
|
|||||||
task.end_time = time.strftime('%Y-%m-%d %H:%M:%S')
|
task.end_time = time.strftime('%Y-%m-%d %H:%M:%S')
|
||||||
task.progress = 100.0
|
task.progress = 100.0
|
||||||
task.metrics = metrics
|
task.metrics = metrics
|
||||||
|
task.artifacts = artifacts if artifacts else {}
|
||||||
task.message = "训练完成"
|
task.message = "训练完成"
|
||||||
|
|
||||||
training_logger.success(f"✅ 训练任务完成 - 耗时: {task.end_time}")
|
training_logger.success(f"✅ 训练任务完成 - 耗时: {task.end_time}")
|
||||||
@ -282,7 +286,7 @@ class TrainingProcessManager:
|
|||||||
self.logger.info("✅ 训练进程管理器已停止")
|
self.logger.info("✅ 训练进程管理器已停止")
|
||||||
|
|
||||||
def submit_task(self, product_id: str, model_type: str, training_mode: str = "product",
|
def submit_task(self, product_id: str, model_type: str, training_mode: str = "product",
|
||||||
store_id: str = None, epochs: int = 100, **kwargs) -> str:
|
store_id: str = None, epochs: int = 100, training_scope: dict = None, **kwargs) -> str:
|
||||||
"""提交训练任务"""
|
"""提交训练任务"""
|
||||||
task_id = str(uuid.uuid4())
|
task_id = str(uuid.uuid4())
|
||||||
|
|
||||||
@ -292,7 +296,8 @@ class TrainingProcessManager:
|
|||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
training_mode=training_mode,
|
training_mode=training_mode,
|
||||||
store_id=store_id,
|
store_id=store_id,
|
||||||
epochs=epochs
|
epochs=epochs,
|
||||||
|
training_scope=training_scope
|
||||||
)
|
)
|
||||||
|
|
||||||
with self.lock:
|
with self.lock:
|
||||||
@ -371,6 +376,34 @@ class TrainingProcessManager:
|
|||||||
'product_id': serializable_task_data.get('product_id'),
|
'product_id': serializable_task_data.get('product_id'),
|
||||||
'model_type': serializable_task_data.get('model_type')
|
'model_type': serializable_task_data.get('model_type')
|
||||||
})
|
})
|
||||||
|
|
||||||
|
# 在此处调用函数,将模型元数据保存到数据库
|
||||||
|
# 注意:我们需要从task_data中构建一个符合save_model_to_db期望的字典
|
||||||
|
# 这是一个简化的示例,实际可能需要传递更多参数
|
||||||
|
# 构建一个更完整的模型数据字典用于保存
|
||||||
|
# 从 artifacts 中获取版本号
|
||||||
|
version = serializable_task_data.get('artifacts', {}).get('version')
|
||||||
|
|
||||||
|
# 构建带有版本号的 display_name
|
||||||
|
base_display_name = f"{serializable_task_data.get('product_id', 'N/A')} - {serializable_task_data.get('model_type')}"
|
||||||
|
display_name_with_version = f"{base_display_name} ({version})" if version else base_display_name
|
||||||
|
|
||||||
|
model_to_save = {
|
||||||
|
'model_uid': f"{serializable_task_data.get('training_mode')}_{serializable_task_data.get('model_type')}_{str(uuid.uuid4())[:8]}",
|
||||||
|
'display_name': display_name_with_version,
|
||||||
|
'model_type': serializable_task_data.get('model_type'),
|
||||||
|
'training_mode': serializable_task_data.get('training_mode'),
|
||||||
|
'training_scope': serializable_task_data.get('training_scope'),
|
||||||
|
'version': version,
|
||||||
|
'status': 'active',
|
||||||
|
'training_params': {
|
||||||
|
'epochs': serializable_task_data.get('epochs')
|
||||||
|
},
|
||||||
|
'performance_metrics': serializable_task_data.get('metrics'),
|
||||||
|
'artifacts': serializable_task_data.get('artifacts')
|
||||||
|
}
|
||||||
|
save_model_to_db(model_to_save)
|
||||||
|
|
||||||
elif action == 'error':
|
elif action == 'error':
|
||||||
# 训练失败
|
# 训练失败
|
||||||
self.websocket_callback('training_update', {
|
self.websocket_callback('training_update', {
|
||||||
|
@ -16,37 +16,45 @@ except ImportError:
|
|||||||
# 后备方案:使用默认值
|
# 后备方案:使用默认值
|
||||||
DEFAULT_MODEL_DIR = "saved_models"
|
DEFAULT_MODEL_DIR = "saved_models"
|
||||||
|
|
||||||
def plot_loss_curve(train_losses, val_losses, product_name, model_type, save_path=None, model_dir=DEFAULT_MODEL_DIR):
|
def plot_loss_curve(train_losses, val_losses, model_type: str, scope: str, identifier: str, version: str = None, save_path=None, model_dir=DEFAULT_MODEL_DIR):
|
||||||
"""
|
"""
|
||||||
绘制训练和验证损失曲线
|
绘制训练和验证损失曲线,并根据scope和identifier生成标准化的文件名。
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
train_losses: 训练损失列表
|
train_losses: 训练损失列表
|
||||||
val_losses: 验证损失列表
|
val_losses: 验证损失列表
|
||||||
product_name: 产品名称
|
model_type: 模型类型 (e.g., 'xgboost')
|
||||||
model_type: 模型类型
|
scope: 训练范围 ('product', 'store', 'global')
|
||||||
|
identifier: 范围对应的标识符 (产品名, 店铺ID, 或聚合方法)
|
||||||
|
version: (可选) 模型版本号,用于生成唯一的文件名
|
||||||
save_path: 保存路径,如果为None则自动生成路径
|
save_path: 保存路径,如果为None则自动生成路径
|
||||||
model_dir: 模型保存目录,默认使用配置中的DEFAULT_MODEL_DIR
|
model_dir: 模型保存目录
|
||||||
"""
|
"""
|
||||||
plt.figure(figsize=(10, 5))
|
plt.figure(figsize=(10, 5))
|
||||||
plt.plot(train_losses, label='训练损失')
|
plt.plot(train_losses, label='训练损失')
|
||||||
plt.plot(val_losses, label='验证损失')
|
plt.plot(val_losses, label='验证损失')
|
||||||
plt.title(f'{product_name} - {model_type}模型训练和验证损失')
|
|
||||||
|
# 动态生成标题
|
||||||
|
title_identifier = identifier.replace('_', ' ')
|
||||||
|
title = f'{title_identifier} - {model_type} ({scope}) 模型训练和验证损失'
|
||||||
|
plt.title(title)
|
||||||
|
|
||||||
plt.xlabel('Epoch')
|
plt.xlabel('Epoch')
|
||||||
plt.ylabel('Loss')
|
plt.ylabel('Loss')
|
||||||
plt.legend()
|
plt.legend()
|
||||||
plt.grid(True)
|
plt.grid(True)
|
||||||
|
|
||||||
if save_path:
|
if save_path:
|
||||||
# 如果提供了完整路径,直接使用
|
|
||||||
full_path = save_path
|
full_path = save_path
|
||||||
else:
|
else:
|
||||||
# 否则生成默认路径
|
|
||||||
# 确保模型目录存在
|
|
||||||
os.makedirs(model_dir, exist_ok=True)
|
os.makedirs(model_dir, exist_ok=True)
|
||||||
|
|
||||||
# 构建文件名:模型类型_产品名_loss_curve.png
|
# 标准化文件名生成逻辑,与 ModelManager 对齐
|
||||||
filename = f"{model_type}_product_{product_name.replace(' ', '_')}_loss_curve.png"
|
version_str = f"_{version}" if version else ""
|
||||||
|
# 清理标识符中的非法字符
|
||||||
|
safe_identifier = identifier.replace(' ', '_').replace('/', '_').replace('\\', '_')
|
||||||
|
|
||||||
|
filename = f"{model_type}_{scope}_{safe_identifier}{version_str}_loss_curve.png"
|
||||||
full_path = os.path.join(model_dir, filename)
|
full_path = os.path.join(model_dir, filename)
|
||||||
|
|
||||||
plt.savefig(full_path)
|
plt.savefig(full_path)
|
||||||
|
54
xz数据库2025_07_24.md
Normal file
54
xz数据库2025_07_24.md
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
# 数据库设计方案 (最终版)
|
||||||
|
|
||||||
|
本文档定义了项目重构后的核心数据库结构,旨在实现对机器学习模型和预测历史的统一、高效管理。该设计方案结合了业务需求、前端UI逻辑、后端代码实现以及现有数据库结构,并经过了多次迭代确认。
|
||||||
|
|
||||||
|
## 设计原则
|
||||||
|
|
||||||
|
- **模型集中化**: 所有模型的元数据统一存储,摆脱对文件系统的依赖。
|
||||||
|
- **数据一致性**: 通过逻辑外键确保预测记录与模型版本精确关联。
|
||||||
|
- **查询高性能**: 通过冗余关键字段,避免在查询列表时进行不必要的`JOIN`操作。
|
||||||
|
- **可扩展性**: 使用JSON字段存储灵活、复杂的范围定义和文件路径,适应未来业务变化。
|
||||||
|
- **结构清晰**: 物理文件(模型、日志、预测结果)与数据库记录分离,数据库只存元数据和路径,保持自身轻量。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 表结构定义
|
||||||
|
|
||||||
|
### 表1:`models`
|
||||||
|
|
||||||
|
这张表是模型管理的核心,负责存储所有模型版本的全生命周期信息。
|
||||||
|
|
||||||
|
| 字段名 | 类型 | 确认理由与说明 |
|
||||||
|
| :--- | :--- | :--- |
|
||||||
|
| `id` | INTEGER | 主键,系统内部使用。 |
|
||||||
|
| `model_uid` | TEXT | **[关键]** 用户可见的唯一ID,由后端生成,用于API调用和逻辑关联。 |
|
||||||
|
| `display_name` | TEXT | **[建议新增]** 用户可自定义的别名,如“夏季促销模型”,提升易用性。 |
|
||||||
|
| `model_type` | TEXT | **已确认.** 核心参数,如 `mlstm`, `kan`。 |
|
||||||
|
| `training_mode` | TEXT | **已确认.** 模型的训练范围: `product`, `store`, `global`。 |
|
||||||
|
| `training_scope` | TEXT (JSON) | **[最终版]** **精确描述训练范围,并包含中文名。**<br/>- **按药品训练**: `{"product": {"id": "P001", "name": "阿莫西林"}, "stores": "all"}` 或 `{"product": {"id": "P001", "name": "阿莫西林"}, "stores": [{"id": "S001", "name": "城西店"}]}`<br/>- **按店铺训练**: `{"store": {"id": "S001", "name": "城西店"}, "products": "all"}` 或 `{"store": {"id": "S001", "name": "城西店"}, "products": [{"id": "P001", "name": "阿莫西林"}, {"id": "P002", "name": "布洛芬"}]}`<br/>- **全局训练**: `{"stores": "all", "products": "all"}` 或 `{"stores": [{"id": "S001", "name": "城西店"}], "products": [{"id": "P001", "name": "阿莫西林"}]}` |
|
||||||
|
| `parent_model_id` | INTEGER | **已确认.** 外键,指向自身 `id`,用于实现“继续训练”功能,形成版本链。 |
|
||||||
|
| `version` | TEXT | **已确认.** 模型版本号,如 `v1`, `v2`。 |
|
||||||
|
| `status` | TEXT | **已确认.** 模型状态,如 `active`, `archived`,用于控制模型是否可用。 |
|
||||||
|
| `training_params` | TEXT (JSON) | **已确认.** 存储训练时的超参数,如 `{"epochs": 50, "aggregation_method": "sum"}`。 |
|
||||||
|
| `performance_metrics` | TEXT (JSON) | **已确认.** 存储性能指标,如 `{"R2": 0.85, "RMSE": ...}`。 |
|
||||||
|
| `artifacts` | TEXT (JSON) | **[采纳建议]** **存储与模型相关的所有文件路径。** 所有文件均采用**扁平化结构**存放在 `saved_models/` 目录下。示例:<br/>`{"best_model": "saved_models/product_P001_mlstm_best.pth", "versioned_model": "saved_models/product_P001_mlstm_v1.pth", "loss_curve_plot": "saved_models/product_P001_mlstm_v1_loss.png", "loss_curve_data": "saved_models/product_P001_mlstm_v1_history.json"}` |
|
||||||
|
| `created_at` | DATETIME | **已确认.** 模型的创建时间。 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 表2:`prediction_history`
|
||||||
|
|
||||||
|
这张表用于记录每一次预测任务的结果,其结构在现有基础上进行了优化,以实现高效查询和完整追溯。
|
||||||
|
|
||||||
|
| 字段名 | 类型 | 确认理由与说明 |
|
||||||
|
| :--- | :--- | :--- |
|
||||||
|
| `id` | INTEGER | **已确认.** 主键,自增。 |
|
||||||
|
| `prediction_uid` | TEXT | **已确认.** 唯一的预测ID (UUID),用于API调用。 |
|
||||||
|
| `model_id` | TEXT | **[核心变更]** **使用的模型的唯一标识符。** 在重构后,此字段的值应该与 `models` 表中的某条记录的 `model_uid` 相对应,从而建立起逻辑关联。 |
|
||||||
|
| `model_type` | TEXT | **已确认 (冗余).** 冗余存储模型类型(如 `mlstm`),用于在不关联查询的情况下快速筛选历史记录。 |
|
||||||
|
| `product_name` | TEXT | **已确认 (冗余).** **冗余存储产品/店铺的中文名**,用于在列表页快速展示,极大提升查询性能。 |
|
||||||
|
| `prediction_scope` | TEXT (JSON) | **[新增]** 描述本次预测的范围,与 `models` 表的 `training_scope` 结构类似,指明是为哪个产品/店铺做的预测。 |
|
||||||
|
| `prediction_params` | TEXT (JSON) | **已确认.** 存储预测参数,如 `{"start_date": "...", "future_days": 7}`。 |
|
||||||
|
| `metrics` | TEXT (JSON) | **已确认 (冗余).** 缓存的性能指标,用于列表展示和排序。 |
|
||||||
|
| `result_file_path` | TEXT | **[已采纳您的规范]** 指向预测结果JSON文件的**相对路径**。文件存储在 `saved_predictions/` 目录下,并根据模型名和时间戳命名,例如:`saved_predictions/cnn_bilstm_attention_global_sum_v6_pred_20250724111600.json`。 |
|
||||||
|
| `created_at` | DATETIME | **已确认.** 记录的创建时间。 |
|
Loading…
x
Reference in New Issue
Block a user