From 5d505b37af5a69f300cc9023e00bf76672609033 Mon Sep 17 00:00:00 2001 From: gdtiti Date: Sun, 15 Jun 2025 00:00:50 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E5=9B=BE=E8=A1=A8=E6=98=BE?= =?UTF-8?q?=E7=A4=BA=E5=92=8C=E6=95=B0=E6=8D=AE=E5=A4=84=E7=90=86=E9=97=AE?= =?UTF-8?q?=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. 修复前端图表日期排序问题: - 改进 PredictionView.vue 和 HistoryView.vue 中的图表渲染逻辑 - 确保历史数据和预测数据按照正确的日期顺序显示 2. 修复后端API处理: - 解决 optimized_kan 模型类型的路径映射问题 - 添加 JSON 序列化器处理 Pandas Timestamp 对象 - 改进预测数据与历史数据的衔接处理 3. 优化图表样式和用户体验 --- UI/src/views/HistoryView.vue | 425 ++++++++ UI/src/views/PredictionView.vue | 1696 ++++++++++++++++++++++++++++--- api.py | 1333 +++++++++++++++++++++--- 3 files changed, 3174 insertions(+), 280 deletions(-) create mode 100644 UI/src/views/HistoryView.vue diff --git a/UI/src/views/HistoryView.vue b/UI/src/views/HistoryView.vue new file mode 100644 index 0000000..bc0ff07 --- /dev/null +++ b/UI/src/views/HistoryView.vue @@ -0,0 +1,425 @@ + + + + + \ No newline at end of file diff --git a/UI/src/views/PredictionView.vue b/UI/src/views/PredictionView.vue index ed6079f..9c24447 100644 --- a/UI/src/views/PredictionView.vue +++ b/UI/src/views/PredictionView.vue @@ -24,6 +24,7 @@ + @@ -116,16 +117,33 @@ - + @@ -151,111 +169,8 @@ -
- - - - 全屏查看预测结果 - - - - - - - -
-

预测趋势图

- -
-
- -
-

同期销量对比

- -
- -
-
-
-
- - - - -

预测结果数据

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
-
-
- + + @@ -285,45 +200,13 @@ - +

预测趋势图

- -
-

预测图表生成失败或不可用。

-
-
-
- - - -
-

同期销量对比

- -
- -

请尝试选择其他预测日期,或确保历史数据中存在相应的同期记录。

-
+ +
+
@@ -429,6 +312,106 @@ + + + + +

预测分析

+ + +
{{ predictionAnalysis.explanation }}
+
+ + + + + +
+ + {{ getTrendLabel(predictionAnalysis.trend) }} + +
+
+
+ + + + + + +
+
平均销量
+
{{ predictionAnalysis.statistics.mean.toFixed(2) }}
+
+
+ +
+
最高销量
+
{{ predictionAnalysis.statistics.max.toFixed(2) }}
+
+
+ +
+
最低销量
+
{{ predictionAnalysis.statistics.min.toFixed(2) }}
+
+
+ +
+
标准差
+
{{ predictionAnalysis.statistics.std.toFixed(2) }}
+
+
+
+
+
+
+ + + + + +
+
+
+
+ + + + + + + + + + + +
+
+ + + +
+ + {{ modelDetails.model_info.model_id }} + + + {{ modelDetails.model_info.model_type }} + + + {{ modelDetails.model_info.product_name }} + {{ formatDateTime(modelDetails.model_info.created_at) }} + + + 训练评估指标 + + + + +
+
{{ getMetricLabel(key) }}
+
{{ formatMetricValue(key, value) }}
+
+
+
+
+ + 损失曲线 + +
+
+
+
+
+ +
+
+ +
+
+ + + +
+ + +
+ + +

{{ metric.description }}

+
+
+ + + +

{{ modelAnalysisResult.RMSE_MAE_COMP.description }}

+
+ +
+
+ +
+ +
\ No newline at end of file + + + \ No newline at end of file diff --git a/api.py b/api.py index cd39e0d..66d2606 100644 --- a/api.py +++ b/api.py @@ -1,5 +1,5 @@ -import os import sys +import os # 获取当前脚本所在目录的绝对路径 current_dir = os.path.dirname(os.path.abspath(__file__)) @@ -12,11 +12,26 @@ sys.path.append(current_dir) #sys.path.append(parent_dir) -from flask import Flask, request, jsonify, send_file, redirect, send_from_directory -from flask_cors import CORS -import os -import uuid + + +import json import pandas as pd +import numpy as np +import torch +import matplotlib.pyplot as plt +import io +import base64 +import uuid +from datetime import datetime, timedelta +from flask import Flask, jsonify, request, send_file, redirect, send_from_directory, Response +from flask_cors import CORS +from flasgger import Swagger +from werkzeug.utils import secure_filename +import sqlite3 +import traceback + + + from pharmacy_predictor import ( train_product_model_with_mlstm, train_product_model_with_kan, @@ -30,15 +45,34 @@ from io import BytesIO import json from flasgger import Swagger, swag_from import argparse -from datetime import datetime +from datetime import datetime, timedelta from concurrent.futures import ThreadPoolExecutor from threading import Lock import traceback import torch +import sqlite3 +import numpy as np +import io +from werkzeug.utils import secure_filename - +# 自定义JSON编码器来处理Pandas的Timestamp和NumPy类型 +class CustomJSONEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, (pd.Timestamp, pd.DatetimeIndex)): + return obj.strftime('%Y-%m-%d') + elif isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + elif pd.isna(obj): # 处理NaN值 + return None + return super(CustomJSONEncoder, self).default(obj) app = Flask(__name__) +# 设置自定义JSON编码器 +app.json_encoder = CustomJSONEncoder CORS(app) # 启用CORS支持 # Swagger配置 @@ -455,7 +489,7 @@ def get_all_training_tasks(): 'type': 'object', 'properties': { 'product_id': {'type': 'string', 'description': '例如 P001'}, - 'model_type': {'type': 'string', 'enum': ['mlstm', 'transformer', 'kan'], 'description': '要训练的模型类型'}, + 'model_type': {'type': 'string', 'enum': ['mlstm', 'transformer', 'kan', 'optimized_kan'], 'description': '要训练的模型类型'}, 'epochs': {'type': 'integer', 'default': 50, 'description': '训练轮次'} }, 'required': ['product_id', 'model_type'] @@ -499,7 +533,8 @@ def start_training(): model_train_functions = { 'mlstm': train_product_model_with_mlstm, 'kan': train_product_model_with_kan, - 'transformer': train_product_model_with_transformer + 'transformer': train_product_model_with_transformer, + 'optimized_kan': lambda product_id, epochs: train_product_model_with_kan(product_id, epochs, use_optimized=True) } if model_type not in model_train_functions: @@ -619,7 +654,7 @@ def get_training_status(task_id): 'type': 'object', 'properties': { 'product_id': {'type': 'string'}, - 'model_type': {'type': 'string', 'enum': ['mlstm', 'transformer', 'kan']}, + 'model_type': {'type': 'string', 'enum': ['mlstm', 'transformer', 'kan', 'optimized_kan']}, 'version': {'type': 'string'}, 'future_days': {'type': 'integer'}, 'include_visualization': {'type': 'boolean'}, @@ -697,98 +732,62 @@ def predict(): 404: description: 模型文件未找到 """ - data = request.json - product_id = data.get('product_id') - model_type = data.get('model_type') - future_days = data.get('future_days', 7) - start_date = data.get('start_date', '') - - print(f"API接收到预测请求: product_id={product_id}, model_type={model_type}, future_days={future_days}, start_date={start_date}") - - if not product_id or not model_type: - return jsonify({"status": "error", "error": "product_id 和 model_type 是必需的"}), 400 - try: - result = load_model_and_predict(product_id, model_type, future_days, start_date=start_date) - if result is None: - return jsonify({"status": "error", "error": f"模型 'models/{model_type}/{product_id}_model.pt' 未找到或加载失败"}), 404 - - # 获取预测数据 - predictions_df = result['predictions_df'] - chart_path = result['chart_path'] - csv_path = result['csv_path'] - - # 将DataFrame转换为JSON - predictions_json = predictions_df.to_dict(orient='records') - - # 获取当前服务器主机和端口 - host_url = request.host_url.rstrip('/') # 移除末尾的斜杠 - - response_data = { - "status": "success", - "data": predictions_json, # 现在包含历史和预测数据 - "history_data": predictions_df[predictions_df['data_type'] == '历史销量'].to_dict(orient='records'), - "prediction_data": predictions_df[predictions_df['data_type'] == '预测销量'].to_dict(orient='records') - } - - # 只有在文件确实存在时才添加URL - timestamp = datetime.now().timestamp() - - # 获取文件名中的日期和预测天数部分 + data = request.json + product_id = data.get('product_id') + model_type = data.get('model_type') + future_days = int(data.get('future_days', 7)) start_date = data.get('start_date', '') - future_days = data.get('future_days', 7) + include_visualization = data.get('include_visualization', False) - if start_date: - try: - start_date_obj = datetime.strptime(start_date, '%Y-%m-%d') - start_date_str = start_date_obj.strftime('%Y%m%d') - except: - start_date_str = datetime.now().strftime('%Y%m%d') - else: - # 如果未提供日期,使用当前日期 - start_date_str = datetime.now().strftime('%Y%m%d') - - # 构建文件名 - png_filename = f"forecast_{start_date_str}_days{future_days}.png" - csv_filename = f"forecast_{start_date_str}_days{future_days}.csv" - - # 构建图片路径 - png_path = os.path.join('predictions', model_type, product_id, png_filename) - csv_path = os.path.join('predictions', model_type, product_id, csv_filename) - history_png_filename = f"history_{start_date_str}.png" - history_png_path = os.path.join('predictions', model_type, product_id, history_png_filename) - - # 检查预测图表 - if os.path.exists(png_path): - # 构建完整的URL,包含主机名和端口 - response_data["image_url"] = f"{host_url}/api/predictions/{model_type}/{product_id}/{png_filename}?t={timestamp}" - print(f"图表URL: {response_data['image_url']}") - else: - response_data["image_url"] = None - print(f"警告: 预测图表文件未生成或不存在: {png_path}") - - # 检查历史图表 - if os.path.exists(history_png_path): - response_data["history_image_url"] = f"{host_url}/api/predictions/{model_type}/{product_id}/{history_png_filename}?t={timestamp}" - print(f"历史图表URL: {response_data['history_image_url']}") - else: - response_data["history_image_url"] = None - if result and 'history_chart_path' in result and result['history_chart_path'] is None: - print(f"警告: 历史图表生成过程中出现错误,请检查服务器日志") - else: - print(f"警告: 历史图表文件未生成或不存在: {history_png_path}") - - # 检查CSV文件 - if os.path.exists(csv_path): - response_data["csv_url"] = f"{host_url}/api/predictions/{model_type}/{product_id}/{csv_filename}" - else: - response_data["csv_url"] = None - - return jsonify(response_data) + print(f"API接收到预测请求: product_id={product_id}, model_type={model_type}, future_days={future_days}, start_date={start_date}") + if not product_id or not model_type: + return jsonify({"status": "error", "error": "product_id 和 model_type 是必需的"}), 400 + + # 获取产品名称 + product_name = get_product_name(product_id) + if not product_name: + product_name = product_id + + # 获取最新模型ID - 此处不需要处理模型类型映射,因为 get_latest_model_id 和 load_model_and_predict 会处理 + model_id = get_latest_model_id(model_type, product_id) + if not model_id: + return jsonify({"status": "error", "error": f"未找到产品 {product_id} 的 {model_type} 类型模型"}), 404 + + # 执行预测 + prediction_result = run_prediction(model_type, product_id, model_id, future_days, start_date) + + # 如果需要可视化,添加图表数据 + if include_visualization: + # 添加图表数据 + chart_data = prepare_chart_data(prediction_result) + prediction_result['chart_data'] = chart_data + + # 添加分析结果 + analysis_result = analyze_prediction(prediction_result) + prediction_result['analysis'] = analysis_result + + # 保存预测结果到文件和数据库 + prediction_id, file_path = save_prediction_result( + prediction_result, + product_id, + product_name, + model_type, + model_id, + start_date, + future_days + ) + + # 添加预测ID到结果中 + prediction_result['prediction_id'] = prediction_id + + return jsonify(prediction_result) except Exception as e: + print(f"预测失败: {str(e)}") + import traceback traceback.print_exc() - return jsonify({"status": "error", "error": f"预测过程中发生错误: {str(e)}"}), 500 + return jsonify({"status": "error", "error": str(e)}), 500 @app.route('/api/prediction/compare', methods=['POST']) @swag_from({ @@ -921,6 +920,143 @@ def compare_predictions(): traceback.print_exc() return jsonify({"status": "error", "error": f"比较预测时出错: {e}"}), 500 +@app.route('/api/prediction/analyze', methods=['POST']) +def analyze_prediction(): + """ + 分析预测结果,提供详细解释 + --- + tags: + - 预测分析 + parameters: + - name: body + in: body + required: true + schema: + type: object + properties: + product_id: + type: string + example: P001 + model_type: + type: string + enum: [mlstm, transformer, kan, optimized_kan] + future_days: + type: integer + default: 7 + start_date: + type: string + description: 预测起始日期,格式为'YYYY-MM-DD' + responses: + 200: + description: 预测结果及其分析 + schema: + type: object + properties: + status: + type: string + example: success + data: + type: object + properties: + predictions: + type: array + items: + type: number + dates: + type: array + items: + type: string + analysis: + type: object + properties: + trend: + type: string + statistics: + type: object + historical_comparison: + type: object + factors: + type: array + items: + type: object + explanation: + type: string + """ + try: + data = request.json + product_id = data.get('product_id') + model_type = data.get('model_type') + future_days = data.get('future_days', 7) + start_date = data.get('start_date') + + # 验证参数 + if not product_id or not model_type: + return jsonify({"status": "error", "error": "缺少必要参数"}), 400 + + if model_type not in ['mlstm', 'transformer', 'kan', 'optimized_kan']: + return jsonify({"status": "error", "error": "不支持的模型类型"}), 400 + + # 获取预测结果和分析 + result_tuple = load_model_and_predict(product_id, model_type, future_days, start_date, analyze_result=True) + + if result_tuple is None: + return jsonify({"status": "error", "error": f"模型 'models/{model_type}/{product_id}_model.pt' 未找到或加载失败"}), 404 + + # 解包元组 + result, analysis = result_tuple + + if result is None: + return jsonify({"status": "error", "error": "预测失败"}), 500 + + # 从result中获取预测值 + predictions_df = result['predictions_df'] + prediction_data = predictions_df[predictions_df['data_type'] == '预测销量'] + predictions_list = prediction_data['sales'].tolist() + + # 生成日期列表 + if start_date: + start_date_obj = datetime.strptime(start_date, '%Y-%m-%d') + else: + # 获取最后一条数据的日期并加1天 + df = pd.read_excel('pharmacy_sales.xlsx') + product_df = df[df['product_id'] == product_id].sort_values('date') + last_date = product_df['date'].iloc[-1] + start_date_obj = last_date + timedelta(days=1) + + dates = [(start_date_obj + timedelta(days=i)).strftime('%Y-%m-%d') for i in range(future_days)] + + # 为前端ECharts准备数据 + chart_data = { + "dates": dates, + "values": predictions_list, + "day_over_day_changes": analysis["statistics"]["day_over_day_changes"] if "day_over_day_changes" in analysis["statistics"] else [] + } + + # 如果有历史对比数据,也添加到图表数据中 + if analysis["historical_comparison"]["has_historical_data"]: + chart_data["historical"] = { + "mean": analysis["historical_comparison"]["historical_mean"], + "max": analysis["historical_comparison"]["historical_max"], + "min": analysis["historical_comparison"]["historical_min"] + } + + return jsonify({ + "status": "success", + "data": { + "product_id": product_id, + "model_type": model_type, + "predictions": predictions_list, + "dates": dates, + "chart_data": chart_data, + "analysis": analysis + } + }) + + except Exception as e: + import traceback + traceback.print_exc() + return jsonify({"status": "error", "error": str(e)}), 500 + # 4. 模型管理API @app.route('/api/models', methods=['GET']) @swag_from({ @@ -940,7 +1076,7 @@ def compare_predictions(): 'in': 'query', 'type': 'string', 'required': False, - 'description': '按模型类型筛选' + 'description': '按模型类型筛选 (mlstm, kan, transformer)' } ], 'responses': { @@ -1017,7 +1153,13 @@ def list_models(): type: object """ models_dir = 'models' - model_types = ['mlstm', 'kan', 'transformer'] + model_types = ['mlstm', 'kan', 'transformer', 'kan_optimized'] + model_type_mapping = { + 'mlstm': 'mlstm', + 'kan': 'kan', + 'transformer': 'transformer', + 'kan_optimized': 'optimized_kan' # 将kan_optimized目录映射为optimized_kan类型 + } available_models = [] product_id_filter = request.args.get('product_id') @@ -1027,8 +1169,14 @@ def list_models(): return jsonify({"status": "success", "data": []}) for model_type in model_types: - if model_type_filter and model_type_filter != model_type: - continue + # 如果指定了模型类型过滤器,检查是否匹配当前类型或其映射类型 + if model_type_filter: + # 如果过滤器是optimized_kan,我们需要查找kan_optimized目录 + if model_type_filter == 'optimized_kan' and model_type != 'kan_optimized': + continue + # 如果过滤器不是optimized_kan,但当前目录是kan_optimized,跳过 + elif model_type_filter != 'optimized_kan' and model_type_filter != model_type: + continue type_dir = os.path.join(models_dir, model_type) if not os.path.exists(type_dir): @@ -1045,11 +1193,15 @@ def list_models(): try: with open(log_path, 'r', encoding='utf-8') as f: log_data = json.load(f) + + # 使用映射表获取显示的模型类型 + display_model_type = model_type_mapping.get(model_type, model_type) + model_info = { - "model_id": f"{model_type}_{product_id}", + "model_id": f"{display_model_type}_{product_id}", "product_id": log_data.get('product_id'), "product_name": log_data.get('product_name'), - "model_type": log_data.get('model_type'), + "model_type": display_model_type, # 使用映射后的模型类型 "created_at": log_data.get('training_completed_at'), "metrics": log_data.get('metrics'), "file_path": log_data.get('file_path') @@ -1133,7 +1285,14 @@ def get_model_details(model_id): """ try: model_type, product_id = model_id.split('_', 1) - log_path = os.path.join('models', model_type, f'{product_id}_log.json') + + # 处理优化版KAN模型的路径 + actual_model_type = model_type + if model_type == 'optimized_kan': + actual_model_type = 'kan_optimized' + print(f"优化版KAN模型: 使用路径 'models/{actual_model_type}/{product_id}_log.json'") + + log_path = os.path.join('models', actual_model_type, f'{product_id}_log.json') if not os.path.exists(log_path): return jsonify({"status": "error", "error": "模型未找到"}), 404 @@ -1141,10 +1300,10 @@ def get_model_details(model_id): with open(log_path, 'r', encoding='utf-8') as f: log_data = json.load(f) - # 还可以添加从 .pt 文件读取更多信息的逻辑 - # checkpoint = torch.load(log_data['file_path'], map_location='cpu') - # log_data['details_from_pt'] = ... - + # 确保返回的模型类型是optimized_kan而不是kan_optimized + if actual_model_type == 'kan_optimized': + log_data['model_type'] = 'optimized_kan' + return jsonify({"status": "success", "data": log_data}) except ValueError: return jsonify({"status": "error", "error": "无效的model_id格式"}), 400 @@ -1204,7 +1363,14 @@ def delete_model(model_id): """ try: model_type, product_id = model_id.split('_', 1) - model_dir = os.path.join('models', model_type) + + # 处理优化版KAN模型的路径 + actual_model_type = model_type + if model_type == 'optimized_kan': + actual_model_type = 'kan_optimized' + print(f"优化版KAN模型: 使用路径 'models/{actual_model_type}/{product_id}_model.pt'") + + model_dir = os.path.join('models', actual_model_type) model_path = os.path.join(model_dir, f'{product_id}_model.pt') log_path = os.path.join(model_dir, f'{product_id}_log.json') @@ -1254,7 +1420,14 @@ def delete_model(model_id): def export_model(model_id): try: model_type, product_id = model_id.split('_', 1) - model_path = os.path.join('models', model_type, f'{product_id}_model.pt') + + # 处理优化版KAN模型的路径 + actual_model_type = model_type + if model_type == 'optimized_kan': + actual_model_type = 'kan_optimized' + print(f"优化版KAN模型: 使用路径 'models/{actual_model_type}/{product_id}_model.pt'") + + model_path = os.path.join('models', actual_model_type, f'{product_id}_model.pt') if not os.path.exists(model_path): return jsonify({"status": "error", "error": "模型文件未找到"}), 404 @@ -1440,13 +1613,937 @@ def get_compare_file(filename): directory = os.path.join(current_dir, 'predictions', 'compare') return send_from_directory(directory, filename) -if __name__ == '__main__': - # 检查可用的设备 - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - print(f"使用设备: {device}") +@app.route('/api/models///details', methods=['GET']) +def get_model_details_extended(model_type, product_id): + """ + 获取模型详情,包括训练损失曲线、预测效果图等 + --- + tags: + - 模型管理 + parameters: + - name: model_type + in: path + required: true + type: string + description: 模型类型 + - name: product_id + in: path + required: true + type: string + description: 产品ID + responses: + 200: + description: 模型详情 + schema: + type: object + properties: + status: + type: string + example: success + data: + type: object + properties: + model_info: + type: object + description: 模型基本信息 + training_metrics: + type: object + description: 训练评估指标 + chart_data: + type: object + description: 图表数据 + """ + try: + # 处理优化版KAN模型的路径 + actual_model_path = model_type + if model_type == 'optimized_kan': + actual_model_path = 'kan_optimized' + + model_dir = f'models/{actual_model_path}' + model_path = os.path.join(model_dir, f'{product_id}_model.pt') + log_path = os.path.join(model_dir, f'{product_id}_log.json') + + # 检查模型文件是否存在 + if not os.path.exists(model_path): + return jsonify({ + "status": "error", + "error": f"模型 '{model_type}/{product_id}' 不存在" + }), 404 + + # 读取模型日志 + model_log = {} + if os.path.exists(log_path): + with open(log_path, 'r', encoding='utf-8') as f: + model_log = json.load(f) + + # 加载模型检查点,获取训练损失数据 + checkpoint = torch.load(model_path, map_location='cpu', weights_only=False) + + # 获取训练和测试损失 + train_losses = checkpoint.get('train_loss', []) + test_losses = checkpoint.get('test_loss', []) + + # 获取产品名称 + product_name = "" + try: + df = pd.read_excel('pharmacy_sales.xlsx') + product_df = df[df['product_id'] == product_id] + if not product_df.empty: + product_name = product_df['product_name'].iloc[0] + except Exception as e: + print(f"获取产品名称时出错: {e}") + + # 检查是否有损失曲线图和预测效果图 + loss_curve_path = f'{product_id}_{model_type}_loss_curve.png' + prediction_path = f'{product_id}_{model_type}_prediction.png' + + # 获取当前服务器主机和端口 + host_url = request.host_url.rstrip('/') + + # 准备图表数据 + chart_data = { + "loss_chart": { + "epochs": list(range(1, len(train_losses) + 1)), + "train_loss": train_losses, + "test_loss": test_losses if test_losses else [] + } + } + + # 准备模型信息 + model_info = { + "model_id": f"{model_type}_{product_id}", + "model_type": model_type, + "product_id": product_id, + "product_name": product_name, + "created_at": model_log.get("training_completed_at", ""), + "file_path": model_path + } + + # 准备训练指标 + training_metrics = model_log.get("metrics", {}) + + return jsonify({ + "status": "success", + "data": { + "model_info": model_info, + "training_metrics": training_metrics, + "chart_data": chart_data + } + }) + + except Exception as e: + import traceback + traceback.print_exc() + return jsonify({ + "status": "error", + "error": str(e) + }), 500 - # 运行Flask应用 - # 在生产环境中,应使用Gunicorn或uWSGI等WSGI服务器 - # 例如: gunicorn --workers 4 --bind 0.0.0.0:5000 api:app - # 使用--host=0.0.0.0可以使服务在局域网内可访问 - app.run(host='0.0.0.0', port=5000, debug=True) \ No newline at end of file +# 创建SQLite数据库连接函数 +def get_db_connection(): + """获取SQLite数据库连接""" + conn = sqlite3.connect('prediction_history.db') + conn.row_factory = sqlite3.Row + return conn + +# 初始化数据库 +def init_db(): + """初始化SQLite数据库,创建必要的表""" + conn = get_db_connection() + cursor = conn.cursor() + + # 创建历史预测记录表 + cursor.execute(''' + CREATE TABLE IF NOT EXISTS prediction_history ( + id TEXT PRIMARY KEY, + product_id TEXT NOT NULL, + product_name TEXT NOT NULL, + model_type TEXT NOT NULL, + model_id TEXT NOT NULL, + start_date TEXT NOT NULL, + future_days INTEGER NOT NULL, + created_at TEXT NOT NULL, + file_path TEXT NOT NULL + ) + ''') + + conn.commit() + conn.close() + +# 在应用启动时初始化数据库 +init_db() + +# 添加保存预测结果的函数 +def save_prediction_result(prediction_data, product_id, product_name, model_type, model_id, start_date, future_days): + """保存预测结果到文件和数据库""" + # 生成唯一ID + prediction_id = str(uuid.uuid4()) + + # 创建历史预测目录 + history_dir = os.path.join('predictions', 'history') + os.makedirs(history_dir, exist_ok=True) + + # 创建模型类型子目录 + model_dir = os.path.join(history_dir, model_type) + os.makedirs(model_dir, exist_ok=True) + + # 创建文件名 + timestamp = datetime.now().strftime('%Y%m%d%H%M%S') + filename = f"{product_id}_{start_date}_{future_days}days_{timestamp}.json" + file_path = os.path.join(model_dir, filename) + + # 创建一个JSON序列化器来处理Pandas Timestamp对象 + class DateTimeEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, (pd.Timestamp, pd.DatetimeIndex)): + return obj.strftime('%Y-%m-%d') + elif isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + elif pd.isna(obj): # 处理NaN值 + return None + return super(DateTimeEncoder, self).default(obj) + + # 保存预测结果为JSON文件 + with open(file_path, 'w', encoding='utf-8') as f: + json.dump(prediction_data, f, ensure_ascii=False, indent=2, cls=DateTimeEncoder) + + # 将记录保存到数据库 + conn = get_db_connection() + cursor = conn.cursor() + + cursor.execute( + ''' + INSERT INTO prediction_history + (id, product_id, product_name, model_type, model_id, start_date, future_days, created_at, file_path) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + ''', + ( + prediction_id, + product_id, + product_name, + model_type, + model_id, + start_date, + future_days, + datetime.now().isoformat(), + file_path + ) + ) + + conn.commit() + conn.close() + + return prediction_id, file_path + +# 添加获取历史预测列表的API +@app.route('/api/prediction/history', methods=['GET']) +def get_prediction_history(): + """ + 获取历史预测列表 + --- + tags: + - 预测分析 + parameters: + - name: product_id + in: query + type: string + required: false + description: 按产品ID筛选 + - name: model_type + in: query + type: string + required: false + description: 按模型类型筛选 + - name: page + in: query + type: integer + required: false + default: 1 + description: 页码 + - name: page_size + in: query + type: integer + required: false + default: 10 + description: 每页记录数 + responses: + 200: + description: 历史预测列表 + schema: + type: object + properties: + status: + type: string + example: success + data: + type: array + items: + type: object + properties: + id: + type: string + example: 550e8400-e29b-41d4-a716-446655440000 + product_id: + type: string + example: P001 + product_name: + type: string + example: 阿司匹林 + model_type: + type: string + example: mlstm + model_id: + type: string + example: mlstm_P001_20230101 + start_date: + type: string + example: 2023-01-01 + future_days: + type: integer + example: 7 + created_at: + type: string + example: 2023-01-01T12:00:00 + total: + type: integer + example: 100 + page: + type: integer + example: 1 + page_size: + type: integer + example: 10 + """ + try: + # 获取查询参数 + product_id = request.args.get('product_id') + model_type = request.args.get('model_type') + page = int(request.args.get('page', 1)) + page_size = int(request.args.get('page_size', 10)) + + # 构建SQL查询 + query = "SELECT * FROM prediction_history" + params = [] + conditions = [] + + if product_id: + conditions.append("product_id = ?") + params.append(product_id) + + if model_type: + conditions.append("model_type = ?") + params.append(model_type) + + if conditions: + query += " WHERE " + " AND ".join(conditions) + + # 添加排序和分页 + query += " ORDER BY created_at DESC" + + # 获取总记录数 + conn = get_db_connection() + count_result = conn.execute(query, params).fetchall() + total_count = len(count_result) + + # 添加分页限制 + query += f" LIMIT {page_size} OFFSET {(page - 1) * page_size}" + + # 执行查询 + result = conn.execute(query, params).fetchall() + + # 转换为JSON格式 + history_list = [] + for row in result: + history_list.append({ + 'id': row['id'], + 'product_id': row['product_id'], + 'product_name': row['product_name'], + 'model_type': row['model_type'], + 'model_id': row['model_id'], + 'start_date': row['start_date'], + 'future_days': row['future_days'], + 'created_at': row['created_at'], + 'file_path': row['file_path'] + }) + + conn.close() + + return jsonify({ + 'status': 'success', + 'data': history_list, + 'total': total_count, + 'page': page, + 'page_size': page_size + }) + except Exception as e: + print(f"获取历史预测列表失败: {str(e)}") + import traceback + traceback.print_exc() + return jsonify({"status": "error", "error": str(e)}), 500 + +# 添加获取历史预测详情的API +@app.route('/api/prediction/history/', methods=['GET']) +def get_prediction_history_detail(prediction_id): + """ + 获取历史预测详情 + --- + tags: + - 预测分析 + parameters: + - name: prediction_id + in: path + type: string + required: true + description: 预测ID + responses: + 200: + description: 预测详情 + schema: + type: object + properties: + status: + type: string + example: success + data: + type: object + """ + try: + # 查询数据库获取文件路径 + conn = get_db_connection() + result = conn.execute("SELECT * FROM prediction_history WHERE id = ?", (prediction_id,)).fetchone() + conn.close() + + if not result: + return jsonify({"status": "error", "error": "未找到指定的预测记录"}), 404 + + file_path = result['file_path'] + + # 读取预测结果文件 + try: + with open(file_path, 'r', encoding='utf-8') as f: + prediction_data = json.load(f) + + return jsonify({ + 'status': 'success', + 'data': prediction_data, + 'meta': { + 'id': result['id'], + 'product_id': result['product_id'], + 'product_name': result['product_name'], + 'model_type': result['model_type'], + 'model_id': result['model_id'], + 'start_date': result['start_date'], + 'future_days': result['future_days'], + 'created_at': result['created_at'] + } + }) + except FileNotFoundError: + return jsonify({"status": "error", "error": "预测结果文件不存在"}), 404 + except Exception as e: + return jsonify({"status": "error", "error": f"读取预测结果文件失败: {str(e)}"}), 500 + except Exception as e: + print(f"获取历史预测详情失败: {str(e)}") + import traceback + traceback.print_exc() + return jsonify({"status": "error", "error": str(e)}), 500 + +# 添加删除历史预测记录的API +@app.route('/api/prediction/history/', methods=['DELETE']) +def delete_prediction_history(prediction_id): + """ + 删除历史预测记录 + --- + tags: + - 预测分析 + parameters: + - name: prediction_id + in: path + type: string + required: true + description: 预测ID + responses: + 200: + description: 删除结果 + schema: + type: object + properties: + status: + type: string + example: success + """ + try: + # 查询数据库获取文件路径 + conn = get_db_connection() + result = conn.execute("SELECT file_path FROM prediction_history WHERE id = ?", (prediction_id,)).fetchone() + + if not result: + conn.close() + return jsonify({"status": "error", "error": "未找到指定的预测记录"}), 404 + + file_path = result['file_path'] + + # 删除数据库记录 + conn.execute("DELETE FROM prediction_history WHERE id = ?", (prediction_id,)) + conn.commit() + conn.close() + + # 删除预测结果文件 + try: + if os.path.exists(file_path): + os.remove(file_path) + except Exception as e: + print(f"删除预测结果文件失败: {str(e)}") + + return jsonify({'status': 'success'}) + except Exception as e: + print(f"删除历史预测记录失败: {str(e)}") + import traceback + traceback.print_exc() + return jsonify({"status": "error", "error": str(e)}), 500 + +# 获取产品名称的辅助函数 +def get_product_name(product_id): + """根据产品ID获取产品名称""" + try: + # 从产品列表中查找产品名称 + products_file = 'data/products.json' + if os.path.exists(products_file): + with open(products_file, 'r', encoding='utf-8') as f: + products = json.load(f) + + for product in products: + if product['product_id'] == product_id: + return product['product_name'] + + return None + except Exception as e: + print(f"获取产品名称失败: {str(e)}") + return None + +# 获取最新模型ID的辅助函数 +def get_latest_model_id(model_type, product_id): + """根据模型类型和产品ID获取最新的模型ID""" + try: + # 处理优化版KAN模型的路径 + actual_model_path = model_type + if model_type == 'optimized_kan': + actual_model_path = 'kan_optimized' + print(f"优化版KAN模型: 当查找最新模型ID时,使用路径 'models/{actual_model_path}/{product_id}_model.pt'") + + # 查找模型目录中的模型文件 + model_dir = os.path.join('models', actual_model_path) + if not os.path.exists(model_dir): + print(f"模型目录不存在: {model_dir}") + return None + + # 查找匹配的模型文件 + model_files = [f for f in os.listdir(model_dir) if f.startswith(f"{product_id}_") and f.endswith('.pt')] + if not model_files: + print(f"在目录 {model_dir} 中未找到产品 {product_id} 的模型文件") + return None + + # 按照文件修改时间排序,获取最新的模型文件 + model_files.sort(key=lambda x: os.path.getmtime(os.path.join(model_dir, x)), reverse=True) + latest_model_file = model_files[0] + + # 从文件名中提取模型ID + model_id = latest_model_file.replace('.pt', '') + + print(f"找到最新模型: {model_id} 在目录 {model_dir}") + return model_id + except Exception as e: + print(f"获取最新模型ID失败: {str(e)}") + return None + +# 执行预测的辅助函数 +def run_prediction(model_type, product_id, model_id, future_days, start_date): + """执行模型预测""" + try: + # 处理优化版KAN模型的路径在 load_model_and_predict 中已经实现,无需在此处理 + # 直接使用原始的 model_type 调用函数 + + # 解包返回的元组为result和analysis + result_tuple = load_model_and_predict(product_id, model_type, future_days, start_date=start_date, analyze_result=True) + + # 处理返回值可能是None的情况 + if result_tuple is None: + raise Exception(f"模型 '{model_type}' 类型的模型文件未找到或加载失败") + + # 解包元组 - result_tuple可能是(result, analysis)格式或者只是result + if isinstance(result_tuple, tuple): + result = result_tuple[0] + analysis = result_tuple[1] if len(result_tuple) > 1 else None + else: + result = result_tuple + analysis = None + + if result is None: + raise Exception(f"模型预测失败") + + # 获取预测数据 + predictions_df = result['predictions_df'] + + # 确保日期是日期时间格式以便处理 + if 'date' in predictions_df.columns and not pd.api.types.is_datetime64_any_dtype(predictions_df['date']): + predictions_df['date'] = pd.to_datetime(predictions_df['date']) + + # 分离历史数据和预测数据 + history_df = predictions_df[predictions_df['data_type'] == '历史销量'] + prediction_df = predictions_df[predictions_df['data_type'] == '预测销量'] + + # 如果历史数据和预测数据的日期有重叠,调整预测数据的日期 + if not history_df.empty and not prediction_df.empty: + last_history_date = history_df['date'].max() + + # 检查预测数据是否与历史数据重叠 + if any(prediction_df['date'] <= last_history_date): + print(f"检测到预测数据与历史数据日期重叠,调整预测数据日期...") + + # 获取预测数据的起始日期,确保它在历史数据的最后一天之后 + prediction_start_date = last_history_date + pd.Timedelta(days=1) + + # 创建新的日期序列 + new_dates = pd.date_range(start=prediction_start_date, periods=len(prediction_df), freq='D') + + # 更新预测数据的日期 + prediction_df['date'] = new_dates + + # 更新原始DataFrame + predictions_df.loc[predictions_df['data_type'] == '预测销量', 'date'] = prediction_df['date'].values + + print(f"预测数据日期已调整为从 {prediction_start_date} 开始") + + # 将处理后的DataFrame放回结果 + result['predictions_df'] = predictions_df + + chart_path = result.get('chart_path') + csv_path = result.get('csv_path') + + # 将DataFrame转换为JSON + predictions_json = predictions_df.to_dict(orient='records') + + # 准备响应数据 + response_data = { + "status": "success", + "data": predictions_json, # 包含历史和预测数据 + "history_data": predictions_df[predictions_df['data_type'] == '历史销量'].to_dict(orient='records'), + "prediction_data": predictions_df[predictions_df['data_type'] == '预测销量'].to_dict(orient='records'), + } + + return response_data + except Exception as e: + print(f"执行预测失败: {str(e)}") + traceback.print_exc() + raise e + +# 准备图表数据的辅助函数 +def prepare_chart_data(prediction_result): + """准备图表数据""" + try: + # 从预测结果中提取数据 + predictions_df = pd.DataFrame(prediction_result['data']) + + # 确保日期列是日期时间格式,以便正确排序 + if 'date' in predictions_df.columns: + # 先转换为日期时间类型以便排序 + if isinstance(predictions_df['date'][0], str): + predictions_df['date'] = pd.to_datetime(predictions_df['date']) + + # 按日期排序 + predictions_df = predictions_df.sort_values('date') + + # 重置索引 + predictions_df = predictions_df.reset_index(drop=True) + + # 最后再转换回字符串格式用于JSON + predictions_df['date'] = predictions_df['date'].dt.strftime('%Y-%m-%d') + + # 分离历史数据和预测数据 + history_df = predictions_df[predictions_df['data_type'] == '历史销量'] + prediction_df = predictions_df[predictions_df['data_type'] == '预测销量'] + + # 准备图表数据 + chart_data = { + "dates": predictions_df['date'].tolist(), + "sales": predictions_df['sales'].tolist(), + "types": predictions_df['data_type'].tolist() + } + + # 为前端debug提供额外信息 + chart_data["debug"] = { + "history_dates": history_df['date'].tolist() if not history_df.empty else [], + "history_sales": history_df['sales'].tolist() if not history_df.empty else [], + "prediction_dates": prediction_df['date'].tolist() if not prediction_df.empty else [], + "prediction_sales": prediction_df['sales'].tolist() if not prediction_df.empty else [], + } + + print(f"历史数据日期范围: {history_df['date'].min() if not history_df.empty else 'N/A'} 到 {history_df['date'].max() if not history_df.empty else 'N/A'}") + print(f"预测数据日期范围: {prediction_df['date'].min() if not prediction_df.empty else 'N/A'} 到 {prediction_df['date'].max() if not prediction_df.empty else 'N/A'}") + + return chart_data + except Exception as e: + print(f"准备图表数据失败: {str(e)}") + traceback.print_exc() + return {} + +# 分析预测结果的辅助函数 +def analyze_prediction(prediction_result): + """分析预测结果""" + try: + # 从预测结果中提取数据 + prediction_data = prediction_result.get('prediction_data', []) + history_data = prediction_result.get('history_data', []) + + if not prediction_data: + return None + + # 转换为DataFrame以便分析 + prediction_df = pd.DataFrame(prediction_data) + + # 确保日期列是日期时间格式 + if 'date' in prediction_df.columns: + if isinstance(prediction_df['date'][0], str): + prediction_df['date'] = pd.to_datetime(prediction_df['date']) + + # 计算统计数据 + sales = prediction_df['sales'].values + mean_sales = np.mean(sales) + max_sales = np.max(sales) + min_sales = np.min(sales) + std_sales = np.std(sales) + + # 计算日环比变化 + day_over_day_changes = [] + for i in range(1, len(sales)): + if sales[i-1] == 0: + day_over_day_changes.append(0) + else: + change_pct = ((sales[i] - sales[i-1]) / sales[i-1]) * 100 + day_over_day_changes.append(change_pct) + + # 确定趋势 + if len(sales) < 2: + trend = "unknown" + else: + # 计算简单线性回归的斜率 + x = np.arange(len(sales)) + slope = np.polyfit(x, sales, 1)[0] + + # 计算变化的标准差 + changes = np.diff(sales) + changes_std = np.std(changes) + + if abs(slope) < 0.1 * mean_sales: + if changes_std > 0.2 * mean_sales: + trend = "fluctuating" + else: + trend = "stable" + elif slope > 0: + trend = "increasing" + else: + trend = "decreasing" + + # 历史数据对比 + has_historical_data = len(history_data) > 0 + historical_comparison = { + "has_historical_data": has_historical_data, + "mean_difference_pct": 0 + } + + if has_historical_data: + history_df = pd.DataFrame(history_data) + history_mean = history_df['sales'].mean() + prediction_mean = mean_sales + + if history_mean > 0: + mean_difference_pct = ((prediction_mean - history_mean) / history_mean) * 100 + historical_comparison["mean_difference_pct"] = mean_difference_pct + + # 模拟影响因素 + factors = [ + {"name": "季节性", "importance": "high", "description": "季节变化对销量有显著影响"}, + {"name": "促销活动", "importance": "medium", "description": "促销活动可能会短期提升销量"}, + {"name": "市场竞争", "importance": "low", "description": "市场竞争对销量有轻微影响"} + ] + + # 生成解释文本 + if trend == "increasing": + explanation = f"预测期内销量整体呈上升趋势,平均日销量为{mean_sales:.2f},相比历史数据" + if has_historical_data and historical_comparison["mean_difference_pct"] > 0: + explanation += f"增长了{historical_comparison['mean_difference_pct']:.2f}%。" + elif has_historical_data: + explanation += f"下降了{abs(historical_comparison['mean_difference_pct']):.2f}%。" + else: + explanation += "无法比较。" + explanation += "建议适当增加库存以应对销量增长。" + elif trend == "decreasing": + explanation = f"预测期内销量整体呈下降趋势,平均日销量为{mean_sales:.2f},相比历史数据" + if has_historical_data and historical_comparison["mean_difference_pct"] > 0: + explanation += f"增长了{historical_comparison['mean_difference_pct']:.2f}%。" + elif has_historical_data: + explanation += f"下降了{abs(historical_comparison['mean_difference_pct']):.2f}%。" + else: + explanation += "无法比较。" + explanation += "建议控制库存以避免积压。" + elif trend == "fluctuating": + explanation = f"预测期内销量波动较大,平均日销量为{mean_sales:.2f},标准差为{std_sales:.2f}。建议密切关注销售情况,灵活调整库存。" + else: + explanation = f"预测期内销量保持稳定,平均日销量为{mean_sales:.2f},最高销量为{max_sales:.2f},最低销量为{min_sales:.2f}。" + + # 构建分析结果 + analysis = { + "explanation": explanation, + "trend": trend, + "statistics": { + "mean": mean_sales, + "max": max_sales, + "min": min_sales, + "std": std_sales, + "day_over_day_changes": day_over_day_changes + }, + "historical_comparison": historical_comparison, + "factors": factors + } + + return analysis + except Exception as e: + print(f"分析预测结果失败: {str(e)}") + return None + +# 新增模型性能分析接口 +@app.route('/api/models/analyze-metrics', methods=['POST']) +@swag_from({ + 'tags': ['模型管理'], + 'summary': '分析模型评估指标', + 'description': '接收一组模型指标,并返回详细的文字解读和评级。', + 'parameters': [ + { + 'name': 'body', + 'in': 'body', + 'required': True, + 'schema': { + 'type': 'object', + 'description': '模型的评估指标', + 'example': { + 'R2': 0.7067, + 'RMSE': 6.8670, + 'MAE': 4.4062, + 'MAPE': 18.14 + } + } + } + ], + 'responses': { + 200: { + 'description': '分析成功', + 'schema': { + 'type': 'object', + 'properties': { + 'status': {'type': 'string'}, + 'data': {'type': 'object'} + } + } + }, + 400: { + 'description': '请求错误,缺少指标数据' + } + } +}) +def analyze_model_metrics(): + """ + 分析模型的评估指标并提供解读 + """ + try: + metrics = request.json + if not metrics: + return jsonify({"status": "error", "error": "缺少指标数据"}), 400 + + analysis = {} + + # 1. 分析 R² + r2 = metrics.get('R2') + if r2 is not None: + if r2 >= 0.9: + r2_rating = "优秀" + r2_desc = f"R²值为{r2:.4f},表现非常出色。模型能够解释超过90%的销售数据波动,意味着它与历史数据的拟合度极高,预测结果非常可靠。" + elif r2 >= 0.7: + r2_rating = "良好" + r2_desc = f"R²值为{r2:.4f},表现良好。模型能解释{(r2*100):.1f}%的销售数据变化,说明模型捕捉到了大部分关键的销售模式,具备很高的实用价值。" + elif r2 >= 0.5: + r2_rating = "中等" + r2_desc = f"R²值为{r2:.4f},表现中等。模型解释了约一半的销售数据变化,说明它掌握了一定的规律,但可能忽略了一些次要因素。预测结果可作为参考,但需结合其他信息判断。" + else: + r2_rating = "较弱" + r2_desc = f"R²值为{r2:.4f},表现较弱。模型对销售数据变化的解释能力有限,预测的准确性可能不高。建议尝试优化模型或增加更多有效特征。" + analysis['R2'] = {"value": r2, "rating": r2_rating, "description": r2_desc} + + # 2. 分析 MAPE + mape = metrics.get('MAPE') + if mape is not None: + if mape <= 10: + mape_rating = "优秀" + mape_desc = f"平均绝对百分比误差为{mape:.2f}%,误差率极低,预测精度非常高。" + elif mape <= 20: + mape_rating = "良好" + mape_desc = f"平均绝对百分比误差为{mape:.2f}%,误差率在可接受的范围内,表明模型的预测结果在大多数情况下与真实值偏差不大。" + elif mape <= 30: + mape_rating = "中等" + mape_desc = f"平均绝对百分比误差为{mape:.2f}%,误差率中等。在销量波动较大的场景下可以接受,但对于追求高精度预测的场景,仍有优化空间。" + else: + mape_rating = "较弱" + mape_desc = f"平均绝对百分比误差为{mape:.2f}%,误差率偏高。模型的预测值与真实值偏差较大,建议谨慎使用其预测结果。" + analysis['MAPE'] = {"value": mape, "rating": mape_rating, "description": mape_desc} + + # 3. 分析 RMSE 和 MAE + rmse = metrics.get('RMSE') + mae = metrics.get('MAE') + if rmse is not None and mae is not None: + rmse_desc = f"均方根误差为{rmse:.4f}。这个值衡量了预测误差的典型大小。因为它对较大的误差值更敏感,所以可以反映模型是否存在'离谱'的预测。" + mae_desc = f"平均绝对误差为{mae:.4f}。这个值直观地表示了模型平均预测会偏离真实销售量多少个单位(如'件')。" + + # 比较RMSE和MAE + if rmse > mae * 1.5: # 经验阈值 + comparison_desc = f"RMSE ({rmse:.4f}) 明显大于 MAE ({mae:.4f}),这通常意味着模型在某些数据点上存在较大的预测误差(离群点)。虽然总体平均误差不大,但偶尔可能会有'猜错得离谱'的情况。" + else: + comparison_desc = f"RMSE ({rmse:.4f}) 与 MAE ({mae:.4f}) 的值较为接近,表明模型的误差分布比较均匀,没有出现极端异常的预测错误。" + + analysis['RMSE'] = {"value": rmse, "rating": "参考指标", "description": rmse_desc} + analysis['MAE'] = {"value": mae, "rating": "参考指标", "description": mae_desc} + analysis['RMSE_MAE_COMP'] = {"description": comparison_desc} + + # 4. 形成总体结论 + overall_ratings = [a.get('rating') for a in analysis.values() if a.get('rating') in ["优秀", "良好", "中等", "较弱"]] + if "较弱" in overall_ratings: + overall_summary = "该模型的综合性能表现较弱,预测结果可能存在较大偏差,建议进行优化或谨慎使用。" + elif "中等" in overall_ratings: + overall_summary = "该模型的综合性能表现中等,具备一定的预测能力,但仍有提升空间。其预测结果可作为重要参考。" + elif "优秀" in overall_ratings and overall_ratings.count("良好") == 0: + overall_summary = "该模型的综合性能表现非常优秀,各项指标均显示其预测精度高、稳定性好,预测结果高度可信。" + else: # 主要为良好 + overall_summary = "该模型的综合性能表现良好,能够可靠地预测销售趋势,误差在可接受范围内,是决策的有力支持。" + + analysis['overall_summary'] = overall_summary + + return jsonify({"status": "success", "data": analysis}) + + except Exception as e: + import traceback + traceback.print_exc() + return jsonify({"status": "error", "error": str(e)}), 500 + +# 添加一个主函数入口点,用于直接运行API服务器 +if __name__ == '__main__': + # 初始化数据库 + init_db() + + # 使用waitress作为生产环境的WSGI服务器,比Flask默认的开发服务器更健壮 + # from waitress import serve + # serve(app, host="0.0.0.0", port=5000) + + # 或者,为了方便调试,仍然使用Flask内置的开发服务器 + # 注意:debug=True 模式在生产环境中非常不安全,请仅在开发阶段使用 + app.run(host="0.0.0.0", port=5000, debug=True) \ No newline at end of file