ShopTRAINING/server/utils/database_utils.py

194 lines
7.3 KiB
Python
Raw Permalink Normal View History

import sqlite3
import json
import uuid
from datetime import datetime
def get_db_connection():
"""获取数据库连接"""
conn = sqlite3.connect('prediction_history.db', check_same_thread=False)
conn.row_factory = sqlite3.Row
return conn
def save_model_to_db(model_data: dict):
"""
将训练完成的模型元数据保存到数据库的 models 表中
"""
conn = get_db_connection()
cursor = conn.cursor()
try:
# 生成唯一的模型UID
model_uid = f"{model_data.get('training_mode', 'model')}_{model_data.get('model_type', 'default')}_{uuid.uuid4().hex[:8]}"
cursor.execute(
"""
INSERT INTO models (
model_uid, display_name, model_type, training_mode,
training_scope, parent_model_id, version, status,
training_params, performance_metrics, artifacts, created_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
model_uid,
model_data.get('display_name'),
model_data.get('model_type'),
model_data.get('training_mode'),
json.dumps(model_data.get('training_scope')),
model_data.get('parent_model_id'),
model_data.get('version'),
'active',
json.dumps(model_data.get('training_params')),
json.dumps(model_data.get('performance_metrics')),
json.dumps(model_data.get('artifacts')),
datetime.now().isoformat()
)
)
conn.commit()
print(f"✅ 模型元数据已成功保存到数据库UID: {model_uid}")
return model_uid
except Exception as e:
print(f"❌ 保存模型元数据到数据库失败: {e}")
conn.rollback()
return None
finally:
conn.close()
def query_models_from_db(filters: dict, page: int = 1, page_size: int = 10):
"""
从数据库查询模型列表支持筛选和分页
"""
conn = get_db_connection()
cursor = conn.cursor()
try:
query = "SELECT * FROM models"
conditions = []
params = []
2025-07-25 17:07:40 +08:00
product_id_filter = filters.get('product_id')
if product_id_filter:
if product_id_filter.lower() == 'global':
conditions.append("training_mode = ?")
params.append('global')
elif product_id_filter.startswith('S'):
conditions.append("json_extract(training_scope, '$.store.id') = ?")
params.append(product_id_filter)
else:
conditions.append("(json_extract(training_scope, '$.product.id') = ? OR display_name LIKE ?)")
params.extend([product_id_filter, f"%{product_id_filter}%"])
if filters.get('model_type'):
2025-07-25 17:07:40 +08:00
conditions.append("model_type LIKE ?")
params.append(f"{filters['model_type']}%")
if filters.get('store_id'):
conditions.append("json_extract(training_scope, '$.store.id') = ?")
params.append(filters['store_id'])
if filters.get('training_mode'):
conditions.append("training_mode = ?")
params.append(filters['training_mode'])
if conditions:
query += " WHERE " + " AND ".join(conditions)
# 修复分页逻辑:先查询所有符合条件的模型,再在内存中进行增强和分页
query += " ORDER BY created_at DESC"
cursor.execute(query, params)
all_models = [dict(row) for row in cursor.fetchall()]
# 增强逻辑如果模型同时有best版本则额外生成一条记录
enhanced_models = []
for model in all_models:
# 首先添加原始的版本记录
enhanced_models.append(model)
try:
artifacts = json.loads(model.get('artifacts', '{}'))
if artifacts and artifacts.get('best_model'):
original_version = model.get('version', 'N/A')
best_model_record = model.copy()
best_model_record['version'] = f"best({original_version})"
best_model_record['model_uid'] = f"{model['model_uid']}_best"
best_model_record['model_type'] = f"{model['model_type']}(best)"
new_artifacts = {
'best_model': artifacts.get('best_model'),
'scaler_X': artifacts.get('scaler_X'),
'scaler_y': artifacts.get('scaler_y')
}
best_model_record['artifacts'] = json.dumps(new_artifacts)
enhanced_models.append(best_model_record)
except (json.JSONDecodeError, TypeError):
continue
# 在内存中进行分页
total_count = len(enhanced_models)
offset = (page - 1) * page_size
paginated_models = enhanced_models[offset : offset + page_size]
return {
"models": paginated_models,
"pagination": {
"total": total_count,
"page": page,
"page_size": page_size,
"total_pages": (total_count + page_size - 1) // page_size if page_size > 0 else 0
}
}
except Exception as e:
print(f"❌ 从数据库查询模型失败: {e}")
return {"models": [], "pagination": {}}
finally:
conn.close()
def save_prediction_to_db(prediction_data: dict):
"""
将单次预测记录保存到数据库
"""
conn = get_db_connection()
cursor = conn.cursor()
try:
cursor.execute('''
INSERT INTO prediction_history (
prediction_uid, model_id, model_type, product_name, model_version,
prediction_scope, prediction_params, metrics, result_file_path
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
''', (
prediction_data.get('prediction_uid'),
prediction_data.get('model_id'), # This should be the model_uid from the 'models' table
prediction_data.get('model_type'),
prediction_data.get('product_name'),
prediction_data.get('model_version'),
json.dumps(prediction_data.get('prediction_scope')),
json.dumps(prediction_data.get('prediction_params')),
json.dumps(prediction_data.get('metrics')),
prediction_data.get('result_file_path')
))
conn.commit()
print(f"✅ 预测记录 {prediction_data.get('prediction_uid')} 已保存到数据库。")
except Exception as e:
print(f"❌ 保存预测记录到数据库失败: {e}")
conn.rollback()
finally:
conn.close()
def find_model_by_uid(model_uid: str):
"""
根据 model_uid 从数据库中查找模型
"""
conn = get_db_connection()
cursor = conn.cursor()
try:
cursor.execute("SELECT * FROM models WHERE model_uid = ?", (model_uid,))
model_record = cursor.fetchone()
if model_record:
return dict(model_record)
return None
except Exception as e:
print(f"❌ 根据UID查找模型失败: {e}")
return None
finally:
conn.close()