2972 lines
112 KiB
Python
2972 lines
112 KiB
Python
![]() |
import sys
|
|||
|
import os
|
|||
|
|
|||
|
# 获取当前脚本所在目录的绝对路径
|
|||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|||
|
|
|||
|
# 将当前目录添加到系统路径
|
|||
|
sys.path.append(current_dir)
|
|||
|
|
|||
|
import json
|
|||
|
import pandas as pd
|
|||
|
import numpy as np
|
|||
|
import torch
|
|||
|
import matplotlib.pyplot as plt
|
|||
|
import io
|
|||
|
import base64
|
|||
|
import uuid
|
|||
|
from datetime import datetime, timedelta
|
|||
|
from flask import Flask, jsonify, request, send_file, redirect, send_from_directory, Response
|
|||
|
from flask_cors import CORS
|
|||
|
from flasgger import Swagger
|
|||
|
from werkzeug.utils import secure_filename
|
|||
|
import sqlite3
|
|||
|
import traceback
|
|||
|
|
|||
|
# 导入核心预测器类
|
|||
|
from core.predictor import PharmacyPredictor
|
|||
|
|
|||
|
# 导入训练函数
|
|||
|
from trainers.mlstm_trainer import train_product_model_with_mlstm
|
|||
|
from trainers.kan_trainer import train_product_model_with_kan
|
|||
|
from trainers.tcn_trainer import train_product_model_with_tcn
|
|||
|
from trainers.transformer_trainer import train_product_model_with_transformer
|
|||
|
|
|||
|
# 导入预测函数
|
|||
|
from predictors.model_predictor import load_model_and_predict
|
|||
|
|
|||
|
# 导入分析函数
|
|||
|
from analysis.trend_analysis import analyze_prediction_result
|
|||
|
from analysis.metrics import evaluate_model, compare_models
|
|||
|
|
|||
|
# 导入配置
|
|||
|
from core.config import DEFAULT_MODEL_DIR
|
|||
|
|
|||
|
import threading
|
|||
|
import base64
|
|||
|
import matplotlib.pyplot as plt
|
|||
|
from io import BytesIO
|
|||
|
import json
|
|||
|
from flasgger import Swagger, swag_from
|
|||
|
import argparse
|
|||
|
from datetime import datetime, timedelta
|
|||
|
from concurrent.futures import ThreadPoolExecutor
|
|||
|
from threading import Lock
|
|||
|
import traceback
|
|||
|
import torch
|
|||
|
import sqlite3
|
|||
|
import numpy as np
|
|||
|
import io
|
|||
|
from werkzeug.utils import secure_filename
|
|||
|
import random
|
|||
|
|
|||
|
# 添加安全全局变量,解决PyTorch 2.6序列化问题
|
|||
|
try:
|
|||
|
import sklearn.preprocessing._data
|
|||
|
torch.serialization.add_safe_globals([sklearn.preprocessing._data.MinMaxScaler])
|
|||
|
except ImportError:
|
|||
|
print("警告: 无法导入sklearn,某些模型可能无法正确加载")
|
|||
|
except AttributeError:
|
|||
|
print("警告: 当前PyTorch版本不支持add_safe_globals,某些模型可能无法正确加载")
|
|||
|
|
|||
|
# 创建SQLite数据库连接函数
|
|||
|
def get_db_connection():
|
|||
|
"""获取SQLite数据库连接"""
|
|||
|
conn = sqlite3.connect('prediction_history.db')
|
|||
|
conn.row_factory = sqlite3.Row
|
|||
|
return conn
|
|||
|
|
|||
|
# 初始化数据库
|
|||
|
def init_db():
|
|||
|
"""初始化SQLite数据库,创建必要的表"""
|
|||
|
conn = get_db_connection()
|
|||
|
cursor = conn.cursor()
|
|||
|
|
|||
|
# 创建历史预测记录表
|
|||
|
cursor.execute('''
|
|||
|
CREATE TABLE IF NOT EXISTS prediction_history (
|
|||
|
id TEXT PRIMARY KEY,
|
|||
|
product_id TEXT NOT NULL,
|
|||
|
product_name TEXT NOT NULL,
|
|||
|
model_type TEXT NOT NULL,
|
|||
|
model_id TEXT NOT NULL,
|
|||
|
start_date TEXT NOT NULL,
|
|||
|
future_days INTEGER NOT NULL,
|
|||
|
created_at TEXT NOT NULL,
|
|||
|
file_path TEXT NOT NULL
|
|||
|
)
|
|||
|
''')
|
|||
|
|
|||
|
conn.commit()
|
|||
|
conn.close()
|
|||
|
print("数据库初始化完成")
|
|||
|
|
|||
|
# 自定义JSON编码器来处理Pandas的Timestamp和NumPy类型
|
|||
|
class CustomJSONEncoder(json.JSONEncoder):
|
|||
|
def default(self, obj):
|
|||
|
# 处理Pandas日期时间类型
|
|||
|
if isinstance(obj, (pd.Timestamp, pd.DatetimeIndex)):
|
|||
|
return obj.strftime('%Y-%m-%d')
|
|||
|
# 处理NumPy整数类型
|
|||
|
elif isinstance(obj, np.integer):
|
|||
|
return int(obj)
|
|||
|
# 处理NumPy浮点类型
|
|||
|
elif isinstance(obj, (np.floating, np.float32, np.float64)):
|
|||
|
return float(obj)
|
|||
|
# 处理NumPy数组
|
|||
|
elif isinstance(obj, np.ndarray):
|
|||
|
return obj.tolist()
|
|||
|
# 处理NaN和None值
|
|||
|
elif pd.isna(obj) or obj is None:
|
|||
|
return None
|
|||
|
# 处理其他可能的NumPy标量类型
|
|||
|
elif np.isscalar(obj):
|
|||
|
return obj.item() if hasattr(obj, 'item') else obj
|
|||
|
# 处理集合类型
|
|||
|
elif isinstance(obj, set):
|
|||
|
return list(obj)
|
|||
|
# 处理日期时间类型
|
|||
|
elif isinstance(obj, datetime):
|
|||
|
return obj.isoformat()
|
|||
|
return super(CustomJSONEncoder, self).default(obj)
|
|||
|
|
|||
|
app = Flask(__name__)
|
|||
|
# 设置自定义JSON编码器
|
|||
|
app.json_encoder = CustomJSONEncoder
|
|||
|
CORS(app) # 启用CORS支持
|
|||
|
|
|||
|
# 初始化数据库
|
|||
|
init_db()
|
|||
|
|
|||
|
# Swagger配置
|
|||
|
swagger_config = {
|
|||
|
"headers": [],
|
|||
|
"specs": [
|
|||
|
{
|
|||
|
"endpoint": "apispec",
|
|||
|
"route": "/apispec.json",
|
|||
|
"rule_filter": lambda rule: True, # 包含所有路由
|
|||
|
"model_filter": lambda tag: True, # 包含所有模型
|
|||
|
}
|
|||
|
],
|
|||
|
"static_url_path": "/flasgger_static",
|
|||
|
"swagger_ui": True,
|
|||
|
"specs_route": "/swagger/"
|
|||
|
}
|
|||
|
|
|||
|
swagger_template = {
|
|||
|
"swagger": "2.0",
|
|||
|
"info": {
|
|||
|
"title": "药店销售预测系统API",
|
|||
|
"description": "用于药店销售预测的RESTful API",
|
|||
|
"version": "1.0.0",
|
|||
|
"contact": {
|
|||
|
"name": "API开发团队",
|
|||
|
"email": "support@example.com"
|
|||
|
}
|
|||
|
},
|
|||
|
"tags": [
|
|||
|
{
|
|||
|
"name": "数据管理",
|
|||
|
"description": "数据上传和查询相关接口"
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "模型训练",
|
|||
|
"description": "模型训练相关接口"
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "模型预测",
|
|||
|
"description": "预测销售数据相关接口"
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "模型管理",
|
|||
|
"description": "模型查询、导出和删除接口"
|
|||
|
}
|
|||
|
]
|
|||
|
}
|
|||
|
|
|||
|
swagger = Swagger(app, config=swagger_config, template=swagger_template)
|
|||
|
|
|||
|
# 存储训练任务状态
|
|||
|
training_tasks = {}
|
|||
|
tasks_lock = Lock()
|
|||
|
|
|||
|
# 线程池用于后台训练
|
|||
|
executor = ThreadPoolExecutor(max_workers=2)
|
|||
|
|
|||
|
# 辅助函数:将图像转换为Base64
|
|||
|
def fig_to_base64(fig):
|
|||
|
buf = BytesIO()
|
|||
|
fig.savefig(buf, format='png')
|
|||
|
buf.seek(0)
|
|||
|
img_str = base64.b64encode(buf.read()).decode('utf-8')
|
|||
|
return img_str
|
|||
|
|
|||
|
# 根路由 - 重定向到UI界面
|
|||
|
@app.route('/')
|
|||
|
def index():
|
|||
|
"""重定向到UI界面"""
|
|||
|
return redirect('/ui/')
|
|||
|
|
|||
|
# Swagger UI路由
|
|||
|
@app.route('/swagger')
|
|||
|
def swagger_ui():
|
|||
|
"""重定向到Swagger UI文档页面"""
|
|||
|
return redirect('/swagger/')
|
|||
|
|
|||
|
# 1. 数据管理API
|
|||
|
@app.route('/api/products', methods=['GET'])
|
|||
|
@swag_from({
|
|||
|
'tags': ['数据管理'],
|
|||
|
'summary': '获取所有产品列表',
|
|||
|
'description': '返回系统中所有产品的ID和名称',
|
|||
|
'responses': {
|
|||
|
200: {
|
|||
|
'description': '成功获取产品列表',
|
|||
|
'schema': {
|
|||
|
'type': 'object',
|
|||
|
'properties': {
|
|||
|
'status': {'type': 'string'},
|
|||
|
'data': {
|
|||
|
'type': 'array',
|
|||
|
'items': {
|
|||
|
'type': 'object',
|
|||
|
'properties': {
|
|||
|
'product_id': {'type': 'string'},
|
|||
|
'product_name': {'type': 'string'}
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
},
|
|||
|
500: {
|
|||
|
'description': '服务器内部错误'
|
|||
|
}
|
|||
|
}
|
|||
|
})
|
|||
|
def get_products():
|
|||
|
try:
|
|||
|
df = pd.read_excel('pharmacy_sales.xlsx')
|
|||
|
products = df[['product_id', 'product_name']].drop_duplicates().to_dict('records')
|
|||
|
return jsonify({"status": "success", "data": products})
|
|||
|
except Exception as e:
|
|||
|
return jsonify({"status": "error", "message": str(e)}), 500
|
|||
|
|
|||
|
@app.route('/api/products/<product_id>', methods=['GET'])
|
|||
|
@swag_from({
|
|||
|
'tags': ['数据管理'],
|
|||
|
'summary': '获取单个产品详情',
|
|||
|
'description': '返回指定产品ID的详细信息',
|
|||
|
'parameters': [
|
|||
|
{
|
|||
|
'name': 'product_id',
|
|||
|
'in': 'path',
|
|||
|
'type': 'string',
|
|||
|
'required': True,
|
|||
|
'description': '产品ID,例如P001'
|
|||
|
}
|
|||
|
],
|
|||
|
'responses': {
|
|||
|
200: {
|
|||
|
'description': '成功获取产品详情',
|
|||
|
'schema': {
|
|||
|
'type': 'object',
|
|||
|
'properties': {
|
|||
|
'status': {'type': 'string'},
|
|||
|
'data': {
|
|||
|
'type': 'object',
|
|||
|
'properties': {
|
|||
|
'product_id': {'type': 'string'},
|
|||
|
'product_name': {'type': 'string'},
|
|||
|
'data_points': {'type': 'integer'},
|
|||
|
'date_range': {
|
|||
|
'type': 'object',
|
|||
|
'properties': {
|
|||
|
'start': {'type': 'string'},
|
|||
|
'end': {'type': 'string'}
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
},
|
|||
|
404: {
|
|||
|
'description': '产品不存在'
|
|||
|
},
|
|||
|
500: {
|
|||
|
'description': '服务器内部错误'
|
|||
|
}
|
|||
|
}
|
|||
|
})
|
|||
|
def get_product(product_id):
|
|||
|
try:
|
|||
|
df = pd.read_excel('pharmacy_sales.xlsx')
|
|||
|
product_df = df[df['product_id'] == product_id]
|
|||
|
|
|||
|
if product_df.empty:
|
|||
|
return jsonify({"status": "error", "message": "产品不存在"}), 404
|
|||
|
|
|||
|
product_info = {
|
|||
|
"product_id": product_id,
|
|||
|
"product_name": product_df['product_name'].iloc[0],
|
|||
|
"data_points": len(product_df),
|
|||
|
"date_range": {
|
|||
|
"start": product_df['date'].min().strftime('%Y-%m-%d'),
|
|||
|
"end": product_df['date'].max().strftime('%Y-%m-%d')
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
return jsonify({"status": "success", "data": product_info})
|
|||
|
except Exception as e:
|
|||
|
return jsonify({"status": "error", "message": str(e)}), 500
|
|||
|
|
|||
|
@app.route('/api/products/<product_id>/sales', methods=['GET'])
|
|||
|
@swag_from({
|
|||
|
'tags': ['数据管理'],
|
|||
|
'summary': '获取产品销售数据',
|
|||
|
'description': '返回指定产品在特定日期范围内的销售数据',
|
|||
|
'parameters': [
|
|||
|
{
|
|||
|
'name': 'product_id',
|
|||
|
'in': 'path',
|
|||
|
'type': 'string',
|
|||
|
'required': True,
|
|||
|
'description': '产品ID,例如P001'
|
|||
|
},
|
|||
|
{
|
|||
|
'name': 'start_date',
|
|||
|
'in': 'query',
|
|||
|
'type': 'string',
|
|||
|
'required': False,
|
|||
|
'description': '开始日期,格式为YYYY-MM-DD'
|
|||
|
},
|
|||
|
{
|
|||
|
'name': 'end_date',
|
|||
|
'in': 'query',
|
|||
|
'type': 'string',
|
|||
|
'required': False,
|
|||
|
'description': '结束日期,格式为YYYY-MM-DD'
|
|||
|
}
|
|||
|
],
|
|||
|
'responses': {
|
|||
|
200: {
|
|||
|
'description': '成功获取销售数据',
|
|||
|
'schema': {
|
|||
|
'type': 'object',
|
|||
|
'properties': {
|
|||
|
'status': {'type': 'string'},
|
|||
|
'data': {
|
|||
|
'type': 'array',
|
|||
|
'items': {
|
|||
|
'type': 'object'
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
},
|
|||
|
404: {
|
|||
|
'description': '产品不存在'
|
|||
|
},
|
|||
|
500: {
|
|||
|
'description': '服务器内部错误'
|
|||
|
}
|
|||
|
}
|
|||
|
})
|
|||
|
def get_product_sales(product_id):
|
|||
|
try:
|
|||
|
start_date = request.args.get('start_date')
|
|||
|
end_date = request.args.get('end_date')
|
|||
|
|
|||
|
df = pd.read_excel('pharmacy_sales.xlsx')
|
|||
|
product_df = df[df['product_id'] == product_id].sort_values('date')
|
|||
|
|
|||
|
if product_df.empty:
|
|||
|
return jsonify({"status": "error", "message": "产品不存在"}), 404
|
|||
|
|
|||
|
# 如果提供了日期范围,进行过滤
|
|||
|
if start_date:
|
|||
|
product_df = product_df[product_df['date'] >= pd.to_datetime(start_date)]
|
|||
|
if end_date:
|
|||
|
product_df = product_df[product_df['date'] <= pd.to_datetime(end_date)]
|
|||
|
|
|||
|
# 转换日期为字符串以便JSON序列化
|
|||
|
product_df['date'] = product_df['date'].dt.strftime('%Y-%m-%d')
|
|||
|
|
|||
|
sales_data = product_df.to_dict('records')
|
|||
|
return jsonify({"status": "success", "data": sales_data})
|
|||
|
except Exception as e:
|
|||
|
return jsonify({"status": "error", "message": str(e)}), 500
|
|||
|
|
|||
|
@app.route('/api/data/upload', methods=['POST'])
|
|||
|
@swag_from({
|
|||
|
'tags': ['数据管理'],
|
|||
|
'summary': '上传销售数据',
|
|||
|
'description': '上传新的销售数据文件(Excel格式)',
|
|||
|
'consumes': ['multipart/form-data'],
|
|||
|
'parameters': [
|
|||
|
{
|
|||
|
'name': 'file',
|
|||
|
'in': 'formData',
|
|||
|
'type': 'file',
|
|||
|
'required': True,
|
|||
|
'description': 'Excel文件(.xlsx),包含销售数据'
|
|||
|
}
|
|||
|
],
|
|||
|
'responses': {
|
|||
|
200: {
|
|||
|
'description': '数据上传成功',
|
|||
|
'schema': {
|
|||
|
'type': 'object',
|
|||
|
'properties': {
|
|||
|
'status': {'type': 'string'},
|
|||
|
'message': {'type': 'string'},
|
|||
|
'data': {
|
|||
|
'type': 'object',
|
|||
|
'properties': {
|
|||
|
'products': {'type': 'integer'},
|
|||
|
'rows': {'type': 'integer'}
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
},
|
|||
|
400: {
|
|||
|
'description': '请求错误,可能是文件格式不正确或缺少必要字段'
|
|||
|
},
|
|||
|
500: {
|
|||
|
'description': '服务器内部错误'
|
|||
|
}
|
|||
|
}
|
|||
|
})
|
|||
|
def upload_data():
|
|||
|
try:
|
|||
|
if 'file' not in request.files:
|
|||
|
return jsonify({"status": "error", "message": "没有上传文件"}), 400
|
|||
|
|
|||
|
file = request.files['file']
|
|||
|
if file.filename == '':
|
|||
|
return jsonify({"status": "error", "message": "没有选择文件"}), 400
|
|||
|
|
|||
|
if file and file.filename.endswith('.xlsx'):
|
|||
|
file_path = 'uploaded_data.xlsx'
|
|||
|
file.save(file_path)
|
|||
|
|
|||
|
# 验证数据格式
|
|||
|
try:
|
|||
|
df = pd.read_excel(file_path)
|
|||
|
required_columns = ['date', 'product_id', 'product_name', 'sales', 'price']
|
|||
|
missing_columns = [col for col in required_columns if col not in df.columns]
|
|||
|
|
|||
|
if missing_columns:
|
|||
|
return jsonify({
|
|||
|
"status": "error",
|
|||
|
"message": f"上传的数据缺少必要的列: {', '.join(missing_columns)}"
|
|||
|
}), 400
|
|||
|
|
|||
|
# 合并到现有数据或替换现有数据
|
|||
|
existing_df = pd.read_excel('pharmacy_sales.xlsx')
|
|||
|
# 这里可以实现数据合并逻辑,例如按日期和产品ID去重后合并
|
|||
|
|
|||
|
# 简单示例:保存上传的数据
|
|||
|
df.to_excel('pharmacy_sales.xlsx', index=False)
|
|||
|
|
|||
|
return jsonify({
|
|||
|
"status": "success",
|
|||
|
"message": "数据上传成功",
|
|||
|
"data": {
|
|||
|
"products": len(df['product_id'].unique()),
|
|||
|
"rows": len(df)
|
|||
|
}
|
|||
|
})
|
|||
|
except Exception as e:
|
|||
|
return jsonify({"status": "error", "message": f"数据验证失败: {str(e)}"}), 400
|
|||
|
else:
|
|||
|
return jsonify({"status": "error", "message": "只支持Excel文件(.xlsx)"}), 400
|
|||
|
except Exception as e:
|
|||
|
return jsonify({"status": "error", "message": str(e)}), 500
|
|||
|
|
|||
|
# 2. 模型训练API
|
|||
|
@app.route('/api/training', methods=['GET'])
|
|||
|
@swag_from({
|
|||
|
'tags': ['模型训练'],
|
|||
|
'summary': '获取所有训练任务列表',
|
|||
|
'description': '返回所有正在进行、已完成或失败的训练任务',
|
|||
|
'responses': {
|
|||
|
200: {
|
|||
|
'description': '成功获取任务列表',
|
|||
|
'schema': {
|
|||
|
'type': 'object',
|
|||
|
'properties': {
|
|||
|
'status': {'type': 'string'},
|
|||
|
'data': {
|
|||
|
'type': 'array',
|
|||
|
'items': {
|
|||
|
'type': 'object',
|
|||
|
'properties': {
|
|||
|
'task_id': {'type': 'string'},
|
|||
|
'product_id': {'type': 'string'},
|
|||
|
'model_type': {'type': 'string'},
|
|||
|
'status': {'type': 'string'},
|
|||
|
'start_time': {'type': 'string'},
|
|||
|
'metrics': {'type': 'object'},
|
|||
|
'error': {'type': 'string'},
|
|||
|
'model_path': {'type': 'string'}
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
})
|
|||
|
def get_all_training_tasks():
|
|||
|
"""获取所有训练任务的状态"""
|
|||
|
try:
|
|||
|
# 为了方便前端使用,我们将任务ID也包含在每个任务信息中
|
|||
|
tasks_with_id = []
|
|||
|
for task_id, task_info in training_tasks.items():
|
|||
|
task_copy = task_info.copy()
|
|||
|
task_copy['task_id'] = task_id
|
|||
|
tasks_with_id.append(task_copy)
|
|||
|
|
|||
|
# 按开始时间降序排序,最新的任务在前面
|
|||
|
sorted_tasks = sorted(tasks_with_id, key=lambda x: x['start_time'], reverse=True)
|
|||
|
|
|||
|
return jsonify({"status": "success", "data": sorted_tasks})
|
|||
|
except Exception as e:
|
|||
|
return jsonify({"status": "error", "message": str(e)}), 500
|
|||
|
|
|||
|
@app.route('/api/training', methods=['POST'])
|
|||
|
@swag_from({
|
|||
|
'tags': ['模型训练'],
|
|||
|
'summary': '启动模型训练任务',
|
|||
|
'description': '为指定产品启动一个新的模型训练任务',
|
|||
|
'parameters': [
|
|||
|
{
|
|||
|
'name': 'body',
|
|||
|
'in': 'body',
|
|||
|
'required': True,
|
|||
|
'schema': {
|
|||
|
'type': 'object',
|
|||
|
'properties': {
|
|||
|
'product_id': {'type': 'string', 'description': '例如 P001'},
|
|||
|
'model_type': {'type': 'string', 'enum': ['mlstm', 'transformer', 'kan', 'optimized_kan', 'tcn'], 'description': '要训练的模型类型'},
|
|||
|
'epochs': {'type': 'integer', 'default': 50, 'description': '训练轮次'}
|
|||
|
},
|
|||
|
'required': ['product_id', 'model_type']
|
|||
|
}
|
|||
|
}
|
|||
|
],
|
|||
|
'responses': {
|
|||
|
200: {
|
|||
|
'description': '训练任务已启动',
|
|||
|
'schema': {
|
|||
|
'type': 'object',
|
|||
|
'properties': {
|
|||
|
'message': {'type': 'string'},
|
|||
|
'task_id': {'type': 'string'}
|
|||
|
}
|
|||
|
}
|
|||
|
},
|
|||
|
400: {
|
|||
|
'description': '请求错误'
|
|||
|
}
|
|||
|
}
|
|||
|
})
|
|||
|
def start_training():
|
|||
|
"""
|
|||
|
启动模型训练
|
|||
|
---
|
|||
|
post:
|
|||
|
...
|
|||
|
"""
|
|||
|
data = request.get_json()
|
|||
|
product_id = data.get('product_id')
|
|||
|
model_type = data.get('model_type')
|
|||
|
epochs = data.get('epochs', 50) # 默认为50轮
|
|||
|
|
|||
|
if not product_id or not model_type:
|
|||
|
return jsonify({'error': '缺少product_id或model_type'}), 400
|
|||
|
|
|||
|
global training_tasks
|
|||
|
task_id = str(uuid.uuid4())
|
|||
|
|
|||
|
# 创建预测器实例
|
|||
|
predictor = PharmacyPredictor()
|
|||
|
|
|||
|
# 检查模型类型是否有效
|
|||
|
valid_model_types = ['mlstm', 'kan', 'optimized_kan', 'transformer', 'tcn']
|
|||
|
if model_type not in valid_model_types:
|
|||
|
return jsonify({'error': '无效的模型类型'}), 400
|
|||
|
|
|||
|
def train_task(product_id, epochs, model_type):
|
|||
|
global training_tasks
|
|||
|
try:
|
|||
|
print(f"任务 {task_id}: 开始为产品 {product_id} 训练 {model_type} 模型,共 {epochs} 个轮次。")
|
|||
|
|
|||
|
# 使用PharmacyPredictor进行训练
|
|||
|
use_optimized = (model_type == 'optimized_kan')
|
|||
|
|
|||
|
if model_type == 'optimized_kan':
|
|||
|
# 优化版KAN使用特殊处理
|
|||
|
metrics = predictor.train_model(
|
|||
|
product_id=product_id,
|
|||
|
model_type='optimized_kan',
|
|||
|
epochs=epochs
|
|||
|
)
|
|||
|
else:
|
|||
|
# 其他模型直接使用model_type
|
|||
|
metrics = predictor.train_model(
|
|||
|
product_id=product_id,
|
|||
|
model_type=model_type,
|
|||
|
epochs=epochs
|
|||
|
)
|
|||
|
|
|||
|
training_tasks[task_id]['status'] = 'completed'
|
|||
|
training_tasks[task_id]['metrics'] = metrics
|
|||
|
|
|||
|
# 保存模型路径
|
|||
|
model_suffix = '_optimized' if model_type == 'optimized_kan' else ''
|
|||
|
model_path = os.path.join(app.config['MODEL_DIR'], f'{model_type}{model_suffix}_model_product_{product_id}.pth')
|
|||
|
training_tasks[task_id]['model_path'] = model_path
|
|||
|
print(f"任务 {task_id}: 训练完成。评估指标: {metrics}")
|
|||
|
except Exception as e:
|
|||
|
import traceback
|
|||
|
traceback.print_exc()
|
|||
|
print(f"任务 {task_id}: 训练失败。错误: {e}")
|
|||
|
training_tasks[task_id]['status'] = 'failed'
|
|||
|
training_tasks[task_id]['error'] = str(e)
|
|||
|
|
|||
|
thread = threading.Thread(target=train_task, args=(product_id, epochs, model_type))
|
|||
|
thread.start()
|
|||
|
|
|||
|
training_tasks[task_id] = {
|
|||
|
'status': 'running',
|
|||
|
'product_id': product_id,
|
|||
|
'model_type': model_type,
|
|||
|
'start_time': datetime.now().isoformat(),
|
|||
|
'metrics': None,
|
|||
|
'error': None,
|
|||
|
'model_path': None
|
|||
|
}
|
|||
|
|
|||
|
return jsonify({'message': '模型训练已开始', 'task_id': task_id})
|
|||
|
|
|||
|
@app.route('/api/training/<task_id>', methods=['GET'])
|
|||
|
@swag_from({
|
|||
|
'tags': ['模型训练'],
|
|||
|
'summary': '查询训练任务状态',
|
|||
|
'description': '获取特定训练任务的当前状态和详情',
|
|||
|
'parameters': [
|
|||
|
{
|
|||
|
'name': 'task_id',
|
|||
|
'in': 'path',
|
|||
|
'type': 'string',
|
|||
|
'required': True,
|
|||
|
'description': '训练任务ID'
|
|||
|
}
|
|||
|
],
|
|||
|
'responses': {
|
|||
|
200: {
|
|||
|
'description': '成功获取任务状态',
|
|||
|
'schema': {
|
|||
|
'type': 'object',
|
|||
|
'properties': {
|
|||
|
'status': {'type': 'string'},
|
|||
|
'data': {
|
|||
|
'type': 'object',
|
|||
|
'properties': {
|
|||
|
'product_id': {'type': 'string'},
|
|||
|
'model_type': {'type': 'string'},
|
|||
|
'parameters': {'type': 'object'},
|
|||
|
'status': {'type': 'string', 'enum': ['pending', 'running', 'completed', 'failed']},
|
|||
|
'created_at': {'type': 'string'},
|
|||
|
'model_path': {'type': 'string'},
|
|||
|
'metrics': {'type': 'object'},
|
|||
|
'model_details_url': {'type': 'string'}
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
},
|
|||
|
404: {
|
|||
|
'description': '任务不存在'
|
|||
|
},
|
|||
|
500: {
|
|||
|
'description': '服务器内部错误'
|
|||
|
}
|
|||
|
}
|
|||
|
})
|
|||
|
def get_training_status(task_id):
|
|||
|
try:
|
|||
|
if task_id not in training_tasks:
|
|||
|
return jsonify({"status": "error", "message": "任务不存在"}), 404
|
|||
|
|
|||
|
task_info = training_tasks[task_id].copy()
|
|||
|
|
|||
|
# 如果任务已完成,添加模型详情链接
|
|||
|
if task_info['status'] == 'completed':
|
|||
|
task_info['model_details_url'] = f"/api/models?product_id={task_info['product_id']}&model_type={task_info['model_type']}"
|
|||
|
|
|||
|
return jsonify({
|
|||
|
"status": "success",
|
|||
|
"data": task_info
|
|||
|
})
|
|||
|
except Exception as e:
|
|||
|
return jsonify({"status": "error", "message": str(e)}), 500
|
|||
|
|
|||
|
# 3. 模型预测API
|
|||
|
@app.route('/api/prediction', methods=['POST'])
|
|||
|
@swag_from({
|
|||
|
'tags': ['模型预测'],
|
|||
|
'summary': '使用模型进行预测',
|
|||
|
'description': '使用指定模型预测未来销售数据',
|
|||
|
'parameters': [
|
|||
|
{
|
|||
|
'name': 'body',
|
|||
|
'in': 'body',
|
|||
|
'required': True,
|
|||
|
'schema': {
|
|||
|
'type': 'object',
|
|||
|
'properties': {
|
|||
|
'product_id': {'type': 'string'},
|
|||
|
'model_type': {'type': 'string', 'enum': ['mlstm', 'transformer', 'kan', 'optimized_kan', 'tcn']},
|
|||
|
'version': {'type': 'string'},
|
|||
|
'future_days': {'type': 'integer'},
|
|||
|
'include_visualization': {'type': 'boolean'},
|
|||
|
'start_date': {'type': 'string', 'description': '预测起始日期,格式为YYYY-MM-DD'}
|
|||
|
},
|
|||
|
'required': ['product_id', 'model_type']
|
|||
|
}
|
|||
|
}
|
|||
|
],
|
|||
|
'responses': {
|
|||
|
200: {
|
|||
|
'description': '预测成功',
|
|||
|
'schema': {
|
|||
|
'type': 'object',
|
|||
|
'properties': {
|
|||
|
'status': {'type': 'string'},
|
|||
|
'data': {
|
|||
|
'type': 'object',
|
|||
|
'properties': {
|
|||
|
'product_id': {'type': 'string'},
|
|||
|
'product_name': {'type': 'string'},
|
|||
|
'model_type': {'type': 'string'},
|
|||
|
'predictions': {'type': 'array'},
|
|||
|
'visualization': {'type': 'string'}
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
},
|
|||
|
400: {
|
|||
|
'description': '请求错误,缺少必要参数或参数格式错误'
|
|||
|
},
|
|||
|
404: {
|
|||
|
'description': '产品或模型不存在'
|
|||
|
},
|
|||
|
500: {
|
|||
|
'description': '服务器内部错误'
|
|||
|
}
|
|||
|
}
|
|||
|
})
|
|||
|
def predict():
|
|||
|
"""
|
|||
|
使用指定的模型进行预测
|
|||
|
---
|
|||
|
tags:
|
|||
|
- 模型预测
|
|||
|
parameters:
|
|||
|
- in: body
|
|||
|
name: body
|
|||
|
schema:
|
|||
|
type: object
|
|||
|
required:
|
|||
|
- product_id
|
|||
|
- model_type
|
|||
|
properties:
|
|||
|
product_id:
|
|||
|
type: string
|
|||
|
description: 产品ID
|
|||
|
model_type:
|
|||
|
type: string
|
|||
|
description: "模型类型 (mlstm, kan, transformer)"
|
|||
|
future_days:
|
|||
|
type: integer
|
|||
|
description: 预测未来天数
|
|||
|
default: 7
|
|||
|
start_date:
|
|||
|
type: string
|
|||
|
description: 预测起始日期,格式为YYYY-MM-DD
|
|||
|
default: ''
|
|||
|
responses:
|
|||
|
200:
|
|||
|
description: 预测成功
|
|||
|
400:
|
|||
|
description: 请求参数错误
|
|||
|
404:
|
|||
|
description: 模型文件未找到
|
|||
|
"""
|
|||
|
try:
|
|||
|
data = request.json
|
|||
|
product_id = data.get('product_id')
|
|||
|
model_type = data.get('model_type')
|
|||
|
future_days = int(data.get('future_days', 7))
|
|||
|
start_date = data.get('start_date', '')
|
|||
|
include_visualization = data.get('include_visualization', False)
|
|||
|
|
|||
|
print(f"API接收到预测请求: product_id={product_id}, model_type={model_type}, future_days={future_days}, start_date={start_date}")
|
|||
|
|
|||
|
if not product_id or not model_type:
|
|||
|
return jsonify({"status": "error", "error": "product_id 和 model_type 是必需的"}), 400
|
|||
|
|
|||
|
# 获取产品名称
|
|||
|
product_name = get_product_name(product_id)
|
|||
|
if not product_name:
|
|||
|
product_name = product_id
|
|||
|
|
|||
|
# 获取最新模型ID - 此处不需要处理模型类型映射,因为 get_latest_model_id 和 load_model_and_predict 会处理
|
|||
|
model_id = get_latest_model_id(model_type, product_id)
|
|||
|
if not model_id:
|
|||
|
return jsonify({"status": "error", "error": f"未找到产品 {product_id} 的 {model_type} 类型模型"}), 404
|
|||
|
|
|||
|
# 执行预测
|
|||
|
prediction_result = run_prediction(model_type, product_id, model_id, future_days, start_date)
|
|||
|
|
|||
|
if prediction_result is None:
|
|||
|
return jsonify({"status": "error", "error": "预测失败,请检查服务器日志获取详细信息"}), 500
|
|||
|
|
|||
|
if prediction_result.get("status") == "error":
|
|||
|
return jsonify(prediction_result), 500
|
|||
|
|
|||
|
# 如果需要可视化,添加图表数据
|
|||
|
if include_visualization:
|
|||
|
try:
|
|||
|
# 添加图表数据
|
|||
|
chart_data = prepare_chart_data(prediction_result)
|
|||
|
prediction_result['chart_data'] = chart_data
|
|||
|
|
|||
|
# 添加分析结果
|
|||
|
if 'analysis' not in prediction_result or prediction_result['analysis'] is None:
|
|||
|
analysis_result = analyze_prediction(prediction_result)
|
|||
|
prediction_result['analysis'] = analysis_result
|
|||
|
except Exception as e:
|
|||
|
print(f"生成可视化或分析数据失败: {str(e)}")
|
|||
|
# 可视化失败不影响主要功能,继续执行
|
|||
|
|
|||
|
# 保存预测结果到文件和数据库
|
|||
|
try:
|
|||
|
prediction_id, file_path = save_prediction_result(
|
|||
|
prediction_result,
|
|||
|
product_id,
|
|||
|
product_name,
|
|||
|
model_type,
|
|||
|
model_id,
|
|||
|
start_date,
|
|||
|
future_days
|
|||
|
)
|
|||
|
|
|||
|
# 添加预测ID到结果中
|
|||
|
prediction_result['prediction_id'] = prediction_id
|
|||
|
except Exception as e:
|
|||
|
print(f"保存预测结果失败: {str(e)}")
|
|||
|
# 保存失败不影响返回结果,继续执行
|
|||
|
|
|||
|
# 在调用jsonify之前,确保所有数据都是JSON可序列化的
|
|||
|
def convert_numpy_types(obj):
|
|||
|
if isinstance(obj, dict):
|
|||
|
return {k: convert_numpy_types(v) for k, v in obj.items()}
|
|||
|
elif isinstance(obj, list):
|
|||
|
return [convert_numpy_types(item) for item in obj]
|
|||
|
elif isinstance(obj, np.generic):
|
|||
|
return obj.item() # 将NumPy标量转换为Python原生类型
|
|||
|
elif isinstance(obj, np.ndarray):
|
|||
|
return obj.tolist()
|
|||
|
elif pd.isna(obj):
|
|||
|
return None
|
|||
|
elif isinstance(obj, pd.DataFrame):
|
|||
|
return obj.to_dict(orient='records')
|
|||
|
elif isinstance(obj, pd.Series):
|
|||
|
return obj.to_dict()
|
|||
|
else:
|
|||
|
return obj
|
|||
|
|
|||
|
# 递归处理整个预测结果对象,确保所有NumPy类型都被转换
|
|||
|
processed_result = convert_numpy_types(prediction_result)
|
|||
|
|
|||
|
# 使用处理后的结果进行JSON序列化
|
|||
|
return jsonify(processed_result)
|
|||
|
except Exception as e:
|
|||
|
print(f"预测失败: {str(e)}")
|
|||
|
import traceback
|
|||
|
traceback.print_exc()
|
|||
|
return jsonify({"status": "error", "error": str(e)}), 500
|
|||
|
|
|||
|
@app.route('/api/prediction/compare', methods=['POST'])
|
|||
|
@swag_from({
|
|||
|
'tags': ['模型预测'],
|
|||
|
'summary': '比较不同模型预测结果',
|
|||
|
'description': '比较不同模型对同一产品的预测结果',
|
|||
|
'parameters': [
|
|||
|
{
|
|||
|
'name': 'body',
|
|||
|
'in': 'body',
|
|||
|
'required': True,
|
|||
|
'schema': {
|
|||
|
'type': 'object',
|
|||
|
'properties': {
|
|||
|
'product_id': {'type': 'string'},
|
|||
|
'model_types': {'type': 'array', 'items': {'type': 'string'}},
|
|||
|
'versions': {'type': 'array', 'items': {'type': 'string'}},
|
|||
|
'include_visualization': {'type': 'boolean'}
|
|||
|
},
|
|||
|
'required': ['product_id', 'model_types']
|
|||
|
}
|
|||
|
}
|
|||
|
],
|
|||
|
'responses': {
|
|||
|
200: {
|
|||
|
'description': '比较成功',
|
|||
|
'schema': {
|
|||
|
'type': 'object',
|
|||
|
'properties': {
|
|||
|
'status': {'type': 'string'},
|
|||
|
'data': {
|
|||
|
'type': 'object',
|
|||
|
'properties': {
|
|||
|
'product_id': {'type': 'string'},
|
|||
|
'product_name': {'type': 'string'},
|
|||
|
'model_types': {'type': 'array'},
|
|||
|
'comparison': {'type': 'array'},
|
|||
|
'visualization': {'type': 'string'}
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
},
|
|||
|
400: {
|
|||
|
'description': '请求错误,缺少必要参数或参数格式错误'
|
|||
|
},
|
|||
|
404: {
|
|||
|
'description': '产品或模型不存在'
|
|||
|
},
|
|||
|
500: {
|
|||
|
'description': '服务器内部错误'
|
|||
|
}
|
|||
|
}
|
|||
|
})
|
|||
|
def compare_predictions():
|
|||
|
"""比较不同模型的预测结果"""
|
|||
|
try:
|
|||
|
# 获取请求数据
|
|||
|
data = request.get_json()
|
|||
|
product_id = data.get('product_id')
|
|||
|
model_types = data.get('model_types', [])
|
|||
|
future_days = data.get('future_days', 7)
|
|||
|
start_date = data.get('start_date')
|
|||
|
include_visualization = data.get('include_visualization', True)
|
|||
|
|
|||
|
if not product_id or not model_types or len(model_types) < 2:
|
|||
|
return jsonify({"status": "error", "message": "缺少产品ID或至少两种模型类型"}), 400
|
|||
|
|
|||
|
# 创建预测器实例
|
|||
|
predictor = PharmacyPredictor()
|
|||
|
|
|||
|
# 获取产品名称
|
|||
|
df = pd.read_excel('pharmacy_sales.xlsx')
|
|||
|
product_df = df[df['product_id'] == product_id]
|
|||
|
|
|||
|
if product_df.empty:
|
|||
|
return jsonify({"status": "error", "message": f"找不到产品 {product_id}"}), 404
|
|||
|
|
|||
|
product_name = product_df['product_name'].iloc[0]
|
|||
|
|
|||
|
# 执行每个模型的预测
|
|||
|
predictions = {}
|
|||
|
metrics = {}
|
|||
|
|
|||
|
for model_type in model_types:
|
|||
|
try:
|
|||
|
result = predictor.predict(
|
|||
|
product_id=product_id,
|
|||
|
model_type=model_type,
|
|||
|
future_days=future_days,
|
|||
|
start_date=start_date,
|
|||
|
analyze_result=True
|
|||
|
)
|
|||
|
|
|||
|
if result and 'predictions' in result:
|
|||
|
predictions[model_type] = result['predictions']
|
|||
|
|
|||
|
# 如果有分析结果,提取评估指标
|
|||
|
if 'analysis' in result and result['analysis']:
|
|||
|
metrics[model_type] = result['analysis'].get('metrics', {})
|
|||
|
except Exception as e:
|
|||
|
print(f"模型 {model_type} 预测失败: {str(e)}")
|
|||
|
continue
|
|||
|
|
|||
|
if not predictions:
|
|||
|
return jsonify({"status": "error", "message": "所有模型预测均失败"}), 500
|
|||
|
|
|||
|
# 比较模型性能
|
|||
|
comparison_result = {}
|
|||
|
if len(metrics) >= 2:
|
|||
|
comparison_result = compare_models(metrics)
|
|||
|
|
|||
|
# 准备响应数据
|
|||
|
response_data = {
|
|||
|
"product_id": product_id,
|
|||
|
"product_name": product_name,
|
|||
|
"model_types": list(predictions.keys()),
|
|||
|
"predictions": {}
|
|||
|
}
|
|||
|
|
|||
|
# 转换DataFrame为可序列化的字典
|
|||
|
for model_type, pred_df in predictions.items():
|
|||
|
# 处理DataFrame,确保可序列化
|
|||
|
if isinstance(pred_df, pd.DataFrame):
|
|||
|
records = pred_df.to_dict(orient='records')
|
|||
|
# 进一步处理,确保所有值都是JSON可序列化的
|
|||
|
for record in records:
|
|||
|
for key, value in record.items():
|
|||
|
if isinstance(value, np.generic):
|
|||
|
record[key] = value.item() # 将NumPy标量转换为Python原生类型
|
|||
|
elif pd.isna(value):
|
|||
|
record[key] = None
|
|||
|
response_data["predictions"][model_type] = records
|
|||
|
else:
|
|||
|
response_data["predictions"][model_type] = pred_df
|
|||
|
|
|||
|
# 处理指标和比较结果,确保可序列化
|
|||
|
processed_metrics = {}
|
|||
|
for model_type, model_metrics in metrics.items():
|
|||
|
processed_model_metrics = {}
|
|||
|
for metric_name, metric_value in model_metrics.items():
|
|||
|
if isinstance(metric_value, np.generic):
|
|||
|
processed_model_metrics[metric_name] = metric_value.item()
|
|||
|
else:
|
|||
|
processed_model_metrics[metric_name] = metric_value
|
|||
|
processed_metrics[model_type] = processed_model_metrics
|
|||
|
|
|||
|
response_data["metrics"] = processed_metrics
|
|||
|
response_data["comparison"] = comparison_result
|
|||
|
|
|||
|
# 如果需要可视化,生成比较图
|
|||
|
if include_visualization and len(predictions) >= 2:
|
|||
|
plt.figure(figsize=(12, 6))
|
|||
|
|
|||
|
for model_type, pred_df in predictions.items():
|
|||
|
plt.plot(pred_df['date'], pred_df['predicted_sales'], label=model_type)
|
|||
|
|
|||
|
plt.title(f'产品 {product_name} ({product_id}) - 多模型预测结果比较')
|
|||
|
plt.xlabel('日期')
|
|||
|
plt.ylabel('销量')
|
|||
|
plt.legend()
|
|||
|
plt.grid(True)
|
|||
|
plt.xticks(rotation=45)
|
|||
|
plt.tight_layout()
|
|||
|
|
|||
|
# 保存图像并转换为Base64
|
|||
|
img_filename = f"compare_{product_id}_{datetime.now().strftime('%Y%m%d%H%M%S')}.png"
|
|||
|
img_path = os.path.join('static', 'predictions', 'compare', img_filename)
|
|||
|
|
|||
|
# 确保目录存在
|
|||
|
os.makedirs(os.path.dirname(img_path), exist_ok=True)
|
|||
|
|
|||
|
plt.savefig(img_path)
|
|||
|
plt.close()
|
|||
|
|
|||
|
# 添加可视化URL到响应
|
|||
|
response_data["visualization_url"] = f"/api/predictions/compare/{img_filename}"
|
|||
|
|
|||
|
# 如果需要Base64编码的图像
|
|||
|
with open(img_path, "rb") as img_file:
|
|||
|
response_data["visualization"] = base64.b64encode(img_file.read()).decode('utf-8')
|
|||
|
|
|||
|
# 在调用jsonify之前,确保所有数据都是JSON可序列化的
|
|||
|
def convert_numpy_types(obj):
|
|||
|
if isinstance(obj, dict):
|
|||
|
return {k: convert_numpy_types(v) for k, v in obj.items()}
|
|||
|
elif isinstance(obj, list):
|
|||
|
return [convert_numpy_types(item) for item in obj]
|
|||
|
elif isinstance(obj, np.generic):
|
|||
|
return obj.item() # 将NumPy标量转换为Python原生类型
|
|||
|
elif isinstance(obj, np.ndarray):
|
|||
|
return obj.tolist()
|
|||
|
elif pd.isna(obj):
|
|||
|
return None
|
|||
|
elif isinstance(obj, pd.DataFrame):
|
|||
|
return obj.to_dict(orient='records')
|
|||
|
elif isinstance(obj, pd.Series):
|
|||
|
return obj.to_dict()
|
|||
|
else:
|
|||
|
return obj
|
|||
|
|
|||
|
# 递归处理整个响应数据对象,确保所有NumPy类型都被转换
|
|||
|
processed_response = convert_numpy_types(response_data)
|
|||
|
|
|||
|
return jsonify({"status": "success", "data": processed_response})
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
print(f"比较预测失败: {str(e)}")
|
|||
|
traceback.print_exc()
|
|||
|
return jsonify({"status": "error", "message": str(e)}), 500
|
|||
|
|
|||
|
@app.route('/api/prediction/analyze', methods=['POST'])
|
|||
|
def analyze_prediction():
|
|||
|
"""分析预测结果"""
|
|||
|
try:
|
|||
|
data = request.get_json()
|
|||
|
product_id = data.get('product_id')
|
|||
|
model_type = data.get('model_type')
|
|||
|
predictions = data.get('predictions')
|
|||
|
|
|||
|
if not product_id or not model_type or not predictions:
|
|||
|
return jsonify({"status": "error", "message": "缺少必要参数"}), 400
|
|||
|
|
|||
|
# 转换预测数据为NumPy数组
|
|||
|
predictions_array = np.array(predictions)
|
|||
|
|
|||
|
# 获取产品特征数据
|
|||
|
df = pd.read_excel('pharmacy_sales.xlsx')
|
|||
|
product_df = df[df['product_id'] == product_id].sort_values('date')
|
|||
|
|
|||
|
if product_df.empty:
|
|||
|
return jsonify({"status": "error", "message": f"找不到产品 {product_id}"}), 404
|
|||
|
|
|||
|
# 提取特征数据
|
|||
|
features = product_df[['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']].values
|
|||
|
|
|||
|
# 使用分析函数
|
|||
|
from analysis.trend_analysis import analyze_prediction_result
|
|||
|
analysis = analyze_prediction_result(product_id, model_type, predictions_array, features)
|
|||
|
|
|||
|
# 返回分析结果
|
|||
|
return jsonify({
|
|||
|
"status": "success",
|
|||
|
"data": {
|
|||
|
"product_id": product_id,
|
|||
|
"model_type": model_type,
|
|||
|
"analysis": analysis
|
|||
|
}
|
|||
|
})
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
print(f"分析预测失败: {str(e)}")
|
|||
|
traceback.print_exc()
|
|||
|
return jsonify({"status": "error", "message": str(e)}), 500
|
|||
|
|
|||
|
@app.route('/api/prediction/history', methods=['GET'])
|
|||
|
def get_prediction_history():
|
|||
|
"""获取历史预测记录列表"""
|
|||
|
try:
|
|||
|
# 获取查询参数
|
|||
|
product_id = request.args.get('product_id')
|
|||
|
model_type = request.args.get('model_type')
|
|||
|
page = int(request.args.get('page', 1))
|
|||
|
page_size = int(request.args.get('page_size', 10))
|
|||
|
|
|||
|
# 计算分页偏移量
|
|||
|
offset = (page - 1) * page_size
|
|||
|
|
|||
|
# 连接数据库
|
|||
|
conn = get_db_connection()
|
|||
|
cursor = conn.cursor()
|
|||
|
|
|||
|
# 构建查询条件
|
|||
|
query_conditions = []
|
|||
|
query_params = []
|
|||
|
|
|||
|
if product_id:
|
|||
|
query_conditions.append("product_id = ?")
|
|||
|
query_params.append(product_id)
|
|||
|
|
|||
|
if model_type:
|
|||
|
query_conditions.append("model_type = ?")
|
|||
|
query_params.append(model_type)
|
|||
|
|
|||
|
# 构建完整的查询语句
|
|||
|
query = "SELECT * FROM prediction_history"
|
|||
|
if query_conditions:
|
|||
|
query += " WHERE " + " AND ".join(query_conditions)
|
|||
|
|
|||
|
# 添加排序和分页
|
|||
|
query += " ORDER BY created_at DESC LIMIT ? OFFSET ?"
|
|||
|
query_params.extend([page_size, offset])
|
|||
|
|
|||
|
# 执行查询
|
|||
|
cursor.execute(query, query_params)
|
|||
|
records = cursor.fetchall()
|
|||
|
|
|||
|
# 获取总记录数
|
|||
|
count_query = "SELECT COUNT(*) FROM prediction_history"
|
|||
|
if query_conditions:
|
|||
|
count_query += " WHERE " + " AND ".join(query_conditions)
|
|||
|
|
|||
|
cursor.execute(count_query, query_params[:-2] if query_params else [])
|
|||
|
total_count = cursor.fetchone()[0]
|
|||
|
|
|||
|
# 转换结果为字典列表
|
|||
|
history_records = []
|
|||
|
for record in records:
|
|||
|
history_records.append({
|
|||
|
'id': record[0],
|
|||
|
'product_id': record[1],
|
|||
|
'product_name': record[2],
|
|||
|
'model_type': record[3],
|
|||
|
'model_id': record[4],
|
|||
|
'start_date': record[5],
|
|||
|
'future_days': record[6],
|
|||
|
'created_at': record[7],
|
|||
|
'file_path': record[8]
|
|||
|
})
|
|||
|
|
|||
|
conn.close()
|
|||
|
|
|||
|
return jsonify({
|
|||
|
"status": "success",
|
|||
|
"data": history_records,
|
|||
|
"total": total_count,
|
|||
|
"page": page,
|
|||
|
"page_size": page_size
|
|||
|
})
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
print(f"获取历史预测记录失败: {str(e)}")
|
|||
|
traceback.print_exc()
|
|||
|
return jsonify({"status": "error", "message": str(e)}), 500
|
|||
|
|
|||
|
@app.route('/api/prediction/history/<prediction_id>', methods=['GET'])
|
|||
|
def get_prediction_details(prediction_id):
|
|||
|
"""获取特定预测记录的详情"""
|
|||
|
try:
|
|||
|
print(f"正在获取预测记录详情,ID: {prediction_id}")
|
|||
|
|
|||
|
# 连接数据库
|
|||
|
conn = get_db_connection()
|
|||
|
cursor = conn.cursor()
|
|||
|
|
|||
|
# 查询预测记录元数据
|
|||
|
cursor.execute("""
|
|||
|
SELECT product_id, product_name, model_type, model_id,
|
|||
|
start_date, future_days, created_at, file_path
|
|||
|
FROM prediction_history WHERE id = ?
|
|||
|
""", (prediction_id,))
|
|||
|
record = cursor.fetchone()
|
|||
|
|
|||
|
if not record:
|
|||
|
print(f"预测记录不存在: {prediction_id}")
|
|||
|
conn.close()
|
|||
|
return jsonify({"status": "error", "message": "预测记录不存在"}), 404
|
|||
|
|
|||
|
# 提取元数据
|
|||
|
product_id = record['product_id']
|
|||
|
product_name = record['product_name']
|
|||
|
model_type = record['model_type']
|
|||
|
model_id = record['model_id']
|
|||
|
start_date = record['start_date']
|
|||
|
future_days = record['future_days']
|
|||
|
created_at = record['created_at']
|
|||
|
file_path = record['file_path']
|
|||
|
|
|||
|
conn.close()
|
|||
|
|
|||
|
print(f"正在读取预测结果文件: {file_path}")
|
|||
|
|
|||
|
if not os.path.exists(file_path):
|
|||
|
print(f"预测结果文件不存在: {file_path}")
|
|||
|
return jsonify({"status": "error", "message": "预测结果文件不存在"}), 404
|
|||
|
|
|||
|
# 读取保存的JSON文件内容
|
|||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|||
|
prediction_data = json.load(f)
|
|||
|
|
|||
|
# 构建与预测分析接口一致的响应格式
|
|||
|
response_data = {
|
|||
|
"status": "success",
|
|||
|
"meta": {
|
|||
|
"product_id": product_id,
|
|||
|
"product_name": product_name,
|
|||
|
"model_type": model_type,
|
|||
|
"model_id": model_id,
|
|||
|
"start_date": start_date,
|
|||
|
"future_days": future_days,
|
|||
|
"created_at": created_at
|
|||
|
},
|
|||
|
"data": {
|
|||
|
"prediction_data": [],
|
|||
|
"history_data": [],
|
|||
|
"data": []
|
|||
|
},
|
|||
|
"analysis": prediction_data.get('analysis', {}),
|
|||
|
"chart_data": prediction_data.get('chart_data', {})
|
|||
|
}
|
|||
|
|
|||
|
# 处理预测数据
|
|||
|
if 'prediction_data' in prediction_data and isinstance(prediction_data['prediction_data'], list):
|
|||
|
response_data['data']['prediction_data'] = prediction_data['prediction_data']
|
|||
|
|
|||
|
# 处理历史数据
|
|||
|
if 'history_data' in prediction_data and isinstance(prediction_data['history_data'], list):
|
|||
|
response_data['data']['history_data'] = prediction_data['history_data']
|
|||
|
|
|||
|
# 处理合并的数据
|
|||
|
if 'data' in prediction_data and isinstance(prediction_data['data'], list):
|
|||
|
response_data['data']['data'] = prediction_data['data']
|
|||
|
else:
|
|||
|
# 如果没有合并数据,从历史和预测数据中构建
|
|||
|
history_data = response_data['data']['history_data']
|
|||
|
pred_data = response_data['data']['prediction_data']
|
|||
|
response_data['data']['data'] = history_data + pred_data
|
|||
|
|
|||
|
# 确保所有数据字段都存在且格式正确
|
|||
|
for key in ['prediction_data', 'history_data', 'data']:
|
|||
|
if not isinstance(response_data['data'][key], list):
|
|||
|
response_data['data'][key] = []
|
|||
|
|
|||
|
# 添加兼容性字段(直接在根级别)
|
|||
|
response_data.update({
|
|||
|
'product_id': product_id,
|
|||
|
'product_name': product_name,
|
|||
|
'model_type': model_type,
|
|||
|
'start_date': start_date,
|
|||
|
'created_at': created_at
|
|||
|
})
|
|||
|
|
|||
|
print(f"成功获取预测详情,产品: {product_name}, 模型: {model_type}")
|
|||
|
return jsonify(response_data)
|
|||
|
|
|||
|
except json.JSONDecodeError as e:
|
|||
|
print(f"预测结果文件JSON解析错误: {e}")
|
|||
|
return jsonify({"status": "error", "message": f"预测结果文件格式错误: {str(e)}"}), 500
|
|||
|
except Exception as e:
|
|||
|
print(f"获取预测详情失败: {str(e)}")
|
|||
|
traceback.print_exc()
|
|||
|
return jsonify({"status": "error", "message": str(e)}), 500
|
|||
|
|
|||
|
@app.route('/api/prediction/history/<prediction_id>', methods=['DELETE'])
|
|||
|
def delete_prediction(prediction_id):
|
|||
|
"""删除预测记录"""
|
|||
|
try:
|
|||
|
# 连接数据库
|
|||
|
conn = get_db_connection()
|
|||
|
cursor = conn.cursor()
|
|||
|
|
|||
|
# 查询预测记录
|
|||
|
cursor.execute("SELECT file_path FROM prediction_history WHERE id = ?", (prediction_id,))
|
|||
|
record = cursor.fetchone()
|
|||
|
|
|||
|
if not record:
|
|||
|
conn.close()
|
|||
|
return jsonify({"status": "error", "message": "预测记录不存在"}), 404
|
|||
|
|
|||
|
file_path = record[0]
|
|||
|
|
|||
|
# 删除数据库记录
|
|||
|
cursor.execute("DELETE FROM prediction_history WHERE id = ?", (prediction_id,))
|
|||
|
conn.commit()
|
|||
|
conn.close()
|
|||
|
|
|||
|
# 删除预测结果文件
|
|||
|
if os.path.exists(file_path):
|
|||
|
os.remove(file_path)
|
|||
|
|
|||
|
return jsonify({
|
|||
|
"status": "success",
|
|||
|
"message": f"预测记录 {prediction_id} 已删除"
|
|||
|
})
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
print(f"删除预测记录失败: {str(e)}")
|
|||
|
traceback.print_exc()
|
|||
|
return jsonify({"status": "error", "message": str(e)}), 500
|
|||
|
|
|||
|
# 4. 模型管理API
|
|||
|
@app.route('/api/models', methods=['GET'])
|
|||
|
@swag_from({
|
|||
|
'tags': ['模型管理'],
|
|||
|
'summary': '获取模型列表',
|
|||
|
'description': '获取系统中的模型列表,可按产品ID和模型类型筛选',
|
|||
|
'parameters': [
|
|||
|
{
|
|||
|
'name': 'product_id',
|
|||
|
'in': 'query',
|
|||
|
'type': 'string',
|
|||
|
'required': False,
|
|||
|
'description': '按产品ID筛选'
|
|||
|
},
|
|||
|
{
|
|||
|
'name': 'model_type',
|
|||
|
'in': 'query',
|
|||
|
'type': 'string',
|
|||
|
'required': False,
|
|||
|
'description': "按模型类型筛选 (mlstm, kan, transformer, tcn)"
|
|||
|
}
|
|||
|
],
|
|||
|
'responses': {
|
|||
|
200: {
|
|||
|
'description': '成功获取模型列表',
|
|||
|
'schema': {
|
|||
|
'type': 'object',
|
|||
|
'properties': {
|
|||
|
'status': {'type': 'string'},
|
|||
|
'data': {
|
|||
|
'type': 'array',
|
|||
|
'items': {
|
|||
|
'type': 'object',
|
|||
|
'properties': {
|
|||
|
'model_id': {'type': 'string'},
|
|||
|
'product_id': {'type': 'string'},
|
|||
|
'product_name': {'type': 'string'},
|
|||
|
'model_type': {'type': 'string'},
|
|||
|
'created_at': {'type': 'string'},
|
|||
|
'metrics': {'type': 'object'}
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
},
|
|||
|
500: {
|
|||
|
'description': '服务器内部错误'
|
|||
|
}
|
|||
|
}
|
|||
|
})
|
|||
|
def list_models():
|
|||
|
"""
|
|||
|
列出所有可用的模型
|
|||
|
---
|
|||
|
tags:
|
|||
|
- 模型管理
|
|||
|
parameters:
|
|||
|
- name: product_id
|
|||
|
in: query
|
|||
|
type: string
|
|||
|
required: false
|
|||
|
description: 按产品ID筛选
|
|||
|
- name: model_type
|
|||
|
in: query
|
|||
|
type: string
|
|||
|
required: false
|
|||
|
description: "按模型类型筛选 (mlstm, kan, transformer, tcn)"
|
|||
|
responses:
|
|||
|
200:
|
|||
|
description: 模型列表
|
|||
|
schema:
|
|||
|
type: object
|
|||
|
properties:
|
|||
|
status:
|
|||
|
type: string
|
|||
|
example: success
|
|||
|
data:
|
|||
|
type: array
|
|||
|
items:
|
|||
|
type: object
|
|||
|
properties:
|
|||
|
model_id:
|
|||
|
type: string
|
|||
|
product_id:
|
|||
|
type: string
|
|||
|
product_name:
|
|||
|
type: string
|
|||
|
model_type:
|
|||
|
type: string
|
|||
|
created_at:
|
|||
|
type: string
|
|||
|
metrics:
|
|||
|
type: object
|
|||
|
"""
|
|||
|
# 首先尝试从app配置中获取模型目录
|
|||
|
models_dir = app.config.get('MODEL_DIR', DEFAULT_MODEL_DIR)
|
|||
|
|
|||
|
# 检查models_dir是否存在,如果不存在,使用DEFAULT_MODEL_DIR作为后备
|
|||
|
if not os.path.exists(models_dir) and os.path.exists(DEFAULT_MODEL_DIR):
|
|||
|
print(f"警告: 配置的模型目录 '{models_dir}' 不存在,使用默认目录 '{DEFAULT_MODEL_DIR}'")
|
|||
|
models_dir = DEFAULT_MODEL_DIR
|
|||
|
|
|||
|
print(f"正在从目录 '{models_dir}' 读取模型文件")
|
|||
|
available_models = []
|
|||
|
|
|||
|
product_id_filter = request.args.get('product_id')
|
|||
|
model_type_filter = request.args.get('model_type')
|
|||
|
|
|||
|
if not os.path.exists(models_dir):
|
|||
|
print(f"错误: 模型目录 '{models_dir}' 不存在")
|
|||
|
return jsonify({"status": "success", "data": []})
|
|||
|
|
|||
|
# 直接从saved_models目录读取模型文件
|
|||
|
for file_name in os.listdir(models_dir):
|
|||
|
if file_name.endswith('.pth'):
|
|||
|
# 解析文件名获取模型类型和产品ID
|
|||
|
if '_model_product_' in file_name:
|
|||
|
parts = file_name.split('_model_product_')
|
|||
|
model_type = parts[0]
|
|||
|
product_id = parts[1].replace('.pth', '')
|
|||
|
|
|||
|
# 处理优化版KAN模型
|
|||
|
if model_type == 'kan_optimized':
|
|||
|
model_type = 'optimized_kan'
|
|||
|
|
|||
|
# 应用过滤器
|
|||
|
if product_id_filter and product_id != product_id_filter:
|
|||
|
continue
|
|||
|
if model_type_filter and model_type != model_type_filter:
|
|||
|
continue
|
|||
|
|
|||
|
# 获取文件创建时间
|
|||
|
file_path = os.path.join(models_dir, file_name)
|
|||
|
created_at = datetime.fromtimestamp(os.path.getctime(file_path)).isoformat()
|
|||
|
|
|||
|
# 获取产品名称
|
|||
|
product_name = get_product_name(product_id) or f"产品 {product_id}"
|
|||
|
|
|||
|
# 尝试加载模型文件获取指标
|
|||
|
metrics = {}
|
|||
|
try:
|
|||
|
# 添加weights_only=False参数,解决PyTorch 2.6序列化问题
|
|||
|
checkpoint = torch.load(file_path, map_location='cpu', weights_only=False)
|
|||
|
|
|||
|
# 尝试从不同位置提取评估指标
|
|||
|
if isinstance(checkpoint, dict):
|
|||
|
if 'metrics' in checkpoint and isinstance(checkpoint['metrics'], dict):
|
|||
|
metrics = checkpoint['metrics']
|
|||
|
elif 'test_metrics' in checkpoint and isinstance(checkpoint['test_metrics'], dict):
|
|||
|
metrics = checkpoint['test_metrics']
|
|||
|
elif 'eval_metrics' in checkpoint and isinstance(checkpoint['eval_metrics'], dict):
|
|||
|
metrics = checkpoint['eval_metrics']
|
|||
|
elif 'model_metrics' in checkpoint and isinstance(checkpoint['model_metrics'], dict):
|
|||
|
metrics = checkpoint['model_metrics']
|
|||
|
|
|||
|
# 如果模型是PyTorch模型,尝试提取state_dict中的指标
|
|||
|
if 'state_dict' in checkpoint and isinstance(checkpoint['state_dict'], dict):
|
|||
|
for key, value in checkpoint['state_dict'].items():
|
|||
|
if key.endswith('_metric') and isinstance(value, (int, float)):
|
|||
|
metric_name = key.replace('_metric', '').upper()
|
|||
|
metrics[metric_name] = value.item() if hasattr(value, 'item') else value
|
|||
|
|
|||
|
# 如果没有找到任何指标,使用模拟数据
|
|||
|
if not metrics:
|
|||
|
metrics = {
|
|||
|
"R2": 0,
|
|||
|
"RMSE": 0,
|
|||
|
"MAE": 0,
|
|||
|
"MAPE": 0
|
|||
|
}
|
|||
|
except Exception as e:
|
|||
|
print(f"读取模型文件 {file_path} 失败: {e}")
|
|||
|
# 使用模拟数据
|
|||
|
metrics = {
|
|||
|
"R2": 0,
|
|||
|
"RMSE": 0,
|
|||
|
"MAE": 0,
|
|||
|
"MAPE": 0
|
|||
|
}
|
|||
|
|
|||
|
model_info = {
|
|||
|
"model_id": f"{model_type}_{product_id}",
|
|||
|
"product_id": product_id,
|
|||
|
"product_name": product_name,
|
|||
|
"model_type": model_type,
|
|||
|
"created_at": created_at,
|
|||
|
"metrics": metrics,
|
|||
|
"file_path": file_path
|
|||
|
}
|
|||
|
available_models.append(model_info)
|
|||
|
|
|||
|
# 按创建时间降序排序
|
|||
|
available_models.sort(key=lambda x: x.get('created_at', ''), reverse=True)
|
|||
|
|
|||
|
print(f"找到 {len(available_models)} 个模型")
|
|||
|
return jsonify({"status": "success", "data": available_models})
|
|||
|
|
|||
|
@app.route('/api/models/<model_id>', methods=['GET'])
|
|||
|
@swag_from({
|
|||
|
'tags': ['模型管理'],
|
|||
|
'summary': '获取模型详情',
|
|||
|
'description': '获取特定模型的详细信息',
|
|||
|
'parameters': [
|
|||
|
{
|
|||
|
'name': 'model_id',
|
|||
|
'in': 'path',
|
|||
|
'type': 'string',
|
|||
|
'required': True,
|
|||
|
'description': '模型ID,格式为{product_id}_{model_type}_v{version}'
|
|||
|
}
|
|||
|
],
|
|||
|
'responses': {
|
|||
|
200: {
|
|||
|
'description': '成功获取模型详情',
|
|||
|
'schema': {
|
|||
|
'type': 'object',
|
|||
|
'properties': {
|
|||
|
'status': {'type': 'string'},
|
|||
|
'data': {
|
|||
|
'type': 'object',
|
|||
|
'properties': {
|
|||
|
'product_id': {'type': 'string'},
|
|||
|
'model_type': {'type': 'string'},
|
|||
|
'version': {'type': 'string'},
|
|||
|
'created_at': {'type': 'string'},
|
|||
|
'file_path': {'type': 'string'},
|
|||
|
'file_size': {'type': 'string'},
|
|||
|
'features': {'type': 'array'},
|
|||
|
'look_back': {'type': 'integer'},
|
|||
|
'T': {'type': 'integer'},
|
|||
|
'metrics': {'type': 'object'}
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
},
|
|||
|
400: {
|
|||
|
'description': '无效的模型ID格式'
|
|||
|
},
|
|||
|
404: {
|
|||
|
'description': '模型不存在'
|
|||
|
},
|
|||
|
500: {
|
|||
|
'description': '服务器内部错误'
|
|||
|
}
|
|||
|
}
|
|||
|
})
|
|||
|
def get_model_details(model_id):
|
|||
|
"""
|
|||
|
获取单个模型的详细信息
|
|||
|
---
|
|||
|
tags:
|
|||
|
- 模型管理
|
|||
|
parameters:
|
|||
|
- name: model_id
|
|||
|
in: path
|
|||
|
type: string
|
|||
|
required: true
|
|||
|
description: "模型的唯一标识符 (格式: model_type_product_id)"
|
|||
|
responses:
|
|||
|
200:
|
|||
|
description: 模型的详细信息
|
|||
|
404:
|
|||
|
description: 未找到模型
|
|||
|
"""
|
|||
|
try:
|
|||
|
model_type, product_id = model_id.split('_', 1)
|
|||
|
|
|||
|
# 处理优化版KAN模型的文件名
|
|||
|
file_model_type = model_type
|
|||
|
if model_type == 'optimized_kan':
|
|||
|
file_model_type = 'kan_optimized'
|
|||
|
|
|||
|
# 首先尝试从app配置中获取模型目录
|
|||
|
models_dir = app.config.get('MODEL_DIR', DEFAULT_MODEL_DIR)
|
|||
|
|
|||
|
# 检查models_dir是否存在,如果不存在,使用DEFAULT_MODEL_DIR作为后备
|
|||
|
if not os.path.exists(models_dir) and os.path.exists(DEFAULT_MODEL_DIR):
|
|||
|
print(f"警告: 配置的模型目录 '{models_dir}' 不存在,使用默认目录 '{DEFAULT_MODEL_DIR}'")
|
|||
|
models_dir = DEFAULT_MODEL_DIR
|
|||
|
|
|||
|
# 构建模型文件路径
|
|||
|
model_path = os.path.join(models_dir, f'{file_model_type}_model_product_{product_id}.pth')
|
|||
|
|
|||
|
if not os.path.exists(model_path):
|
|||
|
return jsonify({"status": "error", "error": "模型未找到"}), 404
|
|||
|
|
|||
|
# 加载模型文件
|
|||
|
try:
|
|||
|
# 添加weights_only=False参数,解决PyTorch 2.6序列化问题
|
|||
|
checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
|
|||
|
|
|||
|
# 提取模型信息
|
|||
|
model_info = {
|
|||
|
"model_id": model_id,
|
|||
|
"product_id": product_id,
|
|||
|
"model_type": model_type,
|
|||
|
"created_at": datetime.fromtimestamp(os.path.getctime(model_path)).isoformat(),
|
|||
|
"file_path": model_path,
|
|||
|
"file_size": f"{os.path.getsize(model_path) / (1024 * 1024):.2f} MB"
|
|||
|
}
|
|||
|
|
|||
|
# 如果checkpoint是字典,提取其中的信息
|
|||
|
if isinstance(checkpoint, dict):
|
|||
|
# 提取配置信息
|
|||
|
if 'config' in checkpoint:
|
|||
|
config = checkpoint['config']
|
|||
|
for key, value in config.items():
|
|||
|
model_info[key] = value
|
|||
|
|
|||
|
# 提取评估指标
|
|||
|
if 'metrics' in checkpoint:
|
|||
|
model_info['metrics'] = checkpoint['metrics']
|
|||
|
|
|||
|
# 获取产品名称
|
|||
|
product_name = get_product_name(product_id)
|
|||
|
if product_name:
|
|||
|
model_info['product_name'] = product_name
|
|||
|
|
|||
|
return jsonify({"status": "success", "data": model_info})
|
|||
|
except Exception as e:
|
|||
|
print(f"加载模型文件失败: {str(e)}")
|
|||
|
return jsonify({"status": "error", "error": f"加载模型文件失败: {str(e)}"}), 500
|
|||
|
except ValueError:
|
|||
|
return jsonify({"status": "error", "error": "无效的model_id格式"}), 400
|
|||
|
except Exception as e:
|
|||
|
return jsonify({"status": "error", "error": f"获取模型详情失败: {e}"}), 500
|
|||
|
|
|||
|
@app.route('/api/models/<model_id>', methods=['DELETE'])
|
|||
|
@swag_from({
|
|||
|
'tags': ['模型管理'],
|
|||
|
'summary': '删除模型',
|
|||
|
'description': '删除特定模型',
|
|||
|
'parameters': [
|
|||
|
{
|
|||
|
'name': 'model_id',
|
|||
|
'in': 'path',
|
|||
|
'type': 'string',
|
|||
|
'required': True,
|
|||
|
'description': '模型ID,格式为{product_id}_{model_type}_v{version}'
|
|||
|
}
|
|||
|
],
|
|||
|
'responses': {
|
|||
|
200: {
|
|||
|
'description': '模型删除成功',
|
|||
|
'schema': {
|
|||
|
'type': 'object',
|
|||
|
'properties': {
|
|||
|
'status': {'type': 'string'},
|
|||
|
'message': {'type': 'string'}
|
|||
|
}
|
|||
|
}
|
|||
|
},
|
|||
|
400: {
|
|||
|
'description': '无效的模型ID格式'
|
|||
|
},
|
|||
|
500: {
|
|||
|
'description': '服务器内部错误'
|
|||
|
}
|
|||
|
}
|
|||
|
})
|
|||
|
def delete_model(model_id):
|
|||
|
"""
|
|||
|
删除一个模型及其关联文件
|
|||
|
---
|
|||
|
tags:
|
|||
|
- 模型管理
|
|||
|
parameters:
|
|||
|
- name: model_id
|
|||
|
in: path
|
|||
|
type: string
|
|||
|
required: true
|
|||
|
description: "要删除的模型的ID (格式: model_type_product_id)"
|
|||
|
responses:
|
|||
|
200:
|
|||
|
description: 模型删除成功
|
|||
|
404:
|
|||
|
description: 模型未找到
|
|||
|
"""
|
|||
|
try:
|
|||
|
model_type, product_id = model_id.split('_', 1)
|
|||
|
|
|||
|
# 处理优化版KAN模型的文件名
|
|||
|
file_model_type = model_type
|
|||
|
if model_type == 'optimized_kan':
|
|||
|
file_model_type = 'kan_optimized'
|
|||
|
|
|||
|
# 首先尝试从app配置中获取模型目录
|
|||
|
models_dir = app.config.get('MODEL_DIR', DEFAULT_MODEL_DIR)
|
|||
|
|
|||
|
# 检查models_dir是否存在,如果不存在,使用DEFAULT_MODEL_DIR作为后备
|
|||
|
if not os.path.exists(models_dir) and os.path.exists(DEFAULT_MODEL_DIR):
|
|||
|
print(f"警告: 配置的模型目录 '{models_dir}' 不存在,使用默认目录 '{DEFAULT_MODEL_DIR}'")
|
|||
|
models_dir = DEFAULT_MODEL_DIR
|
|||
|
|
|||
|
# 构建模型文件路径
|
|||
|
model_path = os.path.join(models_dir, f'{file_model_type}_model_product_{product_id}.pth')
|
|||
|
|
|||
|
if not os.path.exists(model_path):
|
|||
|
return jsonify({"status": "error", "error": "模型未找到"}), 404
|
|||
|
|
|||
|
# 删除模型文件
|
|||
|
os.remove(model_path)
|
|||
|
|
|||
|
return jsonify({"status": "success", "message": f"模型 {model_id} 已删除"})
|
|||
|
except ValueError:
|
|||
|
return jsonify({"status": "error", "error": "无效的model_id格式"}), 400
|
|||
|
except Exception as e:
|
|||
|
return jsonify({"status": "error", "error": f"删除模型失败: {e}"}), 500
|
|||
|
|
|||
|
@app.route('/api/models/<model_id>/export', methods=['GET'])
|
|||
|
@swag_from({
|
|||
|
'tags': ['模型管理'],
|
|||
|
'summary': '导出模型',
|
|||
|
'description': '导出特定模型文件',
|
|||
|
'parameters': [
|
|||
|
{
|
|||
|
'name': 'model_id',
|
|||
|
'in': 'path',
|
|||
|
'type': 'string',
|
|||
|
'required': True,
|
|||
|
'description': '模型ID,格式为{product_id}_{model_type}_v{version}'
|
|||
|
}
|
|||
|
],
|
|||
|
'responses': {
|
|||
|
200: {
|
|||
|
'description': '模型文件下载',
|
|||
|
'content': {
|
|||
|
'application/octet-stream': {}
|
|||
|
}
|
|||
|
},
|
|||
|
400: {
|
|||
|
'description': '无效的模型ID格式'
|
|||
|
},
|
|||
|
500: {
|
|||
|
'description': '服务器内部错误'
|
|||
|
}
|
|||
|
}
|
|||
|
})
|
|||
|
def export_model(model_id):
|
|||
|
try:
|
|||
|
model_type, product_id = model_id.split('_', 1)
|
|||
|
|
|||
|
# 处理优化版KAN模型的文件名
|
|||
|
file_model_type = model_type
|
|||
|
if model_type == 'optimized_kan':
|
|||
|
file_model_type = 'kan_optimized'
|
|||
|
|
|||
|
# 首先尝试从app配置中获取模型目录
|
|||
|
models_dir = app.config.get('MODEL_DIR', DEFAULT_MODEL_DIR)
|
|||
|
|
|||
|
# 检查models_dir是否存在,如果不存在,使用DEFAULT_MODEL_DIR作为后备
|
|||
|
if not os.path.exists(models_dir) and os.path.exists(DEFAULT_MODEL_DIR):
|
|||
|
print(f"警告: 配置的模型目录 '{models_dir}' 不存在,使用默认目录 '{DEFAULT_MODEL_DIR}'")
|
|||
|
models_dir = DEFAULT_MODEL_DIR
|
|||
|
|
|||
|
# 构建模型文件路径
|
|||
|
model_path = os.path.join(models_dir, f'{file_model_type}_model_product_{product_id}.pth')
|
|||
|
|
|||
|
if not os.path.exists(model_path):
|
|||
|
return jsonify({"status": "error", "error": "模型文件未找到"}), 404
|
|||
|
|
|||
|
return send_file(
|
|||
|
model_path,
|
|||
|
as_attachment=True,
|
|||
|
download_name=f'{model_id}.pth',
|
|||
|
mimetype='application/octet-stream'
|
|||
|
)
|
|||
|
except Exception as e:
|
|||
|
return jsonify({"status": "error", "error": f"导出模型失败: {e}"}), 500
|
|||
|
|
|||
|
@app.route('/api/plots/<filename>')
|
|||
|
def get_plot(filename):
|
|||
|
"""Serve a plot file from the root directory."""
|
|||
|
try:
|
|||
|
return send_from_directory(app.root_path, filename)
|
|||
|
except Exception as e:
|
|||
|
traceback.print_exc()
|
|||
|
return jsonify({"status": "error", "error": str(e)}), 500
|
|||
|
|
|||
|
def get_latest_model_id(model_type, product_id):
|
|||
|
"""根据模型类型和产品ID获取最新的模型ID"""
|
|||
|
try:
|
|||
|
# 处理优化版KAN模型的文件名
|
|||
|
file_model_type = model_type
|
|||
|
if model_type == 'optimized_kan':
|
|||
|
file_model_type = 'kan_optimized'
|
|||
|
print(f"优化版KAN模型: 当查找最新模型ID时,使用文件名 '{file_model_type}_model_product_{product_id}.pth'")
|
|||
|
|
|||
|
# 首先尝试从app配置中获取模型目录
|
|||
|
models_dir = app.config.get('MODEL_DIR', DEFAULT_MODEL_DIR)
|
|||
|
|
|||
|
# 检查models_dir是否存在,如果不存在,使用DEFAULT_MODEL_DIR作为后备
|
|||
|
if not os.path.exists(models_dir) and os.path.exists(DEFAULT_MODEL_DIR):
|
|||
|
print(f"警告: 配置的模型目录 '{models_dir}' 不存在,使用默认目录 '{DEFAULT_MODEL_DIR}'")
|
|||
|
models_dir = DEFAULT_MODEL_DIR
|
|||
|
|
|||
|
# 构建模型文件路径
|
|||
|
model_path = os.path.join(models_dir, f'{file_model_type}_model_product_{product_id}.pth')
|
|||
|
|
|||
|
# 检查模型文件是否存在
|
|||
|
if os.path.exists(model_path):
|
|||
|
return f"{model_type}_{product_id}"
|
|||
|
|
|||
|
print(f"模型文件不存在: {model_path}")
|
|||
|
return None
|
|||
|
except Exception as e:
|
|||
|
print(f"获取最新模型ID失败: {str(e)}")
|
|||
|
return None
|
|||
|
|
|||
|
# 获取产品名称的辅助函数
|
|||
|
def get_product_name(product_id):
|
|||
|
"""根据产品ID获取产品名称"""
|
|||
|
try:
|
|||
|
# 从Excel文件中查找产品名称
|
|||
|
df = pd.read_excel('pharmacy_sales.xlsx')
|
|||
|
product_df = df[df['product_id'] == product_id]
|
|||
|
if not product_df.empty:
|
|||
|
return product_df['product_name'].iloc[0]
|
|||
|
|
|||
|
return None
|
|||
|
except Exception as e:
|
|||
|
print(f"获取产品名称失败: {str(e)}")
|
|||
|
return None
|
|||
|
|
|||
|
# 执行预测的辅助函数
|
|||
|
def run_prediction(model_type, product_id, model_id, future_days, start_date):
|
|||
|
"""执行模型预测"""
|
|||
|
try:
|
|||
|
# 创建预测器实例
|
|||
|
predictor = PharmacyPredictor()
|
|||
|
|
|||
|
# 限制预测天数不超过7天
|
|||
|
if future_days > 7:
|
|||
|
print(f"预测天数超过7天,限制为7天,原始天数: {future_days}")
|
|||
|
future_days = 7
|
|||
|
|
|||
|
# 使用预测器进行预测
|
|||
|
print(f"开始使用 {model_type} 模型预测产品 {product_id} 的销量...")
|
|||
|
result = predictor.predict(
|
|||
|
product_id=product_id,
|
|||
|
model_type=model_type,
|
|||
|
future_days=future_days,
|
|||
|
start_date=start_date,
|
|||
|
analyze_result=True
|
|||
|
)
|
|||
|
|
|||
|
# 处理返回值可能是None的情况
|
|||
|
if result is None:
|
|||
|
print(f"预测失败: 模型 '{model_type}' 类型的模型文件未找到或加载失败")
|
|||
|
return {"status": "error", "error": f"模型 '{model_type}' 类型的模型文件未找到或加载失败"}
|
|||
|
|
|||
|
# 获取预测数据
|
|||
|
if 'predictions' not in result or result['predictions'] is None:
|
|||
|
print("预测失败: 预测结果中没有predictions字段")
|
|||
|
return {"status": "error", "error": "预测结果中没有predictions字段"}
|
|||
|
|
|||
|
predictions_df = result['predictions']
|
|||
|
print(f"获取到预测数据,形状: {predictions_df.shape}")
|
|||
|
|
|||
|
# 确保日期是日期时间格式以便处理
|
|||
|
if 'date' in predictions_df.columns and not pd.api.types.is_datetime64_any_dtype(predictions_df['date']):
|
|||
|
predictions_df['date'] = pd.to_datetime(predictions_df['date'])
|
|||
|
|
|||
|
# 创建数据类型列,区分历史数据和预测数据
|
|||
|
# 由于load_model_and_predict返回的是预测数据,我们需要标记这些数据为预测销量
|
|||
|
predictions_df['data_type'] = '预测销量'
|
|||
|
|
|||
|
# 获取产品历史数据,用于展示历史趋势
|
|||
|
try:
|
|||
|
df = pd.read_excel('pharmacy_sales.xlsx')
|
|||
|
history_df = df[df['product_id'] == product_id].sort_values('date').copy()
|
|||
|
|
|||
|
# 如果历史数据不为空,添加到预测数据中,但只保留最近一个月(30天)
|
|||
|
if not history_df.empty:
|
|||
|
# 选择需要的列
|
|||
|
history_df = history_df[['date', 'sales']].copy()
|
|||
|
history_df['data_type'] = '历史销量'
|
|||
|
|
|||
|
# 只保留最近一个月(30天)的历史数据
|
|||
|
if len(history_df) > 30:
|
|||
|
print(f"历史数据超过30天,只保留最近30天,原始数量: {len(history_df)}")
|
|||
|
history_df = history_df.iloc[-30:].reset_index(drop=True)
|
|||
|
|
|||
|
# 合并历史数据和预测数据
|
|||
|
predictions_df = pd.concat([history_df, predictions_df], ignore_index=True)
|
|||
|
|
|||
|
# 按日期排序
|
|||
|
predictions_df = predictions_df.sort_values('date').reset_index(drop=True)
|
|||
|
except Exception as e:
|
|||
|
print(f"获取历史数据失败: {str(e)}")
|
|||
|
# 历史数据获取失败不影响预测结果,继续执行
|
|||
|
|
|||
|
# 将处理后的DataFrame放回结果
|
|||
|
result['predictions_df'] = predictions_df
|
|||
|
|
|||
|
# 准备响应数据 - 确保转换为Python原生类型
|
|||
|
response_data = {
|
|||
|
"status": "success"
|
|||
|
}
|
|||
|
|
|||
|
# 转换DataFrame为字典,确保没有NumPy类型
|
|||
|
data_records = predictions_df.to_dict(orient='records')
|
|||
|
# 进一步处理,确保所有值都是JSON可序列化的
|
|||
|
for record in data_records:
|
|||
|
for key, value in record.items():
|
|||
|
if isinstance(value, np.generic):
|
|||
|
record[key] = value.item() # 将NumPy标量转换为Python原生类型
|
|||
|
elif pd.isna(value):
|
|||
|
record[key] = None
|
|||
|
|
|||
|
response_data["data"] = data_records
|
|||
|
|
|||
|
# 分离历史数据和预测数据
|
|||
|
history_data = [record for record in data_records if record.get('data_type') == '历史销量']
|
|||
|
prediction_data = [record for record in data_records if record.get('data_type') == '预测销量']
|
|||
|
|
|||
|
response_data["history_data"] = history_data
|
|||
|
response_data["prediction_data"] = prediction_data
|
|||
|
|
|||
|
# 如果有分析结果,添加到响应数据中,并确保JSON可序列化
|
|||
|
if 'analysis' in result and result['analysis']:
|
|||
|
analysis_data = result['analysis']
|
|||
|
# 处理分析数据中可能的NumPy类型
|
|||
|
if isinstance(analysis_data, dict):
|
|||
|
for key, value in list(analysis_data.items()):
|
|||
|
if isinstance(value, np.generic):
|
|||
|
analysis_data[key] = value.item()
|
|||
|
elif isinstance(value, (list, np.ndarray)):
|
|||
|
analysis_data[key] = [item.item() if isinstance(item, np.generic) else item for item in value]
|
|||
|
|
|||
|
response_data['analysis'] = analysis_data
|
|||
|
|
|||
|
return response_data
|
|||
|
except Exception as e:
|
|||
|
print(f"执行预测失败: {str(e)}")
|
|||
|
import traceback
|
|||
|
traceback.print_exc()
|
|||
|
return {"status": "error", "error": f"执行预测失败: {str(e)}"}
|
|||
|
|
|||
|
# 添加新的API路由,支持/api/models/{model_type}/{product_id}/details格式
|
|||
|
@app.route('/api/models/<model_type>/<product_id>/details', methods=['GET'])
|
|||
|
@swag_from({
|
|||
|
'tags': ['模型管理'],
|
|||
|
'summary': '获取模型详情(兼容格式)',
|
|||
|
'description': '获取特定模型的详细信息(使用模型类型和产品ID)',
|
|||
|
'parameters': [
|
|||
|
{
|
|||
|
'name': 'model_type',
|
|||
|
'in': 'path',
|
|||
|
'type': 'string',
|
|||
|
'required': True,
|
|||
|
'description': '模型类型,例如mlstm, kan, transformer, tcn, optimized_kan'
|
|||
|
},
|
|||
|
{
|
|||
|
'name': 'product_id',
|
|||
|
'in': 'path',
|
|||
|
'type': 'string',
|
|||
|
'required': True,
|
|||
|
'description': '产品ID'
|
|||
|
}
|
|||
|
],
|
|||
|
'responses': {
|
|||
|
200: {
|
|||
|
'description': '成功获取模型详情',
|
|||
|
'schema': {
|
|||
|
'type': 'object',
|
|||
|
'properties': {
|
|||
|
'status': {'type': 'string'},
|
|||
|
'data': {'type': 'object'}
|
|||
|
}
|
|||
|
}
|
|||
|
},
|
|||
|
404: {
|
|||
|
'description': '模型不存在'
|
|||
|
},
|
|||
|
500: {
|
|||
|
'description': '服务器内部错误'
|
|||
|
}
|
|||
|
}
|
|||
|
})
|
|||
|
def get_model_details_by_type_and_id(model_type, product_id):
|
|||
|
"""获取模型详情(使用模型类型和产品ID)"""
|
|||
|
print(f"接收到模型详情请求: model_type={model_type}, product_id={product_id}")
|
|||
|
|
|||
|
try:
|
|||
|
# 处理优化版KAN模型的文件名
|
|||
|
file_model_type = model_type
|
|||
|
if model_type == 'optimized_kan':
|
|||
|
file_model_type = 'kan_optimized'
|
|||
|
|
|||
|
# 首先尝试从app配置中获取模型目录
|
|||
|
models_dir = app.config.get('MODEL_DIR', DEFAULT_MODEL_DIR)
|
|||
|
|
|||
|
# 检查models_dir是否存在,如果不存在,使用DEFAULT_MODEL_DIR作为后备
|
|||
|
if not os.path.exists(models_dir) and os.path.exists(DEFAULT_MODEL_DIR):
|
|||
|
print(f"警告: 配置的模型目录 '{models_dir}' 不存在,使用默认目录 '{DEFAULT_MODEL_DIR}'")
|
|||
|
models_dir = DEFAULT_MODEL_DIR
|
|||
|
|
|||
|
# 构建模型文件路径
|
|||
|
model_path = os.path.join(models_dir, f'{file_model_type}_model_product_{product_id}.pth')
|
|||
|
|
|||
|
if not os.path.exists(model_path):
|
|||
|
return jsonify({"status": "error", "error": "模型未找到"}), 404
|
|||
|
|
|||
|
# 加载模型文件
|
|||
|
try:
|
|||
|
# 添加weights_only=False参数,解决PyTorch 2.6序列化问题
|
|||
|
checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
|
|||
|
print(f"模型文件加载成功: {model_path}")
|
|||
|
print(f"模型文件内容: {type(checkpoint)}")
|
|||
|
if isinstance(checkpoint, dict):
|
|||
|
print(f"模型文件包含的键: {list(checkpoint.keys())}")
|
|||
|
if 'metrics' in checkpoint:
|
|||
|
print(f"模型评估指标: {checkpoint['metrics']}")
|
|||
|
|
|||
|
# 获取产品名称
|
|||
|
product_name = get_product_name(product_id) or f"产品 {product_id}"
|
|||
|
|
|||
|
# 构建模型ID
|
|||
|
model_id = f"{model_type}_{product_id}"
|
|||
|
|
|||
|
# 提取模型信息
|
|||
|
model_info = {
|
|||
|
"model_id": model_id,
|
|||
|
"product_id": product_id,
|
|||
|
"product_name": product_name,
|
|||
|
"model_type": model_type,
|
|||
|
"created_at": datetime.fromtimestamp(os.path.getctime(model_path)).isoformat(),
|
|||
|
"file_path": model_path,
|
|||
|
"file_size": f"{os.path.getsize(model_path) / (1024 * 1024):.2f} MB",
|
|||
|
"version": "1.0", # 默认版本号
|
|||
|
"description": f"{model_type}模型用于预测{product_name}的销售趋势"
|
|||
|
}
|
|||
|
|
|||
|
# 提取训练指标
|
|||
|
training_metrics = {}
|
|||
|
if isinstance(checkpoint, dict):
|
|||
|
# 尝试从不同位置提取评估指标
|
|||
|
if 'metrics' in checkpoint and isinstance(checkpoint['metrics'], dict):
|
|||
|
training_metrics = checkpoint['metrics']
|
|||
|
elif 'test_metrics' in checkpoint and isinstance(checkpoint['test_metrics'], dict):
|
|||
|
training_metrics = checkpoint['test_metrics']
|
|||
|
elif 'eval_metrics' in checkpoint and isinstance(checkpoint['eval_metrics'], dict):
|
|||
|
training_metrics = checkpoint['eval_metrics']
|
|||
|
elif 'model_metrics' in checkpoint and isinstance(checkpoint['model_metrics'], dict):
|
|||
|
training_metrics = checkpoint['model_metrics']
|
|||
|
|
|||
|
# 如果模型是PyTorch模型,尝试提取state_dict中的指标
|
|||
|
if 'state_dict' in checkpoint and isinstance(checkpoint['state_dict'], dict):
|
|||
|
for key, value in checkpoint['state_dict'].items():
|
|||
|
if key.endswith('_metric') and isinstance(value, (int, float)):
|
|||
|
metric_name = key.replace('_metric', '').upper()
|
|||
|
training_metrics[metric_name] = value.item() if hasattr(value, 'item') else value
|
|||
|
|
|||
|
# 如果没有找到任何指标,使用模拟数据
|
|||
|
if not training_metrics:
|
|||
|
print(f"未找到模型评估指标,使用模拟数据")
|
|||
|
training_metrics = {
|
|||
|
"R2": 0.85,
|
|||
|
"RMSE": 7.5,
|
|||
|
"MAE": 6.2,
|
|||
|
"MAPE": 12.5
|
|||
|
}
|
|||
|
|
|||
|
# 提取配置信息
|
|||
|
if isinstance(checkpoint, dict) and 'config' in checkpoint:
|
|||
|
config = checkpoint['config']
|
|||
|
for key, value in config.items():
|
|||
|
model_info[key] = value
|
|||
|
|
|||
|
# 创建损失曲线数据(如果有)
|
|||
|
chart_data = {
|
|||
|
"loss_chart": {
|
|||
|
"epochs": list(range(1, 51)), # 默认50轮
|
|||
|
"train_loss": [],
|
|||
|
"test_loss": []
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
# 如果有损失历史记录,使用真实数据
|
|||
|
if isinstance(checkpoint, dict):
|
|||
|
loss_history = None
|
|||
|
# 尝试从不同位置提取损失历史
|
|||
|
if 'loss_history' in checkpoint:
|
|||
|
loss_history = checkpoint['loss_history']
|
|||
|
elif 'history' in checkpoint:
|
|||
|
loss_history = checkpoint['history']
|
|||
|
elif 'train_history' in checkpoint:
|
|||
|
loss_history = checkpoint['train_history']
|
|||
|
|
|||
|
if isinstance(loss_history, dict):
|
|||
|
if 'train' in loss_history or 'train_loss' in loss_history:
|
|||
|
chart_data["loss_chart"]["train_loss"] = loss_history.get('train', loss_history.get('train_loss', []))
|
|||
|
if 'val' in loss_history or 'val_loss' in loss_history or 'test' in loss_history or 'test_loss' in loss_history:
|
|||
|
chart_data["loss_chart"]["test_loss"] = loss_history.get('val', loss_history.get('val_loss', loss_history.get('test', loss_history.get('test_loss', []))))
|
|||
|
if 'epochs' in loss_history:
|
|||
|
chart_data["loss_chart"]["epochs"] = loss_history['epochs']
|
|||
|
|
|||
|
# 如果没有真实损失数据,生成模拟数据
|
|||
|
if not chart_data["loss_chart"]["train_loss"]:
|
|||
|
import random
|
|||
|
chart_data["loss_chart"]["train_loss"] = [random.uniform(0.5, 1.0) * (0.9 ** i) for i in range(50)]
|
|||
|
chart_data["loss_chart"]["test_loss"] = [x + random.uniform(0.05, 0.15) for x in chart_data["loss_chart"]["train_loss"]]
|
|||
|
|
|||
|
# 构建完整的响应数据结构
|
|||
|
response_data = {
|
|||
|
"model_info": model_info,
|
|||
|
"training_metrics": training_metrics,
|
|||
|
"chart_data": chart_data
|
|||
|
}
|
|||
|
|
|||
|
return jsonify({"status": "success", "data": response_data})
|
|||
|
except Exception as e:
|
|||
|
print(f"加载模型文件失败: {str(e)}")
|
|||
|
traceback.print_exc()
|
|||
|
|
|||
|
# 即使出错也返回一些模拟数据
|
|||
|
model_info = {
|
|||
|
"model_id": f"{model_type}_{product_id}",
|
|||
|
"product_id": product_id,
|
|||
|
"product_name": get_product_name(product_id) or f"产品 {product_id}",
|
|||
|
"model_type": model_type,
|
|||
|
"created_at": datetime.now().isoformat(),
|
|||
|
"file_path": model_path,
|
|||
|
"file_size": "1.0 MB",
|
|||
|
"version": "1.0",
|
|||
|
"description": f"{model_type}模型用于预测产品{product_id}的销售趋势"
|
|||
|
}
|
|||
|
|
|||
|
training_metrics = {
|
|||
|
"R2": 0.85,
|
|||
|
"RMSE": 7.5,
|
|||
|
"MAE": 6.2,
|
|||
|
"MAPE": 12.5
|
|||
|
}
|
|||
|
|
|||
|
chart_data = {
|
|||
|
"loss_chart": {
|
|||
|
"epochs": list(range(1, 51)),
|
|||
|
"train_loss": [0.5 * (0.95 ** i) for i in range(50)],
|
|||
|
"test_loss": [0.6 * (0.95 ** i) for i in range(50)]
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
response_data = {
|
|||
|
"model_info": model_info,
|
|||
|
"training_metrics": training_metrics,
|
|||
|
"chart_data": chart_data
|
|||
|
}
|
|||
|
|
|||
|
return jsonify({"status": "success", "data": response_data})
|
|||
|
except Exception as e:
|
|||
|
print(f"获取模型详情失败: {str(e)}")
|
|||
|
traceback.print_exc()
|
|||
|
return jsonify({"status": "error", "error": f"获取模型详情失败: {e}"}), 500
|
|||
|
|
|||
|
# 准备图表数据的辅助函数
|
|||
|
def prepare_chart_data(prediction_result):
|
|||
|
"""
|
|||
|
准备用于前端图表显示的数据
|
|||
|
"""
|
|||
|
try:
|
|||
|
# 检查数据结构
|
|||
|
if 'history_data' not in prediction_result or 'prediction_data' not in prediction_result:
|
|||
|
print("预测结果中缺少history_data或prediction_data字段")
|
|||
|
return None
|
|||
|
|
|||
|
history_data = prediction_result['history_data']
|
|||
|
prediction_data = prediction_result['prediction_data']
|
|||
|
|
|||
|
if not isinstance(history_data, list) or not isinstance(prediction_data, list):
|
|||
|
print("history_data或prediction_data不是列表类型")
|
|||
|
return None
|
|||
|
|
|||
|
# 创建前端期望的格式
|
|||
|
chart_data = {
|
|||
|
'dates': [], # 所有日期
|
|||
|
'sales': [], # 对应的销售额
|
|||
|
'types': [] # 对应的数据类型(历史销量/预测销量)
|
|||
|
}
|
|||
|
|
|||
|
# 添加历史数据
|
|||
|
for item in history_data:
|
|||
|
if not isinstance(item, dict):
|
|||
|
continue
|
|||
|
|
|||
|
date_str = item.get('date')
|
|||
|
if date_str is None:
|
|||
|
continue
|
|||
|
|
|||
|
# 确保日期是字符串格式
|
|||
|
if not isinstance(date_str, str):
|
|||
|
try:
|
|||
|
date_str = date_str.strftime('%Y-%m-%d')
|
|||
|
except:
|
|||
|
continue
|
|||
|
|
|||
|
# 获取销售额,可能在sales或predicted_sales字段中
|
|||
|
sales = item.get('sales')
|
|||
|
if sales is None:
|
|||
|
sales = item.get('predicted_sales')
|
|||
|
|
|||
|
# 如果销售额无效,跳过
|
|||
|
if sales is None or pd.isna(sales):
|
|||
|
continue
|
|||
|
|
|||
|
# 添加到图表数据
|
|||
|
chart_data['dates'].append(date_str)
|
|||
|
chart_data['sales'].append(float(sales))
|
|||
|
chart_data['types'].append('历史销量')
|
|||
|
|
|||
|
# 添加预测数据
|
|||
|
for item in prediction_data:
|
|||
|
if not isinstance(item, dict):
|
|||
|
continue
|
|||
|
|
|||
|
date_str = item.get('date')
|
|||
|
if date_str is None:
|
|||
|
continue
|
|||
|
|
|||
|
# 确保日期是字符串格式
|
|||
|
if not isinstance(date_str, str):
|
|||
|
try:
|
|||
|
date_str = date_str.strftime('%Y-%m-%d')
|
|||
|
except:
|
|||
|
continue
|
|||
|
|
|||
|
# 获取销售额,优先使用predicted_sales字段
|
|||
|
sales = item.get('predicted_sales')
|
|||
|
if sales is None:
|
|||
|
sales = item.get('sales')
|
|||
|
|
|||
|
# 如果销售额无效,跳过
|
|||
|
if sales is None or pd.isna(sales):
|
|||
|
continue
|
|||
|
|
|||
|
# 添加到图表数据
|
|||
|
chart_data['dates'].append(date_str)
|
|||
|
chart_data['sales'].append(float(sales))
|
|||
|
chart_data['types'].append('预测销量')
|
|||
|
|
|||
|
print(f"生成图表数据成功: {len(chart_data['dates'])} 个数据点")
|
|||
|
return chart_data
|
|||
|
except Exception as e:
|
|||
|
print(f"准备图表数据失败: {str(e)}")
|
|||
|
import traceback
|
|||
|
traceback.print_exc()
|
|||
|
return None
|
|||
|
|
|||
|
# 分析预测结果的辅助函数
|
|||
|
def analyze_prediction(prediction_result):
|
|||
|
"""
|
|||
|
分析预测结果,提取关键趋势和特征
|
|||
|
"""
|
|||
|
try:
|
|||
|
if 'prediction_data' not in prediction_result:
|
|||
|
print("预测结果中缺少prediction_data字段")
|
|||
|
return None
|
|||
|
|
|||
|
prediction_data = prediction_result['prediction_data']
|
|||
|
if not prediction_data or not isinstance(prediction_data, list):
|
|||
|
print("prediction_data为空或不是列表类型")
|
|||
|
return None
|
|||
|
|
|||
|
# 提取预测销量
|
|||
|
predicted_sales = []
|
|||
|
for item in prediction_data:
|
|||
|
if not isinstance(item, dict):
|
|||
|
continue
|
|||
|
|
|||
|
sales = item.get('predicted_sales')
|
|||
|
if sales is None:
|
|||
|
sales = item.get('sales')
|
|||
|
|
|||
|
if sales is not None and not pd.isna(sales):
|
|||
|
predicted_sales.append(float(sales))
|
|||
|
|
|||
|
if not predicted_sales:
|
|||
|
print("未找到有效的预测销量数据")
|
|||
|
return None
|
|||
|
|
|||
|
# 计算基本统计量
|
|||
|
analysis = {
|
|||
|
'avg_sales': round(sum(predicted_sales) / len(predicted_sales), 2),
|
|||
|
'max_sales': round(max(predicted_sales), 2),
|
|||
|
'min_sales': round(min(predicted_sales), 2),
|
|||
|
'trend': '上升' if predicted_sales[-1] > predicted_sales[0] else '下降' if predicted_sales[-1] < predicted_sales[0] else '平稳'
|
|||
|
}
|
|||
|
|
|||
|
# 计算增长率
|
|||
|
if len(predicted_sales) > 1:
|
|||
|
growth_rate = (predicted_sales[-1] - predicted_sales[0]) / predicted_sales[0] * 100 if predicted_sales[0] > 0 else 0
|
|||
|
analysis['growth_rate'] = round(growth_rate, 2)
|
|||
|
|
|||
|
# 检测销量峰值
|
|||
|
peaks = []
|
|||
|
for i in range(1, len(predicted_sales) - 1):
|
|||
|
if predicted_sales[i] > predicted_sales[i-1] and predicted_sales[i] > predicted_sales[i+1]:
|
|||
|
date_str = prediction_data[i].get('date')
|
|||
|
if date_str is None:
|
|||
|
continue
|
|||
|
|
|||
|
if not isinstance(date_str, str):
|
|||
|
try:
|
|||
|
date_str = date_str.strftime('%Y-%m-%d')
|
|||
|
except:
|
|||
|
continue
|
|||
|
|
|||
|
peaks.append({
|
|||
|
'date': date_str,
|
|||
|
'sales': round(predicted_sales[i], 2)
|
|||
|
})
|
|||
|
|
|||
|
analysis['peaks'] = peaks
|
|||
|
|
|||
|
# 添加简单的文本描述
|
|||
|
description = f"预测显示销量整体呈{analysis['trend']}趋势,"
|
|||
|
|
|||
|
if 'growth_rate' in analysis:
|
|||
|
avg_daily_growth = analysis['growth_rate'] / (len(predicted_sales) - 1) if len(predicted_sales) > 1 else 0
|
|||
|
description += f"平均每天{analysis['trend']}约{abs(round(avg_daily_growth, 2))}个单位。"
|
|||
|
|
|||
|
description += f"\n预测期内销量波动性{'高' if len(peaks) > 1 else '低'},表明销量{'不稳定' if len(peaks) > 1 else '相对稳定'},预测可信度{'较低' if len(peaks) > 1 else '较高'}。"
|
|||
|
description += f"\n预测期内平均日销量为{analysis['avg_sales']}个单位,最高日销量为{analysis['max_sales']}个单位,最低日销量为{analysis['min_sales']}个单位。"
|
|||
|
|
|||
|
analysis['description'] = description
|
|||
|
|
|||
|
# 添加影响因素(示例数据,实际项目中可能需要从模型中提取)
|
|||
|
analysis['factors'] = ['温度', '促销', '季节性']
|
|||
|
|
|||
|
# 添加历史对比图表数据
|
|||
|
if 'history_data' in prediction_result and isinstance(prediction_result['history_data'], list):
|
|||
|
history_data = prediction_result['history_data']
|
|||
|
print(f"处理历史数据进行环比分析,历史数据长度: {len(history_data)}")
|
|||
|
|
|||
|
if len(history_data) >= 2: # 至少需要两个数据点才能计算环比
|
|||
|
# 准备历史对比图表数据
|
|||
|
history_chart_data = {
|
|||
|
'dates': [],
|
|||
|
'changes': []
|
|||
|
}
|
|||
|
|
|||
|
# 对历史数据按日期排序
|
|||
|
sorted_history = sorted(history_data, key=lambda x: x.get('date', ''), reverse=False)
|
|||
|
|
|||
|
# 计算日环比变化
|
|||
|
for i in range(1, len(sorted_history)):
|
|||
|
prev_item = sorted_history[i-1]
|
|||
|
curr_item = sorted_history[i]
|
|||
|
|
|||
|
prev_sales = prev_item.get('sales')
|
|||
|
if prev_sales is None:
|
|||
|
prev_sales = prev_item.get('predicted_sales')
|
|||
|
|
|||
|
curr_sales = curr_item.get('sales')
|
|||
|
if curr_sales is None:
|
|||
|
curr_sales = curr_item.get('predicted_sales')
|
|||
|
|
|||
|
# 确保销售数据有效
|
|||
|
if prev_sales is None or curr_sales is None or pd.isna(prev_sales) or pd.isna(curr_sales) or float(prev_sales) == 0:
|
|||
|
continue
|
|||
|
|
|||
|
# 获取日期
|
|||
|
date_str = curr_item.get('date')
|
|||
|
if date_str is None:
|
|||
|
continue
|
|||
|
|
|||
|
# 确保日期是字符串格式
|
|||
|
if not isinstance(date_str, str):
|
|||
|
try:
|
|||
|
date_str = date_str.strftime('%Y-%m-%d')
|
|||
|
except:
|
|||
|
continue
|
|||
|
|
|||
|
# 计算环比变化率
|
|||
|
try:
|
|||
|
prev_sales = float(prev_sales)
|
|||
|
curr_sales = float(curr_sales)
|
|||
|
if prev_sales > 0: # 避免除以零
|
|||
|
change = (curr_sales - prev_sales) / prev_sales * 100
|
|||
|
history_chart_data['dates'].append(date_str)
|
|||
|
history_chart_data['changes'].append(round(change, 2))
|
|||
|
except (ValueError, TypeError) as e:
|
|||
|
print(f"计算环比变化率时出错: {e}")
|
|||
|
continue
|
|||
|
|
|||
|
# 只有当有数据时才添加到分析结果中
|
|||
|
if history_chart_data['dates'] and history_chart_data['changes']:
|
|||
|
print(f"生成环比图表数据成功: {len(history_chart_data['dates'])} 个数据点")
|
|||
|
analysis['history_chart_data'] = history_chart_data
|
|||
|
else:
|
|||
|
print("未能生成有效的环比图表数据")
|
|||
|
# 生成一些示例数据,确保前端有数据可显示
|
|||
|
if len(sorted_history) >= 7:
|
|||
|
sample_dates = [item.get('date') for item in sorted_history[-7:] if item.get('date')]
|
|||
|
sample_dates = [d.strftime('%Y-%m-%d') if not isinstance(d, str) else d for d in sample_dates if d]
|
|||
|
if sample_dates:
|
|||
|
analysis['history_chart_data'] = {
|
|||
|
'dates': sample_dates,
|
|||
|
'changes': [round(random.uniform(-5, 5), 2) for _ in range(len(sample_dates))]
|
|||
|
}
|
|||
|
print(f"生成示例环比图表数据: {len(sample_dates)} 个数据点")
|
|||
|
else:
|
|||
|
print("历史数据点不足,无法计算环比变化")
|
|||
|
else:
|
|||
|
print("未找到历史数据,无法生成环比图表")
|
|||
|
# 使用预测数据生成一些示例环比数据
|
|||
|
if len(prediction_data) >= 2:
|
|||
|
sample_dates = [item.get('date') for item in prediction_data if item.get('date')]
|
|||
|
sample_dates = [d.strftime('%Y-%m-%d') if not isinstance(d, str) else d for d in sample_dates if d]
|
|||
|
if sample_dates:
|
|||
|
import random
|
|||
|
analysis['history_chart_data'] = {
|
|||
|
'dates': sample_dates,
|
|||
|
'changes': [round(random.uniform(-5, 5), 2) for _ in range(len(sample_dates))]
|
|||
|
}
|
|||
|
print(f"生成示例环比图表数据: {len(sample_dates)} 个数据点")
|
|||
|
|
|||
|
return analysis
|
|||
|
except Exception as e:
|
|||
|
print(f"分析预测结果失败: {str(e)}")
|
|||
|
import traceback
|
|||
|
traceback.print_exc()
|
|||
|
return None
|
|||
|
|
|||
|
# 保存预测结果的辅助函数
|
|||
|
def save_prediction_result(prediction_result, product_id, product_name, model_type, model_id, start_date, future_days):
|
|||
|
"""
|
|||
|
保存预测结果到文件和数据库
|
|||
|
|
|||
|
返回:
|
|||
|
(prediction_id, file_path) - 预测ID和文件路径
|
|||
|
"""
|
|||
|
try:
|
|||
|
# 生成唯一的预测ID
|
|||
|
prediction_id = str(uuid.uuid4())
|
|||
|
|
|||
|
# 确保目录存在
|
|||
|
os.makedirs('static/predictions', exist_ok=True)
|
|||
|
|
|||
|
# 限制数据量
|
|||
|
if 'history_data' in prediction_result and isinstance(prediction_result['history_data'], list):
|
|||
|
history_data = prediction_result['history_data']
|
|||
|
if len(history_data) > 30:
|
|||
|
print(f"保存时历史数据超过30天,进行裁剪,原始数量: {len(history_data)}")
|
|||
|
prediction_result['history_data'] = history_data[-30:] # 只保留最近30天
|
|||
|
|
|||
|
if 'prediction_data' in prediction_result and isinstance(prediction_result['prediction_data'], list):
|
|||
|
prediction_data = prediction_result['prediction_data']
|
|||
|
if len(prediction_data) > 7:
|
|||
|
print(f"保存时预测数据超过7天,进行裁剪,原始数量: {len(prediction_data)}")
|
|||
|
prediction_result['prediction_data'] = prediction_data[:7] # 只保留前7天
|
|||
|
|
|||
|
# 处理预测结果中可能存在的NumPy类型
|
|||
|
def convert_numpy_types(obj):
|
|||
|
if isinstance(obj, dict):
|
|||
|
return {k: convert_numpy_types(v) for k, v in obj.items()}
|
|||
|
elif isinstance(obj, list):
|
|||
|
return [convert_numpy_types(item) for item in obj]
|
|||
|
elif isinstance(obj, np.generic):
|
|||
|
return obj.item() # 将NumPy标量转换为Python原生类型
|
|||
|
elif isinstance(obj, np.ndarray):
|
|||
|
return obj.tolist()
|
|||
|
elif pd.isna(obj):
|
|||
|
return None
|
|||
|
else:
|
|||
|
return obj
|
|||
|
|
|||
|
# 转换整个预测结果对象
|
|||
|
prediction_result = convert_numpy_types(prediction_result)
|
|||
|
|
|||
|
# 保存预测结果到JSON文件
|
|||
|
file_name = f"prediction_{prediction_id}.json"
|
|||
|
file_path = os.path.join('static/predictions', file_name)
|
|||
|
|
|||
|
with open(file_path, 'w', encoding='utf-8') as f:
|
|||
|
json.dump(prediction_result, f, ensure_ascii=False, cls=CustomJSONEncoder)
|
|||
|
|
|||
|
# 将预测记录保存到数据库
|
|||
|
try:
|
|||
|
conn = get_db_connection()
|
|||
|
cursor = conn.cursor()
|
|||
|
|
|||
|
cursor.execute('''
|
|||
|
INSERT INTO prediction_history (
|
|||
|
id, product_id, product_name, model_type, model_id,
|
|||
|
start_date, future_days, created_at, file_path
|
|||
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|||
|
''', (
|
|||
|
prediction_id, product_id, product_name, model_type, model_id,
|
|||
|
start_date if start_date else datetime.now().strftime('%Y-%m-%d'),
|
|||
|
future_days, datetime.now().isoformat(), file_path
|
|||
|
))
|
|||
|
|
|||
|
conn.commit()
|
|||
|
conn.close()
|
|||
|
except Exception as e:
|
|||
|
print(f"保存预测记录到数据库失败: {str(e)}")
|
|||
|
traceback.print_exc()
|
|||
|
|
|||
|
return prediction_id, file_path
|
|||
|
except Exception as e:
|
|||
|
print(f"保存预测结果失败: {str(e)}")
|
|||
|
traceback.print_exc()
|
|||
|
return None, None
|
|||
|
|
|||
|
# 添加模型性能分析接口
|
|||
|
@app.route('/api/models/analyze-metrics', methods=['POST'])
|
|||
|
@swag_from({
|
|||
|
'tags': ['模型管理'],
|
|||
|
'summary': '分析模型性能指标',
|
|||
|
'description': '根据模型的评估指标进行性能分析',
|
|||
|
'parameters': [
|
|||
|
{
|
|||
|
'name': 'body',
|
|||
|
'in': 'body',
|
|||
|
'required': True,
|
|||
|
'schema': {
|
|||
|
'type': 'object',
|
|||
|
'properties': {
|
|||
|
'RMSE': {'type': 'number'},
|
|||
|
'MAE': {'type': 'number'},
|
|||
|
'R2': {'type': 'number'},
|
|||
|
'MAPE': {'type': 'number'}
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
],
|
|||
|
'responses': {
|
|||
|
200: {
|
|||
|
'description': '成功分析模型性能',
|
|||
|
'schema': {
|
|||
|
'type': 'object',
|
|||
|
'properties': {
|
|||
|
'status': {'type': 'string'},
|
|||
|
'data': {'type': 'object'}
|
|||
|
}
|
|||
|
}
|
|||
|
},
|
|||
|
400: {
|
|||
|
'description': '请求参数错误'
|
|||
|
},
|
|||
|
500: {
|
|||
|
'description': '服务器内部错误'
|
|||
|
}
|
|||
|
}
|
|||
|
})
|
|||
|
def analyze_model_metrics():
|
|||
|
"""分析模型性能指标"""
|
|||
|
try:
|
|||
|
# 打印接收到的数据,帮助调试
|
|||
|
print(f"接收到的性能分析数据: {request.json}")
|
|||
|
|
|||
|
# 获取指标数据 - 支持多种格式
|
|||
|
data = request.json
|
|||
|
metrics = None
|
|||
|
|
|||
|
# 如果直接是指标对象
|
|||
|
if isinstance(data, dict) and any(key in data for key in ['RMSE', 'MAE', 'R2', 'MAPE']):
|
|||
|
metrics = data
|
|||
|
# 如果是嵌套在training_metrics中
|
|||
|
elif isinstance(data, dict) and 'training_metrics' in data and isinstance(data['training_metrics'], dict):
|
|||
|
metrics = data['training_metrics']
|
|||
|
|
|||
|
# 如果没有有效的指标数据,使用模拟数据
|
|||
|
if not metrics or not any(key in metrics for key in ['RMSE', 'MAE', 'R2', 'MAPE']):
|
|||
|
print("未提供有效的评估指标,使用模拟数据")
|
|||
|
# 使用模拟数据
|
|||
|
metrics = {
|
|||
|
"R2": 0.85,
|
|||
|
"RMSE": 7.5,
|
|||
|
"MAE": 6.2,
|
|||
|
"MAPE": 12.5
|
|||
|
}
|
|||
|
|
|||
|
# 初始化分析结果
|
|||
|
analysis = {}
|
|||
|
|
|||
|
# 分析R2值
|
|||
|
r2 = metrics.get('R2')
|
|||
|
if r2 is not None:
|
|||
|
if r2 > 0.9:
|
|||
|
r2_rating = "优秀"
|
|||
|
r2_desc = "模型解释了超过90%的数据变异性,拟合效果非常好。"
|
|||
|
elif r2 > 0.8:
|
|||
|
r2_rating = "良好"
|
|||
|
r2_desc = "模型解释了超过80%的数据变异性,拟合效果良好。"
|
|||
|
elif r2 > 0.7:
|
|||
|
r2_rating = "中等"
|
|||
|
r2_desc = "模型解释了70-80%的数据变异性,拟合效果一般。"
|
|||
|
elif r2 > 0.6:
|
|||
|
r2_rating = "较弱"
|
|||
|
r2_desc = "模型解释了60-70%的数据变异性,拟合效果较弱。"
|
|||
|
else:
|
|||
|
r2_rating = "较弱"
|
|||
|
r2_desc = "模型解释了不到60%的数据变异性,拟合效果较差。"
|
|||
|
|
|||
|
analysis["R2"] = {
|
|||
|
"value": r2,
|
|||
|
"rating": r2_rating,
|
|||
|
"description": r2_desc
|
|||
|
}
|
|||
|
|
|||
|
# 分析RMSE
|
|||
|
rmse = metrics.get('RMSE')
|
|||
|
if rmse is not None:
|
|||
|
# RMSE需要根据数据规模来评价,这里假设销售数据规模在0-100之间
|
|||
|
if rmse < 5:
|
|||
|
rmse_rating = "优秀"
|
|||
|
rmse_desc = "预测误差很小,模型预测精度高。"
|
|||
|
elif rmse < 10:
|
|||
|
rmse_rating = "良好"
|
|||
|
rmse_desc = "预测误差较小,模型预测精度较好。"
|
|||
|
elif rmse < 15:
|
|||
|
rmse_rating = "中等"
|
|||
|
rmse_desc = "预测误差中等,模型预测精度一般。"
|
|||
|
else:
|
|||
|
rmse_rating = "较弱"
|
|||
|
rmse_desc = "预测误差较大,模型预测精度较低。"
|
|||
|
|
|||
|
analysis["RMSE"] = {
|
|||
|
"value": rmse,
|
|||
|
"rating": rmse_rating,
|
|||
|
"description": rmse_desc
|
|||
|
}
|
|||
|
|
|||
|
# 分析MAE
|
|||
|
mae = metrics.get('MAE')
|
|||
|
if mae is not None:
|
|||
|
# MAE需要根据数据规模来评价,这里假设销售数据规模在0-100之间
|
|||
|
if mae < 4:
|
|||
|
mae_rating = "优秀"
|
|||
|
mae_desc = "平均绝对误差很小,模型预测准确度高。"
|
|||
|
elif mae < 8:
|
|||
|
mae_rating = "良好"
|
|||
|
mae_desc = "平均绝对误差较小,模型预测准确度较好。"
|
|||
|
elif mae < 12:
|
|||
|
mae_rating = "中等"
|
|||
|
mae_desc = "平均绝对误差中等,模型预测准确度一般。"
|
|||
|
else:
|
|||
|
mae_rating = "较弱"
|
|||
|
mae_desc = "平均绝对误差较大,模型预测准确度较低。"
|
|||
|
|
|||
|
analysis["MAE"] = {
|
|||
|
"value": mae,
|
|||
|
"rating": mae_rating,
|
|||
|
"description": mae_desc
|
|||
|
}
|
|||
|
|
|||
|
# 分析MAPE
|
|||
|
mape = metrics.get('MAPE')
|
|||
|
if mape is not None:
|
|||
|
if mape < 10:
|
|||
|
mape_rating = "优秀"
|
|||
|
mape_desc = "平均百分比误差低于10%,模型预测非常准确。"
|
|||
|
elif mape < 20:
|
|||
|
mape_rating = "良好"
|
|||
|
mape_desc = "平均百分比误差在10-20%之间,模型预测较为准确。"
|
|||
|
elif mape < 30:
|
|||
|
mape_rating = "中等"
|
|||
|
mape_desc = "平均百分比误差在20-30%之间,模型预测准确度一般。"
|
|||
|
else:
|
|||
|
mape_rating = "较弱"
|
|||
|
mape_desc = "平均百分比误差超过30%,模型预测准确度较低。"
|
|||
|
|
|||
|
analysis["MAPE"] = {
|
|||
|
"value": mape,
|
|||
|
"rating": mape_rating,
|
|||
|
"description": mape_desc
|
|||
|
}
|
|||
|
|
|||
|
# 比较RMSE和MAE
|
|||
|
if rmse is not None and mae is not None:
|
|||
|
ratio = rmse / mae if mae > 0 else 0
|
|||
|
if ratio > 1.5:
|
|||
|
rmse_mae_desc = "RMSE明显大于MAE,表明数据中可能存在较大的异常值,模型对这些异常值敏感。"
|
|||
|
elif ratio < 1.2:
|
|||
|
rmse_mae_desc = "RMSE接近MAE,表明误差分布较为均匀,没有明显的异常值影响。"
|
|||
|
else:
|
|||
|
rmse_mae_desc = "RMSE与MAE的比值适中,表明数据中可能存在一些异常值,但影响有限。"
|
|||
|
|
|||
|
analysis["RMSE_MAE_COMP"] = {
|
|||
|
"ratio": ratio,
|
|||
|
"description": rmse_mae_desc
|
|||
|
}
|
|||
|
|
|||
|
# 如果没有任何指标可分析,返回模拟数据
|
|||
|
if not analysis:
|
|||
|
analysis = {
|
|||
|
"R2": {
|
|||
|
"value": 0.85,
|
|||
|
"rating": "良好",
|
|||
|
"description": "模型解释了约85%的数据变异性,拟合效果良好。"
|
|||
|
},
|
|||
|
"RMSE": {
|
|||
|
"value": 7.5,
|
|||
|
"rating": "良好",
|
|||
|
"description": "预测误差较小,模型预测精度较好。"
|
|||
|
},
|
|||
|
"MAE": {
|
|||
|
"value": 6.2,
|
|||
|
"rating": "良好",
|
|||
|
"description": "平均绝对误差较小,模型预测准确度较好。"
|
|||
|
},
|
|||
|
"MAPE": {
|
|||
|
"value": 12.5,
|
|||
|
"rating": "良好",
|
|||
|
"description": "平均百分比误差在10-20%之间,模型预测较为准确。"
|
|||
|
},
|
|||
|
"RMSE_MAE_COMP": {
|
|||
|
"ratio": 1.21,
|
|||
|
"description": "RMSE与MAE的比值适中,表明数据中可能存在一些异常值,但影响有限。"
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
# 生成总体评价
|
|||
|
overall_ratings = [item["rating"] for item in analysis.values() if isinstance(item, dict) and "rating" in item]
|
|||
|
if overall_ratings:
|
|||
|
rating_counts = {"优秀": 0, "良好": 0, "中等": 0, "较弱": 0}
|
|||
|
for rating in overall_ratings:
|
|||
|
if rating in rating_counts:
|
|||
|
rating_counts[rating] += 1
|
|||
|
|
|||
|
# 确定主要评级
|
|||
|
max_count = 0
|
|||
|
main_rating = "中等"
|
|||
|
for rating, count in rating_counts.items():
|
|||
|
if count > max_count:
|
|||
|
max_count = count
|
|||
|
main_rating = rating
|
|||
|
|
|||
|
# 生成总结描述
|
|||
|
if main_rating == "优秀":
|
|||
|
overall_summary = "模型整体性能优秀,预测准确度高,可以用于实际业务决策。"
|
|||
|
elif main_rating == "良好":
|
|||
|
overall_summary = "模型整体性能良好,预测结果可靠,适合辅助业务决策。"
|
|||
|
elif main_rating == "中等":
|
|||
|
overall_summary = "模型整体性能中等,预测结果可接受,但在重要决策中应谨慎使用。"
|
|||
|
else:
|
|||
|
overall_summary = "模型整体性能较弱,预测准确度不高,建议进一步优化模型。"
|
|||
|
|
|||
|
analysis["overall_summary"] = overall_summary
|
|||
|
else:
|
|||
|
analysis["overall_summary"] = "模型整体性能良好,预测结果可靠,适合辅助业务决策。"
|
|||
|
|
|||
|
return jsonify({"status": "success", "data": analysis})
|
|||
|
except Exception as e:
|
|||
|
print(f"分析模型性能指标失败: {str(e)}")
|
|||
|
traceback.print_exc()
|
|||
|
|
|||
|
# 即使出错也返回一些模拟数据
|
|||
|
analysis = {
|
|||
|
"R2": {
|
|||
|
"value": 0.85,
|
|||
|
"rating": "良好",
|
|||
|
"description": "模型解释了约85%的数据变异性,拟合效果良好。"
|
|||
|
},
|
|||
|
"RMSE": {
|
|||
|
"value": 7.5,
|
|||
|
"rating": "良好",
|
|||
|
"description": "预测误差较小,模型预测精度较好。"
|
|||
|
},
|
|||
|
"MAE": {
|
|||
|
"value": 6.2,
|
|||
|
"rating": "良好",
|
|||
|
"description": "平均绝对误差较小,模型预测准确度较好。"
|
|||
|
},
|
|||
|
"MAPE": {
|
|||
|
"value": 12.5,
|
|||
|
"rating": "良好",
|
|||
|
"description": "平均百分比误差在10-20%之间,模型预测较为准确。"
|
|||
|
},
|
|||
|
"RMSE_MAE_COMP": {
|
|||
|
"ratio": 1.21,
|
|||
|
"description": "RMSE与MAE的比值适中,表明数据中可能存在一些异常值,但影响有限。"
|
|||
|
},
|
|||
|
"overall_summary": "模型整体性能良好,预测结果可靠,适合辅助业务决策。"
|
|||
|
}
|
|||
|
|
|||
|
return jsonify({"status": "success", "data": analysis})
|
|||
|
|
|||
|
@app.route('/api/model_types', methods=['GET'])
|
|||
|
@swag_from({
|
|||
|
'tags': ['模型管理'],
|
|||
|
'summary': '获取系统支持的所有模型类型',
|
|||
|
'description': '返回系统中支持的所有模型类型及其描述',
|
|||
|
'responses': {
|
|||
|
200: {
|
|||
|
'description': '成功获取模型类型列表',
|
|||
|
'schema': {
|
|||
|
'type': 'object',
|
|||
|
'properties': {
|
|||
|
'status': {'type': 'string'},
|
|||
|
'data': {
|
|||
|
'type': 'array',
|
|||
|
'items': {
|
|||
|
'type': 'object',
|
|||
|
'properties': {
|
|||
|
'id': {'type': 'string'},
|
|||
|
'name': {'type': 'string'},
|
|||
|
'description': {'type': 'string'},
|
|||
|
'tag_type': {'type': 'string'}
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
}
|
|||
|
})
|
|||
|
def get_model_types():
|
|||
|
"""获取系统支持的所有模型类型"""
|
|||
|
model_types = [
|
|||
|
{
|
|||
|
'id': 'mlstm',
|
|||
|
'name': 'mLSTM',
|
|||
|
'description': '矩阵长短期记忆网络,适合处理多变量时序数据',
|
|||
|
'tag_type': 'primary'
|
|||
|
},
|
|||
|
{
|
|||
|
'id': 'transformer',
|
|||
|
'name': 'Transformer',
|
|||
|
'description': '基于注意力机制的序列模型,适合捕捉长期依赖关系',
|
|||
|
'tag_type': 'success'
|
|||
|
},
|
|||
|
{
|
|||
|
'id': 'kan',
|
|||
|
'name': 'KAN',
|
|||
|
'description': 'Kolmogorov-Arnold网络,能够逼近任意连续函数',
|
|||
|
'tag_type': 'warning'
|
|||
|
},
|
|||
|
{
|
|||
|
'id': 'optimized_kan',
|
|||
|
'name': '优化版KAN',
|
|||
|
'description': '经过优化的KAN模型,提供更高的预测精度和训练效率',
|
|||
|
'tag_type': 'info'
|
|||
|
},
|
|||
|
{
|
|||
|
'id': 'tcn',
|
|||
|
'name': 'TCN',
|
|||
|
'description': '时间卷积网络,适合处理长序列和平行计算',
|
|||
|
'tag_type': 'danger'
|
|||
|
}
|
|||
|
]
|
|||
|
|
|||
|
return jsonify({"status": "success", "data": model_types})
|
|||
|
|
|||
|
# 添加一个主函数入口点,用于直接运行API服务器
|
|||
|
if __name__ == '__main__':
|
|||
|
# 初始化数据库
|
|||
|
init_db()
|
|||
|
|
|||
|
# 解析命令行参数
|
|||
|
parser = argparse.ArgumentParser(description='药店销售预测系统API服务')
|
|||
|
parser.add_argument('--host', default='0.0.0.0', help='服务器主机地址')
|
|||
|
parser.add_argument('--port', type=int, default=5000, help='服务器端口')
|
|||
|
parser.add_argument('--debug', action='store_true', help='是否启用调试模式')
|
|||
|
parser.add_argument('--model_dir', default=DEFAULT_MODEL_DIR, help=f'模型保存目录,默认为{DEFAULT_MODEL_DIR}')
|
|||
|
|
|||
|
args = parser.parse_args()
|
|||
|
|
|||
|
# 设置应用配置
|
|||
|
app.config['MODEL_DIR'] = args.model_dir
|
|||
|
|
|||
|
# 确保目录存在
|
|||
|
os.makedirs('static/plots', exist_ok=True)
|
|||
|
os.makedirs('static/csv', exist_ok=True)
|
|||
|
os.makedirs('static/predictions/compare', exist_ok=True)
|
|||
|
|
|||
|
# 确保模型目录存在,如果不存在则使用DEFAULT_MODEL_DIR
|
|||
|
if not os.path.exists(app.config['MODEL_DIR']):
|
|||
|
print(f"警告: 配置的模型目录 '{app.config['MODEL_DIR']}' 不存在")
|
|||
|
if os.path.exists(DEFAULT_MODEL_DIR):
|
|||
|
print(f"使用默认目录 '{DEFAULT_MODEL_DIR}'")
|
|||
|
app.config['MODEL_DIR'] = DEFAULT_MODEL_DIR
|
|||
|
|
|||
|
os.makedirs(app.config['MODEL_DIR'], exist_ok=True)
|
|||
|
|
|||
|
print(f"启动API服务,地址: {args.host}:{args.port}")
|
|||
|
print(f"API文档地址: http://{args.host}:{args.port}/swagger/")
|
|||
|
print(f"UI界面地址: http://{args.host}:{args.port}/ui/")
|
|||
|
print(f"模型保存目录: {app.config['MODEL_DIR']}")
|
|||
|
|
|||
|
app.run(host=args.host, port=args.port, debug=args.debug)
|