2025-07-24 17:55:10 +08:00
|
|
|
|
import sqlite3
|
|
|
|
|
import json
|
|
|
|
|
import uuid
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
|
|
|
|
|
def get_db_connection():
|
|
|
|
|
"""获取数据库连接"""
|
|
|
|
|
conn = sqlite3.connect('prediction_history.db', check_same_thread=False)
|
|
|
|
|
conn.row_factory = sqlite3.Row
|
|
|
|
|
return conn
|
|
|
|
|
|
|
|
|
|
def save_model_to_db(model_data: dict):
|
|
|
|
|
"""
|
|
|
|
|
将训练完成的模型元数据保存到数据库的 models 表中。
|
|
|
|
|
"""
|
|
|
|
|
conn = get_db_connection()
|
|
|
|
|
cursor = conn.cursor()
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
# 生成唯一的模型UID
|
|
|
|
|
model_uid = f"{model_data.get('training_mode', 'model')}_{model_data.get('model_type', 'default')}_{uuid.uuid4().hex[:8]}"
|
|
|
|
|
|
|
|
|
|
cursor.execute(
|
|
|
|
|
"""
|
|
|
|
|
INSERT INTO models (
|
|
|
|
|
model_uid, display_name, model_type, training_mode,
|
|
|
|
|
training_scope, parent_model_id, version, status,
|
|
|
|
|
training_params, performance_metrics, artifacts, created_at
|
|
|
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
|
|
|
""",
|
|
|
|
|
(
|
|
|
|
|
model_uid,
|
|
|
|
|
model_data.get('display_name'),
|
|
|
|
|
model_data.get('model_type'),
|
|
|
|
|
model_data.get('training_mode'),
|
|
|
|
|
json.dumps(model_data.get('training_scope')),
|
|
|
|
|
model_data.get('parent_model_id'),
|
|
|
|
|
model_data.get('version'),
|
|
|
|
|
'active',
|
|
|
|
|
json.dumps(model_data.get('training_params')),
|
|
|
|
|
json.dumps(model_data.get('performance_metrics')),
|
|
|
|
|
json.dumps(model_data.get('artifacts')),
|
|
|
|
|
datetime.now().isoformat()
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
conn.commit()
|
|
|
|
|
print(f"✅ 模型元数据已成功保存到数据库,UID: {model_uid}")
|
|
|
|
|
return model_uid
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(f"❌ 保存模型元数据到数据库失败: {e}")
|
|
|
|
|
conn.rollback()
|
|
|
|
|
return None
|
|
|
|
|
finally:
|
|
|
|
|
conn.close()
|
|
|
|
|
|
|
|
|
|
def query_models_from_db(filters: dict, page: int = 1, page_size: int = 10):
|
|
|
|
|
"""
|
|
|
|
|
从数据库查询模型列表,支持筛选和分页。
|
|
|
|
|
"""
|
|
|
|
|
conn = get_db_connection()
|
|
|
|
|
cursor = conn.cursor()
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
query = "SELECT * FROM models"
|
|
|
|
|
conditions = []
|
|
|
|
|
params = []
|
|
|
|
|
|
|
|
|
|
if filters.get('product_id'):
|
|
|
|
|
conditions.append("json_extract(training_scope, '$.product.id') = ?")
|
|
|
|
|
params.append(filters['product_id'])
|
|
|
|
|
|
|
|
|
|
if filters.get('model_type'):
|
|
|
|
|
conditions.append("model_type = ?")
|
|
|
|
|
params.append(filters['model_type'])
|
|
|
|
|
|
|
|
|
|
if filters.get('store_id'):
|
|
|
|
|
conditions.append("json_extract(training_scope, '$.store.id') = ?")
|
|
|
|
|
params.append(filters['store_id'])
|
|
|
|
|
|
|
|
|
|
if filters.get('training_mode'):
|
|
|
|
|
conditions.append("training_mode = ?")
|
|
|
|
|
params.append(filters['training_mode'])
|
|
|
|
|
|
|
|
|
|
if conditions:
|
|
|
|
|
query += " WHERE " + " AND ".join(conditions)
|
|
|
|
|
|
2025-07-25 13:55:31 +08:00
|
|
|
|
# 修复分页逻辑:先查询所有符合条件的模型,再在内存中进行增强和分页
|
|
|
|
|
query += " ORDER BY created_at DESC"
|
2025-07-24 17:55:10 +08:00
|
|
|
|
cursor.execute(query, params)
|
2025-07-25 13:55:31 +08:00
|
|
|
|
all_models = [dict(row) for row in cursor.fetchall()]
|
2025-07-24 17:55:10 +08:00
|
|
|
|
|
2025-07-25 13:55:31 +08:00
|
|
|
|
# 增强逻辑:如果模型同时有best版本,则额外生成一条记录
|
|
|
|
|
enhanced_models = []
|
|
|
|
|
for model in all_models:
|
|
|
|
|
# 首先添加原始的版本记录
|
|
|
|
|
enhanced_models.append(model)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
artifacts = json.loads(model.get('artifacts', '{}'))
|
|
|
|
|
if artifacts and artifacts.get('best_model'):
|
|
|
|
|
original_version = model.get('version', 'N/A')
|
|
|
|
|
best_model_record = model.copy()
|
|
|
|
|
best_model_record['version'] = f"best({original_version})"
|
|
|
|
|
best_model_record['model_uid'] = f"{model['model_uid']}_best"
|
|
|
|
|
best_model_record['model_type'] = f"{model['model_type']}(best)"
|
|
|
|
|
|
|
|
|
|
new_artifacts = {
|
|
|
|
|
'best_model': artifacts.get('best_model'),
|
|
|
|
|
'scaler_X': artifacts.get('scaler_X'),
|
|
|
|
|
'scaler_y': artifacts.get('scaler_y')
|
|
|
|
|
}
|
|
|
|
|
best_model_record['artifacts'] = json.dumps(new_artifacts)
|
|
|
|
|
enhanced_models.append(best_model_record)
|
|
|
|
|
except (json.JSONDecodeError, TypeError):
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# 在内存中进行分页
|
|
|
|
|
total_count = len(enhanced_models)
|
|
|
|
|
offset = (page - 1) * page_size
|
|
|
|
|
paginated_models = enhanced_models[offset : offset + page_size]
|
|
|
|
|
|
2025-07-24 17:55:10 +08:00
|
|
|
|
return {
|
2025-07-25 13:55:31 +08:00
|
|
|
|
"models": paginated_models,
|
2025-07-24 17:55:10 +08:00
|
|
|
|
"pagination": {
|
|
|
|
|
"total": total_count,
|
|
|
|
|
"page": page,
|
|
|
|
|
"page_size": page_size,
|
2025-07-25 13:55:31 +08:00
|
|
|
|
"total_pages": (total_count + page_size - 1) // page_size if page_size > 0 else 0
|
2025-07-24 17:55:10 +08:00
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(f"❌ 从数据库查询模型失败: {e}")
|
|
|
|
|
return {"models": [], "pagination": {}}
|
|
|
|
|
finally:
|
|
|
|
|
conn.close()
|
|
|
|
|
|
|
|
|
|
def save_prediction_to_db(prediction_data: dict):
|
|
|
|
|
"""
|
|
|
|
|
将单次预测记录保存到数据库。
|
|
|
|
|
"""
|
|
|
|
|
conn = get_db_connection()
|
|
|
|
|
cursor = conn.cursor()
|
|
|
|
|
try:
|
|
|
|
|
cursor.execute('''
|
|
|
|
|
INSERT INTO prediction_history (
|
2025-07-25 12:43:33 +08:00
|
|
|
|
prediction_uid, model_id, model_type, product_name, model_version,
|
2025-07-24 17:55:10 +08:00
|
|
|
|
prediction_scope, prediction_params, metrics, result_file_path
|
2025-07-25 12:43:33 +08:00
|
|
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
2025-07-24 17:55:10 +08:00
|
|
|
|
''', (
|
|
|
|
|
prediction_data.get('prediction_uid'),
|
|
|
|
|
prediction_data.get('model_id'), # This should be the model_uid from the 'models' table
|
|
|
|
|
prediction_data.get('model_type'),
|
|
|
|
|
prediction_data.get('product_name'),
|
2025-07-25 12:43:33 +08:00
|
|
|
|
prediction_data.get('model_version'),
|
2025-07-24 17:55:10 +08:00
|
|
|
|
json.dumps(prediction_data.get('prediction_scope')),
|
|
|
|
|
json.dumps(prediction_data.get('prediction_params')),
|
|
|
|
|
json.dumps(prediction_data.get('metrics')),
|
|
|
|
|
prediction_data.get('result_file_path')
|
|
|
|
|
))
|
|
|
|
|
conn.commit()
|
|
|
|
|
print(f"✅ 预测记录 {prediction_data.get('prediction_uid')} 已保存到数据库。")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(f"❌ 保存预测记录到数据库失败: {e}")
|
|
|
|
|
conn.rollback()
|
|
|
|
|
finally:
|
|
|
|
|
conn.close()
|
|
|
|
|
|
|
|
|
|
def find_model_by_uid(model_uid: str):
|
|
|
|
|
"""
|
|
|
|
|
根据 model_uid 从数据库中查找模型。
|
|
|
|
|
"""
|
|
|
|
|
conn = get_db_connection()
|
|
|
|
|
cursor = conn.cursor()
|
|
|
|
|
try:
|
|
|
|
|
cursor.execute("SELECT * FROM models WHERE model_uid = ?", (model_uid,))
|
|
|
|
|
model_record = cursor.fetchone()
|
|
|
|
|
if model_record:
|
|
|
|
|
return dict(model_record)
|
|
|
|
|
return None
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(f"❌ 根据UID查找模型失败: {e}")
|
|
|
|
|
return None
|
|
|
|
|
finally:
|
|
|
|
|
conn.close()
|