import sqlite3 import json import uuid from datetime import datetime def get_db_connection(): """获取数据库连接""" conn = sqlite3.connect('prediction_history.db', check_same_thread=False) conn.row_factory = sqlite3.Row return conn def save_model_to_db(model_data: dict): """ 将训练完成的模型元数据保存到数据库的 models 表中。 """ conn = get_db_connection() cursor = conn.cursor() try: # 生成唯一的模型UID model_uid = f"{model_data.get('training_mode', 'model')}_{model_data.get('model_type', 'default')}_{uuid.uuid4().hex[:8]}" cursor.execute( """ INSERT INTO models ( model_uid, display_name, model_type, training_mode, training_scope, parent_model_id, version, status, training_params, performance_metrics, artifacts, created_at ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( model_uid, model_data.get('display_name'), model_data.get('model_type'), model_data.get('training_mode'), json.dumps(model_data.get('training_scope')), model_data.get('parent_model_id'), model_data.get('version'), 'active', json.dumps(model_data.get('training_params')), json.dumps(model_data.get('performance_metrics')), json.dumps(model_data.get('artifacts')), datetime.now().isoformat() ) ) conn.commit() print(f"✅ 模型元数据已成功保存到数据库,UID: {model_uid}") return model_uid except Exception as e: print(f"❌ 保存模型元数据到数据库失败: {e}") conn.rollback() return None finally: conn.close() def query_models_from_db(filters: dict, page: int = 1, page_size: int = 10): """ 从数据库查询模型列表,支持筛选和分页。 """ conn = get_db_connection() cursor = conn.cursor() try: query = "SELECT * FROM models" conditions = [] params = [] if filters.get('product_id'): conditions.append("json_extract(training_scope, '$.product.id') = ?") params.append(filters['product_id']) if filters.get('model_type'): conditions.append("model_type = ?") params.append(filters['model_type']) if filters.get('store_id'): conditions.append("json_extract(training_scope, '$.store.id') = ?") params.append(filters['store_id']) if filters.get('training_mode'): conditions.append("training_mode = ?") params.append(filters['training_mode']) if conditions: query += " WHERE " + " AND ".join(conditions) # 获取总数 count_query = query.replace("*", "COUNT(*)") cursor.execute(count_query, params) total_count = cursor.fetchone()[0] # 添加分页 offset = (page - 1) * page_size query += " ORDER BY created_at DESC LIMIT ? OFFSET ?" params.extend([page_size, offset]) cursor.execute(query, params) models = [dict(row) for row in cursor.fetchall()] return { "models": models, "pagination": { "total": total_count, "page": page, "page_size": page_size, "total_pages": (total_count + page_size - 1) // page_size } } except Exception as e: print(f"❌ 从数据库查询模型失败: {e}") return {"models": [], "pagination": {}} finally: conn.close() def save_prediction_to_db(prediction_data: dict): """ 将单次预测记录保存到数据库。 """ conn = get_db_connection() cursor = conn.cursor() try: cursor.execute(''' INSERT INTO prediction_history ( prediction_uid, model_id, model_type, product_name, prediction_scope, prediction_params, metrics, result_file_path ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) ''', ( prediction_data.get('prediction_uid'), prediction_data.get('model_id'), # This should be the model_uid from the 'models' table prediction_data.get('model_type'), prediction_data.get('product_name'), json.dumps(prediction_data.get('prediction_scope')), json.dumps(prediction_data.get('prediction_params')), json.dumps(prediction_data.get('metrics')), prediction_data.get('result_file_path') )) conn.commit() print(f"✅ 预测记录 {prediction_data.get('prediction_uid')} 已保存到数据库。") except Exception as e: print(f"❌ 保存预测记录到数据库失败: {e}") conn.rollback() finally: conn.close() def find_model_by_uid(model_uid: str): """ 根据 model_uid 从数据库中查找模型。 """ conn = get_db_connection() cursor = conn.cursor() try: cursor.execute("SELECT * FROM models WHERE model_uid = ?", (model_uid,)) model_record = cursor.fetchone() if model_record: return dict(model_record) return None except Exception as e: print(f"❌ 根据UID查找模型失败: {e}") return None finally: conn.close()