diff --git a/UI/src/views/prediction/GlobalPredictionView.vue b/UI/src/views/prediction/GlobalPredictionView.vue index 93c05a4..5f69a79 100644 --- a/UI/src/views/prediction/GlobalPredictionView.vue +++ b/UI/src/views/prediction/GlobalPredictionView.vue @@ -116,7 +116,7 @@ const filters = reactive({ const pagination = reactive({ currentPage: 1, - pageSize: 8 + pageSize: 12 }) const filteredModelList = computed(() => { diff --git a/UI/src/views/prediction/ProductPredictionView.vue b/UI/src/views/prediction/ProductPredictionView.vue index 92b632d..0effe54 100644 --- a/UI/src/views/prediction/ProductPredictionView.vue +++ b/UI/src/views/prediction/ProductPredictionView.vue @@ -124,7 +124,7 @@ const filters = reactive({ const pagination = reactive({ currentPage: 1, - pageSize: 8 + pageSize: 12 }) const filteredModelList = computed(() => { diff --git a/UI/src/views/prediction/StorePredictionView.vue b/UI/src/views/prediction/StorePredictionView.vue index e17179c..1dedee1 100644 --- a/UI/src/views/prediction/StorePredictionView.vue +++ b/UI/src/views/prediction/StorePredictionView.vue @@ -126,7 +126,7 @@ const filters = reactive({ const pagination = reactive({ currentPage: 1, - pageSize: 8 + pageSize: 12 }) const filteredModelList = computed(() => { diff --git a/prediction_history.db b/prediction_history.db index ee8e5b0..0f731fe 100644 Binary files a/prediction_history.db and b/prediction_history.db differ diff --git a/server/api.py b/server/api.py index dcd7db2..cd0a028 100644 --- a/server/api.py +++ b/server/api.py @@ -1329,10 +1329,17 @@ def predict(): if not model_uid: return jsonify({"status": "error", "message": "缺少 'model_uid' 参数"}), 400 + # 智能处理带 '_best' 后缀的UID,以解决前端key冲突问题 (兼容旧版Python) + is_best_request = model_uid.endswith('_best') + if is_best_request: + db_query_uid = model_uid[:-len('_best')] + else: + db_query_uid = model_uid + # 从数据库查找模型记录 - model_record = find_model_by_uid(model_uid) + model_record = find_model_by_uid(db_query_uid) if not model_record: - return jsonify({"status": "error", "message": f"模型UID '{model_uid}' 不存在"}), 404 + return jsonify({"status": "error", "message": f"找不到基础模型UID '{db_query_uid}'"}), 404 # 解析必要的模型元数据 model_type = model_record.get('model_type') @@ -1341,7 +1348,17 @@ def predict(): # 解析 artifacts 找到模型文件路径 artifacts = json.loads(model_record.get('artifacts', '{}')) - model_file_path = artifacts.get('best_model') or artifacts.get('versioned_model') + + # 根据请求类型选择正确的模型路径 + if is_best_request: + model_file_path = artifacts.get('best_model') + else: + # 对于 v1, v2 等版本,使用 versioned_model + model_file_path = artifacts.get('versioned_model') + + # 如果特定版本路径不存在,提供一个后备方案,增加鲁棒性 + if not model_file_path: + model_file_path = artifacts.get('best_model') or artifacts.get('versioned_model') # 修正路径问题:将相对路径转换为绝对路径以进行可靠的文件检查 project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) @@ -1412,10 +1429,13 @@ def predict(): analysis_result = prediction_result.get('analysis', {}) if prediction_result else {} metrics_result = analysis_result.get('metrics', {}) if analysis_result else {} + # 准备要保存到历史记录的model_type,确保与前端显示一致 + history_model_type = f"{model_type}(best)" if is_best_request else model_type + db_payload = { "prediction_uid": prediction_uid, "model_id": model_uid, - "model_type": model_type, + "model_type": history_model_type, # 使用处理后的一致性名称 "product_name": product_name_to_save, # 使用修正后的名称 "model_version": model_record.get('display_name'), # 将模型信息保存到新字段 "prediction_scope": {"product_id": product_id, "store_id": store_id}, diff --git a/server/predictors/model_predictor.py b/server/predictors/model_predictor.py index 2463543..c9aee19 100644 --- a/server/predictors/model_predictor.py +++ b/server/predictors/model_predictor.py @@ -86,6 +86,11 @@ register_predictor('default', default_pytorch_predictor) register_predictor('xgboost', default_pytorch_predictor) # 将新模型也注册给默认预测器 register_predictor('cnn_bilstm_attention', default_pytorch_predictor) +register_predictor('transformer', default_pytorch_predictor) +register_predictor('mlstm', default_pytorch_predictor) +register_predictor('kan', default_pytorch_predictor) +register_predictor('optimized_kan', default_pytorch_predictor) +register_predictor('tcn', default_pytorch_predictor) def load_model_and_predict(model_path: str, product_id: str, model_type: str, store_id: Optional[str] = None, future_days: int = 7, start_date: Optional[str] = None, analyze_result: bool = False, version: Optional[str] = None, training_mode: str = 'product', history_lookback_days: int = 30): diff --git a/server/utils/database_utils.py b/server/utils/database_utils.py index 764eb26..adde8a9 100644 --- a/server/utils/database_utils.py +++ b/server/utils/database_utils.py @@ -84,26 +84,48 @@ def query_models_from_db(filters: dict, page: int = 1, page_size: int = 10): if conditions: query += " WHERE " + " AND ".join(conditions) - # 获取总数 - count_query = query.replace("*", "COUNT(*)") - cursor.execute(count_query, params) - total_count = cursor.fetchone()[0] - - # 添加分页 - offset = (page - 1) * page_size - query += " ORDER BY created_at DESC LIMIT ? OFFSET ?" - params.extend([page_size, offset]) - + # 修复分页逻辑:先查询所有符合条件的模型,再在内存中进行增强和分页 + query += " ORDER BY created_at DESC" cursor.execute(query, params) - models = [dict(row) for row in cursor.fetchall()] + all_models = [dict(row) for row in cursor.fetchall()] + # 增强逻辑:如果模型同时有best版本,则额外生成一条记录 + enhanced_models = [] + for model in all_models: + # 首先添加原始的版本记录 + enhanced_models.append(model) + + try: + artifacts = json.loads(model.get('artifacts', '{}')) + if artifacts and artifacts.get('best_model'): + original_version = model.get('version', 'N/A') + best_model_record = model.copy() + best_model_record['version'] = f"best({original_version})" + best_model_record['model_uid'] = f"{model['model_uid']}_best" + best_model_record['model_type'] = f"{model['model_type']}(best)" + + new_artifacts = { + 'best_model': artifacts.get('best_model'), + 'scaler_X': artifacts.get('scaler_X'), + 'scaler_y': artifacts.get('scaler_y') + } + best_model_record['artifacts'] = json.dumps(new_artifacts) + enhanced_models.append(best_model_record) + except (json.JSONDecodeError, TypeError): + continue + + # 在内存中进行分页 + total_count = len(enhanced_models) + offset = (page - 1) * page_size + paginated_models = enhanced_models[offset : offset + page_size] + return { - "models": models, + "models": paginated_models, "pagination": { "total": total_count, "page": page, "page_size": page_size, - "total_pages": (total_count + page_size - 1) // page_size + "total_pages": (total_count + page_size - 1) // page_size if page_size > 0 else 0 } }