ShopTRAINING/server/utils/database_utils.py

164 lines
5.6 KiB
Python
Raw Normal View History

import sqlite3
import json
import uuid
from datetime import datetime
def get_db_connection():
"""获取数据库连接"""
conn = sqlite3.connect('prediction_history.db', check_same_thread=False)
conn.row_factory = sqlite3.Row
return conn
def save_model_to_db(model_data: dict):
"""
将训练完成的模型元数据保存到数据库的 models 表中
"""
conn = get_db_connection()
cursor = conn.cursor()
try:
# 生成唯一的模型UID
model_uid = f"{model_data.get('training_mode', 'model')}_{model_data.get('model_type', 'default')}_{uuid.uuid4().hex[:8]}"
cursor.execute(
"""
INSERT INTO models (
model_uid, display_name, model_type, training_mode,
training_scope, parent_model_id, version, status,
training_params, performance_metrics, artifacts, created_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
model_uid,
model_data.get('display_name'),
model_data.get('model_type'),
model_data.get('training_mode'),
json.dumps(model_data.get('training_scope')),
model_data.get('parent_model_id'),
model_data.get('version'),
'active',
json.dumps(model_data.get('training_params')),
json.dumps(model_data.get('performance_metrics')),
json.dumps(model_data.get('artifacts')),
datetime.now().isoformat()
)
)
conn.commit()
print(f"✅ 模型元数据已成功保存到数据库UID: {model_uid}")
return model_uid
except Exception as e:
print(f"❌ 保存模型元数据到数据库失败: {e}")
conn.rollback()
return None
finally:
conn.close()
def query_models_from_db(filters: dict, page: int = 1, page_size: int = 10):
"""
从数据库查询模型列表支持筛选和分页
"""
conn = get_db_connection()
cursor = conn.cursor()
try:
query = "SELECT * FROM models"
conditions = []
params = []
if filters.get('product_id'):
conditions.append("json_extract(training_scope, '$.product.id') = ?")
params.append(filters['product_id'])
if filters.get('model_type'):
conditions.append("model_type = ?")
params.append(filters['model_type'])
if filters.get('store_id'):
conditions.append("json_extract(training_scope, '$.store.id') = ?")
params.append(filters['store_id'])
if filters.get('training_mode'):
conditions.append("training_mode = ?")
params.append(filters['training_mode'])
if conditions:
query += " WHERE " + " AND ".join(conditions)
# 获取总数
count_query = query.replace("*", "COUNT(*)")
cursor.execute(count_query, params)
total_count = cursor.fetchone()[0]
# 添加分页
offset = (page - 1) * page_size
query += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
params.extend([page_size, offset])
cursor.execute(query, params)
models = [dict(row) for row in cursor.fetchall()]
return {
"models": models,
"pagination": {
"total": total_count,
"page": page,
"page_size": page_size,
"total_pages": (total_count + page_size - 1) // page_size
}
}
except Exception as e:
print(f"❌ 从数据库查询模型失败: {e}")
return {"models": [], "pagination": {}}
finally:
conn.close()
def save_prediction_to_db(prediction_data: dict):
"""
将单次预测记录保存到数据库
"""
conn = get_db_connection()
cursor = conn.cursor()
try:
cursor.execute('''
INSERT INTO prediction_history (
prediction_uid, model_id, model_type, product_name, model_version,
prediction_scope, prediction_params, metrics, result_file_path
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
''', (
prediction_data.get('prediction_uid'),
prediction_data.get('model_id'), # This should be the model_uid from the 'models' table
prediction_data.get('model_type'),
prediction_data.get('product_name'),
prediction_data.get('model_version'),
json.dumps(prediction_data.get('prediction_scope')),
json.dumps(prediction_data.get('prediction_params')),
json.dumps(prediction_data.get('metrics')),
prediction_data.get('result_file_path')
))
conn.commit()
print(f"✅ 预测记录 {prediction_data.get('prediction_uid')} 已保存到数据库。")
except Exception as e:
print(f"❌ 保存预测记录到数据库失败: {e}")
conn.rollback()
finally:
conn.close()
def find_model_by_uid(model_uid: str):
"""
根据 model_uid 从数据库中查找模型
"""
conn = get_db_connection()
cursor = conn.cursor()
try:
cursor.execute("SELECT * FROM models WHERE model_uid = ?", (model_uid,))
model_record = cursor.fetchone()
if model_record:
return dict(model_record)
return None
except Exception as e:
print(f"❌ 根据UID查找模型失败: {e}")
return None
finally:
conn.close()