194 lines
7.3 KiB
Python
194 lines
7.3 KiB
Python
import sqlite3
|
||
import json
|
||
import uuid
|
||
from datetime import datetime
|
||
|
||
def get_db_connection():
|
||
"""获取数据库连接"""
|
||
conn = sqlite3.connect('prediction_history.db', check_same_thread=False)
|
||
conn.row_factory = sqlite3.Row
|
||
return conn
|
||
|
||
def save_model_to_db(model_data: dict):
|
||
"""
|
||
将训练完成的模型元数据保存到数据库的 models 表中。
|
||
"""
|
||
conn = get_db_connection()
|
||
cursor = conn.cursor()
|
||
|
||
try:
|
||
# 生成唯一的模型UID
|
||
model_uid = f"{model_data.get('training_mode', 'model')}_{model_data.get('model_type', 'default')}_{uuid.uuid4().hex[:8]}"
|
||
|
||
cursor.execute(
|
||
"""
|
||
INSERT INTO models (
|
||
model_uid, display_name, model_type, training_mode,
|
||
training_scope, parent_model_id, version, status,
|
||
training_params, performance_metrics, artifacts, created_at
|
||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||
""",
|
||
(
|
||
model_uid,
|
||
model_data.get('display_name'),
|
||
model_data.get('model_type'),
|
||
model_data.get('training_mode'),
|
||
json.dumps(model_data.get('training_scope')),
|
||
model_data.get('parent_model_id'),
|
||
model_data.get('version'),
|
||
'active',
|
||
json.dumps(model_data.get('training_params')),
|
||
json.dumps(model_data.get('performance_metrics')),
|
||
json.dumps(model_data.get('artifacts')),
|
||
datetime.now().isoformat()
|
||
)
|
||
)
|
||
conn.commit()
|
||
print(f"✅ 模型元数据已成功保存到数据库,UID: {model_uid}")
|
||
return model_uid
|
||
except Exception as e:
|
||
print(f"❌ 保存模型元数据到数据库失败: {e}")
|
||
conn.rollback()
|
||
return None
|
||
finally:
|
||
conn.close()
|
||
|
||
def query_models_from_db(filters: dict, page: int = 1, page_size: int = 10):
|
||
"""
|
||
从数据库查询模型列表,支持筛选和分页。
|
||
"""
|
||
conn = get_db_connection()
|
||
cursor = conn.cursor()
|
||
|
||
try:
|
||
query = "SELECT * FROM models"
|
||
conditions = []
|
||
params = []
|
||
|
||
product_id_filter = filters.get('product_id')
|
||
if product_id_filter:
|
||
if product_id_filter.lower() == 'global':
|
||
conditions.append("training_mode = ?")
|
||
params.append('global')
|
||
elif product_id_filter.startswith('S'):
|
||
conditions.append("json_extract(training_scope, '$.store.id') = ?")
|
||
params.append(product_id_filter)
|
||
else:
|
||
conditions.append("(json_extract(training_scope, '$.product.id') = ? OR display_name LIKE ?)")
|
||
params.extend([product_id_filter, f"%{product_id_filter}%"])
|
||
|
||
if filters.get('model_type'):
|
||
conditions.append("model_type LIKE ?")
|
||
params.append(f"{filters['model_type']}%")
|
||
|
||
if filters.get('store_id'):
|
||
conditions.append("json_extract(training_scope, '$.store.id') = ?")
|
||
params.append(filters['store_id'])
|
||
|
||
if filters.get('training_mode'):
|
||
conditions.append("training_mode = ?")
|
||
params.append(filters['training_mode'])
|
||
|
||
if conditions:
|
||
query += " WHERE " + " AND ".join(conditions)
|
||
|
||
# 修复分页逻辑:先查询所有符合条件的模型,再在内存中进行增强和分页
|
||
query += " ORDER BY created_at DESC"
|
||
cursor.execute(query, params)
|
||
all_models = [dict(row) for row in cursor.fetchall()]
|
||
|
||
# 增强逻辑:如果模型同时有best版本,则额外生成一条记录
|
||
enhanced_models = []
|
||
for model in all_models:
|
||
# 首先添加原始的版本记录
|
||
enhanced_models.append(model)
|
||
|
||
try:
|
||
artifacts = json.loads(model.get('artifacts', '{}'))
|
||
if artifacts and artifacts.get('best_model'):
|
||
original_version = model.get('version', 'N/A')
|
||
best_model_record = model.copy()
|
||
best_model_record['version'] = f"best({original_version})"
|
||
best_model_record['model_uid'] = f"{model['model_uid']}_best"
|
||
best_model_record['model_type'] = f"{model['model_type']}(best)"
|
||
|
||
new_artifacts = {
|
||
'best_model': artifacts.get('best_model'),
|
||
'scaler_X': artifacts.get('scaler_X'),
|
||
'scaler_y': artifacts.get('scaler_y')
|
||
}
|
||
best_model_record['artifacts'] = json.dumps(new_artifacts)
|
||
enhanced_models.append(best_model_record)
|
||
except (json.JSONDecodeError, TypeError):
|
||
continue
|
||
|
||
# 在内存中进行分页
|
||
total_count = len(enhanced_models)
|
||
offset = (page - 1) * page_size
|
||
paginated_models = enhanced_models[offset : offset + page_size]
|
||
|
||
return {
|
||
"models": paginated_models,
|
||
"pagination": {
|
||
"total": total_count,
|
||
"page": page,
|
||
"page_size": page_size,
|
||
"total_pages": (total_count + page_size - 1) // page_size if page_size > 0 else 0
|
||
}
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"❌ 从数据库查询模型失败: {e}")
|
||
return {"models": [], "pagination": {}}
|
||
finally:
|
||
conn.close()
|
||
|
||
def save_prediction_to_db(prediction_data: dict):
|
||
"""
|
||
将单次预测记录保存到数据库。
|
||
"""
|
||
conn = get_db_connection()
|
||
cursor = conn.cursor()
|
||
try:
|
||
cursor.execute('''
|
||
INSERT INTO prediction_history (
|
||
prediction_uid, model_id, model_type, product_name, model_version,
|
||
prediction_scope, prediction_params, metrics, result_file_path
|
||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||
''', (
|
||
prediction_data.get('prediction_uid'),
|
||
prediction_data.get('model_id'), # This should be the model_uid from the 'models' table
|
||
prediction_data.get('model_type'),
|
||
prediction_data.get('product_name'),
|
||
prediction_data.get('model_version'),
|
||
json.dumps(prediction_data.get('prediction_scope')),
|
||
json.dumps(prediction_data.get('prediction_params')),
|
||
json.dumps(prediction_data.get('metrics')),
|
||
prediction_data.get('result_file_path')
|
||
))
|
||
conn.commit()
|
||
print(f"✅ 预测记录 {prediction_data.get('prediction_uid')} 已保存到数据库。")
|
||
except Exception as e:
|
||
print(f"❌ 保存预测记录到数据库失败: {e}")
|
||
conn.rollback()
|
||
finally:
|
||
conn.close()
|
||
|
||
def find_model_by_uid(model_uid: str):
|
||
"""
|
||
根据 model_uid 从数据库中查找模型。
|
||
"""
|
||
conn = get_db_connection()
|
||
cursor = conn.cursor()
|
||
try:
|
||
cursor.execute("SELECT * FROM models WHERE model_uid = ?", (model_uid,))
|
||
model_record = cursor.fetchone()
|
||
if model_record:
|
||
return dict(model_record)
|
||
return None
|
||
except Exception as e:
|
||
print(f"❌ 根据UID查找模型失败: {e}")
|
||
return None
|
||
finally:
|
||
conn.close()
|