ShopTRAINING/server/utils/database_utils.py

194 lines
7.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import sqlite3
import json
import uuid
from datetime import datetime
def get_db_connection():
"""获取数据库连接"""
conn = sqlite3.connect('prediction_history.db', check_same_thread=False)
conn.row_factory = sqlite3.Row
return conn
def save_model_to_db(model_data: dict):
"""
将训练完成的模型元数据保存到数据库的 models 表中。
"""
conn = get_db_connection()
cursor = conn.cursor()
try:
# 生成唯一的模型UID
model_uid = f"{model_data.get('training_mode', 'model')}_{model_data.get('model_type', 'default')}_{uuid.uuid4().hex[:8]}"
cursor.execute(
"""
INSERT INTO models (
model_uid, display_name, model_type, training_mode,
training_scope, parent_model_id, version, status,
training_params, performance_metrics, artifacts, created_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
model_uid,
model_data.get('display_name'),
model_data.get('model_type'),
model_data.get('training_mode'),
json.dumps(model_data.get('training_scope')),
model_data.get('parent_model_id'),
model_data.get('version'),
'active',
json.dumps(model_data.get('training_params')),
json.dumps(model_data.get('performance_metrics')),
json.dumps(model_data.get('artifacts')),
datetime.now().isoformat()
)
)
conn.commit()
print(f"✅ 模型元数据已成功保存到数据库UID: {model_uid}")
return model_uid
except Exception as e:
print(f"❌ 保存模型元数据到数据库失败: {e}")
conn.rollback()
return None
finally:
conn.close()
def query_models_from_db(filters: dict, page: int = 1, page_size: int = 10):
"""
从数据库查询模型列表,支持筛选和分页。
"""
conn = get_db_connection()
cursor = conn.cursor()
try:
query = "SELECT * FROM models"
conditions = []
params = []
product_id_filter = filters.get('product_id')
if product_id_filter:
if product_id_filter.lower() == 'global':
conditions.append("training_mode = ?")
params.append('global')
elif product_id_filter.startswith('S'):
conditions.append("json_extract(training_scope, '$.store.id') = ?")
params.append(product_id_filter)
else:
conditions.append("(json_extract(training_scope, '$.product.id') = ? OR display_name LIKE ?)")
params.extend([product_id_filter, f"%{product_id_filter}%"])
if filters.get('model_type'):
conditions.append("model_type LIKE ?")
params.append(f"{filters['model_type']}%")
if filters.get('store_id'):
conditions.append("json_extract(training_scope, '$.store.id') = ?")
params.append(filters['store_id'])
if filters.get('training_mode'):
conditions.append("training_mode = ?")
params.append(filters['training_mode'])
if conditions:
query += " WHERE " + " AND ".join(conditions)
# 修复分页逻辑:先查询所有符合条件的模型,再在内存中进行增强和分页
query += " ORDER BY created_at DESC"
cursor.execute(query, params)
all_models = [dict(row) for row in cursor.fetchall()]
# 增强逻辑如果模型同时有best版本则额外生成一条记录
enhanced_models = []
for model in all_models:
# 首先添加原始的版本记录
enhanced_models.append(model)
try:
artifacts = json.loads(model.get('artifacts', '{}'))
if artifacts and artifacts.get('best_model'):
original_version = model.get('version', 'N/A')
best_model_record = model.copy()
best_model_record['version'] = f"best({original_version})"
best_model_record['model_uid'] = f"{model['model_uid']}_best"
best_model_record['model_type'] = f"{model['model_type']}(best)"
new_artifacts = {
'best_model': artifacts.get('best_model'),
'scaler_X': artifacts.get('scaler_X'),
'scaler_y': artifacts.get('scaler_y')
}
best_model_record['artifacts'] = json.dumps(new_artifacts)
enhanced_models.append(best_model_record)
except (json.JSONDecodeError, TypeError):
continue
# 在内存中进行分页
total_count = len(enhanced_models)
offset = (page - 1) * page_size
paginated_models = enhanced_models[offset : offset + page_size]
return {
"models": paginated_models,
"pagination": {
"total": total_count,
"page": page,
"page_size": page_size,
"total_pages": (total_count + page_size - 1) // page_size if page_size > 0 else 0
}
}
except Exception as e:
print(f"❌ 从数据库查询模型失败: {e}")
return {"models": [], "pagination": {}}
finally:
conn.close()
def save_prediction_to_db(prediction_data: dict):
"""
将单次预测记录保存到数据库。
"""
conn = get_db_connection()
cursor = conn.cursor()
try:
cursor.execute('''
INSERT INTO prediction_history (
prediction_uid, model_id, model_type, product_name, model_version,
prediction_scope, prediction_params, metrics, result_file_path
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
''', (
prediction_data.get('prediction_uid'),
prediction_data.get('model_id'), # This should be the model_uid from the 'models' table
prediction_data.get('model_type'),
prediction_data.get('product_name'),
prediction_data.get('model_version'),
json.dumps(prediction_data.get('prediction_scope')),
json.dumps(prediction_data.get('prediction_params')),
json.dumps(prediction_data.get('metrics')),
prediction_data.get('result_file_path')
))
conn.commit()
print(f"✅ 预测记录 {prediction_data.get('prediction_uid')} 已保存到数据库。")
except Exception as e:
print(f"❌ 保存预测记录到数据库失败: {e}")
conn.rollback()
finally:
conn.close()
def find_model_by_uid(model_uid: str):
"""
根据 model_uid 从数据库中查找模型。
"""
conn = get_db_connection()
cursor = conn.cursor()
try:
cursor.execute("SELECT * FROM models WHERE model_uid = ?", (model_uid,))
model_record = cursor.fetchone()
if model_record:
return dict(model_record)
return None
except Exception as e:
print(f"❌ 根据UID查找模型失败: {e}")
return None
finally:
conn.close()