ShopTRAINING/api.py

1469 lines
50 KiB
Python
Raw Normal View History

import os
import sys
# 获取当前脚本所在目录的绝对路径
current_dir = os.path.dirname(os.path.abspath(__file__))
# 将当前目录添加到系统路径
sys.path.append(current_dir)
# 或者添加父目录
#parent_dir = os.path.dirname(current_dir)
#sys.path.append(parent_dir)
from flask import Flask, request, jsonify, send_file, redirect, send_from_directory
from flask_cors import CORS
import os
import uuid
import pandas as pd
from pharmacy_predictor import (
train_product_model_with_mlstm,
train_product_model_with_kan,
train_product_model_with_transformer,
load_model_and_predict
)
import threading
import base64
import matplotlib.pyplot as plt
from io import BytesIO
import json
from flasgger import Swagger, swag_from
import argparse
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor
from threading import Lock
import traceback
import torch
app = Flask(__name__)
CORS(app) # 启用CORS支持
# Swagger配置
swagger_config = {
"headers": [],
"specs": [
{
"endpoint": "apispec",
"route": "/apispec.json",
"rule_filter": lambda rule: True, # 包含所有路由
"model_filter": lambda tag: True, # 包含所有模型
}
],
"static_url_path": "/flasgger_static",
"swagger_ui": True,
"specs_route": "/swagger/"
}
swagger_template = {
"swagger": "2.0",
"info": {
"title": "药店销售预测系统API",
"description": "用于药店销售预测的RESTful API",
"version": "1.0.0",
"contact": {
"name": "API开发团队",
"email": "support@example.com"
}
},
"tags": [
{
"name": "数据管理",
"description": "数据上传和查询相关接口"
},
{
"name": "模型训练",
"description": "模型训练相关接口"
},
{
"name": "模型预测",
"description": "预测销售数据相关接口"
},
{
"name": "模型管理",
"description": "模型查询、导出和删除接口"
}
]
}
swagger = Swagger(app, config=swagger_config, template=swagger_template)
# 存储训练任务状态
training_tasks = {}
tasks_lock = Lock()
# 线程池用于后台训练
executor = ThreadPoolExecutor(max_workers=2)
# 辅助函数将图像转换为Base64
def fig_to_base64(fig):
buf = BytesIO()
fig.savefig(buf, format='png')
buf.seek(0)
img_str = base64.b64encode(buf.read()).decode('utf-8')
return img_str
# 根路由 - 重定向到UI界面
@app.route('/')
def index():
"""重定向到UI界面"""
return redirect('/ui/')
# Swagger UI路由
@app.route('/swagger')
def swagger_ui():
"""重定向到Swagger UI文档页面"""
return redirect('/swagger/')
# 1. 数据管理API
@app.route('/api/products', methods=['GET'])
@swag_from({
'tags': ['数据管理'],
'summary': '获取所有产品列表',
'description': '返回系统中所有产品的ID和名称',
'responses': {
200: {
'description': '成功获取产品列表',
'schema': {
'type': 'object',
'properties': {
'status': {'type': 'string'},
'data': {
'type': 'array',
'items': {
'type': 'object',
'properties': {
'product_id': {'type': 'string'},
'product_name': {'type': 'string'}
}
}
}
}
}
},
500: {
'description': '服务器内部错误'
}
}
})
def get_products():
try:
df = pd.read_excel('pharmacy_sales.xlsx')
products = df[['product_id', 'product_name']].drop_duplicates().to_dict('records')
return jsonify({"status": "success", "data": products})
except Exception as e:
return jsonify({"status": "error", "message": str(e)}), 500
@app.route('/api/products/<product_id>', methods=['GET'])
@swag_from({
'tags': ['数据管理'],
'summary': '获取单个产品详情',
'description': '返回指定产品ID的详细信息',
'parameters': [
{
'name': 'product_id',
'in': 'path',
'type': 'string',
'required': True,
'description': '产品ID例如P001'
}
],
'responses': {
200: {
'description': '成功获取产品详情',
'schema': {
'type': 'object',
'properties': {
'status': {'type': 'string'},
'data': {
'type': 'object',
'properties': {
'product_id': {'type': 'string'},
'product_name': {'type': 'string'},
'data_points': {'type': 'integer'},
'date_range': {
'type': 'object',
'properties': {
'start': {'type': 'string'},
'end': {'type': 'string'}
}
}
}
}
}
}
},
404: {
'description': '产品不存在'
},
500: {
'description': '服务器内部错误'
}
}
})
def get_product(product_id):
try:
df = pd.read_excel('pharmacy_sales.xlsx')
product_df = df[df['product_id'] == product_id]
if product_df.empty:
return jsonify({"status": "error", "message": "产品不存在"}), 404
product_info = {
"product_id": product_id,
"product_name": product_df['product_name'].iloc[0],
"data_points": len(product_df),
"date_range": {
"start": product_df['date'].min().strftime('%Y-%m-%d'),
"end": product_df['date'].max().strftime('%Y-%m-%d')
}
}
return jsonify({"status": "success", "data": product_info})
except Exception as e:
return jsonify({"status": "error", "message": str(e)}), 500
@app.route('/api/products/<product_id>/sales', methods=['GET'])
@swag_from({
'tags': ['数据管理'],
'summary': '获取产品销售数据',
'description': '返回指定产品在特定日期范围内的销售数据',
'parameters': [
{
'name': 'product_id',
'in': 'path',
'type': 'string',
'required': True,
'description': '产品ID例如P001'
},
{
'name': 'start_date',
'in': 'query',
'type': 'string',
'required': False,
'description': '开始日期格式为YYYY-MM-DD'
},
{
'name': 'end_date',
'in': 'query',
'type': 'string',
'required': False,
'description': '结束日期格式为YYYY-MM-DD'
}
],
'responses': {
200: {
'description': '成功获取销售数据',
'schema': {
'type': 'object',
'properties': {
'status': {'type': 'string'},
'data': {
'type': 'array',
'items': {
'type': 'object'
}
}
}
}
},
404: {
'description': '产品不存在'
},
500: {
'description': '服务器内部错误'
}
}
})
def get_product_sales(product_id):
try:
start_date = request.args.get('start_date')
end_date = request.args.get('end_date')
df = pd.read_excel('pharmacy_sales.xlsx')
product_df = df[df['product_id'] == product_id].sort_values('date')
if product_df.empty:
return jsonify({"status": "error", "message": "产品不存在"}), 404
# 如果提供了日期范围,进行过滤
if start_date:
product_df = product_df[product_df['date'] >= pd.to_datetime(start_date)]
if end_date:
product_df = product_df[product_df['date'] <= pd.to_datetime(end_date)]
# 转换日期为字符串以便JSON序列化
product_df['date'] = product_df['date'].dt.strftime('%Y-%m-%d')
sales_data = product_df.to_dict('records')
return jsonify({"status": "success", "data": sales_data})
except Exception as e:
return jsonify({"status": "error", "message": str(e)}), 500
@app.route('/api/data/upload', methods=['POST'])
@swag_from({
'tags': ['数据管理'],
'summary': '上传销售数据',
'description': '上传新的销售数据文件(Excel格式)',
'consumes': ['multipart/form-data'],
'parameters': [
{
'name': 'file',
'in': 'formData',
'type': 'file',
'required': True,
'description': 'Excel文件(.xlsx),包含销售数据'
}
],
'responses': {
200: {
'description': '数据上传成功',
'schema': {
'type': 'object',
'properties': {
'status': {'type': 'string'},
'message': {'type': 'string'},
'data': {
'type': 'object',
'properties': {
'products': {'type': 'integer'},
'rows': {'type': 'integer'}
}
}
}
}
},
400: {
'description': '请求错误,可能是文件格式不正确或缺少必要字段'
},
500: {
'description': '服务器内部错误'
}
}
})
def upload_data():
try:
if 'file' not in request.files:
return jsonify({"status": "error", "message": "没有上传文件"}), 400
file = request.files['file']
if file.filename == '':
return jsonify({"status": "error", "message": "没有选择文件"}), 400
if file and file.filename.endswith('.xlsx'):
file_path = 'uploaded_data.xlsx'
file.save(file_path)
# 验证数据格式
try:
df = pd.read_excel(file_path)
required_columns = ['date', 'product_id', 'product_name', 'sales', 'price']
missing_columns = [col for col in required_columns if col not in df.columns]
if missing_columns:
return jsonify({
"status": "error",
"message": f"上传的数据缺少必要的列: {', '.join(missing_columns)}"
}), 400
# 合并到现有数据或替换现有数据
existing_df = pd.read_excel('pharmacy_sales.xlsx')
# 这里可以实现数据合并逻辑例如按日期和产品ID去重后合并
# 简单示例:保存上传的数据
df.to_excel('pharmacy_sales.xlsx', index=False)
return jsonify({
"status": "success",
"message": "数据上传成功",
"data": {
"products": len(df['product_id'].unique()),
"rows": len(df)
}
})
except Exception as e:
return jsonify({"status": "error", "message": f"数据验证失败: {str(e)}"}), 400
else:
return jsonify({"status": "error", "message": "只支持Excel文件(.xlsx)"}), 400
except Exception as e:
return jsonify({"status": "error", "message": str(e)}), 500
# 2. 模型训练API
@app.route('/api/training', methods=['GET'])
@swag_from({
'tags': ['模型训练'],
'summary': '获取所有训练任务列表',
'description': '返回所有正在进行、已完成或失败的训练任务',
'responses': {
200: {
'description': '成功获取任务列表',
'schema': {
'type': 'object',
'properties': {
'status': {'type': 'string'},
'data': {
'type': 'array',
'items': {
'type': 'object',
'properties': {
'task_id': {'type': 'string'},
'product_id': {'type': 'string'},
'model_type': {'type': 'string'},
'status': {'type': 'string'},
'start_time': {'type': 'string'},
'metrics': {'type': 'object'},
'error': {'type': 'string'},
'model_path': {'type': 'string'}
}
}
}
}
}
}
}
})
def get_all_training_tasks():
"""获取所有训练任务的状态"""
try:
# 为了方便前端使用我们将任务ID也包含在每个任务信息中
tasks_with_id = []
for task_id, task_info in training_tasks.items():
task_copy = task_info.copy()
task_copy['task_id'] = task_id
tasks_with_id.append(task_copy)
# 按开始时间降序排序,最新的任务在前面
sorted_tasks = sorted(tasks_with_id, key=lambda x: x['start_time'], reverse=True)
return jsonify({"status": "success", "data": sorted_tasks})
except Exception as e:
return jsonify({"status": "error", "message": str(e)}), 500
@app.route('/api/training', methods=['POST'])
@swag_from({
'tags': ['模型训练'],
'summary': '启动模型训练任务',
'description': '为指定产品启动一个新的模型训练任务',
'parameters': [
{
'name': 'body',
'in': 'body',
'required': True,
'schema': {
'type': 'object',
'properties': {
'product_id': {'type': 'string', 'description': '例如 P001'},
'model_type': {'type': 'string', 'enum': ['mlstm', 'transformer', 'kan'], 'description': '要训练的模型类型'},
'epochs': {'type': 'integer', 'default': 50, 'description': '训练轮次'}
},
'required': ['product_id', 'model_type']
}
}
],
'responses': {
200: {
'description': '训练任务已启动',
'schema': {
'type': 'object',
'properties': {
'message': {'type': 'string'},
'task_id': {'type': 'string'}
}
}
},
400: {
'description': '请求错误'
}
}
})
def start_training():
"""
启动模型训练
---
post:
...
"""
data = request.get_json()
product_id = data.get('product_id')
model_type = data.get('model_type')
epochs = data.get('epochs', 50) # 默认为50轮
if not product_id or not model_type:
return jsonify({'error': '缺少product_id或model_type'}), 400
global training_tasks
task_id = str(uuid.uuid4())
model_train_functions = {
'mlstm': train_product_model_with_mlstm,
'kan': train_product_model_with_kan,
'transformer': train_product_model_with_transformer
}
if model_type not in model_train_functions:
return jsonify({'error': '无效的模型类型'}), 400
train_function = model_train_functions[model_type]
def train_task(product_id, epochs, model_type):
global training_tasks
try:
print(f"任务 {task_id}: 开始为产品 {product_id} 训练 {model_type} 模型,共 {epochs} 个轮次。")
# 这里的 train_function 会返回 (model, metrics)
_, metrics = train_function(product_id, epochs)
training_tasks[task_id]['status'] = 'completed'
training_tasks[task_id]['metrics'] = metrics
# 保存模型路径
model_path = f'models/{model_type}/{product_id}_model.pt'
training_tasks[task_id]['model_path'] = model_path
print(f"任务 {task_id}: 训练完成。评估指标: {metrics}")
except Exception as e:
import traceback
traceback.print_exc()
print(f"任务 {task_id}: 训练失败。错误: {e}")
training_tasks[task_id]['status'] = 'failed'
training_tasks[task_id]['error'] = str(e)
thread = threading.Thread(target=train_task, args=(product_id, epochs, model_type))
thread.start()
training_tasks[task_id] = {
'status': 'running',
'product_id': product_id,
'model_type': model_type,
'start_time': datetime.now().isoformat(),
'metrics': None,
'error': None,
'model_path': None
}
return jsonify({'message': '模型训练已开始', 'task_id': task_id})
@app.route('/api/training/<task_id>', methods=['GET'])
@swag_from({
'tags': ['模型训练'],
'summary': '查询训练任务状态',
'description': '获取特定训练任务的当前状态和详情',
'parameters': [
{
'name': 'task_id',
'in': 'path',
'type': 'string',
'required': True,
'description': '训练任务ID'
}
],
'responses': {
200: {
'description': '成功获取任务状态',
'schema': {
'type': 'object',
'properties': {
'status': {'type': 'string'},
'data': {
'type': 'object',
'properties': {
'product_id': {'type': 'string'},
'model_type': {'type': 'string'},
'parameters': {'type': 'object'},
'status': {'type': 'string', 'enum': ['pending', 'running', 'completed', 'failed']},
'created_at': {'type': 'string'},
'model_path': {'type': 'string'},
'metrics': {'type': 'object'},
'model_details_url': {'type': 'string'}
}
}
}
}
},
404: {
'description': '任务不存在'
},
500: {
'description': '服务器内部错误'
}
}
})
def get_training_status(task_id):
try:
if task_id not in training_tasks:
return jsonify({"status": "error", "message": "任务不存在"}), 404
task_info = training_tasks[task_id].copy()
# 如果任务已完成,添加模型详情链接
if task_info['status'] == 'completed':
task_info['model_details_url'] = f"/api/models?product_id={task_info['product_id']}&model_type={task_info['model_type']}"
return jsonify({
"status": "success",
"data": task_info
})
except Exception as e:
return jsonify({"status": "error", "message": str(e)}), 500
# 3. 模型预测API
@app.route('/api/prediction', methods=['POST'])
@swag_from({
'tags': ['模型预测'],
'summary': '使用模型进行预测',
'description': '使用指定模型预测未来销售数据',
'parameters': [
{
'name': 'body',
'in': 'body',
'required': True,
'schema': {
'type': 'object',
'properties': {
'product_id': {'type': 'string'},
'model_type': {'type': 'string', 'enum': ['mlstm', 'transformer', 'kan']},
'version': {'type': 'string'},
'future_days': {'type': 'integer'},
'include_visualization': {'type': 'boolean'},
'start_date': {'type': 'string', 'description': '预测起始日期格式为YYYY-MM-DD'}
},
'required': ['product_id', 'model_type']
}
}
],
'responses': {
200: {
'description': '预测成功',
'schema': {
'type': 'object',
'properties': {
'status': {'type': 'string'},
'data': {
'type': 'object',
'properties': {
'product_id': {'type': 'string'},
'product_name': {'type': 'string'},
'model_type': {'type': 'string'},
'predictions': {'type': 'array'},
'visualization': {'type': 'string'}
}
}
}
}
},
400: {
'description': '请求错误,缺少必要参数或参数格式错误'
},
404: {
'description': '产品或模型不存在'
},
500: {
'description': '服务器内部错误'
}
}
})
def predict():
"""
使用指定的模型进行预测
---
tags:
- 模型预测
parameters:
- in: body
name: body
schema:
type: object
required:
- product_id
- model_type
properties:
product_id:
type: string
description: 产品ID
model_type:
type: string
description: 模型类型 (mlstm, kan, transformer)
future_days:
type: integer
description: 预测未来天数
default: 7
start_date:
type: string
description: 预测起始日期格式为YYYY-MM-DD
default: ''
responses:
200:
description: 预测成功
400:
description: 请求参数错误
404:
description: 模型文件未找到
"""
data = request.json
product_id = data.get('product_id')
model_type = data.get('model_type')
future_days = data.get('future_days', 7)
start_date = data.get('start_date', '')
print(f"API接收到预测请求: product_id={product_id}, model_type={model_type}, future_days={future_days}, start_date={start_date}")
if not product_id or not model_type:
return jsonify({"status": "error", "error": "product_id 和 model_type 是必需的"}), 400
try:
result = load_model_and_predict(product_id, model_type, future_days, start_date=start_date)
if result is None:
return jsonify({"status": "error", "error": f"模型 'models/{model_type}/{product_id}_model.pt' 未找到或加载失败"}), 404
# 获取预测数据
predictions_df = result['predictions_df']
chart_path = result['chart_path']
csv_path = result['csv_path']
# 将DataFrame转换为JSON
predictions_json = predictions_df.to_dict(orient='records')
# 获取当前服务器主机和端口
host_url = request.host_url.rstrip('/') # 移除末尾的斜杠
response_data = {
"status": "success",
"data": predictions_json, # 现在包含历史和预测数据
"history_data": predictions_df[predictions_df['data_type'] == '历史销量'].to_dict(orient='records'),
"prediction_data": predictions_df[predictions_df['data_type'] == '预测销量'].to_dict(orient='records')
}
# 只有在文件确实存在时才添加URL
timestamp = datetime.now().timestamp()
# 获取文件名中的日期和预测天数部分
start_date = data.get('start_date', '')
future_days = data.get('future_days', 7)
if start_date:
try:
start_date_obj = datetime.strptime(start_date, '%Y-%m-%d')
start_date_str = start_date_obj.strftime('%Y%m%d')
except:
start_date_str = datetime.now().strftime('%Y%m%d')
else:
# 如果未提供日期,使用当前日期
start_date_str = datetime.now().strftime('%Y%m%d')
# 构建文件名
png_filename = f"forecast_{start_date_str}_days{future_days}.png"
csv_filename = f"forecast_{start_date_str}_days{future_days}.csv"
# 构建图片路径
png_path = os.path.join('predictions', model_type, product_id, png_filename)
csv_path = os.path.join('predictions', model_type, product_id, csv_filename)
history_png_filename = f"history_{start_date_str}.png"
history_png_path = os.path.join('predictions', model_type, product_id, history_png_filename)
# 检查预测图表
if os.path.exists(png_path):
# 构建完整的URL包含主机名和端口
response_data["image_url"] = f"{host_url}/api/predictions/{model_type}/{product_id}/{png_filename}?t={timestamp}"
print(f"图表URL: {response_data['image_url']}")
else:
response_data["image_url"] = None
print(f"警告: 预测图表文件未生成或不存在: {png_path}")
# 检查历史图表
if os.path.exists(history_png_path):
response_data["history_image_url"] = f"{host_url}/api/predictions/{model_type}/{product_id}/{history_png_filename}?t={timestamp}"
print(f"历史图表URL: {response_data['history_image_url']}")
else:
response_data["history_image_url"] = None
if result and 'history_chart_path' in result and result['history_chart_path'] is None:
print(f"警告: 历史图表生成过程中出现错误,请检查服务器日志")
else:
print(f"警告: 历史图表文件未生成或不存在: {history_png_path}")
# 检查CSV文件
if os.path.exists(csv_path):
response_data["csv_url"] = f"{host_url}/api/predictions/{model_type}/{product_id}/{csv_filename}"
else:
response_data["csv_url"] = None
return jsonify(response_data)
except Exception as e:
traceback.print_exc()
return jsonify({"status": "error", "error": f"预测过程中发生错误: {str(e)}"}), 500
@app.route('/api/prediction/compare', methods=['POST'])
@swag_from({
'tags': ['模型预测'],
'summary': '比较不同模型预测结果',
'description': '比较不同模型对同一产品的预测结果',
'parameters': [
{
'name': 'body',
'in': 'body',
'required': True,
'schema': {
'type': 'object',
'properties': {
'product_id': {'type': 'string'},
'model_types': {'type': 'array', 'items': {'type': 'string'}},
'versions': {'type': 'array', 'items': {'type': 'string'}},
'include_visualization': {'type': 'boolean'}
},
'required': ['product_id', 'model_types']
}
}
],
'responses': {
200: {
'description': '比较成功',
'schema': {
'type': 'object',
'properties': {
'status': {'type': 'string'},
'data': {
'type': 'object',
'properties': {
'product_id': {'type': 'string'},
'product_name': {'type': 'string'},
'model_types': {'type': 'array'},
'comparison': {'type': 'array'},
'visualization': {'type': 'string'}
}
}
}
}
},
400: {
'description': '请求错误,缺少必要参数或参数格式错误'
},
404: {
'description': '产品或模型不存在'
},
500: {
'description': '服务器内部错误'
}
}
})
def compare_predictions():
try:
data = request.json
product_id = data.get('product_id')
model_types = data.get('model_types')
if not product_id or not model_types:
return jsonify({"status": "error", "error": "product_id 和 model_types 是必需的"}), 400
all_predictions = {}
plt.figure(figsize=(12, 8))
# 加载历史数据用于绘图
df = pd.read_excel('pharmacy_sales.xlsx')
product_df = df[df['product_id'] == product_id].sort_values('date')
product_name = product_df['product_name'].iloc[0]
history_days = 30
history_dates = product_df['date'].iloc[-history_days:].values
history_sales = product_df['sales'].iloc[-history_days:].values
plt.plot(history_dates, history_sales, 'b-', label='历史销量')
comparison_data = []
future_dates = None
# 创建比较结果目录
compare_dir = f'predictions/compare'
os.makedirs(compare_dir, exist_ok=True)
for model_type in model_types:
result = load_model_and_predict(product_id, model_type)
if result is not None:
predictions_df = result['predictions_df']
if future_dates is None:
future_dates = predictions_df['date']
comparison_data = [{'date': d.strftime('%Y-%m-%d')} for d in future_dates]
plt.plot(predictions_df['date'], predictions_df['predicted_sales'], '--', label=f'{model_type.upper()} 预测')
preds = predictions_df['predicted_sales'].tolist()
for i in range(len(comparison_data)):
comparison_data[i][f'{model_type}_prediction'] = preds[i] if i < len(preds) else None
plt.title(f'{product_name} - 多模型预测比较')
plt.xlabel('日期')
plt.ylabel('销量')
plt.legend()
plt.grid(True)
plt.xticks(rotation=45)
plt.tight_layout()
chart_path = f'{compare_dir}/{product_id}_model_comparison.png'
plt.savefig(chart_path)
plt.close() # 关闭图表,释放内存
# 获取当前服务器主机和端口
host_url = request.host_url.rstrip('/') # 移除末尾的斜杠
# 生成带时间戳的URL以避免缓存
timestamp = datetime.now().timestamp()
image_url = f"{host_url}/api/predictions/compare/{product_id}_model_comparison.png?t={timestamp}"
return jsonify({
"status": "success",
"data": {
"product_id": product_id,
"product_name": product_name,
"model_types": model_types,
"comparison": comparison_data,
"visualization_url": image_url
}
})
except Exception as e:
traceback.print_exc()
return jsonify({"status": "error", "error": f"比较预测时出错: {e}"}), 500
# 4. 模型管理API
@app.route('/api/models', methods=['GET'])
@swag_from({
'tags': ['模型管理'],
'summary': '获取模型列表',
'description': '获取系统中的模型列表可按产品ID和模型类型筛选',
'parameters': [
{
'name': 'product_id',
'in': 'query',
'type': 'string',
'required': False,
'description': '按产品ID筛选'
},
{
'name': 'model_type',
'in': 'query',
'type': 'string',
'required': False,
'description': '按模型类型筛选'
}
],
'responses': {
200: {
'description': '成功获取模型列表',
'schema': {
'type': 'object',
'properties': {
'status': {'type': 'string'},
'data': {
'type': 'array',
'items': {
'type': 'object',
'properties': {
'model_id': {'type': 'string'},
'product_id': {'type': 'string'},
'product_name': {'type': 'string'},
'model_type': {'type': 'string'},
'created_at': {'type': 'string'},
'metrics': {'type': 'object'}
}
}
}
}
}
},
500: {
'description': '服务器内部错误'
}
}
})
def list_models():
"""
列出所有可用的模型
---
tags:
- 模型管理
parameters:
- name: product_id
in: query
type: string
required: false
description: 按产品ID筛选模型
- name: model_type
in: query
type: string
required: false
description: 按模型类型筛选 (mlstm, kan, transformer)
responses:
200:
description: 模型列表
schema:
type: object
properties:
status:
type: string
example: success
data:
type: array
items:
type: object
properties:
model_id:
type: string
product_id:
type: string
product_name:
type: string
model_type:
type: string
created_at:
type: string
metrics:
type: object
"""
models_dir = 'models'
model_types = ['mlstm', 'kan', 'transformer']
available_models = []
product_id_filter = request.args.get('product_id')
model_type_filter = request.args.get('model_type')
if not os.path.exists(models_dir):
return jsonify({"status": "success", "data": []})
for model_type in model_types:
if model_type_filter and model_type_filter != model_type:
continue
type_dir = os.path.join(models_dir, model_type)
if not os.path.exists(type_dir):
continue
for file_name in os.listdir(type_dir):
if file_name.endswith('_log.json'):
product_id = file_name.replace('_log.json', '')
if product_id_filter and product_id_filter != product_id:
continue
log_path = os.path.join(type_dir, file_name)
try:
with open(log_path, 'r', encoding='utf-8') as f:
log_data = json.load(f)
model_info = {
"model_id": f"{model_type}_{product_id}",
"product_id": log_data.get('product_id'),
"product_name": log_data.get('product_name'),
"model_type": log_data.get('model_type'),
"created_at": log_data.get('training_completed_at'),
"metrics": log_data.get('metrics'),
"file_path": log_data.get('file_path')
}
available_models.append(model_info)
except Exception as e:
print(f"读取日志文件 {log_path} 失败: {e}")
# 按创建时间降序排序
available_models.sort(key=lambda x: x.get('created_at', ''), reverse=True)
return jsonify({"status": "success", "data": available_models})
@app.route('/api/models/<model_id>', methods=['GET'])
@swag_from({
'tags': ['模型管理'],
'summary': '获取模型详情',
'description': '获取特定模型的详细信息',
'parameters': [
{
'name': 'model_id',
'in': 'path',
'type': 'string',
'required': True,
'description': '模型ID格式为{product_id}_{model_type}_v{version}'
}
],
'responses': {
200: {
'description': '成功获取模型详情',
'schema': {
'type': 'object',
'properties': {
'status': {'type': 'string'},
'data': {
'type': 'object',
'properties': {
'product_id': {'type': 'string'},
'model_type': {'type': 'string'},
'version': {'type': 'string'},
'created_at': {'type': 'string'},
'file_path': {'type': 'string'},
'file_size': {'type': 'string'},
'features': {'type': 'array'},
'look_back': {'type': 'integer'},
'T': {'type': 'integer'},
'metrics': {'type': 'object'}
}
}
}
}
},
400: {
'description': '无效的模型ID格式'
},
404: {
'description': '模型不存在'
},
500: {
'description': '服务器内部错误'
}
}
})
def get_model_details(model_id):
"""
获取单个模型的详细信息
---
tags:
- 模型管理
parameters:
- name: model_id
in: path
type: string
required: true
description: 模型的唯一标识符 (格式: model_type_product_id)
responses:
200:
description: 模型的详细信息
404:
description: 未找到模型
"""
try:
model_type, product_id = model_id.split('_', 1)
log_path = os.path.join('models', model_type, f'{product_id}_log.json')
if not os.path.exists(log_path):
return jsonify({"status": "error", "error": "模型未找到"}), 404
with open(log_path, 'r', encoding='utf-8') as f:
log_data = json.load(f)
# 还可以添加从 .pt 文件读取更多信息的逻辑
# checkpoint = torch.load(log_data['file_path'], map_location='cpu')
# log_data['details_from_pt'] = ...
return jsonify({"status": "success", "data": log_data})
except ValueError:
return jsonify({"status": "error", "error": "无效的model_id格式"}), 400
except Exception as e:
return jsonify({"status": "error", "error": f"获取模型详情失败: {e}"}), 500
@app.route('/api/models/<model_id>', methods=['DELETE'])
@swag_from({
'tags': ['模型管理'],
'summary': '删除模型',
'description': '删除特定模型',
'parameters': [
{
'name': 'model_id',
'in': 'path',
'type': 'string',
'required': True,
'description': '模型ID格式为{product_id}_{model_type}_v{version}'
}
],
'responses': {
200: {
'description': '模型删除成功',
'schema': {
'type': 'object',
'properties': {
'status': {'type': 'string'},
'message': {'type': 'string'}
}
}
},
400: {
'description': '无效的模型ID格式'
},
500: {
'description': '服务器内部错误'
}
}
})
def delete_model(model_id):
"""
删除一个模型及其关联文件
---
tags:
- 模型管理
parameters:
- name: model_id
in: path
type: string
required: true
description: 要删除的模型的ID (格式: model_type_product_id)
responses:
200:
description: 模型删除成功
404:
description: 模型未找到
"""
try:
model_type, product_id = model_id.split('_', 1)
model_dir = os.path.join('models', model_type)
model_path = os.path.join(model_dir, f'{product_id}_model.pt')
log_path = os.path.join(model_dir, f'{product_id}_log.json')
if not os.path.exists(model_path) and not os.path.exists(log_path):
return jsonify({"status": "error", "error": "模型未找到"}), 404
if os.path.exists(model_path):
os.remove(model_path)
if os.path.exists(log_path):
os.remove(log_path)
return jsonify({"status": "success", "message": f"模型 {model_id} 已删除"})
except ValueError:
return jsonify({"status": "error", "error": "无效的model_id格式"}), 400
except Exception as e:
return jsonify({"status": "error", "error": f"删除模型失败: {e}"}), 500
@app.route('/api/models/<model_id>/export', methods=['GET'])
@swag_from({
'tags': ['模型管理'],
'summary': '导出模型',
'description': '导出特定模型文件',
'parameters': [
{
'name': 'model_id',
'in': 'path',
'type': 'string',
'required': True,
'description': '模型ID格式为{product_id}_{model_type}_v{version}'
}
],
'responses': {
200: {
'description': '模型文件下载',
'content': {
'application/octet-stream': {}
}
},
400: {
'description': '无效的模型ID格式'
},
500: {
'description': '服务器内部错误'
}
}
})
def export_model(model_id):
try:
model_type, product_id = model_id.split('_', 1)
model_path = os.path.join('models', model_type, f'{product_id}_model.pt')
if not os.path.exists(model_path):
return jsonify({"status": "error", "error": "模型文件未找到"}), 404
return send_file(
model_path,
as_attachment=True,
download_name=f'{model_id}.pt',
mimetype='application/octet-stream'
)
except Exception as e:
return jsonify({"status": "error", "error": f"导出模型失败: {e}"}), 500
@app.route('/api/models/import', methods=['POST'])
@swag_from({
'tags': ['模型管理'],
'summary': '导入模型',
'description': '导入模型文件',
'consumes': ['multipart/form-data'],
'parameters': [
{
'name': 'file',
'in': 'formData',
'type': 'file',
'required': True,
'description': 'PyTorch模型文件(.pt)'
}
],
'responses': {
200: {
'description': '模型导入成功',
'schema': {
'type': 'object',
'properties': {
'status': {'type': 'string'},
'message': {'type': 'string'},
'data': {
'type': 'object',
'properties': {
'model_path': {'type': 'string'}
}
}
}
}
},
400: {
'description': '请求错误,文件格式不正确或缺少文件'
},
500: {
'description': '服务器内部错误'
}
}
})
def import_model():
try:
if 'file' not in request.files:
return jsonify({"status": "error", "message": "没有上传文件"}), 400
file = request.files['file']
if file.filename == '' or not file.filename.endswith('.pt'):
return jsonify({"status": "error", "message": "请上传有效的.pt模型文件"}), 400
# 从文件名解析 product_id, model_type
# 假设文件名格式为 `transformer_P001.pt`
try:
model_type, product_id_ext = file.filename.split('_', 1)
product_id = os.path.splitext(product_id_ext)[0]
model_types = ['mlstm', 'kan', 'transformer']
if model_type not in model_types:
raise ValueError("无效的模型类型")
except ValueError:
return jsonify({
"status": "error",
"message": "文件名格式不正确,应为 'model_type_product_id.pt' (例如: mlstm_P001.pt)"
}), 400
# 创建目标目录并保存文件
model_dir = os.path.join('models', model_type)
os.makedirs(model_dir, exist_ok=True)
model_path = os.path.join(model_dir, f'{product_id}_model.pt')
# 检查是否需要创建关联的日志文件
log_path = os.path.join(model_dir, f'{product_id}_log.json')
if not os.path.exists(log_path):
# 尝试从.pt文件加载信息如果可能
try:
checkpoint = torch.load(file, map_location='cpu')
# 重新定位文件指针到开头,以便 `file.save` 正常工作
file.seek(0)
except Exception:
checkpoint = {} # 如果加载失败创建一个空的checkpoint
# 创建一个基础的log文件
log_data = {
'product_id': product_id,
'product_name': f"导入的产品 {product_id}", # 可能需要用户后续编辑
'model_type': model_type,
'training_completed_at': datetime.now().isoformat(),
'epochs': checkpoint.get('epochs', 'N/A'),
'metrics': checkpoint.get('metrics', {'info': '导入的模型,无详细指标'}),
'file_path': model_path
}
with open(log_path, 'w', encoding='utf-8') as f:
json.dump(log_data, f, indent=4, ensure_ascii=False)
# 保存模型文件
file.save(model_path)
return jsonify({
"status": "success",
"message": "模型已成功导入",
"data": {"model_path": model_path}
})
except Exception as e:
traceback.print_exc()
return jsonify({"status": "error", "error": f"导入模型失败: {e}"}), 500
@app.route('/api/plots/<filename>')
def get_plot(filename):
"""Serve a plot file from the root directory."""
try:
return send_from_directory(app.root_path, filename)
except FileNotFoundError:
return jsonify({"status": "error", "error": "Plot not found"}), 404
@app.route('/api/csv/<filename>')
def get_csv(filename):
"""Serve a CSV file from the root directory."""
try:
return send_from_directory(app.root_path, filename, as_attachment=True)
except FileNotFoundError:
return jsonify({"status": "error", "error": "CSV file not found"}), 404
# 添加静态UI路由将/ui路径映射到wwwroot目录
@app.route('/ui/', defaults={'path': 'index.html'})
@app.route('/ui/<path:path>')
def serve_ui(path):
"""提供UI静态文件服务将/ui路径映射到wwwroot目录"""
try:
# 从wwwroot目录提供静态文件
return send_from_directory('wwwroot', path)
except FileNotFoundError:
# 如果是子路径请求(例如/ui/about)但文件不存在尝试返回index.html以支持SPA路由
if path != 'index.html':
try:
return send_from_directory('wwwroot', 'index.html')
except FileNotFoundError:
return jsonify({"status": "error", "error": "UI files not found"}), 404
return jsonify({"status": "error", "error": f"UI file {path} not found"}), 404
@app.route('/api/predictions/<model_type>/<product_id>/<filename>')
def get_prediction_file(model_type, product_id, filename):
"""按模型类型和产品ID获取预测文件"""
try:
# 构建文件路径
file_path = os.path.join('predictions', model_type, product_id)
# 如果是图片文件,添加防缓存头
if filename.endswith('.png'):
response = send_from_directory(file_path, filename)
# 强制浏览器不缓存图片
response.headers['Cache-Control'] = 'no-store, no-cache, must-revalidate, max-age=0, post-check=0, pre-check=0'
response.headers['Pragma'] = 'no-cache'
response.headers['Expires'] = '0'
response.headers['Last-Modified'] = datetime.now().strftime("%a, %d %b %Y %H:%M:%S GMT")
response.headers['Vary'] = '*'
print(f"提供图片文件 {filename} 并添加防缓存头")
return response
else:
return send_from_directory(file_path, filename)
except FileNotFoundError:
return jsonify({"status": "error", "error": f"预测文件 {filename} 未找到"}), 404
@app.route('/api/predictions/compare/<filename>')
def get_compare_file(filename):
"""获取模型比较结果文件"""
try:
file_path = os.path.join('predictions', 'compare')
return send_from_directory(file_path, filename)
except FileNotFoundError:
return jsonify({"status": "error", "error": f"比较结果文件 {filename} 未找到"}), 404
if __name__ == '__main__':
# 命令行参数解析
parser = argparse.ArgumentParser(description='药店销售预测系统API服务')
parser.add_argument('--host', type=str, default='0.0.0.0', help='API服务监听的主机地址')
parser.add_argument('--port', type=int, default=5000, help='API服务监听的端口')
parser.add_argument('--swagger', action='store_true', default=True, help='是否启用Swagger UI')
parser.add_argument('--debug', action='store_true', default=True, help='是否启用调试模式')
args = parser.parse_args()
# 如果不启用Swagger则关闭Swagger UI
if not args.swagger:
app.config['SWAGGER'] = {'enabled': False}
print("Swagger UI已禁用")
else:
print(f"Swagger UI已启用访问: http://{args.host}:{args.port}/swagger/")
# 确保 models 目录存在
if not os.path.exists('models'):
os.makedirs('models')
# 确保预测结果目录存在
if not os.path.exists('predictions'):
os.makedirs('predictions')
app.run(debug=args.debug, host=args.host, port=args.port)