ShopTRAINING/api.py
gdtiti 5d505b37af 修复图表显示和数据处理问题
1. 修复前端图表日期排序问题:
   - 改进 PredictionView.vue 和 HistoryView.vue 中的图表渲染逻辑
   - 确保历史数据和预测数据按照正确的日期顺序显示

2. 修复后端API处理:
   - 解决 optimized_kan 模型类型的路径映射问题
   - 添加 JSON 序列化器处理 Pandas Timestamp 对象
   - 改进预测数据与历史数据的衔接处理

3. 优化图表样式和用户体验
2025-06-15 00:00:50 +08:00

2549 lines
90 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import sys
import os
# 获取当前脚本所在目录的绝对路径
current_dir = os.path.dirname(os.path.abspath(__file__))
# 将当前目录添加到系统路径
sys.path.append(current_dir)
# 或者添加父目录
#parent_dir = os.path.dirname(current_dir)
#sys.path.append(parent_dir)
import json
import pandas as pd
import numpy as np
import torch
import matplotlib.pyplot as plt
import io
import base64
import uuid
from datetime import datetime, timedelta
from flask import Flask, jsonify, request, send_file, redirect, send_from_directory, Response
from flask_cors import CORS
from flasgger import Swagger
from werkzeug.utils import secure_filename
import sqlite3
import traceback
from pharmacy_predictor import (
train_product_model_with_mlstm,
train_product_model_with_kan,
train_product_model_with_transformer,
load_model_and_predict
)
import threading
import base64
import matplotlib.pyplot as plt
from io import BytesIO
import json
from flasgger import Swagger, swag_from
import argparse
from datetime import datetime, timedelta
from concurrent.futures import ThreadPoolExecutor
from threading import Lock
import traceback
import torch
import sqlite3
import numpy as np
import io
from werkzeug.utils import secure_filename
# 自定义JSON编码器来处理Pandas的Timestamp和NumPy类型
class CustomJSONEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, (pd.Timestamp, pd.DatetimeIndex)):
return obj.strftime('%Y-%m-%d')
elif isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
elif pd.isna(obj): # 处理NaN值
return None
return super(CustomJSONEncoder, self).default(obj)
app = Flask(__name__)
# 设置自定义JSON编码器
app.json_encoder = CustomJSONEncoder
CORS(app) # 启用CORS支持
# Swagger配置
swagger_config = {
"headers": [],
"specs": [
{
"endpoint": "apispec",
"route": "/apispec.json",
"rule_filter": lambda rule: True, # 包含所有路由
"model_filter": lambda tag: True, # 包含所有模型
}
],
"static_url_path": "/flasgger_static",
"swagger_ui": True,
"specs_route": "/swagger/"
}
swagger_template = {
"swagger": "2.0",
"info": {
"title": "药店销售预测系统API",
"description": "用于药店销售预测的RESTful API",
"version": "1.0.0",
"contact": {
"name": "API开发团队",
"email": "support@example.com"
}
},
"tags": [
{
"name": "数据管理",
"description": "数据上传和查询相关接口"
},
{
"name": "模型训练",
"description": "模型训练相关接口"
},
{
"name": "模型预测",
"description": "预测销售数据相关接口"
},
{
"name": "模型管理",
"description": "模型查询、导出和删除接口"
}
]
}
swagger = Swagger(app, config=swagger_config, template=swagger_template)
# 存储训练任务状态
training_tasks = {}
tasks_lock = Lock()
# 线程池用于后台训练
executor = ThreadPoolExecutor(max_workers=2)
# 辅助函数将图像转换为Base64
def fig_to_base64(fig):
buf = BytesIO()
fig.savefig(buf, format='png')
buf.seek(0)
img_str = base64.b64encode(buf.read()).decode('utf-8')
return img_str
# 根路由 - 重定向到UI界面
@app.route('/')
def index():
"""重定向到UI界面"""
return redirect('/ui/')
# Swagger UI路由
@app.route('/swagger')
def swagger_ui():
"""重定向到Swagger UI文档页面"""
return redirect('/swagger/')
# 1. 数据管理API
@app.route('/api/products', methods=['GET'])
@swag_from({
'tags': ['数据管理'],
'summary': '获取所有产品列表',
'description': '返回系统中所有产品的ID和名称',
'responses': {
200: {
'description': '成功获取产品列表',
'schema': {
'type': 'object',
'properties': {
'status': {'type': 'string'},
'data': {
'type': 'array',
'items': {
'type': 'object',
'properties': {
'product_id': {'type': 'string'},
'product_name': {'type': 'string'}
}
}
}
}
}
},
500: {
'description': '服务器内部错误'
}
}
})
def get_products():
try:
df = pd.read_excel('pharmacy_sales.xlsx')
products = df[['product_id', 'product_name']].drop_duplicates().to_dict('records')
return jsonify({"status": "success", "data": products})
except Exception as e:
return jsonify({"status": "error", "message": str(e)}), 500
@app.route('/api/products/<product_id>', methods=['GET'])
@swag_from({
'tags': ['数据管理'],
'summary': '获取单个产品详情',
'description': '返回指定产品ID的详细信息',
'parameters': [
{
'name': 'product_id',
'in': 'path',
'type': 'string',
'required': True,
'description': '产品ID例如P001'
}
],
'responses': {
200: {
'description': '成功获取产品详情',
'schema': {
'type': 'object',
'properties': {
'status': {'type': 'string'},
'data': {
'type': 'object',
'properties': {
'product_id': {'type': 'string'},
'product_name': {'type': 'string'},
'data_points': {'type': 'integer'},
'date_range': {
'type': 'object',
'properties': {
'start': {'type': 'string'},
'end': {'type': 'string'}
}
}
}
}
}
}
},
404: {
'description': '产品不存在'
},
500: {
'description': '服务器内部错误'
}
}
})
def get_product(product_id):
try:
df = pd.read_excel('pharmacy_sales.xlsx')
product_df = df[df['product_id'] == product_id]
if product_df.empty:
return jsonify({"status": "error", "message": "产品不存在"}), 404
product_info = {
"product_id": product_id,
"product_name": product_df['product_name'].iloc[0],
"data_points": len(product_df),
"date_range": {
"start": product_df['date'].min().strftime('%Y-%m-%d'),
"end": product_df['date'].max().strftime('%Y-%m-%d')
}
}
return jsonify({"status": "success", "data": product_info})
except Exception as e:
return jsonify({"status": "error", "message": str(e)}), 500
@app.route('/api/products/<product_id>/sales', methods=['GET'])
@swag_from({
'tags': ['数据管理'],
'summary': '获取产品销售数据',
'description': '返回指定产品在特定日期范围内的销售数据',
'parameters': [
{
'name': 'product_id',
'in': 'path',
'type': 'string',
'required': True,
'description': '产品ID例如P001'
},
{
'name': 'start_date',
'in': 'query',
'type': 'string',
'required': False,
'description': '开始日期格式为YYYY-MM-DD'
},
{
'name': 'end_date',
'in': 'query',
'type': 'string',
'required': False,
'description': '结束日期格式为YYYY-MM-DD'
}
],
'responses': {
200: {
'description': '成功获取销售数据',
'schema': {
'type': 'object',
'properties': {
'status': {'type': 'string'},
'data': {
'type': 'array',
'items': {
'type': 'object'
}
}
}
}
},
404: {
'description': '产品不存在'
},
500: {
'description': '服务器内部错误'
}
}
})
def get_product_sales(product_id):
try:
start_date = request.args.get('start_date')
end_date = request.args.get('end_date')
df = pd.read_excel('pharmacy_sales.xlsx')
product_df = df[df['product_id'] == product_id].sort_values('date')
if product_df.empty:
return jsonify({"status": "error", "message": "产品不存在"}), 404
# 如果提供了日期范围,进行过滤
if start_date:
product_df = product_df[product_df['date'] >= pd.to_datetime(start_date)]
if end_date:
product_df = product_df[product_df['date'] <= pd.to_datetime(end_date)]
# 转换日期为字符串以便JSON序列化
product_df['date'] = product_df['date'].dt.strftime('%Y-%m-%d')
sales_data = product_df.to_dict('records')
return jsonify({"status": "success", "data": sales_data})
except Exception as e:
return jsonify({"status": "error", "message": str(e)}), 500
@app.route('/api/data/upload', methods=['POST'])
@swag_from({
'tags': ['数据管理'],
'summary': '上传销售数据',
'description': '上传新的销售数据文件(Excel格式)',
'consumes': ['multipart/form-data'],
'parameters': [
{
'name': 'file',
'in': 'formData',
'type': 'file',
'required': True,
'description': 'Excel文件(.xlsx),包含销售数据'
}
],
'responses': {
200: {
'description': '数据上传成功',
'schema': {
'type': 'object',
'properties': {
'status': {'type': 'string'},
'message': {'type': 'string'},
'data': {
'type': 'object',
'properties': {
'products': {'type': 'integer'},
'rows': {'type': 'integer'}
}
}
}
}
},
400: {
'description': '请求错误,可能是文件格式不正确或缺少必要字段'
},
500: {
'description': '服务器内部错误'
}
}
})
def upload_data():
try:
if 'file' not in request.files:
return jsonify({"status": "error", "message": "没有上传文件"}), 400
file = request.files['file']
if file.filename == '':
return jsonify({"status": "error", "message": "没有选择文件"}), 400
if file and file.filename.endswith('.xlsx'):
file_path = 'uploaded_data.xlsx'
file.save(file_path)
# 验证数据格式
try:
df = pd.read_excel(file_path)
required_columns = ['date', 'product_id', 'product_name', 'sales', 'price']
missing_columns = [col for col in required_columns if col not in df.columns]
if missing_columns:
return jsonify({
"status": "error",
"message": f"上传的数据缺少必要的列: {', '.join(missing_columns)}"
}), 400
# 合并到现有数据或替换现有数据
existing_df = pd.read_excel('pharmacy_sales.xlsx')
# 这里可以实现数据合并逻辑例如按日期和产品ID去重后合并
# 简单示例:保存上传的数据
df.to_excel('pharmacy_sales.xlsx', index=False)
return jsonify({
"status": "success",
"message": "数据上传成功",
"data": {
"products": len(df['product_id'].unique()),
"rows": len(df)
}
})
except Exception as e:
return jsonify({"status": "error", "message": f"数据验证失败: {str(e)}"}), 400
else:
return jsonify({"status": "error", "message": "只支持Excel文件(.xlsx)"}), 400
except Exception as e:
return jsonify({"status": "error", "message": str(e)}), 500
# 2. 模型训练API
@app.route('/api/training', methods=['GET'])
@swag_from({
'tags': ['模型训练'],
'summary': '获取所有训练任务列表',
'description': '返回所有正在进行、已完成或失败的训练任务',
'responses': {
200: {
'description': '成功获取任务列表',
'schema': {
'type': 'object',
'properties': {
'status': {'type': 'string'},
'data': {
'type': 'array',
'items': {
'type': 'object',
'properties': {
'task_id': {'type': 'string'},
'product_id': {'type': 'string'},
'model_type': {'type': 'string'},
'status': {'type': 'string'},
'start_time': {'type': 'string'},
'metrics': {'type': 'object'},
'error': {'type': 'string'},
'model_path': {'type': 'string'}
}
}
}
}
}
}
}
})
def get_all_training_tasks():
"""获取所有训练任务的状态"""
try:
# 为了方便前端使用我们将任务ID也包含在每个任务信息中
tasks_with_id = []
for task_id, task_info in training_tasks.items():
task_copy = task_info.copy()
task_copy['task_id'] = task_id
tasks_with_id.append(task_copy)
# 按开始时间降序排序,最新的任务在前面
sorted_tasks = sorted(tasks_with_id, key=lambda x: x['start_time'], reverse=True)
return jsonify({"status": "success", "data": sorted_tasks})
except Exception as e:
return jsonify({"status": "error", "message": str(e)}), 500
@app.route('/api/training', methods=['POST'])
@swag_from({
'tags': ['模型训练'],
'summary': '启动模型训练任务',
'description': '为指定产品启动一个新的模型训练任务',
'parameters': [
{
'name': 'body',
'in': 'body',
'required': True,
'schema': {
'type': 'object',
'properties': {
'product_id': {'type': 'string', 'description': '例如 P001'},
'model_type': {'type': 'string', 'enum': ['mlstm', 'transformer', 'kan', 'optimized_kan'], 'description': '要训练的模型类型'},
'epochs': {'type': 'integer', 'default': 50, 'description': '训练轮次'}
},
'required': ['product_id', 'model_type']
}
}
],
'responses': {
200: {
'description': '训练任务已启动',
'schema': {
'type': 'object',
'properties': {
'message': {'type': 'string'},
'task_id': {'type': 'string'}
}
}
},
400: {
'description': '请求错误'
}
}
})
def start_training():
"""
启动模型训练
---
post:
...
"""
data = request.get_json()
product_id = data.get('product_id')
model_type = data.get('model_type')
epochs = data.get('epochs', 50) # 默认为50轮
if not product_id or not model_type:
return jsonify({'error': '缺少product_id或model_type'}), 400
global training_tasks
task_id = str(uuid.uuid4())
model_train_functions = {
'mlstm': train_product_model_with_mlstm,
'kan': train_product_model_with_kan,
'transformer': train_product_model_with_transformer,
'optimized_kan': lambda product_id, epochs: train_product_model_with_kan(product_id, epochs, use_optimized=True)
}
if model_type not in model_train_functions:
return jsonify({'error': '无效的模型类型'}), 400
train_function = model_train_functions[model_type]
def train_task(product_id, epochs, model_type):
global training_tasks
try:
print(f"任务 {task_id}: 开始为产品 {product_id} 训练 {model_type} 模型,共 {epochs} 个轮次。")
# 这里的 train_function 会返回 (model, metrics)
_, metrics = train_function(product_id, epochs)
training_tasks[task_id]['status'] = 'completed'
training_tasks[task_id]['metrics'] = metrics
# 保存模型路径
model_path = f'models/{model_type}/{product_id}_model.pt'
training_tasks[task_id]['model_path'] = model_path
print(f"任务 {task_id}: 训练完成。评估指标: {metrics}")
except Exception as e:
import traceback
traceback.print_exc()
print(f"任务 {task_id}: 训练失败。错误: {e}")
training_tasks[task_id]['status'] = 'failed'
training_tasks[task_id]['error'] = str(e)
thread = threading.Thread(target=train_task, args=(product_id, epochs, model_type))
thread.start()
training_tasks[task_id] = {
'status': 'running',
'product_id': product_id,
'model_type': model_type,
'start_time': datetime.now().isoformat(),
'metrics': None,
'error': None,
'model_path': None
}
return jsonify({'message': '模型训练已开始', 'task_id': task_id})
@app.route('/api/training/<task_id>', methods=['GET'])
@swag_from({
'tags': ['模型训练'],
'summary': '查询训练任务状态',
'description': '获取特定训练任务的当前状态和详情',
'parameters': [
{
'name': 'task_id',
'in': 'path',
'type': 'string',
'required': True,
'description': '训练任务ID'
}
],
'responses': {
200: {
'description': '成功获取任务状态',
'schema': {
'type': 'object',
'properties': {
'status': {'type': 'string'},
'data': {
'type': 'object',
'properties': {
'product_id': {'type': 'string'},
'model_type': {'type': 'string'},
'parameters': {'type': 'object'},
'status': {'type': 'string', 'enum': ['pending', 'running', 'completed', 'failed']},
'created_at': {'type': 'string'},
'model_path': {'type': 'string'},
'metrics': {'type': 'object'},
'model_details_url': {'type': 'string'}
}
}
}
}
},
404: {
'description': '任务不存在'
},
500: {
'description': '服务器内部错误'
}
}
})
def get_training_status(task_id):
try:
if task_id not in training_tasks:
return jsonify({"status": "error", "message": "任务不存在"}), 404
task_info = training_tasks[task_id].copy()
# 如果任务已完成,添加模型详情链接
if task_info['status'] == 'completed':
task_info['model_details_url'] = f"/api/models?product_id={task_info['product_id']}&model_type={task_info['model_type']}"
return jsonify({
"status": "success",
"data": task_info
})
except Exception as e:
return jsonify({"status": "error", "message": str(e)}), 500
# 3. 模型预测API
@app.route('/api/prediction', methods=['POST'])
@swag_from({
'tags': ['模型预测'],
'summary': '使用模型进行预测',
'description': '使用指定模型预测未来销售数据',
'parameters': [
{
'name': 'body',
'in': 'body',
'required': True,
'schema': {
'type': 'object',
'properties': {
'product_id': {'type': 'string'},
'model_type': {'type': 'string', 'enum': ['mlstm', 'transformer', 'kan', 'optimized_kan']},
'version': {'type': 'string'},
'future_days': {'type': 'integer'},
'include_visualization': {'type': 'boolean'},
'start_date': {'type': 'string', 'description': '预测起始日期格式为YYYY-MM-DD'}
},
'required': ['product_id', 'model_type']
}
}
],
'responses': {
200: {
'description': '预测成功',
'schema': {
'type': 'object',
'properties': {
'status': {'type': 'string'},
'data': {
'type': 'object',
'properties': {
'product_id': {'type': 'string'},
'product_name': {'type': 'string'},
'model_type': {'type': 'string'},
'predictions': {'type': 'array'},
'visualization': {'type': 'string'}
}
}
}
}
},
400: {
'description': '请求错误,缺少必要参数或参数格式错误'
},
404: {
'description': '产品或模型不存在'
},
500: {
'description': '服务器内部错误'
}
}
})
def predict():
"""
使用指定的模型进行预测
---
tags:
- 模型预测
parameters:
- in: body
name: body
schema:
type: object
required:
- product_id
- model_type
properties:
product_id:
type: string
description: 产品ID
model_type:
type: string
description: 模型类型 (mlstm, kan, transformer)
future_days:
type: integer
description: 预测未来天数
default: 7
start_date:
type: string
description: 预测起始日期格式为YYYY-MM-DD
default: ''
responses:
200:
description: 预测成功
400:
description: 请求参数错误
404:
description: 模型文件未找到
"""
try:
data = request.json
product_id = data.get('product_id')
model_type = data.get('model_type')
future_days = int(data.get('future_days', 7))
start_date = data.get('start_date', '')
include_visualization = data.get('include_visualization', False)
print(f"API接收到预测请求: product_id={product_id}, model_type={model_type}, future_days={future_days}, start_date={start_date}")
if not product_id or not model_type:
return jsonify({"status": "error", "error": "product_id 和 model_type 是必需的"}), 400
# 获取产品名称
product_name = get_product_name(product_id)
if not product_name:
product_name = product_id
# 获取最新模型ID - 此处不需要处理模型类型映射,因为 get_latest_model_id 和 load_model_and_predict 会处理
model_id = get_latest_model_id(model_type, product_id)
if not model_id:
return jsonify({"status": "error", "error": f"未找到产品 {product_id}{model_type} 类型模型"}), 404
# 执行预测
prediction_result = run_prediction(model_type, product_id, model_id, future_days, start_date)
# 如果需要可视化,添加图表数据
if include_visualization:
# 添加图表数据
chart_data = prepare_chart_data(prediction_result)
prediction_result['chart_data'] = chart_data
# 添加分析结果
analysis_result = analyze_prediction(prediction_result)
prediction_result['analysis'] = analysis_result
# 保存预测结果到文件和数据库
prediction_id, file_path = save_prediction_result(
prediction_result,
product_id,
product_name,
model_type,
model_id,
start_date,
future_days
)
# 添加预测ID到结果中
prediction_result['prediction_id'] = prediction_id
return jsonify(prediction_result)
except Exception as e:
print(f"预测失败: {str(e)}")
import traceback
traceback.print_exc()
return jsonify({"status": "error", "error": str(e)}), 500
@app.route('/api/prediction/compare', methods=['POST'])
@swag_from({
'tags': ['模型预测'],
'summary': '比较不同模型预测结果',
'description': '比较不同模型对同一产品的预测结果',
'parameters': [
{
'name': 'body',
'in': 'body',
'required': True,
'schema': {
'type': 'object',
'properties': {
'product_id': {'type': 'string'},
'model_types': {'type': 'array', 'items': {'type': 'string'}},
'versions': {'type': 'array', 'items': {'type': 'string'}},
'include_visualization': {'type': 'boolean'}
},
'required': ['product_id', 'model_types']
}
}
],
'responses': {
200: {
'description': '比较成功',
'schema': {
'type': 'object',
'properties': {
'status': {'type': 'string'},
'data': {
'type': 'object',
'properties': {
'product_id': {'type': 'string'},
'product_name': {'type': 'string'},
'model_types': {'type': 'array'},
'comparison': {'type': 'array'},
'visualization': {'type': 'string'}
}
}
}
}
},
400: {
'description': '请求错误,缺少必要参数或参数格式错误'
},
404: {
'description': '产品或模型不存在'
},
500: {
'description': '服务器内部错误'
}
}
})
def compare_predictions():
try:
data = request.json
product_id = data.get('product_id')
model_types = data.get('model_types')
if not product_id or not model_types:
return jsonify({"status": "error", "error": "product_id 和 model_types 是必需的"}), 400
all_predictions = {}
plt.figure(figsize=(12, 8))
# 加载历史数据用于绘图
df = pd.read_excel('pharmacy_sales.xlsx')
product_df = df[df['product_id'] == product_id].sort_values('date')
product_name = product_df['product_name'].iloc[0]
history_days = 30
history_dates = product_df['date'].iloc[-history_days:].values
history_sales = product_df['sales'].iloc[-history_days:].values
plt.plot(history_dates, history_sales, 'b-', label='历史销量')
comparison_data = []
future_dates = None
# 创建比较结果目录
compare_dir = f'predictions/compare'
os.makedirs(compare_dir, exist_ok=True)
for model_type in model_types:
result = load_model_and_predict(product_id, model_type)
if result is not None:
predictions_df = result['predictions_df']
if future_dates is None:
future_dates = predictions_df['date']
comparison_data = [{'date': d.strftime('%Y-%m-%d')} for d in future_dates]
plt.plot(predictions_df['date'], predictions_df['predicted_sales'], '--', label=f'{model_type.upper()} 预测')
preds = predictions_df['predicted_sales'].tolist()
for i in range(len(comparison_data)):
comparison_data[i][f'{model_type}_prediction'] = preds[i] if i < len(preds) else None
plt.title(f'{product_name} - 多模型预测比较')
plt.xlabel('日期')
plt.ylabel('销量')
plt.legend()
plt.grid(True)
plt.xticks(rotation=45)
plt.tight_layout()
chart_path = f'{compare_dir}/{product_id}_model_comparison.png'
plt.savefig(chart_path)
plt.close() # 关闭图表,释放内存
# 获取当前服务器主机和端口
host_url = request.host_url.rstrip('/') # 移除末尾的斜杠
# 生成带时间戳的URL以避免缓存
timestamp = datetime.now().timestamp()
image_url = f"{host_url}/api/predictions/compare/{product_id}_model_comparison.png?t={timestamp}"
return jsonify({
"status": "success",
"data": {
"product_id": product_id,
"product_name": product_name,
"model_types": model_types,
"comparison": comparison_data,
"visualization_url": image_url
}
})
except Exception as e:
traceback.print_exc()
return jsonify({"status": "error", "error": f"比较预测时出错: {e}"}), 500
@app.route('/api/prediction/analyze', methods=['POST'])
def analyze_prediction():
"""
分析预测结果,提供详细解释
---
tags:
- 预测分析
parameters:
- name: body
in: body
required: true
schema:
type: object
properties:
product_id:
type: string
example: P001
model_type:
type: string
enum: [mlstm, transformer, kan, optimized_kan]
future_days:
type: integer
default: 7
start_date:
type: string
description: 预测起始日期,格式为'YYYY-MM-DD'
responses:
200:
description: 预测结果及其分析
schema:
type: object
properties:
status:
type: string
example: success
data:
type: object
properties:
predictions:
type: array
items:
type: number
dates:
type: array
items:
type: string
analysis:
type: object
properties:
trend:
type: string
statistics:
type: object
historical_comparison:
type: object
factors:
type: array
items:
type: object
explanation:
type: string
"""
try:
data = request.json
product_id = data.get('product_id')
model_type = data.get('model_type')
future_days = data.get('future_days', 7)
start_date = data.get('start_date')
# 验证参数
if not product_id or not model_type:
return jsonify({"status": "error", "error": "缺少必要参数"}), 400
if model_type not in ['mlstm', 'transformer', 'kan', 'optimized_kan']:
return jsonify({"status": "error", "error": "不支持的模型类型"}), 400
# 获取预测结果和分析
result_tuple = load_model_and_predict(product_id, model_type, future_days, start_date, analyze_result=True)
if result_tuple is None:
return jsonify({"status": "error", "error": f"模型 'models/{model_type}/{product_id}_model.pt' 未找到或加载失败"}), 404
# 解包元组
result, analysis = result_tuple
if result is None:
return jsonify({"status": "error", "error": "预测失败"}), 500
# 从result中获取预测值
predictions_df = result['predictions_df']
prediction_data = predictions_df[predictions_df['data_type'] == '预测销量']
predictions_list = prediction_data['sales'].tolist()
# 生成日期列表
if start_date:
start_date_obj = datetime.strptime(start_date, '%Y-%m-%d')
else:
# 获取最后一条数据的日期并加1天
df = pd.read_excel('pharmacy_sales.xlsx')
product_df = df[df['product_id'] == product_id].sort_values('date')
last_date = product_df['date'].iloc[-1]
start_date_obj = last_date + timedelta(days=1)
dates = [(start_date_obj + timedelta(days=i)).strftime('%Y-%m-%d') for i in range(future_days)]
# 为前端ECharts准备数据
chart_data = {
"dates": dates,
"values": predictions_list,
"day_over_day_changes": analysis["statistics"]["day_over_day_changes"] if "day_over_day_changes" in analysis["statistics"] else []
}
# 如果有历史对比数据,也添加到图表数据中
if analysis["historical_comparison"]["has_historical_data"]:
chart_data["historical"] = {
"mean": analysis["historical_comparison"]["historical_mean"],
"max": analysis["historical_comparison"]["historical_max"],
"min": analysis["historical_comparison"]["historical_min"]
}
return jsonify({
"status": "success",
"data": {
"product_id": product_id,
"model_type": model_type,
"predictions": predictions_list,
"dates": dates,
"chart_data": chart_data,
"analysis": analysis
}
})
except Exception as e:
import traceback
traceback.print_exc()
return jsonify({"status": "error", "error": str(e)}), 500
# 4. 模型管理API
@app.route('/api/models', methods=['GET'])
@swag_from({
'tags': ['模型管理'],
'summary': '获取模型列表',
'description': '获取系统中的模型列表可按产品ID和模型类型筛选',
'parameters': [
{
'name': 'product_id',
'in': 'query',
'type': 'string',
'required': False,
'description': '按产品ID筛选'
},
{
'name': 'model_type',
'in': 'query',
'type': 'string',
'required': False,
'description': '按模型类型筛选 (mlstm, kan, transformer)'
}
],
'responses': {
200: {
'description': '成功获取模型列表',
'schema': {
'type': 'object',
'properties': {
'status': {'type': 'string'},
'data': {
'type': 'array',
'items': {
'type': 'object',
'properties': {
'model_id': {'type': 'string'},
'product_id': {'type': 'string'},
'product_name': {'type': 'string'},
'model_type': {'type': 'string'},
'created_at': {'type': 'string'},
'metrics': {'type': 'object'}
}
}
}
}
}
},
500: {
'description': '服务器内部错误'
}
}
})
def list_models():
"""
列出所有可用的模型
---
tags:
- 模型管理
parameters:
- name: product_id
in: query
type: string
required: false
description: 按产品ID筛选模型
- name: model_type
in: query
type: string
required: false
description: 按模型类型筛选 (mlstm, kan, transformer)
responses:
200:
description: 模型列表
schema:
type: object
properties:
status:
type: string
example: success
data:
type: array
items:
type: object
properties:
model_id:
type: string
product_id:
type: string
product_name:
type: string
model_type:
type: string
created_at:
type: string
metrics:
type: object
"""
models_dir = 'models'
model_types = ['mlstm', 'kan', 'transformer', 'kan_optimized']
model_type_mapping = {
'mlstm': 'mlstm',
'kan': 'kan',
'transformer': 'transformer',
'kan_optimized': 'optimized_kan' # 将kan_optimized目录映射为optimized_kan类型
}
available_models = []
product_id_filter = request.args.get('product_id')
model_type_filter = request.args.get('model_type')
if not os.path.exists(models_dir):
return jsonify({"status": "success", "data": []})
for model_type in model_types:
# 如果指定了模型类型过滤器,检查是否匹配当前类型或其映射类型
if model_type_filter:
# 如果过滤器是optimized_kan我们需要查找kan_optimized目录
if model_type_filter == 'optimized_kan' and model_type != 'kan_optimized':
continue
# 如果过滤器不是optimized_kan但当前目录是kan_optimized跳过
elif model_type_filter != 'optimized_kan' and model_type_filter != model_type:
continue
type_dir = os.path.join(models_dir, model_type)
if not os.path.exists(type_dir):
continue
for file_name in os.listdir(type_dir):
if file_name.endswith('_log.json'):
product_id = file_name.replace('_log.json', '')
if product_id_filter and product_id_filter != product_id:
continue
log_path = os.path.join(type_dir, file_name)
try:
with open(log_path, 'r', encoding='utf-8') as f:
log_data = json.load(f)
# 使用映射表获取显示的模型类型
display_model_type = model_type_mapping.get(model_type, model_type)
model_info = {
"model_id": f"{display_model_type}_{product_id}",
"product_id": log_data.get('product_id'),
"product_name": log_data.get('product_name'),
"model_type": display_model_type, # 使用映射后的模型类型
"created_at": log_data.get('training_completed_at'),
"metrics": log_data.get('metrics'),
"file_path": log_data.get('file_path')
}
available_models.append(model_info)
except Exception as e:
print(f"读取日志文件 {log_path} 失败: {e}")
# 按创建时间降序排序
available_models.sort(key=lambda x: x.get('created_at', ''), reverse=True)
return jsonify({"status": "success", "data": available_models})
@app.route('/api/models/<model_id>', methods=['GET'])
@swag_from({
'tags': ['模型管理'],
'summary': '获取模型详情',
'description': '获取特定模型的详细信息',
'parameters': [
{
'name': 'model_id',
'in': 'path',
'type': 'string',
'required': True,
'description': '模型ID格式为{product_id}_{model_type}_v{version}'
}
],
'responses': {
200: {
'description': '成功获取模型详情',
'schema': {
'type': 'object',
'properties': {
'status': {'type': 'string'},
'data': {
'type': 'object',
'properties': {
'product_id': {'type': 'string'},
'model_type': {'type': 'string'},
'version': {'type': 'string'},
'created_at': {'type': 'string'},
'file_path': {'type': 'string'},
'file_size': {'type': 'string'},
'features': {'type': 'array'},
'look_back': {'type': 'integer'},
'T': {'type': 'integer'},
'metrics': {'type': 'object'}
}
}
}
}
},
400: {
'description': '无效的模型ID格式'
},
404: {
'description': '模型不存在'
},
500: {
'description': '服务器内部错误'
}
}
})
def get_model_details(model_id):
"""
获取单个模型的详细信息
---
tags:
- 模型管理
parameters:
- name: model_id
in: path
type: string
required: true
description: 模型的唯一标识符 (格式: model_type_product_id)
responses:
200:
description: 模型的详细信息
404:
description: 未找到模型
"""
try:
model_type, product_id = model_id.split('_', 1)
# 处理优化版KAN模型的路径
actual_model_type = model_type
if model_type == 'optimized_kan':
actual_model_type = 'kan_optimized'
print(f"优化版KAN模型: 使用路径 'models/{actual_model_type}/{product_id}_log.json'")
log_path = os.path.join('models', actual_model_type, f'{product_id}_log.json')
if not os.path.exists(log_path):
return jsonify({"status": "error", "error": "模型未找到"}), 404
with open(log_path, 'r', encoding='utf-8') as f:
log_data = json.load(f)
# 确保返回的模型类型是optimized_kan而不是kan_optimized
if actual_model_type == 'kan_optimized':
log_data['model_type'] = 'optimized_kan'
return jsonify({"status": "success", "data": log_data})
except ValueError:
return jsonify({"status": "error", "error": "无效的model_id格式"}), 400
except Exception as e:
return jsonify({"status": "error", "error": f"获取模型详情失败: {e}"}), 500
@app.route('/api/models/<model_id>', methods=['DELETE'])
@swag_from({
'tags': ['模型管理'],
'summary': '删除模型',
'description': '删除特定模型',
'parameters': [
{
'name': 'model_id',
'in': 'path',
'type': 'string',
'required': True,
'description': '模型ID格式为{product_id}_{model_type}_v{version}'
}
],
'responses': {
200: {
'description': '模型删除成功',
'schema': {
'type': 'object',
'properties': {
'status': {'type': 'string'},
'message': {'type': 'string'}
}
}
},
400: {
'description': '无效的模型ID格式'
},
500: {
'description': '服务器内部错误'
}
}
})
def delete_model(model_id):
"""
删除一个模型及其关联文件
---
tags:
- 模型管理
parameters:
- name: model_id
in: path
type: string
required: true
description: 要删除的模型的ID (格式: model_type_product_id)
responses:
200:
description: 模型删除成功
404:
description: 模型未找到
"""
try:
model_type, product_id = model_id.split('_', 1)
# 处理优化版KAN模型的路径
actual_model_type = model_type
if model_type == 'optimized_kan':
actual_model_type = 'kan_optimized'
print(f"优化版KAN模型: 使用路径 'models/{actual_model_type}/{product_id}_model.pt'")
model_dir = os.path.join('models', actual_model_type)
model_path = os.path.join(model_dir, f'{product_id}_model.pt')
log_path = os.path.join(model_dir, f'{product_id}_log.json')
if not os.path.exists(model_path) and not os.path.exists(log_path):
return jsonify({"status": "error", "error": "模型未找到"}), 404
if os.path.exists(model_path):
os.remove(model_path)
if os.path.exists(log_path):
os.remove(log_path)
return jsonify({"status": "success", "message": f"模型 {model_id} 已删除"})
except ValueError:
return jsonify({"status": "error", "error": "无效的model_id格式"}), 400
except Exception as e:
return jsonify({"status": "error", "error": f"删除模型失败: {e}"}), 500
@app.route('/api/models/<model_id>/export', methods=['GET'])
@swag_from({
'tags': ['模型管理'],
'summary': '导出模型',
'description': '导出特定模型文件',
'parameters': [
{
'name': 'model_id',
'in': 'path',
'type': 'string',
'required': True,
'description': '模型ID格式为{product_id}_{model_type}_v{version}'
}
],
'responses': {
200: {
'description': '模型文件下载',
'content': {
'application/octet-stream': {}
}
},
400: {
'description': '无效的模型ID格式'
},
500: {
'description': '服务器内部错误'
}
}
})
def export_model(model_id):
try:
model_type, product_id = model_id.split('_', 1)
# 处理优化版KAN模型的路径
actual_model_type = model_type
if model_type == 'optimized_kan':
actual_model_type = 'kan_optimized'
print(f"优化版KAN模型: 使用路径 'models/{actual_model_type}/{product_id}_model.pt'")
model_path = os.path.join('models', actual_model_type, f'{product_id}_model.pt')
if not os.path.exists(model_path):
return jsonify({"status": "error", "error": "模型文件未找到"}), 404
return send_file(
model_path,
as_attachment=True,
download_name=f'{model_id}.pt',
mimetype='application/octet-stream'
)
except Exception as e:
return jsonify({"status": "error", "error": f"导出模型失败: {e}"}), 500
@app.route('/api/models/import', methods=['POST'])
@swag_from({
'tags': ['模型管理'],
'summary': '导入模型',
'description': '导入模型文件',
'consumes': ['multipart/form-data'],
'parameters': [
{
'name': 'file',
'in': 'formData',
'type': 'file',
'required': True,
'description': 'PyTorch模型文件(.pt)'
}
],
'responses': {
200: {
'description': '模型导入成功',
'schema': {
'type': 'object',
'properties': {
'status': {'type': 'string'},
'message': {'type': 'string'},
'data': {
'type': 'object',
'properties': {
'model_path': {'type': 'string'}
}
}
}
}
},
400: {
'description': '请求错误,文件格式不正确或缺少文件'
},
500: {
'description': '服务器内部错误'
}
}
})
def import_model():
try:
if 'file' not in request.files:
return jsonify({"status": "error", "message": "没有上传文件"}), 400
file = request.files['file']
if file.filename == '' or not file.filename.endswith('.pt'):
return jsonify({"status": "error", "message": "请上传有效的.pt模型文件"}), 400
# 从文件名解析 product_id, model_type
# 假设文件名格式为 `transformer_P001.pt`
try:
model_type, product_id_ext = file.filename.split('_', 1)
product_id = os.path.splitext(product_id_ext)[0]
model_types = ['mlstm', 'kan', 'transformer']
if model_type not in model_types:
raise ValueError("无效的模型类型")
except ValueError:
return jsonify({
"status": "error",
"message": "文件名格式不正确,应为 'model_type_product_id.pt' (例如: mlstm_P001.pt)"
}), 400
# 创建目标目录并保存文件
model_dir = os.path.join('models', model_type)
os.makedirs(model_dir, exist_ok=True)
model_path = os.path.join(model_dir, f'{product_id}_model.pt')
# 检查是否需要创建关联的日志文件
log_path = os.path.join(model_dir, f'{product_id}_log.json')
if not os.path.exists(log_path):
# 尝试从.pt文件加载信息如果可能
try:
checkpoint = torch.load(file, map_location='cpu')
# 重新定位文件指针到开头,以便 `file.save` 正常工作
file.seek(0)
except Exception:
checkpoint = {} # 如果加载失败创建一个空的checkpoint
# 创建一个基础的log文件
log_data = {
'product_id': product_id,
'product_name': f"导入的产品 {product_id}", # 可能需要用户后续编辑
'model_type': model_type,
'training_completed_at': datetime.now().isoformat(),
'epochs': checkpoint.get('epochs', 'N/A'),
'metrics': checkpoint.get('metrics', {'info': '导入的模型,无详细指标'}),
'file_path': model_path
}
with open(log_path, 'w', encoding='utf-8') as f:
json.dump(log_data, f, indent=4, ensure_ascii=False)
# 保存模型文件
file.save(model_path)
return jsonify({
"status": "success",
"message": "模型已成功导入",
"data": {"model_path": model_path}
})
except Exception as e:
traceback.print_exc()
return jsonify({"status": "error", "error": f"导入模型失败: {e}"}), 500
@app.route('/api/plots/<filename>')
def get_plot(filename):
"""Serve a plot file from the root directory."""
try:
return send_from_directory(app.root_path, filename)
except FileNotFoundError:
return jsonify({"status": "error", "error": "Plot not found"}), 404
@app.route('/api/csv/<filename>')
def get_csv(filename):
"""Serve a CSV file from the root directory."""
try:
return send_from_directory(app.root_path, filename, as_attachment=True)
except FileNotFoundError:
return jsonify({"status": "error", "error": "CSV file not found"}), 404
# 添加静态UI路由将/ui路径映射到wwwroot目录
@app.route('/ui/', defaults={'path': 'index.html'})
@app.route('/ui/<path:path>')
def serve_ui(path):
"""提供UI静态文件服务将/ui路径映射到wwwroot目录"""
try:
# 从wwwroot目录提供静态文件
return send_from_directory('wwwroot', path)
except FileNotFoundError:
# 如果是子路径请求(例如/ui/about)但文件不存在尝试返回index.html以支持SPA路由
if path != 'index.html':
try:
return send_from_directory('wwwroot', 'index.html')
except FileNotFoundError:
return jsonify({"status": "error", "error": "UI files not found"}), 404
return jsonify({"status": "error", "error": f"UI file {path} not found"}), 404
@app.route('/api/predictions/<model_type>/<product_id>/<filename>')
def get_prediction_file(model_type, product_id, filename):
"""按模型类型和产品ID获取预测文件"""
try:
# 构建文件路径
file_path = os.path.join('predictions', model_type, product_id)
# 如果是图片文件,添加防缓存头
if filename.endswith('.png'):
response = send_from_directory(file_path, filename)
# 强制浏览器不缓存图片
response.headers['Cache-Control'] = 'no-store, no-cache, must-revalidate, max-age=0, post-check=0, pre-check=0'
response.headers['Pragma'] = 'no-cache'
response.headers['Expires'] = '0'
response.headers['Last-Modified'] = datetime.now().strftime("%a, %d %b %Y %H:%M:%S GMT")
response.headers['Vary'] = '*'
print(f"提供图片文件 {filename} 并添加防缓存头")
return response
else:
return send_from_directory(file_path, filename)
except FileNotFoundError:
return jsonify({"status": "error", "error": f"预测文件 {filename} 未找到"}), 404
@app.route('/api/predictions/compare/<filename>')
def get_compare_file(filename):
"""
提供比较结果文件的下载
"""
directory = os.path.join(current_dir, 'predictions', 'compare')
return send_from_directory(directory, filename)
@app.route('/api/models/<model_type>/<product_id>/details', methods=['GET'])
def get_model_details_extended(model_type, product_id):
"""
获取模型详情,包括训练损失曲线、预测效果图等
---
tags:
- 模型管理
parameters:
- name: model_type
in: path
required: true
type: string
description: 模型类型
- name: product_id
in: path
required: true
type: string
description: 产品ID
responses:
200:
description: 模型详情
schema:
type: object
properties:
status:
type: string
example: success
data:
type: object
properties:
model_info:
type: object
description: 模型基本信息
training_metrics:
type: object
description: 训练评估指标
chart_data:
type: object
description: 图表数据
"""
try:
# 处理优化版KAN模型的路径
actual_model_path = model_type
if model_type == 'optimized_kan':
actual_model_path = 'kan_optimized'
model_dir = f'models/{actual_model_path}'
model_path = os.path.join(model_dir, f'{product_id}_model.pt')
log_path = os.path.join(model_dir, f'{product_id}_log.json')
# 检查模型文件是否存在
if not os.path.exists(model_path):
return jsonify({
"status": "error",
"error": f"模型 '{model_type}/{product_id}' 不存在"
}), 404
# 读取模型日志
model_log = {}
if os.path.exists(log_path):
with open(log_path, 'r', encoding='utf-8') as f:
model_log = json.load(f)
# 加载模型检查点,获取训练损失数据
checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
# 获取训练和测试损失
train_losses = checkpoint.get('train_loss', [])
test_losses = checkpoint.get('test_loss', [])
# 获取产品名称
product_name = ""
try:
df = pd.read_excel('pharmacy_sales.xlsx')
product_df = df[df['product_id'] == product_id]
if not product_df.empty:
product_name = product_df['product_name'].iloc[0]
except Exception as e:
print(f"获取产品名称时出错: {e}")
# 检查是否有损失曲线图和预测效果图
loss_curve_path = f'{product_id}_{model_type}_loss_curve.png'
prediction_path = f'{product_id}_{model_type}_prediction.png'
# 获取当前服务器主机和端口
host_url = request.host_url.rstrip('/')
# 准备图表数据
chart_data = {
"loss_chart": {
"epochs": list(range(1, len(train_losses) + 1)),
"train_loss": train_losses,
"test_loss": test_losses if test_losses else []
}
}
# 准备模型信息
model_info = {
"model_id": f"{model_type}_{product_id}",
"model_type": model_type,
"product_id": product_id,
"product_name": product_name,
"created_at": model_log.get("training_completed_at", ""),
"file_path": model_path
}
# 准备训练指标
training_metrics = model_log.get("metrics", {})
return jsonify({
"status": "success",
"data": {
"model_info": model_info,
"training_metrics": training_metrics,
"chart_data": chart_data
}
})
except Exception as e:
import traceback
traceback.print_exc()
return jsonify({
"status": "error",
"error": str(e)
}), 500
# 创建SQLite数据库连接函数
def get_db_connection():
"""获取SQLite数据库连接"""
conn = sqlite3.connect('prediction_history.db')
conn.row_factory = sqlite3.Row
return conn
# 初始化数据库
def init_db():
"""初始化SQLite数据库创建必要的表"""
conn = get_db_connection()
cursor = conn.cursor()
# 创建历史预测记录表
cursor.execute('''
CREATE TABLE IF NOT EXISTS prediction_history (
id TEXT PRIMARY KEY,
product_id TEXT NOT NULL,
product_name TEXT NOT NULL,
model_type TEXT NOT NULL,
model_id TEXT NOT NULL,
start_date TEXT NOT NULL,
future_days INTEGER NOT NULL,
created_at TEXT NOT NULL,
file_path TEXT NOT NULL
)
''')
conn.commit()
conn.close()
# 在应用启动时初始化数据库
init_db()
# 添加保存预测结果的函数
def save_prediction_result(prediction_data, product_id, product_name, model_type, model_id, start_date, future_days):
"""保存预测结果到文件和数据库"""
# 生成唯一ID
prediction_id = str(uuid.uuid4())
# 创建历史预测目录
history_dir = os.path.join('predictions', 'history')
os.makedirs(history_dir, exist_ok=True)
# 创建模型类型子目录
model_dir = os.path.join(history_dir, model_type)
os.makedirs(model_dir, exist_ok=True)
# 创建文件名
timestamp = datetime.now().strftime('%Y%m%d%H%M%S')
filename = f"{product_id}_{start_date}_{future_days}days_{timestamp}.json"
file_path = os.path.join(model_dir, filename)
# 创建一个JSON序列化器来处理Pandas Timestamp对象
class DateTimeEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, (pd.Timestamp, pd.DatetimeIndex)):
return obj.strftime('%Y-%m-%d')
elif isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
elif pd.isna(obj): # 处理NaN值
return None
return super(DateTimeEncoder, self).default(obj)
# 保存预测结果为JSON文件
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(prediction_data, f, ensure_ascii=False, indent=2, cls=DateTimeEncoder)
# 将记录保存到数据库
conn = get_db_connection()
cursor = conn.cursor()
cursor.execute(
'''
INSERT INTO prediction_history
(id, product_id, product_name, model_type, model_id, start_date, future_days, created_at, file_path)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
''',
(
prediction_id,
product_id,
product_name,
model_type,
model_id,
start_date,
future_days,
datetime.now().isoformat(),
file_path
)
)
conn.commit()
conn.close()
return prediction_id, file_path
# 添加获取历史预测列表的API
@app.route('/api/prediction/history', methods=['GET'])
def get_prediction_history():
"""
获取历史预测列表
---
tags:
- 预测分析
parameters:
- name: product_id
in: query
type: string
required: false
description: 按产品ID筛选
- name: model_type
in: query
type: string
required: false
description: 按模型类型筛选
- name: page
in: query
type: integer
required: false
default: 1
description: 页码
- name: page_size
in: query
type: integer
required: false
default: 10
description: 每页记录数
responses:
200:
description: 历史预测列表
schema:
type: object
properties:
status:
type: string
example: success
data:
type: array
items:
type: object
properties:
id:
type: string
example: 550e8400-e29b-41d4-a716-446655440000
product_id:
type: string
example: P001
product_name:
type: string
example: 阿司匹林
model_type:
type: string
example: mlstm
model_id:
type: string
example: mlstm_P001_20230101
start_date:
type: string
example: 2023-01-01
future_days:
type: integer
example: 7
created_at:
type: string
example: 2023-01-01T12:00:00
total:
type: integer
example: 100
page:
type: integer
example: 1
page_size:
type: integer
example: 10
"""
try:
# 获取查询参数
product_id = request.args.get('product_id')
model_type = request.args.get('model_type')
page = int(request.args.get('page', 1))
page_size = int(request.args.get('page_size', 10))
# 构建SQL查询
query = "SELECT * FROM prediction_history"
params = []
conditions = []
if product_id:
conditions.append("product_id = ?")
params.append(product_id)
if model_type:
conditions.append("model_type = ?")
params.append(model_type)
if conditions:
query += " WHERE " + " AND ".join(conditions)
# 添加排序和分页
query += " ORDER BY created_at DESC"
# 获取总记录数
conn = get_db_connection()
count_result = conn.execute(query, params).fetchall()
total_count = len(count_result)
# 添加分页限制
query += f" LIMIT {page_size} OFFSET {(page - 1) * page_size}"
# 执行查询
result = conn.execute(query, params).fetchall()
# 转换为JSON格式
history_list = []
for row in result:
history_list.append({
'id': row['id'],
'product_id': row['product_id'],
'product_name': row['product_name'],
'model_type': row['model_type'],
'model_id': row['model_id'],
'start_date': row['start_date'],
'future_days': row['future_days'],
'created_at': row['created_at'],
'file_path': row['file_path']
})
conn.close()
return jsonify({
'status': 'success',
'data': history_list,
'total': total_count,
'page': page,
'page_size': page_size
})
except Exception as e:
print(f"获取历史预测列表失败: {str(e)}")
import traceback
traceback.print_exc()
return jsonify({"status": "error", "error": str(e)}), 500
# 添加获取历史预测详情的API
@app.route('/api/prediction/history/<prediction_id>', methods=['GET'])
def get_prediction_history_detail(prediction_id):
"""
获取历史预测详情
---
tags:
- 预测分析
parameters:
- name: prediction_id
in: path
type: string
required: true
description: 预测ID
responses:
200:
description: 预测详情
schema:
type: object
properties:
status:
type: string
example: success
data:
type: object
"""
try:
# 查询数据库获取文件路径
conn = get_db_connection()
result = conn.execute("SELECT * FROM prediction_history WHERE id = ?", (prediction_id,)).fetchone()
conn.close()
if not result:
return jsonify({"status": "error", "error": "未找到指定的预测记录"}), 404
file_path = result['file_path']
# 读取预测结果文件
try:
with open(file_path, 'r', encoding='utf-8') as f:
prediction_data = json.load(f)
return jsonify({
'status': 'success',
'data': prediction_data,
'meta': {
'id': result['id'],
'product_id': result['product_id'],
'product_name': result['product_name'],
'model_type': result['model_type'],
'model_id': result['model_id'],
'start_date': result['start_date'],
'future_days': result['future_days'],
'created_at': result['created_at']
}
})
except FileNotFoundError:
return jsonify({"status": "error", "error": "预测结果文件不存在"}), 404
except Exception as e:
return jsonify({"status": "error", "error": f"读取预测结果文件失败: {str(e)}"}), 500
except Exception as e:
print(f"获取历史预测详情失败: {str(e)}")
import traceback
traceback.print_exc()
return jsonify({"status": "error", "error": str(e)}), 500
# 添加删除历史预测记录的API
@app.route('/api/prediction/history/<prediction_id>', methods=['DELETE'])
def delete_prediction_history(prediction_id):
"""
删除历史预测记录
---
tags:
- 预测分析
parameters:
- name: prediction_id
in: path
type: string
required: true
description: 预测ID
responses:
200:
description: 删除结果
schema:
type: object
properties:
status:
type: string
example: success
"""
try:
# 查询数据库获取文件路径
conn = get_db_connection()
result = conn.execute("SELECT file_path FROM prediction_history WHERE id = ?", (prediction_id,)).fetchone()
if not result:
conn.close()
return jsonify({"status": "error", "error": "未找到指定的预测记录"}), 404
file_path = result['file_path']
# 删除数据库记录
conn.execute("DELETE FROM prediction_history WHERE id = ?", (prediction_id,))
conn.commit()
conn.close()
# 删除预测结果文件
try:
if os.path.exists(file_path):
os.remove(file_path)
except Exception as e:
print(f"删除预测结果文件失败: {str(e)}")
return jsonify({'status': 'success'})
except Exception as e:
print(f"删除历史预测记录失败: {str(e)}")
import traceback
traceback.print_exc()
return jsonify({"status": "error", "error": str(e)}), 500
# 获取产品名称的辅助函数
def get_product_name(product_id):
"""根据产品ID获取产品名称"""
try:
# 从产品列表中查找产品名称
products_file = 'data/products.json'
if os.path.exists(products_file):
with open(products_file, 'r', encoding='utf-8') as f:
products = json.load(f)
for product in products:
if product['product_id'] == product_id:
return product['product_name']
return None
except Exception as e:
print(f"获取产品名称失败: {str(e)}")
return None
# 获取最新模型ID的辅助函数
def get_latest_model_id(model_type, product_id):
"""根据模型类型和产品ID获取最新的模型ID"""
try:
# 处理优化版KAN模型的路径
actual_model_path = model_type
if model_type == 'optimized_kan':
actual_model_path = 'kan_optimized'
print(f"优化版KAN模型: 当查找最新模型ID时使用路径 'models/{actual_model_path}/{product_id}_model.pt'")
# 查找模型目录中的模型文件
model_dir = os.path.join('models', actual_model_path)
if not os.path.exists(model_dir):
print(f"模型目录不存在: {model_dir}")
return None
# 查找匹配的模型文件
model_files = [f for f in os.listdir(model_dir) if f.startswith(f"{product_id}_") and f.endswith('.pt')]
if not model_files:
print(f"在目录 {model_dir} 中未找到产品 {product_id} 的模型文件")
return None
# 按照文件修改时间排序,获取最新的模型文件
model_files.sort(key=lambda x: os.path.getmtime(os.path.join(model_dir, x)), reverse=True)
latest_model_file = model_files[0]
# 从文件名中提取模型ID
model_id = latest_model_file.replace('.pt', '')
print(f"找到最新模型: {model_id} 在目录 {model_dir}")
return model_id
except Exception as e:
print(f"获取最新模型ID失败: {str(e)}")
return None
# 执行预测的辅助函数
def run_prediction(model_type, product_id, model_id, future_days, start_date):
"""执行模型预测"""
try:
# 处理优化版KAN模型的路径在 load_model_and_predict 中已经实现,无需在此处理
# 直接使用原始的 model_type 调用函数
# 解包返回的元组为result和analysis
result_tuple = load_model_and_predict(product_id, model_type, future_days, start_date=start_date, analyze_result=True)
# 处理返回值可能是None的情况
if result_tuple is None:
raise Exception(f"模型 '{model_type}' 类型的模型文件未找到或加载失败")
# 解包元组 - result_tuple可能是(result, analysis)格式或者只是result
if isinstance(result_tuple, tuple):
result = result_tuple[0]
analysis = result_tuple[1] if len(result_tuple) > 1 else None
else:
result = result_tuple
analysis = None
if result is None:
raise Exception(f"模型预测失败")
# 获取预测数据
predictions_df = result['predictions_df']
# 确保日期是日期时间格式以便处理
if 'date' in predictions_df.columns and not pd.api.types.is_datetime64_any_dtype(predictions_df['date']):
predictions_df['date'] = pd.to_datetime(predictions_df['date'])
# 分离历史数据和预测数据
history_df = predictions_df[predictions_df['data_type'] == '历史销量']
prediction_df = predictions_df[predictions_df['data_type'] == '预测销量']
# 如果历史数据和预测数据的日期有重叠,调整预测数据的日期
if not history_df.empty and not prediction_df.empty:
last_history_date = history_df['date'].max()
# 检查预测数据是否与历史数据重叠
if any(prediction_df['date'] <= last_history_date):
print(f"检测到预测数据与历史数据日期重叠,调整预测数据日期...")
# 获取预测数据的起始日期,确保它在历史数据的最后一天之后
prediction_start_date = last_history_date + pd.Timedelta(days=1)
# 创建新的日期序列
new_dates = pd.date_range(start=prediction_start_date, periods=len(prediction_df), freq='D')
# 更新预测数据的日期
prediction_df['date'] = new_dates
# 更新原始DataFrame
predictions_df.loc[predictions_df['data_type'] == '预测销量', 'date'] = prediction_df['date'].values
print(f"预测数据日期已调整为从 {prediction_start_date} 开始")
# 将处理后的DataFrame放回结果
result['predictions_df'] = predictions_df
chart_path = result.get('chart_path')
csv_path = result.get('csv_path')
# 将DataFrame转换为JSON
predictions_json = predictions_df.to_dict(orient='records')
# 准备响应数据
response_data = {
"status": "success",
"data": predictions_json, # 包含历史和预测数据
"history_data": predictions_df[predictions_df['data_type'] == '历史销量'].to_dict(orient='records'),
"prediction_data": predictions_df[predictions_df['data_type'] == '预测销量'].to_dict(orient='records'),
}
return response_data
except Exception as e:
print(f"执行预测失败: {str(e)}")
traceback.print_exc()
raise e
# 准备图表数据的辅助函数
def prepare_chart_data(prediction_result):
"""准备图表数据"""
try:
# 从预测结果中提取数据
predictions_df = pd.DataFrame(prediction_result['data'])
# 确保日期列是日期时间格式,以便正确排序
if 'date' in predictions_df.columns:
# 先转换为日期时间类型以便排序
if isinstance(predictions_df['date'][0], str):
predictions_df['date'] = pd.to_datetime(predictions_df['date'])
# 按日期排序
predictions_df = predictions_df.sort_values('date')
# 重置索引
predictions_df = predictions_df.reset_index(drop=True)
# 最后再转换回字符串格式用于JSON
predictions_df['date'] = predictions_df['date'].dt.strftime('%Y-%m-%d')
# 分离历史数据和预测数据
history_df = predictions_df[predictions_df['data_type'] == '历史销量']
prediction_df = predictions_df[predictions_df['data_type'] == '预测销量']
# 准备图表数据
chart_data = {
"dates": predictions_df['date'].tolist(),
"sales": predictions_df['sales'].tolist(),
"types": predictions_df['data_type'].tolist()
}
# 为前端debug提供额外信息
chart_data["debug"] = {
"history_dates": history_df['date'].tolist() if not history_df.empty else [],
"history_sales": history_df['sales'].tolist() if not history_df.empty else [],
"prediction_dates": prediction_df['date'].tolist() if not prediction_df.empty else [],
"prediction_sales": prediction_df['sales'].tolist() if not prediction_df.empty else [],
}
print(f"历史数据日期范围: {history_df['date'].min() if not history_df.empty else 'N/A'}{history_df['date'].max() if not history_df.empty else 'N/A'}")
print(f"预测数据日期范围: {prediction_df['date'].min() if not prediction_df.empty else 'N/A'}{prediction_df['date'].max() if not prediction_df.empty else 'N/A'}")
return chart_data
except Exception as e:
print(f"准备图表数据失败: {str(e)}")
traceback.print_exc()
return {}
# 分析预测结果的辅助函数
def analyze_prediction(prediction_result):
"""分析预测结果"""
try:
# 从预测结果中提取数据
prediction_data = prediction_result.get('prediction_data', [])
history_data = prediction_result.get('history_data', [])
if not prediction_data:
return None
# 转换为DataFrame以便分析
prediction_df = pd.DataFrame(prediction_data)
# 确保日期列是日期时间格式
if 'date' in prediction_df.columns:
if isinstance(prediction_df['date'][0], str):
prediction_df['date'] = pd.to_datetime(prediction_df['date'])
# 计算统计数据
sales = prediction_df['sales'].values
mean_sales = np.mean(sales)
max_sales = np.max(sales)
min_sales = np.min(sales)
std_sales = np.std(sales)
# 计算日环比变化
day_over_day_changes = []
for i in range(1, len(sales)):
if sales[i-1] == 0:
day_over_day_changes.append(0)
else:
change_pct = ((sales[i] - sales[i-1]) / sales[i-1]) * 100
day_over_day_changes.append(change_pct)
# 确定趋势
if len(sales) < 2:
trend = "unknown"
else:
# 计算简单线性回归的斜率
x = np.arange(len(sales))
slope = np.polyfit(x, sales, 1)[0]
# 计算变化的标准差
changes = np.diff(sales)
changes_std = np.std(changes)
if abs(slope) < 0.1 * mean_sales:
if changes_std > 0.2 * mean_sales:
trend = "fluctuating"
else:
trend = "stable"
elif slope > 0:
trend = "increasing"
else:
trend = "decreasing"
# 历史数据对比
has_historical_data = len(history_data) > 0
historical_comparison = {
"has_historical_data": has_historical_data,
"mean_difference_pct": 0
}
if has_historical_data:
history_df = pd.DataFrame(history_data)
history_mean = history_df['sales'].mean()
prediction_mean = mean_sales
if history_mean > 0:
mean_difference_pct = ((prediction_mean - history_mean) / history_mean) * 100
historical_comparison["mean_difference_pct"] = mean_difference_pct
# 模拟影响因素
factors = [
{"name": "季节性", "importance": "high", "description": "季节变化对销量有显著影响"},
{"name": "促销活动", "importance": "medium", "description": "促销活动可能会短期提升销量"},
{"name": "市场竞争", "importance": "low", "description": "市场竞争对销量有轻微影响"}
]
# 生成解释文本
if trend == "increasing":
explanation = f"预测期内销量整体呈上升趋势,平均日销量为{mean_sales:.2f},相比历史数据"
if has_historical_data and historical_comparison["mean_difference_pct"] > 0:
explanation += f"增长了{historical_comparison['mean_difference_pct']:.2f}%。"
elif has_historical_data:
explanation += f"下降了{abs(historical_comparison['mean_difference_pct']):.2f}%。"
else:
explanation += "无法比较。"
explanation += "建议适当增加库存以应对销量增长。"
elif trend == "decreasing":
explanation = f"预测期内销量整体呈下降趋势,平均日销量为{mean_sales:.2f},相比历史数据"
if has_historical_data and historical_comparison["mean_difference_pct"] > 0:
explanation += f"增长了{historical_comparison['mean_difference_pct']:.2f}%。"
elif has_historical_data:
explanation += f"下降了{abs(historical_comparison['mean_difference_pct']):.2f}%。"
else:
explanation += "无法比较。"
explanation += "建议控制库存以避免积压。"
elif trend == "fluctuating":
explanation = f"预测期内销量波动较大,平均日销量为{mean_sales:.2f},标准差为{std_sales:.2f}。建议密切关注销售情况,灵活调整库存。"
else:
explanation = f"预测期内销量保持稳定,平均日销量为{mean_sales:.2f},最高销量为{max_sales:.2f},最低销量为{min_sales:.2f}"
# 构建分析结果
analysis = {
"explanation": explanation,
"trend": trend,
"statistics": {
"mean": mean_sales,
"max": max_sales,
"min": min_sales,
"std": std_sales,
"day_over_day_changes": day_over_day_changes
},
"historical_comparison": historical_comparison,
"factors": factors
}
return analysis
except Exception as e:
print(f"分析预测结果失败: {str(e)}")
return None
# 新增模型性能分析接口
@app.route('/api/models/analyze-metrics', methods=['POST'])
@swag_from({
'tags': ['模型管理'],
'summary': '分析模型评估指标',
'description': '接收一组模型指标,并返回详细的文字解读和评级。',
'parameters': [
{
'name': 'body',
'in': 'body',
'required': True,
'schema': {
'type': 'object',
'description': '模型的评估指标',
'example': {
'R2': 0.7067,
'RMSE': 6.8670,
'MAE': 4.4062,
'MAPE': 18.14
}
}
}
],
'responses': {
200: {
'description': '分析成功',
'schema': {
'type': 'object',
'properties': {
'status': {'type': 'string'},
'data': {'type': 'object'}
}
}
},
400: {
'description': '请求错误,缺少指标数据'
}
}
})
def analyze_model_metrics():
"""
分析模型的评估指标并提供解读
"""
try:
metrics = request.json
if not metrics:
return jsonify({"status": "error", "error": "缺少指标数据"}), 400
analysis = {}
# 1. 分析 R²
r2 = metrics.get('R2')
if r2 is not None:
if r2 >= 0.9:
r2_rating = "优秀"
r2_desc = f"R²值为{r2:.4f}表现非常出色。模型能够解释超过90%的销售数据波动,意味着它与历史数据的拟合度极高,预测结果非常可靠。"
elif r2 >= 0.7:
r2_rating = "良好"
r2_desc = f"R²值为{r2:.4f},表现良好。模型能解释{(r2*100):.1f}%的销售数据变化,说明模型捕捉到了大部分关键的销售模式,具备很高的实用价值。"
elif r2 >= 0.5:
r2_rating = "中等"
r2_desc = f"R²值为{r2:.4f},表现中等。模型解释了约一半的销售数据变化,说明它掌握了一定的规律,但可能忽略了一些次要因素。预测结果可作为参考,但需结合其他信息判断。"
else:
r2_rating = "较弱"
r2_desc = f"R²值为{r2:.4f},表现较弱。模型对销售数据变化的解释能力有限,预测的准确性可能不高。建议尝试优化模型或增加更多有效特征。"
analysis['R2'] = {"value": r2, "rating": r2_rating, "description": r2_desc}
# 2. 分析 MAPE
mape = metrics.get('MAPE')
if mape is not None:
if mape <= 10:
mape_rating = "优秀"
mape_desc = f"平均绝对百分比误差为{mape:.2f}%,误差率极低,预测精度非常高。"
elif mape <= 20:
mape_rating = "良好"
mape_desc = f"平均绝对百分比误差为{mape:.2f}%,误差率在可接受的范围内,表明模型的预测结果在大多数情况下与真实值偏差不大。"
elif mape <= 30:
mape_rating = "中等"
mape_desc = f"平均绝对百分比误差为{mape:.2f}%,误差率中等。在销量波动较大的场景下可以接受,但对于追求高精度预测的场景,仍有优化空间。"
else:
mape_rating = "较弱"
mape_desc = f"平均绝对百分比误差为{mape:.2f}%,误差率偏高。模型的预测值与真实值偏差较大,建议谨慎使用其预测结果。"
analysis['MAPE'] = {"value": mape, "rating": mape_rating, "description": mape_desc}
# 3. 分析 RMSE 和 MAE
rmse = metrics.get('RMSE')
mae = metrics.get('MAE')
if rmse is not None and mae is not None:
rmse_desc = f"均方根误差为{rmse:.4f}。这个值衡量了预测误差的典型大小。因为它对较大的误差值更敏感,所以可以反映模型是否存在'离谱'的预测。"
mae_desc = f"平均绝对误差为{mae:.4f}。这个值直观地表示了模型平均预测会偏离真实销售量多少个单位(如'')。"
# 比较RMSE和MAE
if rmse > mae * 1.5: # 经验阈值
comparison_desc = f"RMSE ({rmse:.4f}) 明显大于 MAE ({mae:.4f}),这通常意味着模型在某些数据点上存在较大的预测误差(离群点)。虽然总体平均误差不大,但偶尔可能会有'猜错得离谱'的情况。"
else:
comparison_desc = f"RMSE ({rmse:.4f}) 与 MAE ({mae:.4f}) 的值较为接近,表明模型的误差分布比较均匀,没有出现极端异常的预测错误。"
analysis['RMSE'] = {"value": rmse, "rating": "参考指标", "description": rmse_desc}
analysis['MAE'] = {"value": mae, "rating": "参考指标", "description": mae_desc}
analysis['RMSE_MAE_COMP'] = {"description": comparison_desc}
# 4. 形成总体结论
overall_ratings = [a.get('rating') for a in analysis.values() if a.get('rating') in ["优秀", "良好", "中等", "较弱"]]
if "较弱" in overall_ratings:
overall_summary = "该模型的综合性能表现较弱,预测结果可能存在较大偏差,建议进行优化或谨慎使用。"
elif "中等" in overall_ratings:
overall_summary = "该模型的综合性能表现中等,具备一定的预测能力,但仍有提升空间。其预测结果可作为重要参考。"
elif "优秀" in overall_ratings and overall_ratings.count("良好") == 0:
overall_summary = "该模型的综合性能表现非常优秀,各项指标均显示其预测精度高、稳定性好,预测结果高度可信。"
else: # 主要为良好
overall_summary = "该模型的综合性能表现良好,能够可靠地预测销售趋势,误差在可接受范围内,是决策的有力支持。"
analysis['overall_summary'] = overall_summary
return jsonify({"status": "success", "data": analysis})
except Exception as e:
import traceback
traceback.print_exc()
return jsonify({"status": "error", "error": str(e)}), 500
# 添加一个主函数入口点用于直接运行API服务器
if __name__ == '__main__':
# 初始化数据库
init_db()
# 使用waitress作为生产环境的WSGI服务器比Flask默认的开发服务器更健壮
# from waitress import serve
# serve(app, host="0.0.0.0", port=5000)
# 或者为了方便调试仍然使用Flask内置的开发服务器
# 注意debug=True 模式在生产环境中非常不安全,请仅在开发阶段使用
app.run(host="0.0.0.0", port=5000, debug=True)