
1. 修复前端图表日期排序问题: - 改进 PredictionView.vue 和 HistoryView.vue 中的图表渲染逻辑 - 确保历史数据和预测数据按照正确的日期顺序显示 2. 修复后端API处理: - 解决 optimized_kan 模型类型的路径映射问题 - 添加 JSON 序列化器处理 Pandas Timestamp 对象 - 改进预测数据与历史数据的衔接处理 3. 优化图表样式和用户体验
2549 lines
90 KiB
Python
2549 lines
90 KiB
Python
import sys
|
||
import os
|
||
|
||
# 获取当前脚本所在目录的绝对路径
|
||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||
|
||
# 将当前目录添加到系统路径
|
||
sys.path.append(current_dir)
|
||
|
||
# 或者添加父目录
|
||
#parent_dir = os.path.dirname(current_dir)
|
||
#sys.path.append(parent_dir)
|
||
|
||
|
||
|
||
|
||
import json
|
||
import pandas as pd
|
||
import numpy as np
|
||
import torch
|
||
import matplotlib.pyplot as plt
|
||
import io
|
||
import base64
|
||
import uuid
|
||
from datetime import datetime, timedelta
|
||
from flask import Flask, jsonify, request, send_file, redirect, send_from_directory, Response
|
||
from flask_cors import CORS
|
||
from flasgger import Swagger
|
||
from werkzeug.utils import secure_filename
|
||
import sqlite3
|
||
import traceback
|
||
|
||
|
||
|
||
from pharmacy_predictor import (
|
||
train_product_model_with_mlstm,
|
||
train_product_model_with_kan,
|
||
train_product_model_with_transformer,
|
||
load_model_and_predict
|
||
)
|
||
import threading
|
||
import base64
|
||
import matplotlib.pyplot as plt
|
||
from io import BytesIO
|
||
import json
|
||
from flasgger import Swagger, swag_from
|
||
import argparse
|
||
from datetime import datetime, timedelta
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
from threading import Lock
|
||
import traceback
|
||
import torch
|
||
import sqlite3
|
||
import numpy as np
|
||
import io
|
||
from werkzeug.utils import secure_filename
|
||
|
||
# 自定义JSON编码器来处理Pandas的Timestamp和NumPy类型
|
||
class CustomJSONEncoder(json.JSONEncoder):
|
||
def default(self, obj):
|
||
if isinstance(obj, (pd.Timestamp, pd.DatetimeIndex)):
|
||
return obj.strftime('%Y-%m-%d')
|
||
elif isinstance(obj, np.integer):
|
||
return int(obj)
|
||
elif isinstance(obj, np.floating):
|
||
return float(obj)
|
||
elif isinstance(obj, np.ndarray):
|
||
return obj.tolist()
|
||
elif pd.isna(obj): # 处理NaN值
|
||
return None
|
||
return super(CustomJSONEncoder, self).default(obj)
|
||
|
||
app = Flask(__name__)
|
||
# 设置自定义JSON编码器
|
||
app.json_encoder = CustomJSONEncoder
|
||
CORS(app) # 启用CORS支持
|
||
|
||
# Swagger配置
|
||
swagger_config = {
|
||
"headers": [],
|
||
"specs": [
|
||
{
|
||
"endpoint": "apispec",
|
||
"route": "/apispec.json",
|
||
"rule_filter": lambda rule: True, # 包含所有路由
|
||
"model_filter": lambda tag: True, # 包含所有模型
|
||
}
|
||
],
|
||
"static_url_path": "/flasgger_static",
|
||
"swagger_ui": True,
|
||
"specs_route": "/swagger/"
|
||
}
|
||
|
||
swagger_template = {
|
||
"swagger": "2.0",
|
||
"info": {
|
||
"title": "药店销售预测系统API",
|
||
"description": "用于药店销售预测的RESTful API",
|
||
"version": "1.0.0",
|
||
"contact": {
|
||
"name": "API开发团队",
|
||
"email": "support@example.com"
|
||
}
|
||
},
|
||
"tags": [
|
||
{
|
||
"name": "数据管理",
|
||
"description": "数据上传和查询相关接口"
|
||
},
|
||
{
|
||
"name": "模型训练",
|
||
"description": "模型训练相关接口"
|
||
},
|
||
{
|
||
"name": "模型预测",
|
||
"description": "预测销售数据相关接口"
|
||
},
|
||
{
|
||
"name": "模型管理",
|
||
"description": "模型查询、导出和删除接口"
|
||
}
|
||
]
|
||
}
|
||
|
||
swagger = Swagger(app, config=swagger_config, template=swagger_template)
|
||
|
||
# 存储训练任务状态
|
||
training_tasks = {}
|
||
tasks_lock = Lock()
|
||
|
||
# 线程池用于后台训练
|
||
executor = ThreadPoolExecutor(max_workers=2)
|
||
|
||
# 辅助函数:将图像转换为Base64
|
||
def fig_to_base64(fig):
|
||
buf = BytesIO()
|
||
fig.savefig(buf, format='png')
|
||
buf.seek(0)
|
||
img_str = base64.b64encode(buf.read()).decode('utf-8')
|
||
return img_str
|
||
|
||
# 根路由 - 重定向到UI界面
|
||
@app.route('/')
|
||
def index():
|
||
"""重定向到UI界面"""
|
||
return redirect('/ui/')
|
||
|
||
# Swagger UI路由
|
||
@app.route('/swagger')
|
||
def swagger_ui():
|
||
"""重定向到Swagger UI文档页面"""
|
||
return redirect('/swagger/')
|
||
|
||
# 1. 数据管理API
|
||
@app.route('/api/products', methods=['GET'])
|
||
@swag_from({
|
||
'tags': ['数据管理'],
|
||
'summary': '获取所有产品列表',
|
||
'description': '返回系统中所有产品的ID和名称',
|
||
'responses': {
|
||
200: {
|
||
'description': '成功获取产品列表',
|
||
'schema': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'status': {'type': 'string'},
|
||
'data': {
|
||
'type': 'array',
|
||
'items': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'product_id': {'type': 'string'},
|
||
'product_name': {'type': 'string'}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
},
|
||
500: {
|
||
'description': '服务器内部错误'
|
||
}
|
||
}
|
||
})
|
||
def get_products():
|
||
try:
|
||
df = pd.read_excel('pharmacy_sales.xlsx')
|
||
products = df[['product_id', 'product_name']].drop_duplicates().to_dict('records')
|
||
return jsonify({"status": "success", "data": products})
|
||
except Exception as e:
|
||
return jsonify({"status": "error", "message": str(e)}), 500
|
||
|
||
@app.route('/api/products/<product_id>', methods=['GET'])
|
||
@swag_from({
|
||
'tags': ['数据管理'],
|
||
'summary': '获取单个产品详情',
|
||
'description': '返回指定产品ID的详细信息',
|
||
'parameters': [
|
||
{
|
||
'name': 'product_id',
|
||
'in': 'path',
|
||
'type': 'string',
|
||
'required': True,
|
||
'description': '产品ID,例如P001'
|
||
}
|
||
],
|
||
'responses': {
|
||
200: {
|
||
'description': '成功获取产品详情',
|
||
'schema': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'status': {'type': 'string'},
|
||
'data': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'product_id': {'type': 'string'},
|
||
'product_name': {'type': 'string'},
|
||
'data_points': {'type': 'integer'},
|
||
'date_range': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'start': {'type': 'string'},
|
||
'end': {'type': 'string'}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
},
|
||
404: {
|
||
'description': '产品不存在'
|
||
},
|
||
500: {
|
||
'description': '服务器内部错误'
|
||
}
|
||
}
|
||
})
|
||
def get_product(product_id):
|
||
try:
|
||
df = pd.read_excel('pharmacy_sales.xlsx')
|
||
product_df = df[df['product_id'] == product_id]
|
||
|
||
if product_df.empty:
|
||
return jsonify({"status": "error", "message": "产品不存在"}), 404
|
||
|
||
product_info = {
|
||
"product_id": product_id,
|
||
"product_name": product_df['product_name'].iloc[0],
|
||
"data_points": len(product_df),
|
||
"date_range": {
|
||
"start": product_df['date'].min().strftime('%Y-%m-%d'),
|
||
"end": product_df['date'].max().strftime('%Y-%m-%d')
|
||
}
|
||
}
|
||
|
||
return jsonify({"status": "success", "data": product_info})
|
||
except Exception as e:
|
||
return jsonify({"status": "error", "message": str(e)}), 500
|
||
|
||
@app.route('/api/products/<product_id>/sales', methods=['GET'])
|
||
@swag_from({
|
||
'tags': ['数据管理'],
|
||
'summary': '获取产品销售数据',
|
||
'description': '返回指定产品在特定日期范围内的销售数据',
|
||
'parameters': [
|
||
{
|
||
'name': 'product_id',
|
||
'in': 'path',
|
||
'type': 'string',
|
||
'required': True,
|
||
'description': '产品ID,例如P001'
|
||
},
|
||
{
|
||
'name': 'start_date',
|
||
'in': 'query',
|
||
'type': 'string',
|
||
'required': False,
|
||
'description': '开始日期,格式为YYYY-MM-DD'
|
||
},
|
||
{
|
||
'name': 'end_date',
|
||
'in': 'query',
|
||
'type': 'string',
|
||
'required': False,
|
||
'description': '结束日期,格式为YYYY-MM-DD'
|
||
}
|
||
],
|
||
'responses': {
|
||
200: {
|
||
'description': '成功获取销售数据',
|
||
'schema': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'status': {'type': 'string'},
|
||
'data': {
|
||
'type': 'array',
|
||
'items': {
|
||
'type': 'object'
|
||
}
|
||
}
|
||
}
|
||
}
|
||
},
|
||
404: {
|
||
'description': '产品不存在'
|
||
},
|
||
500: {
|
||
'description': '服务器内部错误'
|
||
}
|
||
}
|
||
})
|
||
def get_product_sales(product_id):
|
||
try:
|
||
start_date = request.args.get('start_date')
|
||
end_date = request.args.get('end_date')
|
||
|
||
df = pd.read_excel('pharmacy_sales.xlsx')
|
||
product_df = df[df['product_id'] == product_id].sort_values('date')
|
||
|
||
if product_df.empty:
|
||
return jsonify({"status": "error", "message": "产品不存在"}), 404
|
||
|
||
# 如果提供了日期范围,进行过滤
|
||
if start_date:
|
||
product_df = product_df[product_df['date'] >= pd.to_datetime(start_date)]
|
||
if end_date:
|
||
product_df = product_df[product_df['date'] <= pd.to_datetime(end_date)]
|
||
|
||
# 转换日期为字符串以便JSON序列化
|
||
product_df['date'] = product_df['date'].dt.strftime('%Y-%m-%d')
|
||
|
||
sales_data = product_df.to_dict('records')
|
||
return jsonify({"status": "success", "data": sales_data})
|
||
except Exception as e:
|
||
return jsonify({"status": "error", "message": str(e)}), 500
|
||
|
||
@app.route('/api/data/upload', methods=['POST'])
|
||
@swag_from({
|
||
'tags': ['数据管理'],
|
||
'summary': '上传销售数据',
|
||
'description': '上传新的销售数据文件(Excel格式)',
|
||
'consumes': ['multipart/form-data'],
|
||
'parameters': [
|
||
{
|
||
'name': 'file',
|
||
'in': 'formData',
|
||
'type': 'file',
|
||
'required': True,
|
||
'description': 'Excel文件(.xlsx),包含销售数据'
|
||
}
|
||
],
|
||
'responses': {
|
||
200: {
|
||
'description': '数据上传成功',
|
||
'schema': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'status': {'type': 'string'},
|
||
'message': {'type': 'string'},
|
||
'data': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'products': {'type': 'integer'},
|
||
'rows': {'type': 'integer'}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
},
|
||
400: {
|
||
'description': '请求错误,可能是文件格式不正确或缺少必要字段'
|
||
},
|
||
500: {
|
||
'description': '服务器内部错误'
|
||
}
|
||
}
|
||
})
|
||
def upload_data():
|
||
try:
|
||
if 'file' not in request.files:
|
||
return jsonify({"status": "error", "message": "没有上传文件"}), 400
|
||
|
||
file = request.files['file']
|
||
if file.filename == '':
|
||
return jsonify({"status": "error", "message": "没有选择文件"}), 400
|
||
|
||
if file and file.filename.endswith('.xlsx'):
|
||
file_path = 'uploaded_data.xlsx'
|
||
file.save(file_path)
|
||
|
||
# 验证数据格式
|
||
try:
|
||
df = pd.read_excel(file_path)
|
||
required_columns = ['date', 'product_id', 'product_name', 'sales', 'price']
|
||
missing_columns = [col for col in required_columns if col not in df.columns]
|
||
|
||
if missing_columns:
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"上传的数据缺少必要的列: {', '.join(missing_columns)}"
|
||
}), 400
|
||
|
||
# 合并到现有数据或替换现有数据
|
||
existing_df = pd.read_excel('pharmacy_sales.xlsx')
|
||
# 这里可以实现数据合并逻辑,例如按日期和产品ID去重后合并
|
||
|
||
# 简单示例:保存上传的数据
|
||
df.to_excel('pharmacy_sales.xlsx', index=False)
|
||
|
||
return jsonify({
|
||
"status": "success",
|
||
"message": "数据上传成功",
|
||
"data": {
|
||
"products": len(df['product_id'].unique()),
|
||
"rows": len(df)
|
||
}
|
||
})
|
||
except Exception as e:
|
||
return jsonify({"status": "error", "message": f"数据验证失败: {str(e)}"}), 400
|
||
else:
|
||
return jsonify({"status": "error", "message": "只支持Excel文件(.xlsx)"}), 400
|
||
except Exception as e:
|
||
return jsonify({"status": "error", "message": str(e)}), 500
|
||
|
||
# 2. 模型训练API
|
||
@app.route('/api/training', methods=['GET'])
|
||
@swag_from({
|
||
'tags': ['模型训练'],
|
||
'summary': '获取所有训练任务列表',
|
||
'description': '返回所有正在进行、已完成或失败的训练任务',
|
||
'responses': {
|
||
200: {
|
||
'description': '成功获取任务列表',
|
||
'schema': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'status': {'type': 'string'},
|
||
'data': {
|
||
'type': 'array',
|
||
'items': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'task_id': {'type': 'string'},
|
||
'product_id': {'type': 'string'},
|
||
'model_type': {'type': 'string'},
|
||
'status': {'type': 'string'},
|
||
'start_time': {'type': 'string'},
|
||
'metrics': {'type': 'object'},
|
||
'error': {'type': 'string'},
|
||
'model_path': {'type': 'string'}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
})
|
||
def get_all_training_tasks():
|
||
"""获取所有训练任务的状态"""
|
||
try:
|
||
# 为了方便前端使用,我们将任务ID也包含在每个任务信息中
|
||
tasks_with_id = []
|
||
for task_id, task_info in training_tasks.items():
|
||
task_copy = task_info.copy()
|
||
task_copy['task_id'] = task_id
|
||
tasks_with_id.append(task_copy)
|
||
|
||
# 按开始时间降序排序,最新的任务在前面
|
||
sorted_tasks = sorted(tasks_with_id, key=lambda x: x['start_time'], reverse=True)
|
||
|
||
return jsonify({"status": "success", "data": sorted_tasks})
|
||
except Exception as e:
|
||
return jsonify({"status": "error", "message": str(e)}), 500
|
||
|
||
@app.route('/api/training', methods=['POST'])
|
||
@swag_from({
|
||
'tags': ['模型训练'],
|
||
'summary': '启动模型训练任务',
|
||
'description': '为指定产品启动一个新的模型训练任务',
|
||
'parameters': [
|
||
{
|
||
'name': 'body',
|
||
'in': 'body',
|
||
'required': True,
|
||
'schema': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'product_id': {'type': 'string', 'description': '例如 P001'},
|
||
'model_type': {'type': 'string', 'enum': ['mlstm', 'transformer', 'kan', 'optimized_kan'], 'description': '要训练的模型类型'},
|
||
'epochs': {'type': 'integer', 'default': 50, 'description': '训练轮次'}
|
||
},
|
||
'required': ['product_id', 'model_type']
|
||
}
|
||
}
|
||
],
|
||
'responses': {
|
||
200: {
|
||
'description': '训练任务已启动',
|
||
'schema': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'message': {'type': 'string'},
|
||
'task_id': {'type': 'string'}
|
||
}
|
||
}
|
||
},
|
||
400: {
|
||
'description': '请求错误'
|
||
}
|
||
}
|
||
})
|
||
def start_training():
|
||
"""
|
||
启动模型训练
|
||
---
|
||
post:
|
||
...
|
||
"""
|
||
data = request.get_json()
|
||
product_id = data.get('product_id')
|
||
model_type = data.get('model_type')
|
||
epochs = data.get('epochs', 50) # 默认为50轮
|
||
|
||
if not product_id or not model_type:
|
||
return jsonify({'error': '缺少product_id或model_type'}), 400
|
||
|
||
global training_tasks
|
||
task_id = str(uuid.uuid4())
|
||
|
||
model_train_functions = {
|
||
'mlstm': train_product_model_with_mlstm,
|
||
'kan': train_product_model_with_kan,
|
||
'transformer': train_product_model_with_transformer,
|
||
'optimized_kan': lambda product_id, epochs: train_product_model_with_kan(product_id, epochs, use_optimized=True)
|
||
}
|
||
|
||
if model_type not in model_train_functions:
|
||
return jsonify({'error': '无效的模型类型'}), 400
|
||
|
||
train_function = model_train_functions[model_type]
|
||
|
||
def train_task(product_id, epochs, model_type):
|
||
global training_tasks
|
||
try:
|
||
print(f"任务 {task_id}: 开始为产品 {product_id} 训练 {model_type} 模型,共 {epochs} 个轮次。")
|
||
# 这里的 train_function 会返回 (model, metrics)
|
||
_, metrics = train_function(product_id, epochs)
|
||
training_tasks[task_id]['status'] = 'completed'
|
||
training_tasks[task_id]['metrics'] = metrics
|
||
# 保存模型路径
|
||
model_path = f'models/{model_type}/{product_id}_model.pt'
|
||
training_tasks[task_id]['model_path'] = model_path
|
||
print(f"任务 {task_id}: 训练完成。评估指标: {metrics}")
|
||
except Exception as e:
|
||
import traceback
|
||
traceback.print_exc()
|
||
print(f"任务 {task_id}: 训练失败。错误: {e}")
|
||
training_tasks[task_id]['status'] = 'failed'
|
||
training_tasks[task_id]['error'] = str(e)
|
||
|
||
thread = threading.Thread(target=train_task, args=(product_id, epochs, model_type))
|
||
thread.start()
|
||
|
||
training_tasks[task_id] = {
|
||
'status': 'running',
|
||
'product_id': product_id,
|
||
'model_type': model_type,
|
||
'start_time': datetime.now().isoformat(),
|
||
'metrics': None,
|
||
'error': None,
|
||
'model_path': None
|
||
}
|
||
|
||
return jsonify({'message': '模型训练已开始', 'task_id': task_id})
|
||
|
||
@app.route('/api/training/<task_id>', methods=['GET'])
|
||
@swag_from({
|
||
'tags': ['模型训练'],
|
||
'summary': '查询训练任务状态',
|
||
'description': '获取特定训练任务的当前状态和详情',
|
||
'parameters': [
|
||
{
|
||
'name': 'task_id',
|
||
'in': 'path',
|
||
'type': 'string',
|
||
'required': True,
|
||
'description': '训练任务ID'
|
||
}
|
||
],
|
||
'responses': {
|
||
200: {
|
||
'description': '成功获取任务状态',
|
||
'schema': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'status': {'type': 'string'},
|
||
'data': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'product_id': {'type': 'string'},
|
||
'model_type': {'type': 'string'},
|
||
'parameters': {'type': 'object'},
|
||
'status': {'type': 'string', 'enum': ['pending', 'running', 'completed', 'failed']},
|
||
'created_at': {'type': 'string'},
|
||
'model_path': {'type': 'string'},
|
||
'metrics': {'type': 'object'},
|
||
'model_details_url': {'type': 'string'}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
},
|
||
404: {
|
||
'description': '任务不存在'
|
||
},
|
||
500: {
|
||
'description': '服务器内部错误'
|
||
}
|
||
}
|
||
})
|
||
def get_training_status(task_id):
|
||
try:
|
||
if task_id not in training_tasks:
|
||
return jsonify({"status": "error", "message": "任务不存在"}), 404
|
||
|
||
task_info = training_tasks[task_id].copy()
|
||
|
||
# 如果任务已完成,添加模型详情链接
|
||
if task_info['status'] == 'completed':
|
||
task_info['model_details_url'] = f"/api/models?product_id={task_info['product_id']}&model_type={task_info['model_type']}"
|
||
|
||
return jsonify({
|
||
"status": "success",
|
||
"data": task_info
|
||
})
|
||
except Exception as e:
|
||
return jsonify({"status": "error", "message": str(e)}), 500
|
||
|
||
# 3. 模型预测API
|
||
@app.route('/api/prediction', methods=['POST'])
|
||
@swag_from({
|
||
'tags': ['模型预测'],
|
||
'summary': '使用模型进行预测',
|
||
'description': '使用指定模型预测未来销售数据',
|
||
'parameters': [
|
||
{
|
||
'name': 'body',
|
||
'in': 'body',
|
||
'required': True,
|
||
'schema': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'product_id': {'type': 'string'},
|
||
'model_type': {'type': 'string', 'enum': ['mlstm', 'transformer', 'kan', 'optimized_kan']},
|
||
'version': {'type': 'string'},
|
||
'future_days': {'type': 'integer'},
|
||
'include_visualization': {'type': 'boolean'},
|
||
'start_date': {'type': 'string', 'description': '预测起始日期,格式为YYYY-MM-DD'}
|
||
},
|
||
'required': ['product_id', 'model_type']
|
||
}
|
||
}
|
||
],
|
||
'responses': {
|
||
200: {
|
||
'description': '预测成功',
|
||
'schema': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'status': {'type': 'string'},
|
||
'data': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'product_id': {'type': 'string'},
|
||
'product_name': {'type': 'string'},
|
||
'model_type': {'type': 'string'},
|
||
'predictions': {'type': 'array'},
|
||
'visualization': {'type': 'string'}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
},
|
||
400: {
|
||
'description': '请求错误,缺少必要参数或参数格式错误'
|
||
},
|
||
404: {
|
||
'description': '产品或模型不存在'
|
||
},
|
||
500: {
|
||
'description': '服务器内部错误'
|
||
}
|
||
}
|
||
})
|
||
def predict():
|
||
"""
|
||
使用指定的模型进行预测
|
||
---
|
||
tags:
|
||
- 模型预测
|
||
parameters:
|
||
- in: body
|
||
name: body
|
||
schema:
|
||
type: object
|
||
required:
|
||
- product_id
|
||
- model_type
|
||
properties:
|
||
product_id:
|
||
type: string
|
||
description: 产品ID
|
||
model_type:
|
||
type: string
|
||
description: 模型类型 (mlstm, kan, transformer)
|
||
future_days:
|
||
type: integer
|
||
description: 预测未来天数
|
||
default: 7
|
||
start_date:
|
||
type: string
|
||
description: 预测起始日期,格式为YYYY-MM-DD
|
||
default: ''
|
||
responses:
|
||
200:
|
||
description: 预测成功
|
||
400:
|
||
description: 请求参数错误
|
||
404:
|
||
description: 模型文件未找到
|
||
"""
|
||
try:
|
||
data = request.json
|
||
product_id = data.get('product_id')
|
||
model_type = data.get('model_type')
|
||
future_days = int(data.get('future_days', 7))
|
||
start_date = data.get('start_date', '')
|
||
include_visualization = data.get('include_visualization', False)
|
||
|
||
print(f"API接收到预测请求: product_id={product_id}, model_type={model_type}, future_days={future_days}, start_date={start_date}")
|
||
|
||
if not product_id or not model_type:
|
||
return jsonify({"status": "error", "error": "product_id 和 model_type 是必需的"}), 400
|
||
|
||
# 获取产品名称
|
||
product_name = get_product_name(product_id)
|
||
if not product_name:
|
||
product_name = product_id
|
||
|
||
# 获取最新模型ID - 此处不需要处理模型类型映射,因为 get_latest_model_id 和 load_model_and_predict 会处理
|
||
model_id = get_latest_model_id(model_type, product_id)
|
||
if not model_id:
|
||
return jsonify({"status": "error", "error": f"未找到产品 {product_id} 的 {model_type} 类型模型"}), 404
|
||
|
||
# 执行预测
|
||
prediction_result = run_prediction(model_type, product_id, model_id, future_days, start_date)
|
||
|
||
# 如果需要可视化,添加图表数据
|
||
if include_visualization:
|
||
# 添加图表数据
|
||
chart_data = prepare_chart_data(prediction_result)
|
||
prediction_result['chart_data'] = chart_data
|
||
|
||
# 添加分析结果
|
||
analysis_result = analyze_prediction(prediction_result)
|
||
prediction_result['analysis'] = analysis_result
|
||
|
||
# 保存预测结果到文件和数据库
|
||
prediction_id, file_path = save_prediction_result(
|
||
prediction_result,
|
||
product_id,
|
||
product_name,
|
||
model_type,
|
||
model_id,
|
||
start_date,
|
||
future_days
|
||
)
|
||
|
||
# 添加预测ID到结果中
|
||
prediction_result['prediction_id'] = prediction_id
|
||
|
||
return jsonify(prediction_result)
|
||
except Exception as e:
|
||
print(f"预测失败: {str(e)}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return jsonify({"status": "error", "error": str(e)}), 500
|
||
|
||
@app.route('/api/prediction/compare', methods=['POST'])
|
||
@swag_from({
|
||
'tags': ['模型预测'],
|
||
'summary': '比较不同模型预测结果',
|
||
'description': '比较不同模型对同一产品的预测结果',
|
||
'parameters': [
|
||
{
|
||
'name': 'body',
|
||
'in': 'body',
|
||
'required': True,
|
||
'schema': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'product_id': {'type': 'string'},
|
||
'model_types': {'type': 'array', 'items': {'type': 'string'}},
|
||
'versions': {'type': 'array', 'items': {'type': 'string'}},
|
||
'include_visualization': {'type': 'boolean'}
|
||
},
|
||
'required': ['product_id', 'model_types']
|
||
}
|
||
}
|
||
],
|
||
'responses': {
|
||
200: {
|
||
'description': '比较成功',
|
||
'schema': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'status': {'type': 'string'},
|
||
'data': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'product_id': {'type': 'string'},
|
||
'product_name': {'type': 'string'},
|
||
'model_types': {'type': 'array'},
|
||
'comparison': {'type': 'array'},
|
||
'visualization': {'type': 'string'}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
},
|
||
400: {
|
||
'description': '请求错误,缺少必要参数或参数格式错误'
|
||
},
|
||
404: {
|
||
'description': '产品或模型不存在'
|
||
},
|
||
500: {
|
||
'description': '服务器内部错误'
|
||
}
|
||
}
|
||
})
|
||
def compare_predictions():
|
||
try:
|
||
data = request.json
|
||
|
||
product_id = data.get('product_id')
|
||
model_types = data.get('model_types')
|
||
|
||
if not product_id or not model_types:
|
||
return jsonify({"status": "error", "error": "product_id 和 model_types 是必需的"}), 400
|
||
|
||
all_predictions = {}
|
||
plt.figure(figsize=(12, 8))
|
||
|
||
# 加载历史数据用于绘图
|
||
df = pd.read_excel('pharmacy_sales.xlsx')
|
||
product_df = df[df['product_id'] == product_id].sort_values('date')
|
||
product_name = product_df['product_name'].iloc[0]
|
||
history_days = 30
|
||
history_dates = product_df['date'].iloc[-history_days:].values
|
||
history_sales = product_df['sales'].iloc[-history_days:].values
|
||
plt.plot(history_dates, history_sales, 'b-', label='历史销量')
|
||
|
||
comparison_data = []
|
||
future_dates = None
|
||
|
||
# 创建比较结果目录
|
||
compare_dir = f'predictions/compare'
|
||
os.makedirs(compare_dir, exist_ok=True)
|
||
|
||
for model_type in model_types:
|
||
result = load_model_and_predict(product_id, model_type)
|
||
if result is not None:
|
||
predictions_df = result['predictions_df']
|
||
|
||
if future_dates is None:
|
||
future_dates = predictions_df['date']
|
||
comparison_data = [{'date': d.strftime('%Y-%m-%d')} for d in future_dates]
|
||
|
||
plt.plot(predictions_df['date'], predictions_df['predicted_sales'], '--', label=f'{model_type.upper()} 预测')
|
||
|
||
preds = predictions_df['predicted_sales'].tolist()
|
||
for i in range(len(comparison_data)):
|
||
comparison_data[i][f'{model_type}_prediction'] = preds[i] if i < len(preds) else None
|
||
|
||
plt.title(f'{product_name} - 多模型预测比较')
|
||
plt.xlabel('日期')
|
||
plt.ylabel('销量')
|
||
plt.legend()
|
||
plt.grid(True)
|
||
plt.xticks(rotation=45)
|
||
plt.tight_layout()
|
||
|
||
chart_path = f'{compare_dir}/{product_id}_model_comparison.png'
|
||
plt.savefig(chart_path)
|
||
plt.close() # 关闭图表,释放内存
|
||
|
||
# 获取当前服务器主机和端口
|
||
host_url = request.host_url.rstrip('/') # 移除末尾的斜杠
|
||
|
||
# 生成带时间戳的URL以避免缓存
|
||
timestamp = datetime.now().timestamp()
|
||
image_url = f"{host_url}/api/predictions/compare/{product_id}_model_comparison.png?t={timestamp}"
|
||
|
||
return jsonify({
|
||
"status": "success",
|
||
"data": {
|
||
"product_id": product_id,
|
||
"product_name": product_name,
|
||
"model_types": model_types,
|
||
"comparison": comparison_data,
|
||
"visualization_url": image_url
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
traceback.print_exc()
|
||
return jsonify({"status": "error", "error": f"比较预测时出错: {e}"}), 500
|
||
|
||
@app.route('/api/prediction/analyze', methods=['POST'])
|
||
def analyze_prediction():
|
||
"""
|
||
分析预测结果,提供详细解释
|
||
---
|
||
tags:
|
||
- 预测分析
|
||
parameters:
|
||
- name: body
|
||
in: body
|
||
required: true
|
||
schema:
|
||
type: object
|
||
properties:
|
||
product_id:
|
||
type: string
|
||
example: P001
|
||
model_type:
|
||
type: string
|
||
enum: [mlstm, transformer, kan, optimized_kan]
|
||
future_days:
|
||
type: integer
|
||
default: 7
|
||
start_date:
|
||
type: string
|
||
description: 预测起始日期,格式为'YYYY-MM-DD'
|
||
responses:
|
||
200:
|
||
description: 预测结果及其分析
|
||
schema:
|
||
type: object
|
||
properties:
|
||
status:
|
||
type: string
|
||
example: success
|
||
data:
|
||
type: object
|
||
properties:
|
||
predictions:
|
||
type: array
|
||
items:
|
||
type: number
|
||
dates:
|
||
type: array
|
||
items:
|
||
type: string
|
||
analysis:
|
||
type: object
|
||
properties:
|
||
trend:
|
||
type: string
|
||
statistics:
|
||
type: object
|
||
historical_comparison:
|
||
type: object
|
||
factors:
|
||
type: array
|
||
items:
|
||
type: object
|
||
explanation:
|
||
type: string
|
||
"""
|
||
try:
|
||
data = request.json
|
||
product_id = data.get('product_id')
|
||
model_type = data.get('model_type')
|
||
future_days = data.get('future_days', 7)
|
||
start_date = data.get('start_date')
|
||
|
||
# 验证参数
|
||
if not product_id or not model_type:
|
||
return jsonify({"status": "error", "error": "缺少必要参数"}), 400
|
||
|
||
if model_type not in ['mlstm', 'transformer', 'kan', 'optimized_kan']:
|
||
return jsonify({"status": "error", "error": "不支持的模型类型"}), 400
|
||
|
||
# 获取预测结果和分析
|
||
result_tuple = load_model_and_predict(product_id, model_type, future_days, start_date, analyze_result=True)
|
||
|
||
if result_tuple is None:
|
||
return jsonify({"status": "error", "error": f"模型 'models/{model_type}/{product_id}_model.pt' 未找到或加载失败"}), 404
|
||
|
||
# 解包元组
|
||
result, analysis = result_tuple
|
||
|
||
if result is None:
|
||
return jsonify({"status": "error", "error": "预测失败"}), 500
|
||
|
||
# 从result中获取预测值
|
||
predictions_df = result['predictions_df']
|
||
prediction_data = predictions_df[predictions_df['data_type'] == '预测销量']
|
||
predictions_list = prediction_data['sales'].tolist()
|
||
|
||
# 生成日期列表
|
||
if start_date:
|
||
start_date_obj = datetime.strptime(start_date, '%Y-%m-%d')
|
||
else:
|
||
# 获取最后一条数据的日期并加1天
|
||
df = pd.read_excel('pharmacy_sales.xlsx')
|
||
product_df = df[df['product_id'] == product_id].sort_values('date')
|
||
last_date = product_df['date'].iloc[-1]
|
||
start_date_obj = last_date + timedelta(days=1)
|
||
|
||
dates = [(start_date_obj + timedelta(days=i)).strftime('%Y-%m-%d') for i in range(future_days)]
|
||
|
||
# 为前端ECharts准备数据
|
||
chart_data = {
|
||
"dates": dates,
|
||
"values": predictions_list,
|
||
"day_over_day_changes": analysis["statistics"]["day_over_day_changes"] if "day_over_day_changes" in analysis["statistics"] else []
|
||
}
|
||
|
||
# 如果有历史对比数据,也添加到图表数据中
|
||
if analysis["historical_comparison"]["has_historical_data"]:
|
||
chart_data["historical"] = {
|
||
"mean": analysis["historical_comparison"]["historical_mean"],
|
||
"max": analysis["historical_comparison"]["historical_max"],
|
||
"min": analysis["historical_comparison"]["historical_min"]
|
||
}
|
||
|
||
return jsonify({
|
||
"status": "success",
|
||
"data": {
|
||
"product_id": product_id,
|
||
"model_type": model_type,
|
||
"predictions": predictions_list,
|
||
"dates": dates,
|
||
"chart_data": chart_data,
|
||
"analysis": analysis
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
import traceback
|
||
traceback.print_exc()
|
||
return jsonify({"status": "error", "error": str(e)}), 500
|
||
|
||
# 4. 模型管理API
|
||
@app.route('/api/models', methods=['GET'])
|
||
@swag_from({
|
||
'tags': ['模型管理'],
|
||
'summary': '获取模型列表',
|
||
'description': '获取系统中的模型列表,可按产品ID和模型类型筛选',
|
||
'parameters': [
|
||
{
|
||
'name': 'product_id',
|
||
'in': 'query',
|
||
'type': 'string',
|
||
'required': False,
|
||
'description': '按产品ID筛选'
|
||
},
|
||
{
|
||
'name': 'model_type',
|
||
'in': 'query',
|
||
'type': 'string',
|
||
'required': False,
|
||
'description': '按模型类型筛选 (mlstm, kan, transformer)'
|
||
}
|
||
],
|
||
'responses': {
|
||
200: {
|
||
'description': '成功获取模型列表',
|
||
'schema': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'status': {'type': 'string'},
|
||
'data': {
|
||
'type': 'array',
|
||
'items': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'model_id': {'type': 'string'},
|
||
'product_id': {'type': 'string'},
|
||
'product_name': {'type': 'string'},
|
||
'model_type': {'type': 'string'},
|
||
'created_at': {'type': 'string'},
|
||
'metrics': {'type': 'object'}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
},
|
||
500: {
|
||
'description': '服务器内部错误'
|
||
}
|
||
}
|
||
})
|
||
def list_models():
|
||
"""
|
||
列出所有可用的模型
|
||
---
|
||
tags:
|
||
- 模型管理
|
||
parameters:
|
||
- name: product_id
|
||
in: query
|
||
type: string
|
||
required: false
|
||
description: 按产品ID筛选模型
|
||
- name: model_type
|
||
in: query
|
||
type: string
|
||
required: false
|
||
description: 按模型类型筛选 (mlstm, kan, transformer)
|
||
responses:
|
||
200:
|
||
description: 模型列表
|
||
schema:
|
||
type: object
|
||
properties:
|
||
status:
|
||
type: string
|
||
example: success
|
||
data:
|
||
type: array
|
||
items:
|
||
type: object
|
||
properties:
|
||
model_id:
|
||
type: string
|
||
product_id:
|
||
type: string
|
||
product_name:
|
||
type: string
|
||
model_type:
|
||
type: string
|
||
created_at:
|
||
type: string
|
||
metrics:
|
||
type: object
|
||
"""
|
||
models_dir = 'models'
|
||
model_types = ['mlstm', 'kan', 'transformer', 'kan_optimized']
|
||
model_type_mapping = {
|
||
'mlstm': 'mlstm',
|
||
'kan': 'kan',
|
||
'transformer': 'transformer',
|
||
'kan_optimized': 'optimized_kan' # 将kan_optimized目录映射为optimized_kan类型
|
||
}
|
||
available_models = []
|
||
|
||
product_id_filter = request.args.get('product_id')
|
||
model_type_filter = request.args.get('model_type')
|
||
|
||
if not os.path.exists(models_dir):
|
||
return jsonify({"status": "success", "data": []})
|
||
|
||
for model_type in model_types:
|
||
# 如果指定了模型类型过滤器,检查是否匹配当前类型或其映射类型
|
||
if model_type_filter:
|
||
# 如果过滤器是optimized_kan,我们需要查找kan_optimized目录
|
||
if model_type_filter == 'optimized_kan' and model_type != 'kan_optimized':
|
||
continue
|
||
# 如果过滤器不是optimized_kan,但当前目录是kan_optimized,跳过
|
||
elif model_type_filter != 'optimized_kan' and model_type_filter != model_type:
|
||
continue
|
||
|
||
type_dir = os.path.join(models_dir, model_type)
|
||
if not os.path.exists(type_dir):
|
||
continue
|
||
|
||
for file_name in os.listdir(type_dir):
|
||
if file_name.endswith('_log.json'):
|
||
product_id = file_name.replace('_log.json', '')
|
||
|
||
if product_id_filter and product_id_filter != product_id:
|
||
continue
|
||
|
||
log_path = os.path.join(type_dir, file_name)
|
||
try:
|
||
with open(log_path, 'r', encoding='utf-8') as f:
|
||
log_data = json.load(f)
|
||
|
||
# 使用映射表获取显示的模型类型
|
||
display_model_type = model_type_mapping.get(model_type, model_type)
|
||
|
||
model_info = {
|
||
"model_id": f"{display_model_type}_{product_id}",
|
||
"product_id": log_data.get('product_id'),
|
||
"product_name": log_data.get('product_name'),
|
||
"model_type": display_model_type, # 使用映射后的模型类型
|
||
"created_at": log_data.get('training_completed_at'),
|
||
"metrics": log_data.get('metrics'),
|
||
"file_path": log_data.get('file_path')
|
||
}
|
||
available_models.append(model_info)
|
||
except Exception as e:
|
||
print(f"读取日志文件 {log_path} 失败: {e}")
|
||
|
||
# 按创建时间降序排序
|
||
available_models.sort(key=lambda x: x.get('created_at', ''), reverse=True)
|
||
|
||
return jsonify({"status": "success", "data": available_models})
|
||
|
||
@app.route('/api/models/<model_id>', methods=['GET'])
|
||
@swag_from({
|
||
'tags': ['模型管理'],
|
||
'summary': '获取模型详情',
|
||
'description': '获取特定模型的详细信息',
|
||
'parameters': [
|
||
{
|
||
'name': 'model_id',
|
||
'in': 'path',
|
||
'type': 'string',
|
||
'required': True,
|
||
'description': '模型ID,格式为{product_id}_{model_type}_v{version}'
|
||
}
|
||
],
|
||
'responses': {
|
||
200: {
|
||
'description': '成功获取模型详情',
|
||
'schema': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'status': {'type': 'string'},
|
||
'data': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'product_id': {'type': 'string'},
|
||
'model_type': {'type': 'string'},
|
||
'version': {'type': 'string'},
|
||
'created_at': {'type': 'string'},
|
||
'file_path': {'type': 'string'},
|
||
'file_size': {'type': 'string'},
|
||
'features': {'type': 'array'},
|
||
'look_back': {'type': 'integer'},
|
||
'T': {'type': 'integer'},
|
||
'metrics': {'type': 'object'}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
},
|
||
400: {
|
||
'description': '无效的模型ID格式'
|
||
},
|
||
404: {
|
||
'description': '模型不存在'
|
||
},
|
||
500: {
|
||
'description': '服务器内部错误'
|
||
}
|
||
}
|
||
})
|
||
def get_model_details(model_id):
|
||
"""
|
||
获取单个模型的详细信息
|
||
---
|
||
tags:
|
||
- 模型管理
|
||
parameters:
|
||
- name: model_id
|
||
in: path
|
||
type: string
|
||
required: true
|
||
description: 模型的唯一标识符 (格式: model_type_product_id)
|
||
responses:
|
||
200:
|
||
description: 模型的详细信息
|
||
404:
|
||
description: 未找到模型
|
||
"""
|
||
try:
|
||
model_type, product_id = model_id.split('_', 1)
|
||
|
||
# 处理优化版KAN模型的路径
|
||
actual_model_type = model_type
|
||
if model_type == 'optimized_kan':
|
||
actual_model_type = 'kan_optimized'
|
||
print(f"优化版KAN模型: 使用路径 'models/{actual_model_type}/{product_id}_log.json'")
|
||
|
||
log_path = os.path.join('models', actual_model_type, f'{product_id}_log.json')
|
||
|
||
if not os.path.exists(log_path):
|
||
return jsonify({"status": "error", "error": "模型未找到"}), 404
|
||
|
||
with open(log_path, 'r', encoding='utf-8') as f:
|
||
log_data = json.load(f)
|
||
|
||
# 确保返回的模型类型是optimized_kan而不是kan_optimized
|
||
if actual_model_type == 'kan_optimized':
|
||
log_data['model_type'] = 'optimized_kan'
|
||
|
||
return jsonify({"status": "success", "data": log_data})
|
||
except ValueError:
|
||
return jsonify({"status": "error", "error": "无效的model_id格式"}), 400
|
||
except Exception as e:
|
||
return jsonify({"status": "error", "error": f"获取模型详情失败: {e}"}), 500
|
||
|
||
@app.route('/api/models/<model_id>', methods=['DELETE'])
|
||
@swag_from({
|
||
'tags': ['模型管理'],
|
||
'summary': '删除模型',
|
||
'description': '删除特定模型',
|
||
'parameters': [
|
||
{
|
||
'name': 'model_id',
|
||
'in': 'path',
|
||
'type': 'string',
|
||
'required': True,
|
||
'description': '模型ID,格式为{product_id}_{model_type}_v{version}'
|
||
}
|
||
],
|
||
'responses': {
|
||
200: {
|
||
'description': '模型删除成功',
|
||
'schema': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'status': {'type': 'string'},
|
||
'message': {'type': 'string'}
|
||
}
|
||
}
|
||
},
|
||
400: {
|
||
'description': '无效的模型ID格式'
|
||
},
|
||
500: {
|
||
'description': '服务器内部错误'
|
||
}
|
||
}
|
||
})
|
||
def delete_model(model_id):
|
||
"""
|
||
删除一个模型及其关联文件
|
||
---
|
||
tags:
|
||
- 模型管理
|
||
parameters:
|
||
- name: model_id
|
||
in: path
|
||
type: string
|
||
required: true
|
||
description: 要删除的模型的ID (格式: model_type_product_id)
|
||
responses:
|
||
200:
|
||
description: 模型删除成功
|
||
404:
|
||
description: 模型未找到
|
||
"""
|
||
try:
|
||
model_type, product_id = model_id.split('_', 1)
|
||
|
||
# 处理优化版KAN模型的路径
|
||
actual_model_type = model_type
|
||
if model_type == 'optimized_kan':
|
||
actual_model_type = 'kan_optimized'
|
||
print(f"优化版KAN模型: 使用路径 'models/{actual_model_type}/{product_id}_model.pt'")
|
||
|
||
model_dir = os.path.join('models', actual_model_type)
|
||
model_path = os.path.join(model_dir, f'{product_id}_model.pt')
|
||
log_path = os.path.join(model_dir, f'{product_id}_log.json')
|
||
|
||
if not os.path.exists(model_path) and not os.path.exists(log_path):
|
||
return jsonify({"status": "error", "error": "模型未找到"}), 404
|
||
|
||
if os.path.exists(model_path):
|
||
os.remove(model_path)
|
||
if os.path.exists(log_path):
|
||
os.remove(log_path)
|
||
|
||
return jsonify({"status": "success", "message": f"模型 {model_id} 已删除"})
|
||
except ValueError:
|
||
return jsonify({"status": "error", "error": "无效的model_id格式"}), 400
|
||
except Exception as e:
|
||
return jsonify({"status": "error", "error": f"删除模型失败: {e}"}), 500
|
||
|
||
@app.route('/api/models/<model_id>/export', methods=['GET'])
|
||
@swag_from({
|
||
'tags': ['模型管理'],
|
||
'summary': '导出模型',
|
||
'description': '导出特定模型文件',
|
||
'parameters': [
|
||
{
|
||
'name': 'model_id',
|
||
'in': 'path',
|
||
'type': 'string',
|
||
'required': True,
|
||
'description': '模型ID,格式为{product_id}_{model_type}_v{version}'
|
||
}
|
||
],
|
||
'responses': {
|
||
200: {
|
||
'description': '模型文件下载',
|
||
'content': {
|
||
'application/octet-stream': {}
|
||
}
|
||
},
|
||
400: {
|
||
'description': '无效的模型ID格式'
|
||
},
|
||
500: {
|
||
'description': '服务器内部错误'
|
||
}
|
||
}
|
||
})
|
||
def export_model(model_id):
|
||
try:
|
||
model_type, product_id = model_id.split('_', 1)
|
||
|
||
# 处理优化版KAN模型的路径
|
||
actual_model_type = model_type
|
||
if model_type == 'optimized_kan':
|
||
actual_model_type = 'kan_optimized'
|
||
print(f"优化版KAN模型: 使用路径 'models/{actual_model_type}/{product_id}_model.pt'")
|
||
|
||
model_path = os.path.join('models', actual_model_type, f'{product_id}_model.pt')
|
||
|
||
if not os.path.exists(model_path):
|
||
return jsonify({"status": "error", "error": "模型文件未找到"}), 404
|
||
|
||
return send_file(
|
||
model_path,
|
||
as_attachment=True,
|
||
download_name=f'{model_id}.pt',
|
||
mimetype='application/octet-stream'
|
||
)
|
||
except Exception as e:
|
||
return jsonify({"status": "error", "error": f"导出模型失败: {e}"}), 500
|
||
|
||
@app.route('/api/models/import', methods=['POST'])
|
||
@swag_from({
|
||
'tags': ['模型管理'],
|
||
'summary': '导入模型',
|
||
'description': '导入模型文件',
|
||
'consumes': ['multipart/form-data'],
|
||
'parameters': [
|
||
{
|
||
'name': 'file',
|
||
'in': 'formData',
|
||
'type': 'file',
|
||
'required': True,
|
||
'description': 'PyTorch模型文件(.pt)'
|
||
}
|
||
],
|
||
'responses': {
|
||
200: {
|
||
'description': '模型导入成功',
|
||
'schema': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'status': {'type': 'string'},
|
||
'message': {'type': 'string'},
|
||
'data': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'model_path': {'type': 'string'}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
},
|
||
400: {
|
||
'description': '请求错误,文件格式不正确或缺少文件'
|
||
},
|
||
500: {
|
||
'description': '服务器内部错误'
|
||
}
|
||
}
|
||
})
|
||
def import_model():
|
||
try:
|
||
if 'file' not in request.files:
|
||
return jsonify({"status": "error", "message": "没有上传文件"}), 400
|
||
|
||
file = request.files['file']
|
||
if file.filename == '' or not file.filename.endswith('.pt'):
|
||
return jsonify({"status": "error", "message": "请上传有效的.pt模型文件"}), 400
|
||
|
||
# 从文件名解析 product_id, model_type
|
||
# 假设文件名格式为 `transformer_P001.pt`
|
||
try:
|
||
model_type, product_id_ext = file.filename.split('_', 1)
|
||
product_id = os.path.splitext(product_id_ext)[0]
|
||
model_types = ['mlstm', 'kan', 'transformer']
|
||
if model_type not in model_types:
|
||
raise ValueError("无效的模型类型")
|
||
except ValueError:
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": "文件名格式不正确,应为 'model_type_product_id.pt' (例如: mlstm_P001.pt)"
|
||
}), 400
|
||
|
||
# 创建目标目录并保存文件
|
||
model_dir = os.path.join('models', model_type)
|
||
os.makedirs(model_dir, exist_ok=True)
|
||
|
||
model_path = os.path.join(model_dir, f'{product_id}_model.pt')
|
||
|
||
# 检查是否需要创建关联的日志文件
|
||
log_path = os.path.join(model_dir, f'{product_id}_log.json')
|
||
if not os.path.exists(log_path):
|
||
# 尝试从.pt文件加载信息(如果可能)
|
||
try:
|
||
checkpoint = torch.load(file, map_location='cpu')
|
||
# 重新定位文件指针到开头,以便 `file.save` 正常工作
|
||
file.seek(0)
|
||
except Exception:
|
||
checkpoint = {} # 如果加载失败,创建一个空的checkpoint
|
||
|
||
# 创建一个基础的log文件
|
||
log_data = {
|
||
'product_id': product_id,
|
||
'product_name': f"导入的产品 {product_id}", # 可能需要用户后续编辑
|
||
'model_type': model_type,
|
||
'training_completed_at': datetime.now().isoformat(),
|
||
'epochs': checkpoint.get('epochs', 'N/A'),
|
||
'metrics': checkpoint.get('metrics', {'info': '导入的模型,无详细指标'}),
|
||
'file_path': model_path
|
||
}
|
||
with open(log_path, 'w', encoding='utf-8') as f:
|
||
json.dump(log_data, f, indent=4, ensure_ascii=False)
|
||
|
||
# 保存模型文件
|
||
file.save(model_path)
|
||
|
||
return jsonify({
|
||
"status": "success",
|
||
"message": "模型已成功导入",
|
||
"data": {"model_path": model_path}
|
||
})
|
||
|
||
except Exception as e:
|
||
traceback.print_exc()
|
||
return jsonify({"status": "error", "error": f"导入模型失败: {e}"}), 500
|
||
|
||
@app.route('/api/plots/<filename>')
|
||
def get_plot(filename):
|
||
"""Serve a plot file from the root directory."""
|
||
try:
|
||
return send_from_directory(app.root_path, filename)
|
||
except FileNotFoundError:
|
||
return jsonify({"status": "error", "error": "Plot not found"}), 404
|
||
|
||
@app.route('/api/csv/<filename>')
|
||
def get_csv(filename):
|
||
"""Serve a CSV file from the root directory."""
|
||
try:
|
||
return send_from_directory(app.root_path, filename, as_attachment=True)
|
||
except FileNotFoundError:
|
||
return jsonify({"status": "error", "error": "CSV file not found"}), 404
|
||
|
||
# 添加静态UI路由,将/ui路径映射到wwwroot目录
|
||
@app.route('/ui/', defaults={'path': 'index.html'})
|
||
@app.route('/ui/<path:path>')
|
||
def serve_ui(path):
|
||
"""提供UI静态文件服务,将/ui路径映射到wwwroot目录"""
|
||
try:
|
||
# 从wwwroot目录提供静态文件
|
||
return send_from_directory('wwwroot', path)
|
||
except FileNotFoundError:
|
||
# 如果是子路径请求(例如/ui/about)但文件不存在,尝试返回index.html以支持SPA路由
|
||
if path != 'index.html':
|
||
try:
|
||
return send_from_directory('wwwroot', 'index.html')
|
||
except FileNotFoundError:
|
||
return jsonify({"status": "error", "error": "UI files not found"}), 404
|
||
return jsonify({"status": "error", "error": f"UI file {path} not found"}), 404
|
||
|
||
@app.route('/api/predictions/<model_type>/<product_id>/<filename>')
|
||
def get_prediction_file(model_type, product_id, filename):
|
||
"""按模型类型和产品ID获取预测文件"""
|
||
try:
|
||
# 构建文件路径
|
||
file_path = os.path.join('predictions', model_type, product_id)
|
||
|
||
# 如果是图片文件,添加防缓存头
|
||
if filename.endswith('.png'):
|
||
response = send_from_directory(file_path, filename)
|
||
|
||
# 强制浏览器不缓存图片
|
||
response.headers['Cache-Control'] = 'no-store, no-cache, must-revalidate, max-age=0, post-check=0, pre-check=0'
|
||
response.headers['Pragma'] = 'no-cache'
|
||
response.headers['Expires'] = '0'
|
||
response.headers['Last-Modified'] = datetime.now().strftime("%a, %d %b %Y %H:%M:%S GMT")
|
||
response.headers['Vary'] = '*'
|
||
|
||
print(f"提供图片文件 {filename} 并添加防缓存头")
|
||
return response
|
||
else:
|
||
return send_from_directory(file_path, filename)
|
||
except FileNotFoundError:
|
||
return jsonify({"status": "error", "error": f"预测文件 {filename} 未找到"}), 404
|
||
|
||
@app.route('/api/predictions/compare/<filename>')
|
||
def get_compare_file(filename):
|
||
"""
|
||
提供比较结果文件的下载
|
||
"""
|
||
directory = os.path.join(current_dir, 'predictions', 'compare')
|
||
return send_from_directory(directory, filename)
|
||
|
||
@app.route('/api/models/<model_type>/<product_id>/details', methods=['GET'])
|
||
def get_model_details_extended(model_type, product_id):
|
||
"""
|
||
获取模型详情,包括训练损失曲线、预测效果图等
|
||
---
|
||
tags:
|
||
- 模型管理
|
||
parameters:
|
||
- name: model_type
|
||
in: path
|
||
required: true
|
||
type: string
|
||
description: 模型类型
|
||
- name: product_id
|
||
in: path
|
||
required: true
|
||
type: string
|
||
description: 产品ID
|
||
responses:
|
||
200:
|
||
description: 模型详情
|
||
schema:
|
||
type: object
|
||
properties:
|
||
status:
|
||
type: string
|
||
example: success
|
||
data:
|
||
type: object
|
||
properties:
|
||
model_info:
|
||
type: object
|
||
description: 模型基本信息
|
||
training_metrics:
|
||
type: object
|
||
description: 训练评估指标
|
||
chart_data:
|
||
type: object
|
||
description: 图表数据
|
||
"""
|
||
try:
|
||
# 处理优化版KAN模型的路径
|
||
actual_model_path = model_type
|
||
if model_type == 'optimized_kan':
|
||
actual_model_path = 'kan_optimized'
|
||
|
||
model_dir = f'models/{actual_model_path}'
|
||
model_path = os.path.join(model_dir, f'{product_id}_model.pt')
|
||
log_path = os.path.join(model_dir, f'{product_id}_log.json')
|
||
|
||
# 检查模型文件是否存在
|
||
if not os.path.exists(model_path):
|
||
return jsonify({
|
||
"status": "error",
|
||
"error": f"模型 '{model_type}/{product_id}' 不存在"
|
||
}), 404
|
||
|
||
# 读取模型日志
|
||
model_log = {}
|
||
if os.path.exists(log_path):
|
||
with open(log_path, 'r', encoding='utf-8') as f:
|
||
model_log = json.load(f)
|
||
|
||
# 加载模型检查点,获取训练损失数据
|
||
checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
|
||
|
||
# 获取训练和测试损失
|
||
train_losses = checkpoint.get('train_loss', [])
|
||
test_losses = checkpoint.get('test_loss', [])
|
||
|
||
# 获取产品名称
|
||
product_name = ""
|
||
try:
|
||
df = pd.read_excel('pharmacy_sales.xlsx')
|
||
product_df = df[df['product_id'] == product_id]
|
||
if not product_df.empty:
|
||
product_name = product_df['product_name'].iloc[0]
|
||
except Exception as e:
|
||
print(f"获取产品名称时出错: {e}")
|
||
|
||
# 检查是否有损失曲线图和预测效果图
|
||
loss_curve_path = f'{product_id}_{model_type}_loss_curve.png'
|
||
prediction_path = f'{product_id}_{model_type}_prediction.png'
|
||
|
||
# 获取当前服务器主机和端口
|
||
host_url = request.host_url.rstrip('/')
|
||
|
||
# 准备图表数据
|
||
chart_data = {
|
||
"loss_chart": {
|
||
"epochs": list(range(1, len(train_losses) + 1)),
|
||
"train_loss": train_losses,
|
||
"test_loss": test_losses if test_losses else []
|
||
}
|
||
}
|
||
|
||
# 准备模型信息
|
||
model_info = {
|
||
"model_id": f"{model_type}_{product_id}",
|
||
"model_type": model_type,
|
||
"product_id": product_id,
|
||
"product_name": product_name,
|
||
"created_at": model_log.get("training_completed_at", ""),
|
||
"file_path": model_path
|
||
}
|
||
|
||
# 准备训练指标
|
||
training_metrics = model_log.get("metrics", {})
|
||
|
||
return jsonify({
|
||
"status": "success",
|
||
"data": {
|
||
"model_info": model_info,
|
||
"training_metrics": training_metrics,
|
||
"chart_data": chart_data
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
import traceback
|
||
traceback.print_exc()
|
||
return jsonify({
|
||
"status": "error",
|
||
"error": str(e)
|
||
}), 500
|
||
|
||
# 创建SQLite数据库连接函数
|
||
def get_db_connection():
|
||
"""获取SQLite数据库连接"""
|
||
conn = sqlite3.connect('prediction_history.db')
|
||
conn.row_factory = sqlite3.Row
|
||
return conn
|
||
|
||
# 初始化数据库
|
||
def init_db():
|
||
"""初始化SQLite数据库,创建必要的表"""
|
||
conn = get_db_connection()
|
||
cursor = conn.cursor()
|
||
|
||
# 创建历史预测记录表
|
||
cursor.execute('''
|
||
CREATE TABLE IF NOT EXISTS prediction_history (
|
||
id TEXT PRIMARY KEY,
|
||
product_id TEXT NOT NULL,
|
||
product_name TEXT NOT NULL,
|
||
model_type TEXT NOT NULL,
|
||
model_id TEXT NOT NULL,
|
||
start_date TEXT NOT NULL,
|
||
future_days INTEGER NOT NULL,
|
||
created_at TEXT NOT NULL,
|
||
file_path TEXT NOT NULL
|
||
)
|
||
''')
|
||
|
||
conn.commit()
|
||
conn.close()
|
||
|
||
# 在应用启动时初始化数据库
|
||
init_db()
|
||
|
||
# 添加保存预测结果的函数
|
||
def save_prediction_result(prediction_data, product_id, product_name, model_type, model_id, start_date, future_days):
|
||
"""保存预测结果到文件和数据库"""
|
||
# 生成唯一ID
|
||
prediction_id = str(uuid.uuid4())
|
||
|
||
# 创建历史预测目录
|
||
history_dir = os.path.join('predictions', 'history')
|
||
os.makedirs(history_dir, exist_ok=True)
|
||
|
||
# 创建模型类型子目录
|
||
model_dir = os.path.join(history_dir, model_type)
|
||
os.makedirs(model_dir, exist_ok=True)
|
||
|
||
# 创建文件名
|
||
timestamp = datetime.now().strftime('%Y%m%d%H%M%S')
|
||
filename = f"{product_id}_{start_date}_{future_days}days_{timestamp}.json"
|
||
file_path = os.path.join(model_dir, filename)
|
||
|
||
# 创建一个JSON序列化器来处理Pandas Timestamp对象
|
||
class DateTimeEncoder(json.JSONEncoder):
|
||
def default(self, obj):
|
||
if isinstance(obj, (pd.Timestamp, pd.DatetimeIndex)):
|
||
return obj.strftime('%Y-%m-%d')
|
||
elif isinstance(obj, np.integer):
|
||
return int(obj)
|
||
elif isinstance(obj, np.floating):
|
||
return float(obj)
|
||
elif isinstance(obj, np.ndarray):
|
||
return obj.tolist()
|
||
elif pd.isna(obj): # 处理NaN值
|
||
return None
|
||
return super(DateTimeEncoder, self).default(obj)
|
||
|
||
# 保存预测结果为JSON文件
|
||
with open(file_path, 'w', encoding='utf-8') as f:
|
||
json.dump(prediction_data, f, ensure_ascii=False, indent=2, cls=DateTimeEncoder)
|
||
|
||
# 将记录保存到数据库
|
||
conn = get_db_connection()
|
||
cursor = conn.cursor()
|
||
|
||
cursor.execute(
|
||
'''
|
||
INSERT INTO prediction_history
|
||
(id, product_id, product_name, model_type, model_id, start_date, future_days, created_at, file_path)
|
||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||
''',
|
||
(
|
||
prediction_id,
|
||
product_id,
|
||
product_name,
|
||
model_type,
|
||
model_id,
|
||
start_date,
|
||
future_days,
|
||
datetime.now().isoformat(),
|
||
file_path
|
||
)
|
||
)
|
||
|
||
conn.commit()
|
||
conn.close()
|
||
|
||
return prediction_id, file_path
|
||
|
||
# 添加获取历史预测列表的API
|
||
@app.route('/api/prediction/history', methods=['GET'])
|
||
def get_prediction_history():
|
||
"""
|
||
获取历史预测列表
|
||
---
|
||
tags:
|
||
- 预测分析
|
||
parameters:
|
||
- name: product_id
|
||
in: query
|
||
type: string
|
||
required: false
|
||
description: 按产品ID筛选
|
||
- name: model_type
|
||
in: query
|
||
type: string
|
||
required: false
|
||
description: 按模型类型筛选
|
||
- name: page
|
||
in: query
|
||
type: integer
|
||
required: false
|
||
default: 1
|
||
description: 页码
|
||
- name: page_size
|
||
in: query
|
||
type: integer
|
||
required: false
|
||
default: 10
|
||
description: 每页记录数
|
||
responses:
|
||
200:
|
||
description: 历史预测列表
|
||
schema:
|
||
type: object
|
||
properties:
|
||
status:
|
||
type: string
|
||
example: success
|
||
data:
|
||
type: array
|
||
items:
|
||
type: object
|
||
properties:
|
||
id:
|
||
type: string
|
||
example: 550e8400-e29b-41d4-a716-446655440000
|
||
product_id:
|
||
type: string
|
||
example: P001
|
||
product_name:
|
||
type: string
|
||
example: 阿司匹林
|
||
model_type:
|
||
type: string
|
||
example: mlstm
|
||
model_id:
|
||
type: string
|
||
example: mlstm_P001_20230101
|
||
start_date:
|
||
type: string
|
||
example: 2023-01-01
|
||
future_days:
|
||
type: integer
|
||
example: 7
|
||
created_at:
|
||
type: string
|
||
example: 2023-01-01T12:00:00
|
||
total:
|
||
type: integer
|
||
example: 100
|
||
page:
|
||
type: integer
|
||
example: 1
|
||
page_size:
|
||
type: integer
|
||
example: 10
|
||
"""
|
||
try:
|
||
# 获取查询参数
|
||
product_id = request.args.get('product_id')
|
||
model_type = request.args.get('model_type')
|
||
page = int(request.args.get('page', 1))
|
||
page_size = int(request.args.get('page_size', 10))
|
||
|
||
# 构建SQL查询
|
||
query = "SELECT * FROM prediction_history"
|
||
params = []
|
||
conditions = []
|
||
|
||
if product_id:
|
||
conditions.append("product_id = ?")
|
||
params.append(product_id)
|
||
|
||
if model_type:
|
||
conditions.append("model_type = ?")
|
||
params.append(model_type)
|
||
|
||
if conditions:
|
||
query += " WHERE " + " AND ".join(conditions)
|
||
|
||
# 添加排序和分页
|
||
query += " ORDER BY created_at DESC"
|
||
|
||
# 获取总记录数
|
||
conn = get_db_connection()
|
||
count_result = conn.execute(query, params).fetchall()
|
||
total_count = len(count_result)
|
||
|
||
# 添加分页限制
|
||
query += f" LIMIT {page_size} OFFSET {(page - 1) * page_size}"
|
||
|
||
# 执行查询
|
||
result = conn.execute(query, params).fetchall()
|
||
|
||
# 转换为JSON格式
|
||
history_list = []
|
||
for row in result:
|
||
history_list.append({
|
||
'id': row['id'],
|
||
'product_id': row['product_id'],
|
||
'product_name': row['product_name'],
|
||
'model_type': row['model_type'],
|
||
'model_id': row['model_id'],
|
||
'start_date': row['start_date'],
|
||
'future_days': row['future_days'],
|
||
'created_at': row['created_at'],
|
||
'file_path': row['file_path']
|
||
})
|
||
|
||
conn.close()
|
||
|
||
return jsonify({
|
||
'status': 'success',
|
||
'data': history_list,
|
||
'total': total_count,
|
||
'page': page,
|
||
'page_size': page_size
|
||
})
|
||
except Exception as e:
|
||
print(f"获取历史预测列表失败: {str(e)}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return jsonify({"status": "error", "error": str(e)}), 500
|
||
|
||
# 添加获取历史预测详情的API
|
||
@app.route('/api/prediction/history/<prediction_id>', methods=['GET'])
|
||
def get_prediction_history_detail(prediction_id):
|
||
"""
|
||
获取历史预测详情
|
||
---
|
||
tags:
|
||
- 预测分析
|
||
parameters:
|
||
- name: prediction_id
|
||
in: path
|
||
type: string
|
||
required: true
|
||
description: 预测ID
|
||
responses:
|
||
200:
|
||
description: 预测详情
|
||
schema:
|
||
type: object
|
||
properties:
|
||
status:
|
||
type: string
|
||
example: success
|
||
data:
|
||
type: object
|
||
"""
|
||
try:
|
||
# 查询数据库获取文件路径
|
||
conn = get_db_connection()
|
||
result = conn.execute("SELECT * FROM prediction_history WHERE id = ?", (prediction_id,)).fetchone()
|
||
conn.close()
|
||
|
||
if not result:
|
||
return jsonify({"status": "error", "error": "未找到指定的预测记录"}), 404
|
||
|
||
file_path = result['file_path']
|
||
|
||
# 读取预测结果文件
|
||
try:
|
||
with open(file_path, 'r', encoding='utf-8') as f:
|
||
prediction_data = json.load(f)
|
||
|
||
return jsonify({
|
||
'status': 'success',
|
||
'data': prediction_data,
|
||
'meta': {
|
||
'id': result['id'],
|
||
'product_id': result['product_id'],
|
||
'product_name': result['product_name'],
|
||
'model_type': result['model_type'],
|
||
'model_id': result['model_id'],
|
||
'start_date': result['start_date'],
|
||
'future_days': result['future_days'],
|
||
'created_at': result['created_at']
|
||
}
|
||
})
|
||
except FileNotFoundError:
|
||
return jsonify({"status": "error", "error": "预测结果文件不存在"}), 404
|
||
except Exception as e:
|
||
return jsonify({"status": "error", "error": f"读取预测结果文件失败: {str(e)}"}), 500
|
||
except Exception as e:
|
||
print(f"获取历史预测详情失败: {str(e)}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return jsonify({"status": "error", "error": str(e)}), 500
|
||
|
||
# 添加删除历史预测记录的API
|
||
@app.route('/api/prediction/history/<prediction_id>', methods=['DELETE'])
|
||
def delete_prediction_history(prediction_id):
|
||
"""
|
||
删除历史预测记录
|
||
---
|
||
tags:
|
||
- 预测分析
|
||
parameters:
|
||
- name: prediction_id
|
||
in: path
|
||
type: string
|
||
required: true
|
||
description: 预测ID
|
||
responses:
|
||
200:
|
||
description: 删除结果
|
||
schema:
|
||
type: object
|
||
properties:
|
||
status:
|
||
type: string
|
||
example: success
|
||
"""
|
||
try:
|
||
# 查询数据库获取文件路径
|
||
conn = get_db_connection()
|
||
result = conn.execute("SELECT file_path FROM prediction_history WHERE id = ?", (prediction_id,)).fetchone()
|
||
|
||
if not result:
|
||
conn.close()
|
||
return jsonify({"status": "error", "error": "未找到指定的预测记录"}), 404
|
||
|
||
file_path = result['file_path']
|
||
|
||
# 删除数据库记录
|
||
conn.execute("DELETE FROM prediction_history WHERE id = ?", (prediction_id,))
|
||
conn.commit()
|
||
conn.close()
|
||
|
||
# 删除预测结果文件
|
||
try:
|
||
if os.path.exists(file_path):
|
||
os.remove(file_path)
|
||
except Exception as e:
|
||
print(f"删除预测结果文件失败: {str(e)}")
|
||
|
||
return jsonify({'status': 'success'})
|
||
except Exception as e:
|
||
print(f"删除历史预测记录失败: {str(e)}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return jsonify({"status": "error", "error": str(e)}), 500
|
||
|
||
# 获取产品名称的辅助函数
|
||
def get_product_name(product_id):
|
||
"""根据产品ID获取产品名称"""
|
||
try:
|
||
# 从产品列表中查找产品名称
|
||
products_file = 'data/products.json'
|
||
if os.path.exists(products_file):
|
||
with open(products_file, 'r', encoding='utf-8') as f:
|
||
products = json.load(f)
|
||
|
||
for product in products:
|
||
if product['product_id'] == product_id:
|
||
return product['product_name']
|
||
|
||
return None
|
||
except Exception as e:
|
||
print(f"获取产品名称失败: {str(e)}")
|
||
return None
|
||
|
||
# 获取最新模型ID的辅助函数
|
||
def get_latest_model_id(model_type, product_id):
|
||
"""根据模型类型和产品ID获取最新的模型ID"""
|
||
try:
|
||
# 处理优化版KAN模型的路径
|
||
actual_model_path = model_type
|
||
if model_type == 'optimized_kan':
|
||
actual_model_path = 'kan_optimized'
|
||
print(f"优化版KAN模型: 当查找最新模型ID时,使用路径 'models/{actual_model_path}/{product_id}_model.pt'")
|
||
|
||
# 查找模型目录中的模型文件
|
||
model_dir = os.path.join('models', actual_model_path)
|
||
if not os.path.exists(model_dir):
|
||
print(f"模型目录不存在: {model_dir}")
|
||
return None
|
||
|
||
# 查找匹配的模型文件
|
||
model_files = [f for f in os.listdir(model_dir) if f.startswith(f"{product_id}_") and f.endswith('.pt')]
|
||
if not model_files:
|
||
print(f"在目录 {model_dir} 中未找到产品 {product_id} 的模型文件")
|
||
return None
|
||
|
||
# 按照文件修改时间排序,获取最新的模型文件
|
||
model_files.sort(key=lambda x: os.path.getmtime(os.path.join(model_dir, x)), reverse=True)
|
||
latest_model_file = model_files[0]
|
||
|
||
# 从文件名中提取模型ID
|
||
model_id = latest_model_file.replace('.pt', '')
|
||
|
||
print(f"找到最新模型: {model_id} 在目录 {model_dir}")
|
||
return model_id
|
||
except Exception as e:
|
||
print(f"获取最新模型ID失败: {str(e)}")
|
||
return None
|
||
|
||
# 执行预测的辅助函数
|
||
def run_prediction(model_type, product_id, model_id, future_days, start_date):
|
||
"""执行模型预测"""
|
||
try:
|
||
# 处理优化版KAN模型的路径在 load_model_and_predict 中已经实现,无需在此处理
|
||
# 直接使用原始的 model_type 调用函数
|
||
|
||
# 解包返回的元组为result和analysis
|
||
result_tuple = load_model_and_predict(product_id, model_type, future_days, start_date=start_date, analyze_result=True)
|
||
|
||
# 处理返回值可能是None的情况
|
||
if result_tuple is None:
|
||
raise Exception(f"模型 '{model_type}' 类型的模型文件未找到或加载失败")
|
||
|
||
# 解包元组 - result_tuple可能是(result, analysis)格式或者只是result
|
||
if isinstance(result_tuple, tuple):
|
||
result = result_tuple[0]
|
||
analysis = result_tuple[1] if len(result_tuple) > 1 else None
|
||
else:
|
||
result = result_tuple
|
||
analysis = None
|
||
|
||
if result is None:
|
||
raise Exception(f"模型预测失败")
|
||
|
||
# 获取预测数据
|
||
predictions_df = result['predictions_df']
|
||
|
||
# 确保日期是日期时间格式以便处理
|
||
if 'date' in predictions_df.columns and not pd.api.types.is_datetime64_any_dtype(predictions_df['date']):
|
||
predictions_df['date'] = pd.to_datetime(predictions_df['date'])
|
||
|
||
# 分离历史数据和预测数据
|
||
history_df = predictions_df[predictions_df['data_type'] == '历史销量']
|
||
prediction_df = predictions_df[predictions_df['data_type'] == '预测销量']
|
||
|
||
# 如果历史数据和预测数据的日期有重叠,调整预测数据的日期
|
||
if not history_df.empty and not prediction_df.empty:
|
||
last_history_date = history_df['date'].max()
|
||
|
||
# 检查预测数据是否与历史数据重叠
|
||
if any(prediction_df['date'] <= last_history_date):
|
||
print(f"检测到预测数据与历史数据日期重叠,调整预测数据日期...")
|
||
|
||
# 获取预测数据的起始日期,确保它在历史数据的最后一天之后
|
||
prediction_start_date = last_history_date + pd.Timedelta(days=1)
|
||
|
||
# 创建新的日期序列
|
||
new_dates = pd.date_range(start=prediction_start_date, periods=len(prediction_df), freq='D')
|
||
|
||
# 更新预测数据的日期
|
||
prediction_df['date'] = new_dates
|
||
|
||
# 更新原始DataFrame
|
||
predictions_df.loc[predictions_df['data_type'] == '预测销量', 'date'] = prediction_df['date'].values
|
||
|
||
print(f"预测数据日期已调整为从 {prediction_start_date} 开始")
|
||
|
||
# 将处理后的DataFrame放回结果
|
||
result['predictions_df'] = predictions_df
|
||
|
||
chart_path = result.get('chart_path')
|
||
csv_path = result.get('csv_path')
|
||
|
||
# 将DataFrame转换为JSON
|
||
predictions_json = predictions_df.to_dict(orient='records')
|
||
|
||
# 准备响应数据
|
||
response_data = {
|
||
"status": "success",
|
||
"data": predictions_json, # 包含历史和预测数据
|
||
"history_data": predictions_df[predictions_df['data_type'] == '历史销量'].to_dict(orient='records'),
|
||
"prediction_data": predictions_df[predictions_df['data_type'] == '预测销量'].to_dict(orient='records'),
|
||
}
|
||
|
||
return response_data
|
||
except Exception as e:
|
||
print(f"执行预测失败: {str(e)}")
|
||
traceback.print_exc()
|
||
raise e
|
||
|
||
# 准备图表数据的辅助函数
|
||
def prepare_chart_data(prediction_result):
|
||
"""准备图表数据"""
|
||
try:
|
||
# 从预测结果中提取数据
|
||
predictions_df = pd.DataFrame(prediction_result['data'])
|
||
|
||
# 确保日期列是日期时间格式,以便正确排序
|
||
if 'date' in predictions_df.columns:
|
||
# 先转换为日期时间类型以便排序
|
||
if isinstance(predictions_df['date'][0], str):
|
||
predictions_df['date'] = pd.to_datetime(predictions_df['date'])
|
||
|
||
# 按日期排序
|
||
predictions_df = predictions_df.sort_values('date')
|
||
|
||
# 重置索引
|
||
predictions_df = predictions_df.reset_index(drop=True)
|
||
|
||
# 最后再转换回字符串格式用于JSON
|
||
predictions_df['date'] = predictions_df['date'].dt.strftime('%Y-%m-%d')
|
||
|
||
# 分离历史数据和预测数据
|
||
history_df = predictions_df[predictions_df['data_type'] == '历史销量']
|
||
prediction_df = predictions_df[predictions_df['data_type'] == '预测销量']
|
||
|
||
# 准备图表数据
|
||
chart_data = {
|
||
"dates": predictions_df['date'].tolist(),
|
||
"sales": predictions_df['sales'].tolist(),
|
||
"types": predictions_df['data_type'].tolist()
|
||
}
|
||
|
||
# 为前端debug提供额外信息
|
||
chart_data["debug"] = {
|
||
"history_dates": history_df['date'].tolist() if not history_df.empty else [],
|
||
"history_sales": history_df['sales'].tolist() if not history_df.empty else [],
|
||
"prediction_dates": prediction_df['date'].tolist() if not prediction_df.empty else [],
|
||
"prediction_sales": prediction_df['sales'].tolist() if not prediction_df.empty else [],
|
||
}
|
||
|
||
print(f"历史数据日期范围: {history_df['date'].min() if not history_df.empty else 'N/A'} 到 {history_df['date'].max() if not history_df.empty else 'N/A'}")
|
||
print(f"预测数据日期范围: {prediction_df['date'].min() if not prediction_df.empty else 'N/A'} 到 {prediction_df['date'].max() if not prediction_df.empty else 'N/A'}")
|
||
|
||
return chart_data
|
||
except Exception as e:
|
||
print(f"准备图表数据失败: {str(e)}")
|
||
traceback.print_exc()
|
||
return {}
|
||
|
||
# 分析预测结果的辅助函数
|
||
def analyze_prediction(prediction_result):
|
||
"""分析预测结果"""
|
||
try:
|
||
# 从预测结果中提取数据
|
||
prediction_data = prediction_result.get('prediction_data', [])
|
||
history_data = prediction_result.get('history_data', [])
|
||
|
||
if not prediction_data:
|
||
return None
|
||
|
||
# 转换为DataFrame以便分析
|
||
prediction_df = pd.DataFrame(prediction_data)
|
||
|
||
# 确保日期列是日期时间格式
|
||
if 'date' in prediction_df.columns:
|
||
if isinstance(prediction_df['date'][0], str):
|
||
prediction_df['date'] = pd.to_datetime(prediction_df['date'])
|
||
|
||
# 计算统计数据
|
||
sales = prediction_df['sales'].values
|
||
mean_sales = np.mean(sales)
|
||
max_sales = np.max(sales)
|
||
min_sales = np.min(sales)
|
||
std_sales = np.std(sales)
|
||
|
||
# 计算日环比变化
|
||
day_over_day_changes = []
|
||
for i in range(1, len(sales)):
|
||
if sales[i-1] == 0:
|
||
day_over_day_changes.append(0)
|
||
else:
|
||
change_pct = ((sales[i] - sales[i-1]) / sales[i-1]) * 100
|
||
day_over_day_changes.append(change_pct)
|
||
|
||
# 确定趋势
|
||
if len(sales) < 2:
|
||
trend = "unknown"
|
||
else:
|
||
# 计算简单线性回归的斜率
|
||
x = np.arange(len(sales))
|
||
slope = np.polyfit(x, sales, 1)[0]
|
||
|
||
# 计算变化的标准差
|
||
changes = np.diff(sales)
|
||
changes_std = np.std(changes)
|
||
|
||
if abs(slope) < 0.1 * mean_sales:
|
||
if changes_std > 0.2 * mean_sales:
|
||
trend = "fluctuating"
|
||
else:
|
||
trend = "stable"
|
||
elif slope > 0:
|
||
trend = "increasing"
|
||
else:
|
||
trend = "decreasing"
|
||
|
||
# 历史数据对比
|
||
has_historical_data = len(history_data) > 0
|
||
historical_comparison = {
|
||
"has_historical_data": has_historical_data,
|
||
"mean_difference_pct": 0
|
||
}
|
||
|
||
if has_historical_data:
|
||
history_df = pd.DataFrame(history_data)
|
||
history_mean = history_df['sales'].mean()
|
||
prediction_mean = mean_sales
|
||
|
||
if history_mean > 0:
|
||
mean_difference_pct = ((prediction_mean - history_mean) / history_mean) * 100
|
||
historical_comparison["mean_difference_pct"] = mean_difference_pct
|
||
|
||
# 模拟影响因素
|
||
factors = [
|
||
{"name": "季节性", "importance": "high", "description": "季节变化对销量有显著影响"},
|
||
{"name": "促销活动", "importance": "medium", "description": "促销活动可能会短期提升销量"},
|
||
{"name": "市场竞争", "importance": "low", "description": "市场竞争对销量有轻微影响"}
|
||
]
|
||
|
||
# 生成解释文本
|
||
if trend == "increasing":
|
||
explanation = f"预测期内销量整体呈上升趋势,平均日销量为{mean_sales:.2f},相比历史数据"
|
||
if has_historical_data and historical_comparison["mean_difference_pct"] > 0:
|
||
explanation += f"增长了{historical_comparison['mean_difference_pct']:.2f}%。"
|
||
elif has_historical_data:
|
||
explanation += f"下降了{abs(historical_comparison['mean_difference_pct']):.2f}%。"
|
||
else:
|
||
explanation += "无法比较。"
|
||
explanation += "建议适当增加库存以应对销量增长。"
|
||
elif trend == "decreasing":
|
||
explanation = f"预测期内销量整体呈下降趋势,平均日销量为{mean_sales:.2f},相比历史数据"
|
||
if has_historical_data and historical_comparison["mean_difference_pct"] > 0:
|
||
explanation += f"增长了{historical_comparison['mean_difference_pct']:.2f}%。"
|
||
elif has_historical_data:
|
||
explanation += f"下降了{abs(historical_comparison['mean_difference_pct']):.2f}%。"
|
||
else:
|
||
explanation += "无法比较。"
|
||
explanation += "建议控制库存以避免积压。"
|
||
elif trend == "fluctuating":
|
||
explanation = f"预测期内销量波动较大,平均日销量为{mean_sales:.2f},标准差为{std_sales:.2f}。建议密切关注销售情况,灵活调整库存。"
|
||
else:
|
||
explanation = f"预测期内销量保持稳定,平均日销量为{mean_sales:.2f},最高销量为{max_sales:.2f},最低销量为{min_sales:.2f}。"
|
||
|
||
# 构建分析结果
|
||
analysis = {
|
||
"explanation": explanation,
|
||
"trend": trend,
|
||
"statistics": {
|
||
"mean": mean_sales,
|
||
"max": max_sales,
|
||
"min": min_sales,
|
||
"std": std_sales,
|
||
"day_over_day_changes": day_over_day_changes
|
||
},
|
||
"historical_comparison": historical_comparison,
|
||
"factors": factors
|
||
}
|
||
|
||
return analysis
|
||
except Exception as e:
|
||
print(f"分析预测结果失败: {str(e)}")
|
||
return None
|
||
|
||
# 新增模型性能分析接口
|
||
@app.route('/api/models/analyze-metrics', methods=['POST'])
|
||
@swag_from({
|
||
'tags': ['模型管理'],
|
||
'summary': '分析模型评估指标',
|
||
'description': '接收一组模型指标,并返回详细的文字解读和评级。',
|
||
'parameters': [
|
||
{
|
||
'name': 'body',
|
||
'in': 'body',
|
||
'required': True,
|
||
'schema': {
|
||
'type': 'object',
|
||
'description': '模型的评估指标',
|
||
'example': {
|
||
'R2': 0.7067,
|
||
'RMSE': 6.8670,
|
||
'MAE': 4.4062,
|
||
'MAPE': 18.14
|
||
}
|
||
}
|
||
}
|
||
],
|
||
'responses': {
|
||
200: {
|
||
'description': '分析成功',
|
||
'schema': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'status': {'type': 'string'},
|
||
'data': {'type': 'object'}
|
||
}
|
||
}
|
||
},
|
||
400: {
|
||
'description': '请求错误,缺少指标数据'
|
||
}
|
||
}
|
||
})
|
||
def analyze_model_metrics():
|
||
"""
|
||
分析模型的评估指标并提供解读
|
||
"""
|
||
try:
|
||
metrics = request.json
|
||
if not metrics:
|
||
return jsonify({"status": "error", "error": "缺少指标数据"}), 400
|
||
|
||
analysis = {}
|
||
|
||
# 1. 分析 R²
|
||
r2 = metrics.get('R2')
|
||
if r2 is not None:
|
||
if r2 >= 0.9:
|
||
r2_rating = "优秀"
|
||
r2_desc = f"R²值为{r2:.4f},表现非常出色。模型能够解释超过90%的销售数据波动,意味着它与历史数据的拟合度极高,预测结果非常可靠。"
|
||
elif r2 >= 0.7:
|
||
r2_rating = "良好"
|
||
r2_desc = f"R²值为{r2:.4f},表现良好。模型能解释{(r2*100):.1f}%的销售数据变化,说明模型捕捉到了大部分关键的销售模式,具备很高的实用价值。"
|
||
elif r2 >= 0.5:
|
||
r2_rating = "中等"
|
||
r2_desc = f"R²值为{r2:.4f},表现中等。模型解释了约一半的销售数据变化,说明它掌握了一定的规律,但可能忽略了一些次要因素。预测结果可作为参考,但需结合其他信息判断。"
|
||
else:
|
||
r2_rating = "较弱"
|
||
r2_desc = f"R²值为{r2:.4f},表现较弱。模型对销售数据变化的解释能力有限,预测的准确性可能不高。建议尝试优化模型或增加更多有效特征。"
|
||
analysis['R2'] = {"value": r2, "rating": r2_rating, "description": r2_desc}
|
||
|
||
# 2. 分析 MAPE
|
||
mape = metrics.get('MAPE')
|
||
if mape is not None:
|
||
if mape <= 10:
|
||
mape_rating = "优秀"
|
||
mape_desc = f"平均绝对百分比误差为{mape:.2f}%,误差率极低,预测精度非常高。"
|
||
elif mape <= 20:
|
||
mape_rating = "良好"
|
||
mape_desc = f"平均绝对百分比误差为{mape:.2f}%,误差率在可接受的范围内,表明模型的预测结果在大多数情况下与真实值偏差不大。"
|
||
elif mape <= 30:
|
||
mape_rating = "中等"
|
||
mape_desc = f"平均绝对百分比误差为{mape:.2f}%,误差率中等。在销量波动较大的场景下可以接受,但对于追求高精度预测的场景,仍有优化空间。"
|
||
else:
|
||
mape_rating = "较弱"
|
||
mape_desc = f"平均绝对百分比误差为{mape:.2f}%,误差率偏高。模型的预测值与真实值偏差较大,建议谨慎使用其预测结果。"
|
||
analysis['MAPE'] = {"value": mape, "rating": mape_rating, "description": mape_desc}
|
||
|
||
# 3. 分析 RMSE 和 MAE
|
||
rmse = metrics.get('RMSE')
|
||
mae = metrics.get('MAE')
|
||
if rmse is not None and mae is not None:
|
||
rmse_desc = f"均方根误差为{rmse:.4f}。这个值衡量了预测误差的典型大小。因为它对较大的误差值更敏感,所以可以反映模型是否存在'离谱'的预测。"
|
||
mae_desc = f"平均绝对误差为{mae:.4f}。这个值直观地表示了模型平均预测会偏离真实销售量多少个单位(如'件')。"
|
||
|
||
# 比较RMSE和MAE
|
||
if rmse > mae * 1.5: # 经验阈值
|
||
comparison_desc = f"RMSE ({rmse:.4f}) 明显大于 MAE ({mae:.4f}),这通常意味着模型在某些数据点上存在较大的预测误差(离群点)。虽然总体平均误差不大,但偶尔可能会有'猜错得离谱'的情况。"
|
||
else:
|
||
comparison_desc = f"RMSE ({rmse:.4f}) 与 MAE ({mae:.4f}) 的值较为接近,表明模型的误差分布比较均匀,没有出现极端异常的预测错误。"
|
||
|
||
analysis['RMSE'] = {"value": rmse, "rating": "参考指标", "description": rmse_desc}
|
||
analysis['MAE'] = {"value": mae, "rating": "参考指标", "description": mae_desc}
|
||
analysis['RMSE_MAE_COMP'] = {"description": comparison_desc}
|
||
|
||
# 4. 形成总体结论
|
||
overall_ratings = [a.get('rating') for a in analysis.values() if a.get('rating') in ["优秀", "良好", "中等", "较弱"]]
|
||
if "较弱" in overall_ratings:
|
||
overall_summary = "该模型的综合性能表现较弱,预测结果可能存在较大偏差,建议进行优化或谨慎使用。"
|
||
elif "中等" in overall_ratings:
|
||
overall_summary = "该模型的综合性能表现中等,具备一定的预测能力,但仍有提升空间。其预测结果可作为重要参考。"
|
||
elif "优秀" in overall_ratings and overall_ratings.count("良好") == 0:
|
||
overall_summary = "该模型的综合性能表现非常优秀,各项指标均显示其预测精度高、稳定性好,预测结果高度可信。"
|
||
else: # 主要为良好
|
||
overall_summary = "该模型的综合性能表现良好,能够可靠地预测销售趋势,误差在可接受范围内,是决策的有力支持。"
|
||
|
||
analysis['overall_summary'] = overall_summary
|
||
|
||
return jsonify({"status": "success", "data": analysis})
|
||
|
||
except Exception as e:
|
||
import traceback
|
||
traceback.print_exc()
|
||
return jsonify({"status": "error", "error": str(e)}), 500
|
||
|
||
# 添加一个主函数入口点,用于直接运行API服务器
|
||
if __name__ == '__main__':
|
||
# 初始化数据库
|
||
init_db()
|
||
|
||
# 使用waitress作为生产环境的WSGI服务器,比Flask默认的开发服务器更健壮
|
||
# from waitress import serve
|
||
# serve(app, host="0.0.0.0", port=5000)
|
||
|
||
# 或者,为了方便调试,仍然使用Flask内置的开发服务器
|
||
# 注意:debug=True 模式在生产环境中非常不安全,请仅在开发阶段使用
|
||
app.run(host="0.0.0.0", port=5000, debug=True) |