将训练模型信息保存到数据库

This commit is contained in:
xz2000 2025-07-24 17:55:10 +08:00
parent 3aaddcd658
commit a02bc11921
22 changed files with 1017 additions and 792 deletions

View File

@ -35,49 +35,50 @@
</el-form> </el-form>
<el-table :data="models" stripe v-loading="loading"> <el-table :data="models" stripe v-loading="loading">
<el-table-column prop="model_id" label="模型ID" min-width="150"> <el-table-column prop="display_name" label="模型名称" min-width="180">
<template #default="{ row }">
<el-tooltip :content="`UID: ${row.model_uid}`" placement="top">
<span>{{ row.display_name }}</span>
</el-tooltip>
</template>
</el-table-column>
<el-table-column prop="model_type" label="模型类型" min-width="120">
<template #default="{ row }"> <template #default="{ row }">
<el-tag size="small" :type="getModelTagType(row.model_type)"> <el-tag size="small" :type="getModelTagType(row.model_type)">
{{ row.model_type }} {{ row.model_type }}
</el-tag> </el-tag>
<span style="margin-left: 8px">{{ row.product_id }}</span>
</template> </template>
</el-table-column> </el-table-column>
<el-table-column prop="product_name" label="产品名称" min-width="120"></el-table-column> <el-table-column prop="training_mode" label="训练模式" min-width="120"></el-table-column>
<el-table-column prop="created_at" label="创建时间" min-width="160"> <el-table-column prop="created_at" label="创建时间" min-width="160">
<template #default="{ row }"> <template #default="{ row }">
{{ formatDateTime(row.created_at) }} {{ formatDateTime(row.created_at) }}
</template> </template>
</el-table-column> </el-table-column>
<el-table-column label="评估指标" min-width="200"> <el-table-column label="评估指标" min-width="220">
<template #default="{ row }"> <template #default="{ row }">
<div class="metrics-display"> <div class="metrics-display" v-if="row.performance_metrics && Object.keys(row.performance_metrics).length > 0">
<el-tooltip effect="dark" content="R平方值" placement="top"> <el-tooltip effect="dark" content="R平方值" placement="top">
<div class="metric-item"> <div class="metric-item">
<span class="metric-label">:</span> <span class="metric-label">:</span>
<span class="metric-value">{{ row.metrics?.R2?.toFixed(4) || 'N/A' }}</span> <span class="metric-value">{{ row.performance_metrics.R2?.toFixed(4) || 'N/A' }}</span>
</div>
</el-tooltip>
<el-tooltip effect="dark" content="均方根误差" placement="top">
<div class="metric-item">
<span class="metric-label">RMSE:</span>
<span class="metric-value">{{ row.metrics?.RMSE?.toFixed(4) || 'N/A' }}</span>
</div>
</el-tooltip>
<el-tooltip effect="dark" content="平均绝对误差" placement="top">
<div class="metric-item">
<span class="metric-label">MAE:</span>
<span class="metric-value">{{ row.metrics?.MAE?.toFixed(4) || 'N/A' }}</span>
</div>
</el-tooltip>
<el-tooltip effect="dark" content="平均绝对百分比误差" placement="top">
<div class="metric-item">
<span class="metric-label">MAPE:</span>
<span class="metric-value">{{ row.metrics?.MAPE?.toFixed(2) || 'N/A' }}%</span>
</div>
</el-tooltip>
</div> </div>
</template> </el-tooltip>
<el-tooltip effect="dark" content="均方根误差" placement="top">
<div class="metric-item">
<span class="metric-label">RMSE:</span>
<span class="metric-value">{{ row.performance_metrics.RMSE?.toFixed(4) || 'N/A' }}</span>
</div>
</el-tooltip>
<el-tooltip effect="dark" content="平均绝对误差" placement="top">
<div class="metric-item">
<span class="metric-label">MAE:</span>
<span class="metric-value">{{ row.performance_metrics.MAE?.toFixed(4) || 'N/A' }}</span>
</div>
</el-tooltip>
</div>
<span v-else></span>
</template>
</el-table-column> </el-table-column>
<el-table-column label="操作" width="280" fixed="right"> <el-table-column label="操作" width="280" fixed="right">
<template #default="{ row }"> <template #default="{ row }">
@ -254,7 +255,7 @@ const fetchModels = async () => {
const response = await axios.get('/api/models', { params }) const response = await axios.get('/api/models', { params })
if (response.data.status === 'success') { if (response.data.status === 'success') {
models.value = response.data.data.map(model => { models.value = response.data.data.map(model => {
model.metrics = normalizeMetricsKeys(model.metrics); model.performance_metrics = normalizeMetricsKeys(model.performance_metrics);
return model; return model;
}); });
@ -291,7 +292,23 @@ const viewDetails = async (model) => {
detailsDialogVisible.value = true; detailsDialogVisible.value = true;
selectedModelDetails.value = null; // selectedModelDetails.value = null; //
try { try {
const response = await axios.get(`/api/models/${model.model_type}/${model.product_id}/details`); // : 使API
const details = {
model_info: {
model_id: model.model_uid,
model_type: model.model_type,
product_name: model.display_name,
created_at: model.created_at,
},
training_metrics: model.performance_metrics,
chart_data: {
loss_chart: model.artifacts?.loss_curve_data || { epochs: [], train_loss: [], test_loss: [] }
}
};
selectedModelDetails.value = details;
nextTick(() => {
initLossChart();
});
if (response.data.status === 'success') { if (response.data.status === 'success') {
const details = response.data.data; const details = response.data.data;
details.training_metrics = normalizeMetricsKeys(details.training_metrics); details.training_metrics = normalizeMetricsKeys(details.training_metrics);
@ -311,7 +328,7 @@ const viewDetails = async (model) => {
const deleteModel = async (model) => { const deleteModel = async (model) => {
try { try {
const response = await axios.delete(`/api/models/${model.model_id}`) const response = await axios.delete(`/api/models/${model.model_uid}`)
if (response.data.status === 'success') { if (response.data.status === 'success') {
ElMessage.success('模型删除成功') ElMessage.success('模型删除成功')
fetchModels() fetchModels()
@ -325,7 +342,7 @@ const deleteModel = async (model) => {
const exportModel = async (model) => { const exportModel = async (model) => {
try { try {
const response = await axios.get(`/api/models/${model.model_id}/export`, { responseType: 'blob' }) const response = await axios.get(`/api/models/${model.model_uid}/export`, { responseType: 'blob' })
const url = window.URL.createObjectURL(new Blob([response.data])) const url = window.URL.createObjectURL(new Blob([response.data]))
const link = document.createElement('a') const link = document.createElement('a')
link.href = url link.href = url

View File

@ -96,17 +96,18 @@
<el-col :span="6"> <el-col :span="6">
<el-form-item label="模型版本"> <el-form-item label="模型版本">
<el-select <el-select
v-model="form.version" v-model="form.model_uid"
placeholder="选择版本" placeholder="选择一个具体的模型"
style="width: 100%" style="width: 100%"
:disabled="!availableVersions.length" :disabled="!availableVersions.length"
:loading="versionsLoading" :loading="versionsLoading"
value-key="model_uid"
> >
<el-option <el-option
v-for="version in availableVersions" v-for="model in availableVersions"
:key="version" :key="model.model_uid"
:label="version" :label="`${model.display_name} (v${model.version})`"
:value="version" :value="model.model_uid"
/> />
</el-select> </el-select>
<div class="version-info"> <div class="version-info">
@ -176,16 +177,6 @@
开始预测 开始预测
</el-button> </el-button>
<el-button
v-if="predictionResult"
type="success"
size="large"
@click="savePrediction"
:loading="saving"
>
<el-icon><Download /></el-icon>
保存结果
</el-button>
</div> </div>
</el-card> </el-card>
@ -274,7 +265,7 @@ const form = reactive({
product_id: '', product_id: '',
store_id: '', store_id: '',
model_type: '', model_type: '',
version: '', model_uid: '',
future_days: 7, future_days: 7,
start_date: '', start_date: '',
analyze_result: true analyze_result: true
@ -282,18 +273,8 @@ const form = reactive({
// //
const canPredict = computed(() => { const canPredict = computed(() => {
const baseCheck = form.training_mode && form.model_type && form.version return !!form.model_uid;
});
if (form.training_mode === 'product') {
return baseCheck && form.product_id
} else if (form.training_mode === 'store') {
return baseCheck && form.store_id
} else if (form.training_mode === 'global') {
return baseCheck
}
return false
})
const predictionTableData = computed(() => { const predictionTableData = computed(() => {
if (!predictionResult.value || !predictionResult.value.predictions) return [] if (!predictionResult.value || !predictionResult.value.predictions) return []
@ -329,24 +310,21 @@ const fetchAvailableVersions = async () => {
} }
try { try {
versionsLoading.value = true versionsLoading.value = true;
let url = '' const params = {
training_mode: form.training_mode,
model_type: form.model_type,
product_id: form.training_mode === 'product' ? form.product_id : undefined,
store_id: form.training_mode === 'store' ? form.store_id : undefined,
page_size: 100 //
};
if (form.training_mode === 'product' && form.product_id) { const response = await axios.get('/api/models', { params });
url = `/api/models/${form.product_id}/${form.model_type}/versions` if (response.data.status === 'success') {
} else if (form.training_mode === 'store' && form.store_id) { availableVersions.value = response.data.data || [];
url = `/api/models/store/${form.store_id}/${form.model_type}/versions` if (availableVersions.value.length > 0) {
} else if (form.training_mode === 'global') { //
url = `/api/models/global/${form.model_type}/versions` form.model_uid = availableVersions.value[0].model_uid;
}
if (url) {
const response = await axios.get(url)
if (response.data.status === 'success') {
availableVersions.value = response.data.data.versions || []
if (response.data.data.latest_version) {
form.version = response.data.data.latest_version
}
} }
} }
} catch (error) { } catch (error) {
@ -361,22 +339,22 @@ const handleTrainingModeChange = () => {
form.product_id = '' form.product_id = ''
form.store_id = '' form.store_id = ''
form.model_type = '' form.model_type = ''
form.version = '' form.model_uid = ''
availableVersions.value = [] availableVersions.value = []
} }
const handleProductChange = () => { const handleProductChange = () => {
form.version = '' form.model_uid = ''
fetchAvailableVersions() fetchAvailableVersions()
} }
const handleStoreChange = () => { const handleStoreChange = () => {
form.version = '' form.model_uid = ''
fetchAvailableVersions() fetchAvailableVersions()
} }
const handleModelTypeChange = () => { const handleModelTypeChange = () => {
form.version = '' form.model_uid = ''
fetchAvailableVersions() fetchAvailableVersions()
} }
@ -385,21 +363,14 @@ const startPrediction = async () => {
predicting.value = true predicting.value = true
const payload = { const payload = {
model_type: form.model_type, model_uid: form.model_uid,
version: form.version,
future_days: form.future_days, future_days: form.future_days,
start_date: form.start_date, start_date: form.start_date,
analyze_result: form.analyze_result include_visualization: form.analyze_result, // include_visualization
} history_lookback_days: 30
};
// const response = await axios.post('/api/prediction', payload);
if (form.training_mode === 'product') {
payload.product_id = form.product_id
} else if (form.training_mode === 'store') {
payload.store_id = form.store_id
}
const response = await axios.post('/api/predict', payload)
if (response.data.status === 'success') { if (response.data.status === 'success') {
predictionResult.value = response.data.data predictionResult.value = response.data.data
@ -470,29 +441,6 @@ const renderChart = () => {
}) })
} }
const savePrediction = async () => {
try {
saving.value = true
const saveData = {
...predictionResult.value,
training_mode: form.training_mode,
parameters: { ...form }
}
const response = await axios.post('/api/predictions/save', saveData)
if (response.data.status === 'success') {
ElMessage.success('预测结果已保存')
} else {
ElMessage.error('保存失败')
}
} catch (error) {
ElMessage.error('保存请求失败')
} finally {
saving.value = false
}
}
// //
const getPredictionInfoText = () => { const getPredictionInfoText = () => {

View File

@ -129,33 +129,20 @@ const pagination = reactive({
pageSize: 8 pageSize: 8
}) })
const storeNameMap = computed(() => {
return stores.value.reduce((acc, store) => {
acc[store.store_id] = store.store_name
return acc
}, {})
})
const modelsWithNames = computed(() => {
return modelList.value.map(model => ({
...model,
store_name: storeNameMap.value[model.store_id] || model.store_id
}))
})
const filteredModelList = computed(() => { const filteredModelList = computed(() => {
return modelsWithNames.value.filter(model => { return modelList.value.filter(model => {
const storeMatch = !filters.store_id || model.store_id === filters.store_id // 使
const modelTypeMatch = !filters.model_type || model.model_type === filters.model_type.id const storeMatch = !filters.store_id || (model.training_scope?.store?.id === filters.store_id);
return storeMatch && modelTypeMatch const modelTypeMatch = !filters.model_type || model.model_type === filters.model_type.id;
}) return storeMatch && modelTypeMatch;
}) });
});
const paginatedModelList = computed(() => { const paginatedModelList = computed(() => {
const start = (pagination.currentPage - 1) * pagination.pageSize const start = (pagination.currentPage - 1) * pagination.pageSize;
const end = start + pagination.pageSize const end = start + pagination.pageSize;
return filteredModelList.value.slice(start, end) return filteredModelList.value.slice(start, end);
}) });
const handlePageChange = (page) => { const handlePageChange = (page) => {
pagination.currentPage = page pagination.currentPage = page

View File

@ -604,14 +604,45 @@ const startTraining = async () => {
? "/api/training/retrain" ? "/api/training/retrain"
: "/api/training"; : "/api/training";
// training_scope
const training_scope = {
stores: "all",
products: "all"
};
if (form.training_scope === 'selected_stores' || form.training_scope === 'custom') {
if (form.store_ids.length > 0) {
training_scope.stores = form.store_ids.map(sid => {
const store = stores.value.find(s => s.store_id === sid);
return { id: sid, name: store ? store.store_name : sid };
});
} else {
// all
training_scope.stores = "all";
}
}
if (form.training_scope === 'selected_products' || form.training_scope === 'custom') {
if (form.product_ids.length > 0) {
training_scope.products = form.product_ids.map(pid => {
const product = products.value.find(p => p.product_id === pid);
return { id: pid, name: product ? product.product_name : pid };
});
} else {
// all
training_scope.products = "all";
}
}
const payload = { const payload = {
model_type: form.model_type, model_type: form.model_type,
epochs: form.epochs, epochs: form.epochs,
training_mode: 'global', // training_mode: 'global', //
training_scope: form.training_scope, training_scope: training_scope, // 使
aggregation_method: form.aggregation_method aggregation_method: form.aggregation_method
}; };
// ID
if (form.store_ids.length > 0) { if (form.store_ids.length > 0) {
payload.store_ids = form.store_ids; payload.store_ids = form.store_ids;
} }

View File

@ -541,12 +541,31 @@ const startTraining = async () => {
? "/api/training/retrain" ? "/api/training/retrain"
: "/api/training"; : "/api/training";
// training_scope
const selectedProduct = products.value.find(p => p.product_id === form.product_id);
const training_scope = {
product: {
id: form.product_id,
name: selectedProduct ? selectedProduct.product_name : form.product_id
},
stores: "all"
};
if (form.data_scope === 'specific' && form.store_id) {
const selectedStore = stores.value.find(s => s.store_id === form.store_id);
training_scope.stores = [{
id: form.store_id,
name: selectedStore ? selectedStore.store_name : form.store_id
}];
}
const payload = { const payload = {
product_id: form.product_id, product_id: form.product_id,
store_id: form.data_scope === 'global' ? null : form.store_id, store_id: form.data_scope === 'global' ? null : form.store_id,
model_type: form.model_type, model_type: form.model_type,
epochs: form.epochs, epochs: form.epochs,
training_mode: 'product' // training_mode: 'product', //
training_scope: training_scope // training_scope
}; };
if (form.training_type === "retrain") { if (form.training_type === "retrain") {

View File

@ -588,12 +588,33 @@ const startTraining = async () => {
? "/api/training/retrain" ? "/api/training/retrain"
: "/api/training"; : "/api/training";
// training_scope
const selectedStore = stores.value.find(s => s.store_id === form.store_id);
const training_scope = {
store: {
id: form.store_id,
name: selectedStore ? selectedStore.store_name : form.store_id
},
products: "all"
};
if (form.product_scope === 'specific' && form.product_ids.length > 0) {
training_scope.products = form.product_ids.map(pid => {
const product = storeProducts.value.find(p => p.product_id === pid);
return {
id: pid,
name: product ? product.product_name : pid
};
});
}
const payload = { const payload = {
store_id: form.store_id, store_id: form.store_id,
model_type: form.model_type, model_type: form.model_type,
epochs: form.epochs, epochs: form.epochs,
training_mode: 'store', // training_mode: 'store', //
product_scope: form.product_scope product_scope: form.product_scope,
training_scope: training_scope // training_scope
}; };
if (form.product_scope === 'specific') { if (form.product_scope === 'specific') {

Binary file not shown.

View File

@ -65,6 +65,7 @@ from core.config import (
from utils.multi_store_data_utils import ( from utils.multi_store_data_utils import (
get_available_stores, get_available_products, get_sales_statistics get_available_stores, get_available_products, get_sales_statistics
) )
from utils.database_utils import query_models_from_db, find_model_by_uid, save_prediction_to_db
# 导入数据库初始化工具 # 导入数据库初始化工具
from init_multi_store_db import get_db_connection from init_multi_store_db import get_db_connection
@ -119,60 +120,68 @@ def init_db():
conn = sqlite3.connect('prediction_history.db') conn = sqlite3.connect('prediction_history.db')
cursor = conn.cursor() cursor = conn.cursor()
# 创建预测历史表 # 废弃旧的 model_versions 表
cursor.execute(''' cursor.execute('DROP TABLE IF EXISTS model_versions')
CREATE TABLE IF NOT EXISTS prediction_history (
id INTEGER PRIMARY KEY AUTOINCREMENT,
prediction_id TEXT UNIQUE NOT NULL,
product_id TEXT NOT NULL,
product_name TEXT NOT NULL,
model_type TEXT NOT NULL,
model_id TEXT NOT NULL,
start_date TEXT,
future_days INTEGER,
created_at TEXT NOT NULL,
predictions_data TEXT,
metrics TEXT,
chart_data TEXT,
analysis TEXT,
file_path TEXT
)
''')
# 创建模型版本 # 创建新的 models 表
cursor.execute(''' cursor.execute('''
CREATE TABLE IF NOT EXISTS model_versions ( CREATE TABLE IF NOT EXISTS models (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
product_id TEXT NOT NULL, model_uid TEXT UNIQUE NOT NULL,
display_name TEXT,
model_type TEXT NOT NULL, model_type TEXT NOT NULL,
training_mode TEXT NOT NULL,
training_scope TEXT,
parent_model_id INTEGER,
version TEXT NOT NULL, version TEXT NOT NULL,
file_path TEXT NOT NULL, status TEXT DEFAULT 'active',
created_at TEXT NOT NULL, training_params TEXT,
metrics TEXT, performance_metrics TEXT,
is_active INTEGER DEFAULT 1, artifacts TEXT,
UNIQUE(product_id, model_type, version) created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (parent_model_id) REFERENCES models (id)
) )
''') ''')
# 创建索引以提高查询性能 # 更新 prediction_history 表
cursor.execute(''' # 为了平滑过渡,我们先检查表是否存在,然后尝试添加新列
CREATE INDEX IF NOT EXISTS idx_prediction_product_model cursor.execute("PRAGMA table_info(prediction_history)")
ON prediction_history(product_id, model_type) columns = [row[1] for row in cursor.fetchall()]
''')
cursor.execute(''' if 'prediction_uid' not in columns:
CREATE INDEX IF NOT EXISTS idx_model_versions_product_type # 如果表结构很旧,重建它
ON model_versions(product_id, model_type) cursor.execute('DROP TABLE IF EXISTS prediction_history')
''') cursor.execute('''
CREATE TABLE prediction_history (
id INTEGER PRIMARY KEY AUTOINCREMENT,
prediction_uid TEXT UNIQUE NOT NULL,
model_id TEXT NOT NULL,
model_type TEXT,
product_name TEXT,
prediction_scope TEXT,
prediction_params TEXT,
metrics TEXT,
result_file_path TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
''')
else:
# 否则,只添加缺失的列
if 'prediction_scope' not in columns:
cursor.execute('ALTER TABLE prediction_history ADD COLUMN prediction_scope TEXT')
if 'result_file_path' not in columns:
cursor.execute('ALTER TABLE prediction_history ADD COLUMN result_file_path TEXT')
# 确保 model_id 字段存在且类型正确
# 在SQLite中修改列类型比较复杂通常建议重建。此处简化处理。
cursor.execute(''' # 创建索引
CREATE INDEX IF NOT EXISTS idx_model_versions_active cursor.execute('CREATE INDEX IF NOT EXISTS idx_models_uid ON models(model_uid)')
ON model_versions(product_id, model_type, is_active) cursor.execute('CREATE INDEX IF NOT EXISTS idx_models_type_mode ON models(model_type, training_mode)')
''') cursor.execute('CREATE INDEX IF NOT EXISTS idx_prediction_history_model_id ON prediction_history(model_id)')
conn.commit() conn.commit()
conn.close() conn.close()
print("数据库初始化完成,包含模型版本管理表") print("数据库初始化完成,已更新为新的 `models` 和 `prediction_history` 表结构。")
# 自定义JSON编码器来处理Pandas的Timestamp和NumPy类型 # 自定义JSON编码器来处理Pandas的Timestamp和NumPy类型
class CustomJSONEncoder(json.JSONEncoder): class CustomJSONEncoder(json.JSONEncoder):
@ -219,6 +228,8 @@ def convert_numpy_types_for_json(obj):
return obj return obj
app = Flask(__name__) app = Flask(__name__)
# 解决jsonify中文显示为unicode的问题
app.config['JSON_AS_ASCII'] = False
# 设置自定义JSON编码器 # 设置自定义JSON编码器
app.json_encoder = CustomJSONEncoder app.json_encoder = CustomJSONEncoder
app.config['SECRET_KEY'] = 'pharmacy_prediction_secret_key' app.config['SECRET_KEY'] = 'pharmacy_prediction_secret_key'
@ -874,7 +885,7 @@ def start_training():
product_ids = data.get('product_ids', []) product_ids = data.get('product_ids', [])
store_ids = data.get('store_ids', []) store_ids = data.get('store_ids', [])
product_scope = data.get('product_scope', 'all') product_scope = data.get('product_scope', 'all')
training_scope = data.get('training_scope', 'all_stores_all_products') # training_scope = data.get('training_scope', 'all_stores_all_products') # 已废弃,由下面的 training_scope_obj 替代
aggregation_method = data.get('aggregation_method', 'sum') aggregation_method = data.get('aggregation_method', 'sum')
if not model_type: if not model_type:
@ -896,12 +907,18 @@ def start_training():
# 使用新的训练进程管理器提交任务 # 使用新的训练进程管理器提交任务
try: try:
# 修正: 直接从请求中获取完整的 training_scope 对象
training_scope_obj = data.get('training_scope')
if not training_scope_obj or not isinstance(training_scope_obj, dict):
return jsonify({'error': '请求体中缺少有效(必须是JSON对象)的 training_scope 参数'}), 400
task_id = training_manager.submit_task( task_id = training_manager.submit_task(
product_id=product_id or "unknown", product_id=product_id or "unknown",
model_type=model_type, model_type=model_type,
training_mode=training_mode, training_mode=training_mode,
store_id=store_id, store_id=store_id,
epochs=epochs epochs=epochs,
training_scope=training_scope_obj
) )
logger.info(f"🚀 训练任务已提交到进程管理器: {task_id[:8]}") logger.info(f"🚀 训练任务已提交到进程管理器: {task_id[:8]}")
@ -1273,8 +1290,8 @@ def get_training_status(task_id):
@app.route('/api/prediction', methods=['POST']) @app.route('/api/prediction', methods=['POST'])
@swag_from({ @swag_from({
'tags': ['模型预测'], 'tags': ['模型预测'],
'summary': '使用模型进行预测', 'summary': '使用模型进行预测 (v2)',
'description': '使用指定模型预测未来销售数据', 'description': '使用指定模型UID预测未来销售数据',
'parameters': [ 'parameters': [
{ {
'name': 'body', 'name': 'body',
@ -1283,165 +1300,61 @@ def get_training_status(task_id):
'schema': { 'schema': {
'type': 'object', 'type': 'object',
'properties': { 'properties': {
'product_id': {'type': 'string'}, 'model_uid': {'type': 'string', 'description': '要用于预测的模型的唯一ID'},
'model_type': {'type': 'string', 'enum': ['mlstm', 'transformer', 'kan', 'optimized_kan', 'tcn']}, 'future_days': {'type': 'integer', 'default': 7},
'store_id': {'type': 'string', 'description': '店铺ID如 S001。为空时使用全局模型'}, 'start_date': {'type': 'string', 'description': '预测起始日期 (YYYY-MM-DD)'},
'version': {'type': 'string'}, 'include_visualization': {'type': 'boolean', 'default': True},
'future_days': {'type': 'integer'}, 'history_lookback_days': {'type': 'integer', 'default': 30}
'include_visualization': {'type': 'boolean'},
'start_date': {'type': 'string', 'description': '预测起始日期格式为YYYY-MM-DD'}
}, },
'required': ['product_id', 'model_type'] 'required': ['model_uid']
} }
} }
], ],
'responses': { 'responses': {
200: { 200: {'description': '预测成功'},
'description': '预测成功', 400: {'description': '请求错误'},
'schema': { 404: {'description': '模型不存在'},
'type': 'object', 500: {'description': '服务器内部错误'}
'properties': {
'status': {'type': 'string'},
'data': {
'type': 'object',
'properties': {
'product_id': {'type': 'string'},
'product_name': {'type': 'string'},
'model_type': {'type': 'string'},
'predictions': {'type': 'array'},
'visualization': {'type': 'string'}
}
}
}
}
},
400: {
'description': '请求错误,缺少必要参数或参数格式错误'
},
404: {
'description': '产品或模型不存在'
},
500: {
'description': '服务器内部错误'
}
} }
}) })
def predict(): def predict():
""" """
使用指定的模型进行预测 使用指定的模型进行预测 (v2 - 基于数据库)
---
tags:
- 模型预测
parameters:
- in: body
name: body
schema:
type: object
required:
- product_id
- model_type
properties:
product_id:
type: string
description: 产品ID
model_type:
type: string
description: "模型类型 (mlstm, kan, transformer)"
version:
type: string
description: "模型版本 (v1, v2, v3 等),如果不指定则使用最新版本"
future_days:
type: integer
description: 预测未来天数
default: 7
start_date:
type: string
description: 预测起始日期格式为YYYY-MM-DD
default: ''
responses:
200:
description: 预测成功
400:
description: 请求参数错误
404:
description: 模型文件未找到
""" """
try: try:
data = request.json data = request.json
model_type = data.get('model_type') model_uid = data.get('model_uid')
version = data.get('version') if not model_uid:
return jsonify({"status": "error", "message": "缺少 'model_uid' 参数"}), 400
# 从数据库查找模型记录
model_record = find_model_by_uid(model_uid)
if not model_record:
return jsonify({"status": "error", "message": f"模型UID '{model_uid}' 不存在"}), 404
# 解析必要的模型元数据
model_type = model_record.get('model_type')
training_mode = model_record.get('training_mode')
version = model_record.get('version')
# 解析 artifacts 找到模型文件路径
artifacts = json.loads(model_record.get('artifacts', '{}'))
model_file_path = artifacts.get('best_model') or artifacts.get('versioned_model')
if not model_file_path or not os.path.exists(model_file_path):
return jsonify({"status": "error", "message": "找不到模型文件或文件路径无效"}), 404
# 解析 training_scope 获取 product_id 或 store_id
training_scope = json.loads(model_record.get('training_scope', '{}'))
product_id = training_scope.get('product', {}).get('id')
store_id = training_scope.get('store', {}).get('id')
# 获取其他预测参数
future_days = int(data.get('future_days', 7)) future_days = int(data.get('future_days', 7))
start_date = data.get('start_date', '') start_date = data.get('start_date', '')
include_visualization = data.get('include_visualization', False) include_visualization = data.get('include_visualization', True)
history_lookback_days = int(data.get('history_lookback_days', 30)) # 新增参数 history_lookback_days = int(data.get('history_lookback_days', 30))
# 确定训练模式和标识符 # 调用核心预测函数
training_mode = data.get('training_mode', 'product')
product_id = data.get('product_id')
store_id = data.get('store_id')
# v2版根据训练模式和ID构建模型标识符
aggregation_method = data.get('aggregation_method', 'sum') # 全局模式需要
if training_mode == 'global':
model_identifier = f"global_{aggregation_method}"
product_name = f"全局聚合数据 ({aggregation_method})"
elif training_mode == 'store':
if not store_id:
return jsonify({"status": "error", "error": "店铺模式需要 store_id"}), 400
model_identifier = f"store_{store_id}"
product_name = f"店铺 {store_id} 整体"
else: # 默认为 'product' 模式
if not product_id:
return jsonify({"status": "error", "error": "药品模式需要 product_id"}), 400
model_identifier = product_id
product_name = get_product_name(product_id) or product_id
print(f"API接收到预测请求: mode={training_mode}, model_identifier='{model_identifier}', model_type='{model_type}', version='{version}'")
if not model_type:
return jsonify({"status": "error", "error": "model_type 是必需的"}), 400
# 获取模型版本
if not version:
version = get_latest_model_version(model_identifier, model_type)
if not version:
return jsonify({"status": "error", "error": f"未找到标识符为 {model_identifier}{model_type} 类型模型"}), 404
# v2版使用 ModelManager 查找模型文件,不再使用旧的 get_model_file_path
from utils.model_manager import model_manager
# 智能修正 training_mode (兼容前端可能发送的错误模式)
if model_identifier.startswith('store_'):
training_mode = 'store'
store_id = model_identifier.split('_')[1]
elif model_identifier.startswith('global_'):
training_mode = 'global'
# 使用 model_manager 查找模型
models_result = model_manager.list_models(
model_type=model_type,
store_id=store_id if training_mode == 'store' else None,
product_id=product_id if training_mode == 'product' else None,
training_mode=training_mode
)
found_model = None
for model in models_result.get('models', []):
if model.get('version') == version:
found_model = model
break
if not found_model or not found_model.get('file_path'):
error_msg = f"在系统中未找到匹配的模型: mode={training_mode}, type={model_type}, id='{model_identifier}', version={version}"
print(error_msg)
return jsonify({"status": "error", "error": error_msg}), 404
model_file_path = found_model['file_path']
model_id = f"{model_identifier}_{model_type}_{version}"
# v3版直接调用核心预测函数
prediction_result = load_model_and_predict( prediction_result = load_model_and_predict(
model_path=model_file_path, model_path=model_file_path,
product_id=product_id, product_id=product_id,
@ -1456,49 +1369,36 @@ def predict():
) )
if prediction_result is None: if prediction_result is None:
return jsonify({"status": "error", "error": "预测失败核心预测器返回None"}), 500 return jsonify({"status": "error", "message": "预测失败核心预测器返回None"}), 500
# 核心函数已处理好所有数据格式,此处直接构建最终响应 # 保存预测记录到数据库
response_data = { prediction_uid = str(uuid.uuid4())
'status': 'success', result_file_path = os.path.join('saved_predictions', f'prediction_{prediction_uid}.json')
'data': prediction_result, # 包含所有信息的完整结果 os.makedirs(os.path.dirname(result_file_path), exist_ok=True)
'history_data': prediction_result.get('history_data', []),
'prediction_data': prediction_result.get('prediction_data', []) with open(result_file_path, 'w', encoding='utf-8') as f:
json.dump(prediction_result, f, ensure_ascii=False, cls=CustomJSONEncoder)
db_payload = {
"prediction_uid": prediction_uid,
"model_id": model_uid,
"model_type": model_type,
"product_name": model_record.get('display_name'),
"prediction_scope": {"product_id": product_id, "store_id": store_id},
"prediction_params": {"future_days": future_days, "start_date": start_date},
"metrics": prediction_result.get('analysis', {}).get('metrics', {}),
"result_file_path": result_file_path
} }
save_prediction_to_db(db_payload)
# 调试日志 return jsonify({
print("=== 预测API响应数据结构 (v3) ===") 'status': 'success',
print(f"history_data 长度: {len(response_data['history_data'])}") 'data': prediction_result
print(f"prediction_data 长度: {len(response_data['prediction_data'])}") })
print("================================")
# 重新加入保存预测结果的逻辑
try:
model_id_to_save = f"{model_identifier}_{model_type}_{version}"
product_name_to_save = prediction_result.get('product_name', product_id or store_id or 'global')
# 调用辅助函数保存结果
save_prediction_result(
prediction_result=prediction_result.copy(),
product_id=product_id or store_id or 'global',
product_name=product_name_to_save,
model_type=model_type,
model_id=model_id_to_save,
start_date=start_date,
future_days=future_days
)
print(f"✅ 预测结果已成功保存到历史记录。")
except Exception as e:
print(f"⚠️ 警告: 保存预测结果到历史记录失败: {str(e)}")
traceback.print_exc()
# 不应阻止向用户返回结果,因此只打印警告
return jsonify(response_data)
except Exception as e: except Exception as e:
print(f"预测失败: {str(e)}") logger.error(f"预测失败: {e}\n{traceback.format_exc()}")
import traceback return jsonify({"status": "error", "message": str(e)}), 500
traceback.print_exc()
return jsonify({"status": "error", "error": str(e)}), 500
@app.route('/api/prediction/compare', methods=['POST']) @app.route('/api/prediction/compare', methods=['POST'])
@swag_from({ @swag_from({
@ -2042,154 +1942,87 @@ def delete_prediction(prediction_id):
}) })
def list_models(): def list_models():
""" """
列出所有可用的模型 - 使用统一模型管理器 列出所有可用的模型 - v2版从数据库查询
---
tags:
- 模型管理
parameters:
- name: product_id
in: query
type: string
required: false
description: 按产品ID筛选
- name: model_type
in: query
type: string
required: false
description: "按模型类型筛选 (mlstm, kan, transformer, tcn)"
- name: store_id
in: query
type: string
required: false
description: 按店铺ID筛选
- name: training_mode
in: query
type: string
required: false
description: "按训练模式筛选 (product, store, global)"
responses:
200:
description: 模型列表
schema:
type: object
properties:
status:
type: string
example: success
data:
type: array
items:
type: object
properties:
model_id:
type: string
product_id:
type: string
product_name:
type: string
model_type:
type: string
training_mode:
type: string
store_id:
type: string
version:
type: string
created_at:
type: string
metrics:
type: object
""" """
try: try:
from utils.model_manager import ModelManager
# 创建新的ModelManager实例以避免缓存问题
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(current_dir) # 向上一级到项目根目录
model_dir = os.path.join(project_root, 'saved_models')
model_manager = ModelManager(model_dir)
logger.info(f"[API] 获取模型列表请求")
logger.info(f"[API] 模型管理器目录: {model_manager.model_dir}")
logger.info(f"[API] 目录存在: {os.path.exists(model_manager.model_dir)}")
# 获取查询参数 # 获取查询参数
product_id_filter = request.args.get('product_id') filters = {
model_type_filter = request.args.get('model_type') 'product_id': request.args.get('product_id'),
store_id_filter = request.args.get('store_id') 'model_type': request.args.get('model_type'),
training_mode_filter = request.args.get('training_mode') 'store_id': request.args.get('store_id'),
'training_mode': request.args.get('training_mode')
}
# 移除值为None的过滤器
filters = {k: v for k, v in filters.items() if v is not None}
# 获取分页参数 # 获取分页参数
page = request.args.get('page', type=int) page = request.args.get('page', 1, type=int)
page_size = request.args.get('page_size', type=int, default=10) page_size = request.args.get('page_size', 10, type=int)
logger.info(f"[API] 分页参数: page={page}, page_size={page_size}") # 调用新的数据库查询函数
result = query_models_from_db(filters=filters, page=page, page_size=page_size)
# 使用模型管理器获取模型列表 models = result.get('models', [])
result = model_manager.list_models( pagination = result.get('pagination', {})
product_id=product_id_filter,
model_type=model_type_filter,
store_id=store_id_filter,
training_mode=training_mode_filter,
page=page,
page_size=page_size
)
models = result['models'] # 格式化响应数据解析JSON字段
pagination = result['pagination']
# 格式化响应数据
formatted_models = [] formatted_models = []
for model in models: for model in models:
# 生成唯一且有意义的model_id try:
model_id = model.get('filename', '').replace('.pth', '') model['training_scope'] = json.loads(model['training_scope']) if model.get('training_scope') else {}
if not model_id: model['training_params'] = json.loads(model['training_params']) if model.get('training_params') else {}
# 备用方案基于模型信息生成ID model['performance_metrics'] = json.loads(model['performance_metrics']) if model.get('performance_metrics') else {}
product_id = model.get('product_id', 'unknown') model['artifacts'] = json.loads(model['artifacts']) if model.get('artifacts') else {}
model_type = model.get('model_type', 'unknown') # =================================================================
version = model.get('version', 'v1') # 核心修复 v4精确提供前端所需的 product_name 和 store_name 字段
training_mode = model.get('training_mode', 'product') # =================================================================
store_id = model.get('store_id') scope = model.get('training_scope', {})
mode = model.get('training_mode')
if training_mode == 'store' and store_id: # 初始化所有可能的名称字段
model_id = f"{model_type}_store_{store_id}_{product_id}_{version}" model['product_name'] = None
elif training_mode == 'global': model['store_name'] = None
aggregation_method = model.get('aggregation_method', 'mean') model['display_name'] = model.get('display_name') # 优先使用数据库值
model_id = f"{model_type}_global_{product_id}_{aggregation_method}_{version}"
else:
model_id = f"{model_type}_product_{product_id}_{version}"
formatted_model = { if isinstance(scope, dict):
'model_id': model_id, product_info = scope.get('product')
'filename': model.get('filename', ''), store_info = scope.get('store')
'product_id': model.get('product_id', ''),
'product_name': model.get('product_name', model.get('product_id', '')),
'model_type': model.get('model_type', ''),
'training_mode': model.get('training_mode', 'product'),
'store_id': model.get('store_id'),
'aggregation_method': model.get('aggregation_method'),
'version': model.get('version', 'v1'),
'created_at': model.get('created_at', model.get('modified_at', '')),
'file_size': model.get('file_size', 0),
'metrics': model.get('metrics', {}),
'config': model.get('config', {})
}
formatted_models.append(formatted_model)
logger.info(f"[API] 成功获取 {len(formatted_models)} 个模型") if mode == 'product' and isinstance(product_info, dict):
for i, model in enumerate(formatted_models): model['product_name'] = product_info.get('name')
logger.info(f"[API] 模型 {i+1}: id='{model.get('model_id', 'EMPTY')}', filename='{model.get('filename', 'MISSING')}'") model['display_name'] = model['product_name']
# Manually convert numpy types to prevent JSON serialization errors elif mode == 'store' and isinstance(store_info, dict):
processed_models = convert_numpy_types_for_json(formatted_models) model['store_name'] = store_info.get('name')
model['display_name'] = model['store_name'] # 默认显示店铺名
elif mode == 'global':
model['display_name'] = "全局模型"
# 提供最终的后备方案
if not model.get('display_name'):
model['display_name'] = "信息不完整"
# =================================================================
formatted_models.append(model)
except (json.JSONDecodeError, TypeError) as e:
logger.error(f"解析模型JSON数据失败 (model_uid: {model.get('model_uid')}): {e}")
# 即使解析失败,也添加部分数据
model['training_scope'] = {"error": "invalid json"}
model['performance_metrics'] = {"error": "invalid json"}
model['artifacts'] = {"error": "invalid json"}
model['display_name'] = "元数据解析失败"
model['product_name'] = "N/A"
model['store_name'] = "N/A"
formatted_models.append(model)
return jsonify({ return jsonify({
"status": "success", "status": "success",
"data": processed_models, "data": formatted_models,
"pagination": pagination "pagination": pagination
}) })
except Exception as e: except Exception as e:
print(f"获取模型列表失败: {e}") logger.error(f"获取模型列表失败: {e}\n{traceback.format_exc()}")
return jsonify({ return jsonify({
"status": "error", "status": "error",
"message": f"获取模型列表失败: {str(e)}", "message": f"获取模型列表失败: {str(e)}",
@ -3061,103 +2894,6 @@ def analyze_prediction(prediction_result):
traceback.print_exc() traceback.print_exc()
return None return None
# 保存预测结果的辅助函数
def save_prediction_result(prediction_result, product_id, product_name, model_type, model_id, start_date, future_days):
"""
保存预测结果到文件和数据库
返回:
(prediction_id, file_path) - 预测ID和文件路径
"""
try:
# 生成唯一的预测ID
prediction_id = str(uuid.uuid4())
# 确保目录存在
os.makedirs('static/predictions', exist_ok=True)
# 限制数据量
if 'history_data' in prediction_result and isinstance(prediction_result['history_data'], list):
history_data = prediction_result['history_data']
if len(history_data) > 30:
print(f"保存时历史数据超过30天进行裁剪原始数量: {len(history_data)}")
prediction_result['history_data'] = history_data[-30:] # 只保留最近30天
if 'prediction_data' in prediction_result and isinstance(prediction_result['prediction_data'], list):
prediction_data = prediction_result['prediction_data']
if len(prediction_data) > 7:
print(f"保存时预测数据超过7天进行裁剪原始数量: {len(prediction_data)}")
prediction_result['prediction_data'] = prediction_data[:7] # 只保留前7天
# 处理预测结果中可能存在的NumPy类型
def convert_numpy_types(obj):
if isinstance(obj, dict):
return {k: convert_numpy_types(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [convert_numpy_types(item) for item in obj]
elif isinstance(obj, pd.DataFrame):
return obj.to_dict(orient='records')
elif isinstance(obj, pd.Series):
return obj.to_dict()
elif isinstance(obj, np.generic):
return obj.item() # 将NumPy标量转换为Python原生类型
elif isinstance(obj, np.ndarray):
return obj.tolist()
elif pd.isna(obj):
return None
else:
return obj
# 转换整个预测结果对象
prediction_result = convert_numpy_types(prediction_result)
# 保存预测结果到JSON文件
file_name = f"prediction_{prediction_id}.json"
file_path = os.path.join('static/predictions', file_name)
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(prediction_result, f, ensure_ascii=False, cls=CustomJSONEncoder)
# 将预测记录保存到数据库
try:
conn = get_db_connection()
cursor = conn.cursor()
# 提取并序列化需要存储的数据
predictions_data_json = json.dumps(prediction_result.get('prediction_data', []), cls=CustomJSONEncoder)
# 从分析结果中获取指标,如果分析结果不存在,则使用空字典
analysis_data = prediction_result.get('analysis', {})
metrics_data = analysis_data.get('metrics', {}) if isinstance(analysis_data, dict) else {}
metrics_json = json.dumps(metrics_data, cls=CustomJSONEncoder)
chart_data_json = json.dumps(prediction_result.get('chart_data', {}), cls=CustomJSONEncoder)
analysis_json = json.dumps(analysis_data, cls=CustomJSONEncoder)
cursor.execute('''
INSERT INTO prediction_history (
prediction_id, product_id, product_name, model_type, model_id,
start_date, future_days, created_at, file_path,
predictions_data, metrics, chart_data, analysis
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
''', (
prediction_id, product_id, product_name, model_type, model_id,
start_date if start_date else datetime.now().strftime('%Y-%m-%d'),
future_days, datetime.now().isoformat(), file_path,
predictions_data_json, metrics_json, chart_data_json, analysis_json
))
conn.commit()
conn.close()
except Exception as e:
print(f"保存预测记录到数据库失败: {str(e)}")
traceback.print_exc()
return prediction_id, file_path
except Exception as e:
print(f"保存预测结果失败: {str(e)}")
traceback.print_exc()
return None, None
# 添加模型性能分析接口 # 添加模型性能分析接口
@app.route('/api/models/analyze-metrics', methods=['POST']) @app.route('/api/models/analyze-metrics', methods=['POST'])
@ -4391,13 +4127,10 @@ def test_models_fix():
from utils.model_manager import ModelManager from utils.model_manager import ModelManager
import os import os
# 强制创建新的ModelManager实例 # 修正: 直接使用默认的相对路径
current_dir = os.path.dirname(os.path.abspath(__file__)) manager = ModelManager()
project_root = os.path.dirname(current_dir)
model_dir = os.path.join(project_root, 'saved_models')
manager = ModelManager(model_dir)
models = manager.list_models() models = manager.list_models()['models']
# 简化的响应格式 # 简化的响应格式
test_result = { test_result = {

View File

@ -2178,11 +2178,8 @@ def list_models():
try: try:
from utils.model_manager import ModelManager from utils.model_manager import ModelManager
# 创建新的ModelManager实例以避免缓存问题 # 修正: 直接使用默认的相对路径
current_dir = os.path.dirname(os.path.abspath(__file__)) model_manager = ModelManager()
project_root = os.path.dirname(current_dir) # 向上一级到项目根目录
model_dir = os.path.join(project_root, 'saved_models')
model_manager = ModelManager(model_dir)
logger.info(f"[API] 获取模型列表请求") logger.info(f"[API] 获取模型列表请求")
logger.info(f"[API] 模型管理器目录: {model_manager.model_dir}") logger.info(f"[API] 模型管理器目录: {model_manager.model_dir}")
@ -4465,13 +4462,10 @@ def test_models_fix():
from utils.model_manager import ModelManager from utils.model_manager import ModelManager
import os import os
# 强制创建新的ModelManager实例 # 修正: 直接使用默认的相对路径
current_dir = os.path.dirname(os.path.abspath(__file__)) manager = ModelManager()
project_root = os.path.dirname(current_dir)
model_dir = os.path.join(project_root, 'saved_models')
manager = ModelManager(model_dir)
models = manager.list_models() models = manager.list_models()['models']
# 简化的响应格式 # 简化的响应格式
test_result = { test_result = {

View File

@ -10,13 +10,6 @@ import os
import re import re
import glob import glob
# 项目根目录
# __file__ 是当前文件 (config.py) 的路径
# os.path.dirname(__file__) 是 server/core
# os.path.join(..., '..') 是 server
# os.path.join(..., '..', '..') 是项目根目录
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
# 解决画图中文显示问题 # 解决画图中文显示问题
plt.rcParams['font.sans-serif'] = ['SimHei'] plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False plt.rcParams['axes.unicode_minus'] = False
@ -34,8 +27,9 @@ DEVICE = get_device()
# 数据相关配置 # 数据相关配置
# 使用 os.path.join 构造跨平台的路径 # 使用 os.path.join 构造跨平台的路径
DEFAULT_DATA_PATH = os.path.join(PROJECT_ROOT, 'data', 'timeseries_training_data_sample_10s50p.parquet') # 修正: 改为相对路径
DEFAULT_MODEL_DIR = os.path.join(PROJECT_ROOT, 'saved_models') DEFAULT_DATA_PATH = os.path.join('data', 'timeseries_training_data_sample_10s50p.parquet')
DEFAULT_MODEL_DIR = 'saved_models'
DEFAULT_FEATURES = ['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'] DEFAULT_FEATURES = ['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
# 时间序列参数 # 时间序列参数

View File

@ -225,11 +225,13 @@ class PharmacyPredictor:
result = trainer_function(**valid_args) result = trainer_function(**valid_args)
# 根据返回值的数量解析metrics # 根据返回值的数量解析metrics
if isinstance(result, tuple) and len(result) >= 2: if isinstance(result, tuple) and len(result) >= 3:
metrics = result[1] # 通常第二个返回值是metrics metrics = result[1] # 第二个返回值是metrics
artifacts = result[2] # 第三个返回值是artifacts
else: else:
log_message(f"⚠️ 训练器返回格式未知无法直接提取metrics: {type(result)}", 'warning') log_message(f"⚠️ 训练器返回格式未知无法直接提取metrics和artifacts: {type(result)}", 'warning')
metrics = None metrics = None
artifacts = None
# 检查和打印返回的metrics # 检查和打印返回的metrics
@ -249,7 +251,7 @@ class PharmacyPredictor:
else: else:
log_message(f"⚠️ metrics为空或None", 'warning') log_message(f"⚠️ metrics为空或None", 'warning')
return metrics return metrics, artifacts
except Exception as e: except Exception as e:
log_message(f"模型训练失败: {e}", 'error') log_message(f"模型训练失败: {e}", 'error')

View File

@ -65,9 +65,21 @@ def train_with_cnn_bilstm_attention(product_id, model_identifier, product_df, st
# --- 3. 训练循环与早停 --- # --- 3. 训练循环与早停 ---
print("开始训练 CNN-BiLSTM-Attention 模型 (含早停)...") print("开始训练 CNN-BiLSTM-Attention 模型 (含早停)...")
# 版本锁定:在训练开始前确定本次训练的版本号
current_version = model_manager.peek_next_version(
model_type='cnn_bilstm_attention',
product_id=product_id,
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method
)
print(f"🔒 本次训练版本锁定为: {current_version}")
loss_history = {'train': [], 'val': []} loss_history = {'train': [], 'val': []}
best_val_loss = float('inf') best_val_loss = float('inf')
best_model_state = None best_model_state = None
best_model_path = None # 用于存储最佳模型的路径
patience = kwargs.get('patience', 15) patience = kwargs.get('patience', 15)
patience_counter = 0 patience_counter = 0
@ -99,6 +111,31 @@ def train_with_cnn_bilstm_attention(product_id, model_identifier, product_df, st
best_model_state = copy.deepcopy(model.state_dict()) best_model_state = copy.deepcopy(model.state_dict())
patience_counter = 0 patience_counter = 0
print(f"✨ 新的最佳模型! Epoch: {epoch+1}, Val Loss: {best_val_loss:.4f}") print(f"✨ 新的最佳模型! Epoch: {epoch+1}, Val Loss: {best_val_loss:.4f}")
# 立即保存最佳模型
best_model_data = {
'model_state_dict': best_model_state,
'scaler_X': scaler_X,
'scaler_y': scaler_y,
'config': {
'model_type': 'cnn_bilstm_attention',
'input_dim': input_dim,
'output_dim': forecast_horizon,
'sequence_length': sequence_length,
'features': features
},
'epoch': epoch + 1
}
best_model_path, _ = model_manager.save_model(
model_data=best_model_data,
product_id=product_id,
model_type='cnn_bilstm_attention',
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
product_name=product_name,
version=f"{current_version}_best"
)
else: else:
patience_counter += 1 patience_counter += 1
if patience_counter >= patience: if patience_counter >= patience:
@ -128,17 +165,32 @@ def train_with_cnn_bilstm_attention(product_id, model_identifier, product_df, st
print("\n最佳模型评估指标:") print("\n最佳模型评估指标:")
print(f"MSE: {metrics['mse']:.4f}, RMSE: {metrics['rmse']:.4f}, MAE: {metrics['mae']:.4f}, R²: {metrics['r2']:.4f}, MAPE: {metrics['mape']:.2f}%") print(f"MSE: {metrics['mse']:.4f}, RMSE: {metrics['rmse']:.4f}, MAE: {metrics['mae']:.4f}, R²: {metrics['r2']:.4f}, MAPE: {metrics['mape']:.2f}%")
# 绘制损失曲线 # --- 5. 保存工件 ---
# 准备 scope 和 identifier 以生成标准化的文件名
scope = training_mode
if scope == 'product':
identifier = product_id
elif scope == 'store':
identifier = store_id
elif scope == 'global':
identifier = aggregation_method
else:
identifier = product_name # 后备方案
# 绘制带有版本号的损失曲线图
loss_curve_path = plot_loss_curve( loss_curve_path = plot_loss_curve(
loss_history['train'], train_losses=loss_history['train'],
loss_history['val'], val_losses=loss_history['val'],
product_name, model_type='cnn_bilstm_attention',
'cnn_bilstm_attention', scope=scope,
identifier=identifier,
version=current_version, # 使用锁定的版本
model_dir=model_dir model_dir=model_dir
) )
print(f"📈 损失曲线已保存到: {loss_curve_path}") print(f"📈 带版本号的损失曲线已保存: {loss_curve_path}")
# --- 5. 模型保存 --- # 准备要保存的最终模型数据
model_data = { model_data = {
'model_state_dict': best_model_state, # 保存最佳模型的状态 'model_state_dict': best_model_state, # 保存最佳模型的状态
'scaler_X': scaler_X, 'scaler_X': scaler_X,
@ -151,11 +203,11 @@ def train_with_cnn_bilstm_attention(product_id, model_identifier, product_df, st
'features': features 'features': features
}, },
'metrics': metrics, 'metrics': metrics,
'loss_history': loss_history, # 保存损失历史 'loss_history': loss_history,
'loss_curve_path': loss_curve_path # 添加损失图路径 'loss_curve_path': loss_curve_path # 直接包含路径
} }
# 保存最终版本模型 # 使用模型管理器保存最终模型
final_model_path, final_version = model_manager.save_model( final_model_path, final_version = model_manager.save_model(
model_data=model_data, model_data=model_data,
product_id=product_id, product_id=product_id,
@ -163,24 +215,20 @@ def train_with_cnn_bilstm_attention(product_id, model_identifier, product_df, st
store_id=store_id, store_id=store_id,
training_mode=training_mode, training_mode=training_mode,
aggregation_method=aggregation_method, aggregation_method=aggregation_method,
product_name=product_name product_name=product_name,
version=current_version # 使用锁定的版本号
) )
print(f"✅ CNN-BiLSTM-Attention 最终模型已保存,版本: {final_version}") print(f"✅ CNN-BiLSTM-Attention 最终模型已保存,版本: {final_version}")
# 保存最佳版本模型 # 组装返回的工件
best_model_path, best_version = model_manager.save_model( artifacts = {
model_data=model_data, "versioned_model": final_model_path,
product_id=product_id, "loss_curve_plot": loss_curve_path,
model_type='cnn_bilstm_attention', "best_model": best_model_path,
store_id=store_id, "version": final_version
training_mode=training_mode, }
aggregation_method=aggregation_method,
product_name=product_name,
version='best' # 明确指定版本为 'best'
)
print(f"✅ CNN-BiLSTM-Attention 最佳模型已保存,版本: {best_version}")
return model, metrics, final_version, final_model_path return model, metrics, artifacts
# --- 关键步骤: 将训练器注册到系统中 --- # --- 关键步骤: 将训练器注册到系统中 ---
register_trainer('cnn_bilstm_attention', train_with_cnn_bilstm_attention) register_trainer('cnn_bilstm_attention', train_with_cnn_bilstm_attention)

View File

@ -165,10 +165,22 @@ def train_product_model_with_kan(product_id, model_identifier, product_df=None,
optimizer = optim.Adam(model.parameters(), lr=0.001) optimizer = optim.Adam(model.parameters(), lr=0.001)
# 训练模型 # 训练模型
from utils.model_manager import model_manager
model_type_name = 'optimized_kan' if use_optimized else 'kan'
current_version = model_manager.peek_next_version(
model_type=model_type_name,
product_id=model_identifier,
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method
)
print(f"🔒 本次训练版本锁定为: {current_version}")
train_losses = [] train_losses = []
test_losses = [] test_losses = []
start_time = time.time() start_time = time.time()
best_loss = float('inf') best_loss = float('inf')
best_model_path = None
for epoch in range(epochs): for epoch in range(epochs):
model.train() model.train()
@ -253,7 +265,7 @@ def train_product_model_with_kan(product_id, model_identifier, product_df=None,
# 使用模型管理器保存 'best' 版本 # 使用模型管理器保存 'best' 版本
from utils.model_manager import model_manager from utils.model_manager import model_manager
model_manager.save_model( best_model_path, _ = model_manager.save_model(
model_data=best_model_data, model_data=best_model_data,
product_id=model_identifier, # 修正:使用唯一的标识符 product_id=model_identifier, # 修正:使用唯一的标识符
model_type=model_type_name, model_type=model_type_name,
@ -261,7 +273,7 @@ def train_product_model_with_kan(product_id, model_identifier, product_df=None,
training_mode=training_mode, training_mode=training_mode,
aggregation_method=aggregation_method, aggregation_method=aggregation_method,
product_name=product_name, product_name=product_name,
version='best' # 显式覆盖版本为'best' version=f"{current_version}_best"
) )
if (epoch + 1) % 10 == 0: if (epoch + 1) % 10 == 0:
@ -271,15 +283,6 @@ def train_product_model_with_kan(product_id, model_identifier, product_df=None,
training_time = time.time() - start_time training_time = time.time() - start_time
# 绘制损失曲线并保存到模型目录 # 绘制损失曲线并保存到模型目录
model_name = 'optimized_kan' if use_optimized else 'kan'
loss_curve_path = plot_loss_curve(
train_losses,
test_losses,
product_name,
model_type,
model_dir=model_dir
)
print(f"损失曲线已保存到: {loss_curve_path}")
# 评估模型 # 评估模型
model.eval() model.eval()
@ -307,11 +310,33 @@ def train_product_model_with_kan(product_id, model_identifier, product_df=None,
print(f"MAPE: {metrics['mape']:.2f}%") print(f"MAPE: {metrics['mape']:.2f}%")
print(f"训练时间: {training_time:.2f}") print(f"训练时间: {training_time:.2f}")
# 使用统一模型管理器保存模型 # --- 5. 保存工件 ---
from utils.model_manager import model_manager
model_type_name = 'optimized_kan' if use_optimized else 'kan' model_type_name = 'optimized_kan' if use_optimized else 'kan'
# 准备 scope 和 identifier 以生成标准化的文件名
scope = training_mode
if scope == 'product':
identifier = model_identifier
elif scope == 'store':
identifier = store_id
elif scope == 'global':
identifier = aggregation_method
else:
identifier = product_name # 后备方案
# 绘制带有版本号的损失曲线图
loss_curve_path = plot_loss_curve(
train_losses=train_losses,
val_losses=test_losses,
model_type=model_type_name,
scope=scope,
identifier=identifier,
version=current_version, # 使用锁定的版本
model_dir=model_dir
)
print(f"📈 带版本号的损失曲线已保存: {loss_curve_path}")
# 准备要保存的最终模型数据
model_data = { model_data = {
'model_state_dict': model.state_dict(), 'model_state_dict': model.state_dict(),
'scaler_X': scaler_X, 'scaler_X': scaler_X,
@ -332,24 +357,32 @@ def train_product_model_with_kan(product_id, model_identifier, product_df=None,
'test': test_losses, 'test': test_losses,
'epochs': list(range(1, epochs + 1)) 'epochs': list(range(1, epochs + 1))
}, },
'loss_curve_path': loss_curve_path 'loss_curve_path': loss_curve_path # 直接包含路径
} }
# 保存最终模型,让 model_manager 自动处理版本号 # 使用模型管理器保存最终模型
from utils.model_manager import model_manager
final_model_path, final_version = model_manager.save_model( final_model_path, final_version = model_manager.save_model(
model_data=model_data, model_data=model_data,
product_id=model_identifier, # 修正:使用唯一的标识符 product_id=model_identifier,
model_type=model_type_name, model_type=model_type_name,
store_id=store_id, store_id=store_id,
training_mode=training_mode, training_mode=training_mode,
aggregation_method=aggregation_method, aggregation_method=aggregation_method,
product_name=product_name product_name=product_name,
# 注意此处不传递version参数由管理器自动生成 version=current_version # 使用锁定的版本
) )
print(f"{model_type_name} 最终模型已保存,版本: {final_version}")
print(f"最终模型已保存,版本: {final_version}, 路径: {final_model_path}") # 组装返回的工件
artifacts = {
"versioned_model": final_model_path,
"loss_curve_plot": loss_curve_path,
"best_model": best_model_path,
"version": final_version
}
return model, metrics return model, metrics, artifacts
# --- 将此训练器注册到系统中 --- # --- 将此训练器注册到系统中 ---
from models.model_registry import register_trainer from models.model_registry import register_trainer

View File

@ -255,6 +255,16 @@ def train_product_model_with_mlstm(
emit_progress("数据预处理完成,开始模型训练...", progress=10) emit_progress("数据预处理完成,开始模型训练...", progress=10)
# 训练模型 # 训练模型
# 版本锁定
current_version = model_manager.peek_next_version(
model_type='mlstm',
product_id=model_identifier,
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method
)
print(f"🔒 本次训练版本锁定为: {current_version}")
train_losses = [] train_losses = []
test_losses = [] test_losses = []
start_time = time.time() start_time = time.time()
@ -263,6 +273,7 @@ def train_product_model_with_mlstm(
checkpoint_interval = max(1, epochs // 10) # 每10%进度保存一次最少每1个epoch checkpoint_interval = max(1, epochs // 10) # 每10%进度保存一次最少每1个epoch
best_loss = float('inf') best_loss = float('inf')
epochs_no_improve = 0 epochs_no_improve = 0
best_model_path = None
emit_progress(f"开始训练 - 总epoch: {epochs}, 检查点间隔: {checkpoint_interval}, 耐心值: {patience}") emit_progress(f"开始训练 - 总epoch: {epochs}, 检查点间隔: {checkpoint_interval}, 耐心值: {patience}")
@ -365,7 +376,7 @@ def train_product_model_with_mlstm(
# 如果是最佳模型,额外保存一份 # 如果是最佳模型,额外保存一份
if test_loss < best_loss: if test_loss < best_loss:
best_loss = test_loss best_loss = test_loss
model_manager.save_model( best_model_path, _ = model_manager.save_model(
model_data=checkpoint_data, model_data=checkpoint_data,
product_id=model_identifier, # 修正:使用唯一的标识符 product_id=model_identifier, # 修正:使用唯一的标识符
model_type='mlstm', model_type='mlstm',
@ -373,7 +384,7 @@ def train_product_model_with_mlstm(
training_mode=training_mode, training_mode=training_mode,
aggregation_method=aggregation_method, aggregation_method=aggregation_method,
product_name=product_name, product_name=product_name,
version='best' version=f"{current_version}_best"
) )
emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})") emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})")
epochs_no_improve = 0 epochs_no_improve = 0
@ -391,37 +402,6 @@ def train_product_model_with_mlstm(
# 计算训练时间 # 计算训练时间
training_time = time.time() - start_time training_time = time.time() - start_time
emit_progress("生成损失曲线...", progress=95)
# 确定模型保存目录(支持多店铺)
if store_id:
# 为特定店铺创建子目录
store_model_dir = os.path.join(model_dir, 'mlstm', store_id)
os.makedirs(store_model_dir, exist_ok=True)
loss_curve_filename = f"{product_id}_mlstm_{version}_loss_curve.png"
loss_curve_path = os.path.join(store_model_dir, loss_curve_filename)
else:
# 全局模型保存在global目录
global_model_dir = os.path.join(model_dir, 'mlstm', 'global')
os.makedirs(global_model_dir, exist_ok=True)
loss_curve_filename = f"{product_id}_mlstm_{version}_global_loss_curve.png"
loss_curve_path = os.path.join(global_model_dir, loss_curve_filename)
# 绘制损失曲线并保存到模型目录
plt.figure(figsize=(10, 6))
plt.plot(train_losses, label='Training Loss')
plt.plot(test_losses, label='Test Loss')
title_suffix = f" - {training_scope}" if store_id else " - 全局模型"
plt.title(f'mLSTM 模型训练损失曲线 - {product_name} ({version}){title_suffix}')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.savefig(loss_curve_path, dpi=300, bbox_inches='tight')
plt.close()
print(f"损失曲线已保存到: {loss_curve_path}")
emit_progress("模型评估中...", progress=98) emit_progress("模型评估中...", progress=98)
# 评估模型 # 评估模型
@ -475,7 +455,7 @@ def train_product_model_with_mlstm(
'model_type': 'mlstm' 'model_type': 'mlstm'
}, },
'metrics': metrics, 'metrics': metrics,
'loss_curve_path': loss_curve_path, 'metrics': metrics,
'training_info': { 'training_info': {
'product_id': product_id, 'product_id': product_id,
'product_name': product_name, 'product_name': product_name,
@ -488,6 +468,31 @@ def train_product_model_with_mlstm(
} }
} }
# 准备 scope 和 identifier 以生成标准化的文件名
scope = training_mode
if scope == 'product':
identifier = model_identifier
elif scope == 'store':
identifier = store_id
elif scope == 'global':
identifier = aggregation_method
else:
identifier = product_name # 后备方案
# 绘制带有版本号的损失曲线图
loss_curve_path = plot_loss_curve(
train_losses=train_losses,
val_losses=test_losses,
model_type='mlstm',
scope=scope,
identifier=identifier,
version=current_version, # 使用锁定的版本
model_dir=model_dir
)
print(f"📈 带版本号的损失曲线已保存: {loss_curve_path}")
# 更新模型数据中的损失图路径
final_model_data['loss_curve_path'] = loss_curve_path
# 保存最终模型,让 model_manager 自动处理版本号 # 保存最终模型,让 model_manager 自动处理版本号
final_model_path, final_version = model_manager.save_model( final_model_path, final_version = model_manager.save_model(
model_data=final_model_data, model_data=final_model_data,
@ -496,7 +501,8 @@ def train_product_model_with_mlstm(
store_id=store_id, store_id=store_id,
training_mode=training_mode, training_mode=training_mode,
aggregation_method=aggregation_method, aggregation_method=aggregation_method,
product_name=product_name product_name=product_name,
version=current_version
) )
# 发送训练完成消息 # 发送训练完成消息
@ -514,7 +520,16 @@ def train_product_model_with_mlstm(
emit_progress(f"✅ mLSTM模型训练完成版本 {final_version} 已保存", progress=100, metrics=final_metrics) emit_progress(f"✅ mLSTM模型训练完成版本 {final_version} 已保存", progress=100, metrics=final_metrics)
return model, metrics, epochs, final_model_path # 组装 artifacts 字典
artifacts = {
"versioned_model": final_model_path,
"loss_curve_plot": loss_curve_path,
# 假设 best model 的路径可以从 model_manager 获取或推断
"best_model": best_model_path,
"version": final_version
}
return model, metrics, artifacts
# --- 将此训练器注册到系统中 --- # --- 将此训练器注册到系统中 ---
from models.model_registry import register_trainer from models.model_registry import register_trainer

View File

@ -164,8 +164,19 @@ def train_product_model_with_tcn(
test_losses = [] test_losses = []
start_time = time.time() start_time = time.time()
# 版本锁定
current_version = model_manager.peek_next_version(
model_type='tcn',
product_id=model_identifier,
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method
)
print(f"🔒 本次训练版本锁定为: {current_version}")
checkpoint_interval = max(1, epochs // 10) checkpoint_interval = max(1, epochs // 10)
best_loss = float('inf') best_loss = float('inf')
best_model_path = None
progress_manager.set_stage("model_training", 0) progress_manager.set_stage("model_training", 0)
emit_progress(f"开始训练 - 总epoch: {epochs}, 检查点间隔: {checkpoint_interval}") emit_progress(f"开始训练 - 总epoch: {epochs}, 检查点间隔: {checkpoint_interval}")
@ -269,7 +280,7 @@ def train_product_model_with_tcn(
if test_loss < best_loss: if test_loss < best_loss:
best_loss = test_loss best_loss = test_loss
model_manager.save_model( best_model_path, _ = model_manager.save_model(
model_data=checkpoint_data, model_data=checkpoint_data,
product_id=model_identifier, # 修正:使用唯一的标识符 product_id=model_identifier, # 修正:使用唯一的标识符
model_type='tcn', model_type='tcn',
@ -277,7 +288,7 @@ def train_product_model_with_tcn(
training_mode=training_mode, training_mode=training_mode,
aggregation_method=aggregation_method, aggregation_method=aggregation_method,
product_name=product_name, product_name=product_name,
version='best' version=f"{current_version}_best"
) )
emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})") emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})")
@ -289,14 +300,6 @@ def train_product_model_with_tcn(
progress_manager.set_stage("model_saving", 0) progress_manager.set_stage("model_saving", 0)
emit_progress("训练完成,正在保存模型...") emit_progress("训练完成,正在保存模型...")
loss_curve_path = plot_loss_curve(
train_losses,
test_losses,
product_name,
'TCN',
model_dir=model_dir
)
print(f"损失曲线已保存到: {loss_curve_path}")
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
@ -340,7 +343,7 @@ def train_product_model_with_tcn(
'model_type': 'tcn' 'model_type': 'tcn'
}, },
'metrics': metrics, 'metrics': metrics,
'loss_curve_path': loss_curve_path, 'metrics': metrics,
'training_info': { 'training_info': {
'product_id': product_id, 'product_id': product_id,
'product_name': product_name, 'product_name': product_name,
@ -361,7 +364,8 @@ def train_product_model_with_tcn(
store_id=store_id, store_id=store_id,
training_mode=training_mode, training_mode=training_mode,
aggregation_method=aggregation_method, aggregation_method=aggregation_method,
product_name=product_name product_name=product_name,
version=current_version
) )
progress_manager.set_stage("model_saving", 100) progress_manager.set_stage("model_saving", 100)
@ -379,7 +383,40 @@ def train_product_model_with_tcn(
emit_progress(f"模型训练完成!版本 {final_version} 已保存", progress=100, metrics=final_metrics) emit_progress(f"模型训练完成!版本 {final_version} 已保存", progress=100, metrics=final_metrics)
return model, metrics, epochs, final_model_path # 准备 scope 和 identifier 以生成标准化的文件名
scope = training_mode
if scope == 'product':
identifier = model_identifier
elif scope == 'store':
identifier = store_id
elif scope == 'global':
identifier = aggregation_method
else:
identifier = product_name # 后备方案
# 绘制带有版本号的损失曲线图
loss_curve_path = plot_loss_curve(
train_losses=train_losses,
val_losses=test_losses,
model_type='tcn',
scope=scope,
identifier=identifier,
version=current_version, # 使用锁定的版本
model_dir=model_dir
)
print(f"📈 带版本号的损失曲线已保存: {loss_curve_path}")
# 更新模型数据中的损失图路径
final_model_data['loss_curve_path'] = loss_curve_path
artifacts = {
"versioned_model": final_model_path,
"loss_curve_plot": loss_curve_path,
"best_model": best_model_path,
"version": final_version
}
return model, metrics, artifacts
# --- 将此训练器注册到系统中 --- # --- 将此训练器注册到系统中 ---
from models.model_registry import register_trainer from models.model_registry import register_trainer

View File

@ -187,9 +187,20 @@ def train_product_model_with_transformer(
test_losses = [] test_losses = []
start_time = time.time() start_time = time.time()
# 版本锁定
current_version = model_manager.peek_next_version(
model_type='transformer',
product_id=model_identifier,
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method
)
print(f"🔒 本次训练版本锁定为: {current_version}")
checkpoint_interval = max(1, epochs // 10) checkpoint_interval = max(1, epochs // 10)
best_loss = float('inf') best_loss = float('inf')
epochs_no_improve = 0 epochs_no_improve = 0
best_model_path = None
progress_manager.set_stage("model_training", 0) progress_manager.set_stage("model_training", 0)
emit_progress(f"开始训练 - 总epoch: {epochs}, 检查点间隔: {checkpoint_interval}, 耐心值: {patience}") emit_progress(f"开始训练 - 总epoch: {epochs}, 检查点间隔: {checkpoint_interval}, 耐心值: {patience}")
@ -289,7 +300,8 @@ def train_product_model_with_transformer(
if test_loss < best_loss: if test_loss < best_loss:
best_loss = test_loss best_loss = test_loss
model_manager.save_model( # 修正: 保存最佳模型路径
best_model_path, _ = model_manager.save_model(
model_data=checkpoint_data, model_data=checkpoint_data,
product_id=model_identifier, # 修正:使用唯一的标识符 product_id=model_identifier, # 修正:使用唯一的标识符
model_type='transformer', model_type='transformer',
@ -297,7 +309,7 @@ def train_product_model_with_transformer(
training_mode=training_mode, training_mode=training_mode,
aggregation_method=aggregation_method, aggregation_method=aggregation_method,
product_name=product_name, product_name=product_name,
version='best' version=f"{current_version}_best"
) )
emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})") emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})")
epochs_no_improve = 0 epochs_no_improve = 0
@ -316,14 +328,6 @@ def train_product_model_with_transformer(
progress_manager.set_stage("model_saving", 0) progress_manager.set_stage("model_saving", 0)
emit_progress("训练完成,正在保存模型...") emit_progress("训练完成,正在保存模型...")
loss_curve_path = plot_loss_curve(
train_losses,
test_losses,
product_name,
'Transformer',
model_dir=model_dir
)
print(f"📈 损失曲线已保存到: {loss_curve_path}", flush=True)
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
@ -366,7 +370,7 @@ def train_product_model_with_transformer(
'model_type': 'transformer' 'model_type': 'transformer'
}, },
'metrics': metrics, 'metrics': metrics,
'loss_curve_path': loss_curve_path, 'metrics': metrics,
'training_info': { 'training_info': {
'product_id': product_id, 'product_id': product_id,
'product_name': product_name, 'product_name': product_name,
@ -387,7 +391,8 @@ def train_product_model_with_transformer(
store_id=store_id, store_id=store_id,
training_mode=training_mode, training_mode=training_mode,
aggregation_method=aggregation_method, aggregation_method=aggregation_method,
product_name=product_name product_name=product_name,
version=current_version
) )
progress_manager.set_stage("model_saving", 100) progress_manager.set_stage("model_saving", 100)
@ -406,7 +411,40 @@ def train_product_model_with_transformer(
'version': final_version 'version': final_version
} }
return model, final_metrics, epochs # 准备 scope 和 identifier 以生成标准化的文件名
scope = training_mode
if scope == 'product':
identifier = model_identifier
elif scope == 'store':
identifier = store_id
elif scope == 'global':
identifier = aggregation_method
else:
identifier = product_name # 后备方案
# 绘制带有版本号的损失曲线图
loss_curve_path = plot_loss_curve(
train_losses=train_losses,
val_losses=test_losses,
model_type='transformer',
scope=scope,
identifier=identifier,
version=current_version, # 使用锁定的版本
model_dir=model_dir
)
print(f"📈 带版本号的损失曲线已保存: {loss_curve_path}")
# 更新模型数据中的损失图路径
final_model_data['loss_curve_path'] = loss_curve_path
artifacts = {
"versioned_model": final_model_path,
"loss_curve_plot": loss_curve_path,
"best_model": best_model_path,
"version": final_version
}
return model, final_metrics, artifacts
# --- 将此训练器注册到系统中 --- # --- 将此训练器注册到系统中 ---
from models.model_registry import register_trainer from models.model_registry import register_trainer

View File

@ -76,6 +76,17 @@ def train_product_model_with_xgboost(product_id, model_identifier, product_df, s
n_estimators = kwargs.get('n_estimators', 500) n_estimators = kwargs.get('n_estimators', 500)
print("开始训练XGBoost模型 (使用核心xgb.train API)...") print("开始训练XGBoost模型 (使用核心xgb.train API)...")
# 版本锁定
current_version = model_manager.peek_next_version(
model_type='xgboost',
product_id=product_id,
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method
)
print(f"🔒 本次训练版本锁定为: {current_version}")
start_time = time.time() start_time = time.time()
evals_result = {} evals_result = {}
@ -109,16 +120,33 @@ def train_product_model_with_xgboost(product_id, model_identifier, product_df, s
# 提取损失并绘制曲线 # 提取损失并绘制曲线
train_losses = evals_result['train']['rmse'] train_losses = evals_result['train']['rmse']
test_losses = evals_result['test']['rmse'] test_losses = evals_result['test']['rmse']
# --- 5. 保存工件 ---
# 准备 scope 和 identifier 以生成标准化的文件名
scope = training_mode
if scope == 'product':
identifier = product_id
elif scope == 'store':
identifier = store_id
elif scope == 'global':
identifier = aggregation_method
else:
identifier = product_name # 后备方案
# 绘制带有版本号的损失曲线图
loss_curve_path = plot_loss_curve( loss_curve_path = plot_loss_curve(
train_losses, train_losses=train_losses,
test_losses, val_losses=test_losses,
product_name, model_type='xgboost',
'xgboost', scope=scope,
identifier=identifier,
version=current_version, # 使用锁定的版本
model_dir=model_dir model_dir=model_dir
) )
print(f"📈 损失曲线已保存到: {loss_curve_path}") print(f"📈 带版本号的损失曲线已保存: {loss_curve_path}")
# --- 5. 模型保存 (借道 utils.model_manager) --- # 准备要保存的最终模型数据
model_data = { model_data = {
'model_state_dict': model, # 直接保存模型对象 'model_state_dict': model, # 直接保存模型对象
'scaler_X': scaler_X, 'scaler_X': scaler_X,
@ -132,7 +160,7 @@ def train_product_model_with_xgboost(product_id, model_identifier, product_df, s
}, },
'metrics': metrics, 'metrics': metrics,
'loss_history': evals_result, 'loss_history': evals_result,
'loss_curve_path': loss_curve_path # 添加损失图路径 'loss_curve_path': loss_curve_path # 直接包含路径
} }
# 保存最终版本模型 # 保存最终版本模型
@ -143,12 +171,14 @@ def train_product_model_with_xgboost(product_id, model_identifier, product_df, s
store_id=store_id, store_id=store_id,
training_mode=training_mode, training_mode=training_mode,
aggregation_method=aggregation_method, aggregation_method=aggregation_method,
product_name=product_name product_name=product_name,
version=current_version
) )
print(f"✅ XGBoost最终模型已通过统一管理器保存版本: {final_version}, 路径: {final_model_path}") print(f"✅ XGBoost最终模型已通过统一管理器保存版本: {final_version}")
# 保存最佳版本模型 # XGBoost的 `best_model` 就是它自己,因为 `xgb.train` 内部处理了早停。
best_model_path, best_version = model_manager.save_model( # 我们创建一个指向最终模型的 "best" 版本文件,以保持接口一致性。
best_model_path, _ = model_manager.save_model(
model_data=model_data, model_data=model_data,
product_id=product_id, product_id=product_id,
model_type='xgboost', model_type='xgboost',
@ -156,12 +186,19 @@ def train_product_model_with_xgboost(product_id, model_identifier, product_df, s
training_mode=training_mode, training_mode=training_mode,
aggregation_method=aggregation_method, aggregation_method=aggregation_method,
product_name=product_name, product_name=product_name,
version='best' # 明确指定版本为 'best' version=f"{current_version}_best"
) )
print(f"✅ XGBoost最佳模型已通过统一管理器保存,版本: {best_version}, 路径: {best_model_path}") print(f"✅ XGBoost最佳模型引用已创建,版本: {current_version}_best")
# 返回值遵循统一格式 # 组装返回的工件
return model, metrics, final_version, final_model_path artifacts = {
"versioned_model": final_model_path,
"loss_curve_plot": loss_curve_path,
"best_model": best_model_path,
"version": final_version
}
return model, metrics, artifacts
# --- 将此训练器注册到系统中 --- # --- 将此训练器注册到系统中 ---
register_trainer('xgboost', train_product_model_with_xgboost) register_trainer('xgboost', train_product_model_with_xgboost)

View File

@ -0,0 +1,162 @@
import sqlite3
import json
import uuid
from datetime import datetime
def get_db_connection():
"""获取数据库连接"""
conn = sqlite3.connect('prediction_history.db', check_same_thread=False)
conn.row_factory = sqlite3.Row
return conn
def save_model_to_db(model_data: dict):
"""
将训练完成的模型元数据保存到数据库的 models 表中
"""
conn = get_db_connection()
cursor = conn.cursor()
try:
# 生成唯一的模型UID
model_uid = f"{model_data.get('training_mode', 'model')}_{model_data.get('model_type', 'default')}_{uuid.uuid4().hex[:8]}"
cursor.execute(
"""
INSERT INTO models (
model_uid, display_name, model_type, training_mode,
training_scope, parent_model_id, version, status,
training_params, performance_metrics, artifacts, created_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
model_uid,
model_data.get('display_name'),
model_data.get('model_type'),
model_data.get('training_mode'),
json.dumps(model_data.get('training_scope')),
model_data.get('parent_model_id'),
model_data.get('version'),
'active',
json.dumps(model_data.get('training_params')),
json.dumps(model_data.get('performance_metrics')),
json.dumps(model_data.get('artifacts')),
datetime.now().isoformat()
)
)
conn.commit()
print(f"✅ 模型元数据已成功保存到数据库UID: {model_uid}")
return model_uid
except Exception as e:
print(f"❌ 保存模型元数据到数据库失败: {e}")
conn.rollback()
return None
finally:
conn.close()
def query_models_from_db(filters: dict, page: int = 1, page_size: int = 10):
"""
从数据库查询模型列表支持筛选和分页
"""
conn = get_db_connection()
cursor = conn.cursor()
try:
query = "SELECT * FROM models"
conditions = []
params = []
if filters.get('product_id'):
conditions.append("json_extract(training_scope, '$.product.id') = ?")
params.append(filters['product_id'])
if filters.get('model_type'):
conditions.append("model_type = ?")
params.append(filters['model_type'])
if filters.get('store_id'):
conditions.append("json_extract(training_scope, '$.store.id') = ?")
params.append(filters['store_id'])
if filters.get('training_mode'):
conditions.append("training_mode = ?")
params.append(filters['training_mode'])
if conditions:
query += " WHERE " + " AND ".join(conditions)
# 获取总数
count_query = query.replace("*", "COUNT(*)")
cursor.execute(count_query, params)
total_count = cursor.fetchone()[0]
# 添加分页
offset = (page - 1) * page_size
query += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
params.extend([page_size, offset])
cursor.execute(query, params)
models = [dict(row) for row in cursor.fetchall()]
return {
"models": models,
"pagination": {
"total": total_count,
"page": page,
"page_size": page_size,
"total_pages": (total_count + page_size - 1) // page_size
}
}
except Exception as e:
print(f"❌ 从数据库查询模型失败: {e}")
return {"models": [], "pagination": {}}
finally:
conn.close()
def save_prediction_to_db(prediction_data: dict):
"""
将单次预测记录保存到数据库
"""
conn = get_db_connection()
cursor = conn.cursor()
try:
cursor.execute('''
INSERT INTO prediction_history (
prediction_uid, model_id, model_type, product_name,
prediction_scope, prediction_params, metrics, result_file_path
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
''', (
prediction_data.get('prediction_uid'),
prediction_data.get('model_id'), # This should be the model_uid from the 'models' table
prediction_data.get('model_type'),
prediction_data.get('product_name'),
json.dumps(prediction_data.get('prediction_scope')),
json.dumps(prediction_data.get('prediction_params')),
json.dumps(prediction_data.get('metrics')),
prediction_data.get('result_file_path')
))
conn.commit()
print(f"✅ 预测记录 {prediction_data.get('prediction_uid')} 已保存到数据库。")
except Exception as e:
print(f"❌ 保存预测记录到数据库失败: {e}")
conn.rollback()
finally:
conn.close()
def find_model_by_uid(model_uid: str):
"""
根据 model_uid 从数据库中查找模型
"""
conn = get_db_connection()
cursor = conn.cursor()
try:
cursor.execute("SELECT * FROM models WHERE model_uid = ?", (model_uid,))
model_record = cursor.fetchone()
if model_record:
return dict(model_record)
return None
except Exception as e:
print(f"❌ 根据UID查找模型失败: {e}")
return None
finally:
conn.close()

View File

@ -42,12 +42,26 @@ class ModelManager:
max_version = 0 max_version = 0
for f in existing_files: for f in existing_files:
match = re.search(r'_v(\d+)\.pth$', os.path.basename(f)) # 修正: 同时匹配 _v1.pth 和 _v1_best.pth 这样的文件名
match = re.search(r'_v(\d+)(_best)?\.pth$', os.path.basename(f))
if match: if match:
max_version = max(max_version, int(match.group(1))) max_version = max(max_version, int(match.group(1)))
return max_version + 1 return max_version + 1
def peek_next_version(self, model_type: str, product_id: Optional[str] = None, store_id: Optional[str] = None, training_mode: str = 'product', aggregation_method: Optional[str] = None) -> str:
"""
预览下一个版本号字符串 (e.g., 'v3')但不进行任何文件操作
"""
next_version_num = self._get_next_version(
model_type=model_type,
product_id=product_id,
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method
)
return f"v{next_version_num}"
def generate_model_filename(self, def generate_model_filename(self,
model_type: str, model_type: str,
version: str, version: str,
@ -92,7 +106,9 @@ class ModelManager:
返回: 返回:
(模型文件路径, 使用的版本号) (模型文件路径, 使用的版本号)
""" """
# 修正: 简化版本处理逻辑,由调用方明确提供版本字符串
if version is None: if version is None:
# 如果未提供版本,则自动生成新版本
next_version_num = self._get_next_version( next_version_num = self._get_next_version(
model_type=model_type, model_type=model_type,
product_id=product_id, product_id=product_id,
@ -102,6 +118,7 @@ class ModelManager:
) )
version_str = f"v{next_version_num}" version_str = f"v{next_version_num}"
else: else:
# 直接使用调用方提供的版本字符串 (e.g., 'v3', 'v3_best')
version_str = version version_str = version
filename = self.generate_model_filename( filename = self.generate_model_filename(
@ -212,7 +229,8 @@ class ModelManager:
# 添加文件信息 # 添加文件信息
model_info['filename'] = filename model_info['filename'] = filename
model_info['file_path'] = model_file # 修正: 确保返回的是相对路径
model_info['file_path'] = os.path.join(self.model_dir, filename)
model_info['file_size'] = os.path.getsize(model_file) model_info['file_size'] = os.path.getsize(model_file)
model_info['modified_at'] = datetime.fromtimestamp( model_info['modified_at'] = datetime.fromtimestamp(
os.path.getmtime(model_file) os.path.getmtime(model_file)
@ -343,9 +361,5 @@ class ModelManager:
# 全局模型管理器实例 # 全局模型管理器实例
# 确保使用项目根目录的saved_models而不是相对于当前工作目录 # 修正: 直接使用在config.py中定义的相对路径
import os model_manager = ModelManager()
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(os.path.dirname(current_dir)) # 向上两级到项目根目录
absolute_model_dir = os.path.join(project_root, 'saved_models')
model_manager = ModelManager(absolute_model_dir)

View File

@ -24,6 +24,7 @@ server_dir = os.path.dirname(current_dir)
sys.path.append(server_dir) sys.path.append(server_dir)
from utils.logging_config import setup_api_logging, get_training_logger, log_training_progress from utils.logging_config import setup_api_logging, get_training_logger, log_training_progress
from utils.database_utils import save_model_to_db
import numpy as np import numpy as np
def convert_numpy_types(obj): def convert_numpy_types(obj):
@ -45,6 +46,8 @@ class TrainingTask:
training_mode: str training_mode: str
store_id: Optional[str] = None store_id: Optional[str] = None
epochs: int = 100 epochs: int = 100
training_scope: Optional[Dict[str, Any]] = None # 新增
artifacts: Optional[Dict[str, Any]] = None # 新增
status: str = "pending" # pending, running, completed, failed status: str = "pending" # pending, running, completed, failed
start_time: Optional[str] = None start_time: Optional[str] = None
end_time: Optional[str] = None end_time: Optional[str] = None
@ -138,7 +141,7 @@ class TrainingWorker:
training_logger.error(f"进度回调失败: {e}") training_logger.error(f"进度回调失败: {e}")
# 执行真正的训练,传递进度回调 # 执行真正的训练,传递进度回调
metrics = predictor.train_model( metrics, artifacts = predictor.train_model(
product_id=task.product_id, product_id=task.product_id,
model_type=task.model_type, model_type=task.model_type,
epochs=task.epochs, epochs=task.epochs,
@ -181,6 +184,7 @@ class TrainingWorker:
task.end_time = time.strftime('%Y-%m-%d %H:%M:%S') task.end_time = time.strftime('%Y-%m-%d %H:%M:%S')
task.progress = 100.0 task.progress = 100.0
task.metrics = metrics task.metrics = metrics
task.artifacts = artifacts if artifacts else {}
task.message = "训练完成" task.message = "训练完成"
training_logger.success(f"✅ 训练任务完成 - 耗时: {task.end_time}") training_logger.success(f"✅ 训练任务完成 - 耗时: {task.end_time}")
@ -282,7 +286,7 @@ class TrainingProcessManager:
self.logger.info("✅ 训练进程管理器已停止") self.logger.info("✅ 训练进程管理器已停止")
def submit_task(self, product_id: str, model_type: str, training_mode: str = "product", def submit_task(self, product_id: str, model_type: str, training_mode: str = "product",
store_id: str = None, epochs: int = 100, **kwargs) -> str: store_id: str = None, epochs: int = 100, training_scope: dict = None, **kwargs) -> str:
"""提交训练任务""" """提交训练任务"""
task_id = str(uuid.uuid4()) task_id = str(uuid.uuid4())
@ -292,7 +296,8 @@ class TrainingProcessManager:
model_type=model_type, model_type=model_type,
training_mode=training_mode, training_mode=training_mode,
store_id=store_id, store_id=store_id,
epochs=epochs epochs=epochs,
training_scope=training_scope
) )
with self.lock: with self.lock:
@ -371,6 +376,34 @@ class TrainingProcessManager:
'product_id': serializable_task_data.get('product_id'), 'product_id': serializable_task_data.get('product_id'),
'model_type': serializable_task_data.get('model_type') 'model_type': serializable_task_data.get('model_type')
}) })
# 在此处调用函数,将模型元数据保存到数据库
# 注意我们需要从task_data中构建一个符合save_model_to_db期望的字典
# 这是一个简化的示例,实际可能需要传递更多参数
# 构建一个更完整的模型数据字典用于保存
# 从 artifacts 中获取版本号
version = serializable_task_data.get('artifacts', {}).get('version')
# 构建带有版本号的 display_name
base_display_name = f"{serializable_task_data.get('product_id', 'N/A')} - {serializable_task_data.get('model_type')}"
display_name_with_version = f"{base_display_name} ({version})" if version else base_display_name
model_to_save = {
'model_uid': f"{serializable_task_data.get('training_mode')}_{serializable_task_data.get('model_type')}_{str(uuid.uuid4())[:8]}",
'display_name': display_name_with_version,
'model_type': serializable_task_data.get('model_type'),
'training_mode': serializable_task_data.get('training_mode'),
'training_scope': serializable_task_data.get('training_scope'),
'version': version,
'status': 'active',
'training_params': {
'epochs': serializable_task_data.get('epochs')
},
'performance_metrics': serializable_task_data.get('metrics'),
'artifacts': serializable_task_data.get('artifacts')
}
save_model_to_db(model_to_save)
elif action == 'error': elif action == 'error':
# 训练失败 # 训练失败
self.websocket_callback('training_update', { self.websocket_callback('training_update', {

View File

@ -16,37 +16,45 @@ except ImportError:
# 后备方案:使用默认值 # 后备方案:使用默认值
DEFAULT_MODEL_DIR = "saved_models" DEFAULT_MODEL_DIR = "saved_models"
def plot_loss_curve(train_losses, val_losses, product_name, model_type, save_path=None, model_dir=DEFAULT_MODEL_DIR): def plot_loss_curve(train_losses, val_losses, model_type: str, scope: str, identifier: str, version: str = None, save_path=None, model_dir=DEFAULT_MODEL_DIR):
""" """
绘制训练和验证损失曲线 绘制训练和验证损失曲线并根据scope和identifier生成标准化的文件名
参数: 参数:
train_losses: 训练损失列表 train_losses: 训练损失列表
val_losses: 验证损失列表 val_losses: 验证损失列表
product_name: 产品名称 model_type: 模型类型 (e.g., 'xgboost')
model_type: 模型类型 scope: 训练范围 ('product', 'store', 'global')
identifier: 范围对应的标识符 (产品名, 店铺ID, 或聚合方法)
version: (可选) 模型版本号用于生成唯一的文件名
save_path: 保存路径如果为None则自动生成路径 save_path: 保存路径如果为None则自动生成路径
model_dir: 模型保存目录默认使用配置中的DEFAULT_MODEL_DIR model_dir: 模型保存目录
""" """
plt.figure(figsize=(10, 5)) plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='训练损失') plt.plot(train_losses, label='训练损失')
plt.plot(val_losses, label='验证损失') plt.plot(val_losses, label='验证损失')
plt.title(f'{product_name} - {model_type}模型训练和验证损失')
# 动态生成标题
title_identifier = identifier.replace('_', ' ')
title = f'{title_identifier} - {model_type} ({scope}) 模型训练和验证损失'
plt.title(title)
plt.xlabel('Epoch') plt.xlabel('Epoch')
plt.ylabel('Loss') plt.ylabel('Loss')
plt.legend() plt.legend()
plt.grid(True) plt.grid(True)
if save_path: if save_path:
# 如果提供了完整路径,直接使用
full_path = save_path full_path = save_path
else: else:
# 否则生成默认路径
# 确保模型目录存在
os.makedirs(model_dir, exist_ok=True) os.makedirs(model_dir, exist_ok=True)
# 构建文件名模型类型_产品名_loss_curve.png # 标准化文件名生成逻辑,与 ModelManager 对齐
filename = f"{model_type}_product_{product_name.replace(' ', '_')}_loss_curve.png" version_str = f"_{version}" if version else ""
# 清理标识符中的非法字符
safe_identifier = identifier.replace(' ', '_').replace('/', '_').replace('\\', '_')
filename = f"{model_type}_{scope}_{safe_identifier}{version_str}_loss_curve.png"
full_path = os.path.join(model_dir, filename) full_path = os.path.join(model_dir, filename)
plt.savefig(full_path) plt.savefig(full_path)

54
xz数据库2025_07_24.md Normal file
View File

@ -0,0 +1,54 @@
# 数据库设计方案 (最终版)
本文档定义了项目重构后的核心数据库结构旨在实现对机器学习模型和预测历史的统一、高效管理。该设计方案结合了业务需求、前端UI逻辑、后端代码实现以及现有数据库结构并经过了多次迭代确认。
## 设计原则
- **模型集中化**: 所有模型的元数据统一存储,摆脱对文件系统的依赖。
- **数据一致性**: 通过逻辑外键确保预测记录与模型版本精确关联。
- **查询高性能**: 通过冗余关键字段,避免在查询列表时进行不必要的`JOIN`操作。
- **可扩展性**: 使用JSON字段存储灵活、复杂的范围定义和文件路径适应未来业务变化。
- **结构清晰**: 物理文件(模型、日志、预测结果)与数据库记录分离,数据库只存元数据和路径,保持自身轻量。
---
## 表结构定义
### 表1`models`
这张表是模型管理的核心,负责存储所有模型版本的全生命周期信息。
| 字段名 | 类型 | 确认理由与说明 |
| :--- | :--- | :--- |
| `id` | INTEGER | 主键,系统内部使用。 |
| `model_uid` | TEXT | **[关键]** 用户可见的唯一ID由后端生成用于API调用和逻辑关联。 |
| `display_name` | TEXT | **[建议新增]** 用户可自定义的别名,如“夏季促销模型”,提升易用性。 |
| `model_type` | TEXT | **已确认.** 核心参数,如 `mlstm`, `kan`。 |
| `training_mode` | TEXT | **已确认.** 模型的训练范围: `product`, `store`, `global`。 |
| `training_scope` | TEXT (JSON) | **[最终版]** **精确描述训练范围,并包含中文名。**<br/>- **按药品训练**: `{"product": {"id": "P001", "name": "阿莫西林"}, "stores": "all"}``{"product": {"id": "P001", "name": "阿莫西林"}, "stores": [{"id": "S001", "name": "城西店"}]}`<br/>- **按店铺训练**: `{"store": {"id": "S001", "name": "城西店"}, "products": "all"}``{"store": {"id": "S001", "name": "城西店"}, "products": [{"id": "P001", "name": "阿莫西林"}, {"id": "P002", "name": "布洛芬"}]}`<br/>- **全局训练**: `{"stores": "all", "products": "all"}``{"stores": [{"id": "S001", "name": "城西店"}], "products": [{"id": "P001", "name": "阿莫西林"}]}` |
| `parent_model_id` | INTEGER | **已确认.** 外键,指向自身 `id`,用于实现“继续训练”功能,形成版本链。 |
| `version` | TEXT | **已确认.** 模型版本号,如 `v1`, `v2`。 |
| `status` | TEXT | **已确认.** 模型状态,如 `active`, `archived`,用于控制模型是否可用。 |
| `training_params` | TEXT (JSON) | **已确认.** 存储训练时的超参数,如 `{"epochs": 50, "aggregation_method": "sum"}`。 |
| `performance_metrics` | TEXT (JSON) | **已确认.** 存储性能指标,如 `{"R2": 0.85, "RMSE": ...}`。 |
| `artifacts` | TEXT (JSON) | **[采纳建议]** **存储与模型相关的所有文件路径。** 所有文件均采用**扁平化结构**存放在 `saved_models/` 目录下。示例:<br/>`{"best_model": "saved_models/product_P001_mlstm_best.pth", "versioned_model": "saved_models/product_P001_mlstm_v1.pth", "loss_curve_plot": "saved_models/product_P001_mlstm_v1_loss.png", "loss_curve_data": "saved_models/product_P001_mlstm_v1_history.json"}` |
| `created_at` | DATETIME | **已确认.** 模型的创建时间。 |
---
### 表2`prediction_history`
这张表用于记录每一次预测任务的结果,其结构在现有基础上进行了优化,以实现高效查询和完整追溯。
| 字段名 | 类型 | 确认理由与说明 |
| :--- | :--- | :--- |
| `id` | INTEGER | **已确认.** 主键,自增。 |
| `prediction_uid` | TEXT | **已确认.** 唯一的预测ID (UUID)用于API调用。 |
| `model_id` | TEXT | **[核心变更]** **使用的模型的唯一标识符。** 在重构后,此字段的值应该与 `models` 表中的某条记录的 `model_uid` 相对应,从而建立起逻辑关联。 |
| `model_type` | TEXT | **已确认 (冗余).** 冗余存储模型类型(如 `mlstm`),用于在不关联查询的情况下快速筛选历史记录。 |
| `product_name` | TEXT | **已确认 (冗余).** **冗余存储产品/店铺的中文名**,用于在列表页快速展示,极大提升查询性能。 |
| `prediction_scope` | TEXT (JSON) | **[新增]** 描述本次预测的范围,与 `models` 表的 `training_scope` 结构类似,指明是为哪个产品/店铺做的预测。 |
| `prediction_params` | TEXT (JSON) | **已确认.** 存储预测参数,如 `{"start_date": "...", "future_days": 7}`。 |
| `metrics` | TEXT (JSON) | **已确认 (冗余).** 缓存的性能指标,用于列表展示和排序。 |
| `result_file_path` | TEXT | **[已采纳您的规范]** 指向预测结果JSON文件的**相对路径**。文件存储在 `saved_predictions/` 目录下,并根据模型名和时间戳命名,例如:`saved_predictions/cnn_bilstm_attention_global_sum_v6_pred_20250724111600.json`。 |
| `created_at` | DATETIME | **已确认.** 记录的创建时间。 |