ShopTRAINING/api.py
2025-06-14 05:00:17 +08:00

1452 lines
50 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import sys
# 获取当前脚本所在目录的绝对路径
current_dir = os.path.dirname(os.path.abspath(__file__))
# 将当前目录添加到系统路径
sys.path.append(current_dir)
# 或者添加父目录
#parent_dir = os.path.dirname(current_dir)
#sys.path.append(parent_dir)
from flask import Flask, request, jsonify, send_file, redirect, send_from_directory
from flask_cors import CORS
import os
import uuid
import pandas as pd
from pharmacy_predictor import (
train_product_model_with_mlstm,
train_product_model_with_kan,
train_product_model_with_transformer,
load_model_and_predict
)
import threading
import base64
import matplotlib.pyplot as plt
from io import BytesIO
import json
from flasgger import Swagger, swag_from
import argparse
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor
from threading import Lock
import traceback
import torch
app = Flask(__name__)
CORS(app) # 启用CORS支持
# Swagger配置
swagger_config = {
"headers": [],
"specs": [
{
"endpoint": "apispec",
"route": "/apispec.json",
"rule_filter": lambda rule: True, # 包含所有路由
"model_filter": lambda tag: True, # 包含所有模型
}
],
"static_url_path": "/flasgger_static",
"swagger_ui": True,
"specs_route": "/swagger/"
}
swagger_template = {
"swagger": "2.0",
"info": {
"title": "药店销售预测系统API",
"description": "用于药店销售预测的RESTful API",
"version": "1.0.0",
"contact": {
"name": "API开发团队",
"email": "support@example.com"
}
},
"tags": [
{
"name": "数据管理",
"description": "数据上传和查询相关接口"
},
{
"name": "模型训练",
"description": "模型训练相关接口"
},
{
"name": "模型预测",
"description": "预测销售数据相关接口"
},
{
"name": "模型管理",
"description": "模型查询、导出和删除接口"
}
]
}
swagger = Swagger(app, config=swagger_config, template=swagger_template)
# 存储训练任务状态
training_tasks = {}
tasks_lock = Lock()
# 线程池用于后台训练
executor = ThreadPoolExecutor(max_workers=2)
# 辅助函数将图像转换为Base64
def fig_to_base64(fig):
buf = BytesIO()
fig.savefig(buf, format='png')
buf.seek(0)
img_str = base64.b64encode(buf.read()).decode('utf-8')
return img_str
# 根路由 - 重定向到UI界面
@app.route('/')
def index():
"""重定向到UI界面"""
return redirect('/ui/')
# Swagger UI路由
@app.route('/swagger')
def swagger_ui():
"""重定向到Swagger UI文档页面"""
return redirect('/swagger/')
# 1. 数据管理API
@app.route('/api/products', methods=['GET'])
@swag_from({
'tags': ['数据管理'],
'summary': '获取所有产品列表',
'description': '返回系统中所有产品的ID和名称',
'responses': {
200: {
'description': '成功获取产品列表',
'schema': {
'type': 'object',
'properties': {
'status': {'type': 'string'},
'data': {
'type': 'array',
'items': {
'type': 'object',
'properties': {
'product_id': {'type': 'string'},
'product_name': {'type': 'string'}
}
}
}
}
}
},
500: {
'description': '服务器内部错误'
}
}
})
def get_products():
try:
df = pd.read_excel('pharmacy_sales.xlsx')
products = df[['product_id', 'product_name']].drop_duplicates().to_dict('records')
return jsonify({"status": "success", "data": products})
except Exception as e:
return jsonify({"status": "error", "message": str(e)}), 500
@app.route('/api/products/<product_id>', methods=['GET'])
@swag_from({
'tags': ['数据管理'],
'summary': '获取单个产品详情',
'description': '返回指定产品ID的详细信息',
'parameters': [
{
'name': 'product_id',
'in': 'path',
'type': 'string',
'required': True,
'description': '产品ID例如P001'
}
],
'responses': {
200: {
'description': '成功获取产品详情',
'schema': {
'type': 'object',
'properties': {
'status': {'type': 'string'},
'data': {
'type': 'object',
'properties': {
'product_id': {'type': 'string'},
'product_name': {'type': 'string'},
'data_points': {'type': 'integer'},
'date_range': {
'type': 'object',
'properties': {
'start': {'type': 'string'},
'end': {'type': 'string'}
}
}
}
}
}
}
},
404: {
'description': '产品不存在'
},
500: {
'description': '服务器内部错误'
}
}
})
def get_product(product_id):
try:
df = pd.read_excel('pharmacy_sales.xlsx')
product_df = df[df['product_id'] == product_id]
if product_df.empty:
return jsonify({"status": "error", "message": "产品不存在"}), 404
product_info = {
"product_id": product_id,
"product_name": product_df['product_name'].iloc[0],
"data_points": len(product_df),
"date_range": {
"start": product_df['date'].min().strftime('%Y-%m-%d'),
"end": product_df['date'].max().strftime('%Y-%m-%d')
}
}
return jsonify({"status": "success", "data": product_info})
except Exception as e:
return jsonify({"status": "error", "message": str(e)}), 500
@app.route('/api/products/<product_id>/sales', methods=['GET'])
@swag_from({
'tags': ['数据管理'],
'summary': '获取产品销售数据',
'description': '返回指定产品在特定日期范围内的销售数据',
'parameters': [
{
'name': 'product_id',
'in': 'path',
'type': 'string',
'required': True,
'description': '产品ID例如P001'
},
{
'name': 'start_date',
'in': 'query',
'type': 'string',
'required': False,
'description': '开始日期格式为YYYY-MM-DD'
},
{
'name': 'end_date',
'in': 'query',
'type': 'string',
'required': False,
'description': '结束日期格式为YYYY-MM-DD'
}
],
'responses': {
200: {
'description': '成功获取销售数据',
'schema': {
'type': 'object',
'properties': {
'status': {'type': 'string'},
'data': {
'type': 'array',
'items': {
'type': 'object'
}
}
}
}
},
404: {
'description': '产品不存在'
},
500: {
'description': '服务器内部错误'
}
}
})
def get_product_sales(product_id):
try:
start_date = request.args.get('start_date')
end_date = request.args.get('end_date')
df = pd.read_excel('pharmacy_sales.xlsx')
product_df = df[df['product_id'] == product_id].sort_values('date')
if product_df.empty:
return jsonify({"status": "error", "message": "产品不存在"}), 404
# 如果提供了日期范围,进行过滤
if start_date:
product_df = product_df[product_df['date'] >= pd.to_datetime(start_date)]
if end_date:
product_df = product_df[product_df['date'] <= pd.to_datetime(end_date)]
# 转换日期为字符串以便JSON序列化
product_df['date'] = product_df['date'].dt.strftime('%Y-%m-%d')
sales_data = product_df.to_dict('records')
return jsonify({"status": "success", "data": sales_data})
except Exception as e:
return jsonify({"status": "error", "message": str(e)}), 500
@app.route('/api/data/upload', methods=['POST'])
@swag_from({
'tags': ['数据管理'],
'summary': '上传销售数据',
'description': '上传新的销售数据文件(Excel格式)',
'consumes': ['multipart/form-data'],
'parameters': [
{
'name': 'file',
'in': 'formData',
'type': 'file',
'required': True,
'description': 'Excel文件(.xlsx),包含销售数据'
}
],
'responses': {
200: {
'description': '数据上传成功',
'schema': {
'type': 'object',
'properties': {
'status': {'type': 'string'},
'message': {'type': 'string'},
'data': {
'type': 'object',
'properties': {
'products': {'type': 'integer'},
'rows': {'type': 'integer'}
}
}
}
}
},
400: {
'description': '请求错误,可能是文件格式不正确或缺少必要字段'
},
500: {
'description': '服务器内部错误'
}
}
})
def upload_data():
try:
if 'file' not in request.files:
return jsonify({"status": "error", "message": "没有上传文件"}), 400
file = request.files['file']
if file.filename == '':
return jsonify({"status": "error", "message": "没有选择文件"}), 400
if file and file.filename.endswith('.xlsx'):
file_path = 'uploaded_data.xlsx'
file.save(file_path)
# 验证数据格式
try:
df = pd.read_excel(file_path)
required_columns = ['date', 'product_id', 'product_name', 'sales', 'price']
missing_columns = [col for col in required_columns if col not in df.columns]
if missing_columns:
return jsonify({
"status": "error",
"message": f"上传的数据缺少必要的列: {', '.join(missing_columns)}"
}), 400
# 合并到现有数据或替换现有数据
existing_df = pd.read_excel('pharmacy_sales.xlsx')
# 这里可以实现数据合并逻辑例如按日期和产品ID去重后合并
# 简单示例:保存上传的数据
df.to_excel('pharmacy_sales.xlsx', index=False)
return jsonify({
"status": "success",
"message": "数据上传成功",
"data": {
"products": len(df['product_id'].unique()),
"rows": len(df)
}
})
except Exception as e:
return jsonify({"status": "error", "message": f"数据验证失败: {str(e)}"}), 400
else:
return jsonify({"status": "error", "message": "只支持Excel文件(.xlsx)"}), 400
except Exception as e:
return jsonify({"status": "error", "message": str(e)}), 500
# 2. 模型训练API
@app.route('/api/training', methods=['GET'])
@swag_from({
'tags': ['模型训练'],
'summary': '获取所有训练任务列表',
'description': '返回所有正在进行、已完成或失败的训练任务',
'responses': {
200: {
'description': '成功获取任务列表',
'schema': {
'type': 'object',
'properties': {
'status': {'type': 'string'},
'data': {
'type': 'array',
'items': {
'type': 'object',
'properties': {
'task_id': {'type': 'string'},
'product_id': {'type': 'string'},
'model_type': {'type': 'string'},
'status': {'type': 'string'},
'start_time': {'type': 'string'},
'metrics': {'type': 'object'},
'error': {'type': 'string'},
'model_path': {'type': 'string'}
}
}
}
}
}
}
}
})
def get_all_training_tasks():
"""获取所有训练任务的状态"""
try:
# 为了方便前端使用我们将任务ID也包含在每个任务信息中
tasks_with_id = []
for task_id, task_info in training_tasks.items():
task_copy = task_info.copy()
task_copy['task_id'] = task_id
tasks_with_id.append(task_copy)
# 按开始时间降序排序,最新的任务在前面
sorted_tasks = sorted(tasks_with_id, key=lambda x: x['start_time'], reverse=True)
return jsonify({"status": "success", "data": sorted_tasks})
except Exception as e:
return jsonify({"status": "error", "message": str(e)}), 500
@app.route('/api/training', methods=['POST'])
@swag_from({
'tags': ['模型训练'],
'summary': '启动模型训练任务',
'description': '为指定产品启动一个新的模型训练任务',
'parameters': [
{
'name': 'body',
'in': 'body',
'required': True,
'schema': {
'type': 'object',
'properties': {
'product_id': {'type': 'string', 'description': '例如 P001'},
'model_type': {'type': 'string', 'enum': ['mlstm', 'transformer', 'kan'], 'description': '要训练的模型类型'},
'epochs': {'type': 'integer', 'default': 50, 'description': '训练轮次'}
},
'required': ['product_id', 'model_type']
}
}
],
'responses': {
200: {
'description': '训练任务已启动',
'schema': {
'type': 'object',
'properties': {
'message': {'type': 'string'},
'task_id': {'type': 'string'}
}
}
},
400: {
'description': '请求错误'
}
}
})
def start_training():
"""
启动模型训练
---
post:
...
"""
data = request.get_json()
product_id = data.get('product_id')
model_type = data.get('model_type')
epochs = data.get('epochs', 50) # 默认为50轮
if not product_id or not model_type:
return jsonify({'error': '缺少product_id或model_type'}), 400
global training_tasks
task_id = str(uuid.uuid4())
model_train_functions = {
'mlstm': train_product_model_with_mlstm,
'kan': train_product_model_with_kan,
'transformer': train_product_model_with_transformer
}
if model_type not in model_train_functions:
return jsonify({'error': '无效的模型类型'}), 400
train_function = model_train_functions[model_type]
def train_task(product_id, epochs, model_type):
global training_tasks
try:
print(f"任务 {task_id}: 开始为产品 {product_id} 训练 {model_type} 模型,共 {epochs} 个轮次。")
# 这里的 train_function 会返回 (model, metrics)
_, metrics = train_function(product_id, epochs)
training_tasks[task_id]['status'] = 'completed'
training_tasks[task_id]['metrics'] = metrics
# 保存模型路径
model_path = f'models/{model_type}/{product_id}_model.pt'
training_tasks[task_id]['model_path'] = model_path
print(f"任务 {task_id}: 训练完成。评估指标: {metrics}")
except Exception as e:
import traceback
traceback.print_exc()
print(f"任务 {task_id}: 训练失败。错误: {e}")
training_tasks[task_id]['status'] = 'failed'
training_tasks[task_id]['error'] = str(e)
thread = threading.Thread(target=train_task, args=(product_id, epochs, model_type))
thread.start()
training_tasks[task_id] = {
'status': 'running',
'product_id': product_id,
'model_type': model_type,
'start_time': datetime.now().isoformat(),
'metrics': None,
'error': None,
'model_path': None
}
return jsonify({'message': '模型训练已开始', 'task_id': task_id})
@app.route('/api/training/<task_id>', methods=['GET'])
@swag_from({
'tags': ['模型训练'],
'summary': '查询训练任务状态',
'description': '获取特定训练任务的当前状态和详情',
'parameters': [
{
'name': 'task_id',
'in': 'path',
'type': 'string',
'required': True,
'description': '训练任务ID'
}
],
'responses': {
200: {
'description': '成功获取任务状态',
'schema': {
'type': 'object',
'properties': {
'status': {'type': 'string'},
'data': {
'type': 'object',
'properties': {
'product_id': {'type': 'string'},
'model_type': {'type': 'string'},
'parameters': {'type': 'object'},
'status': {'type': 'string', 'enum': ['pending', 'running', 'completed', 'failed']},
'created_at': {'type': 'string'},
'model_path': {'type': 'string'},
'metrics': {'type': 'object'},
'model_details_url': {'type': 'string'}
}
}
}
}
},
404: {
'description': '任务不存在'
},
500: {
'description': '服务器内部错误'
}
}
})
def get_training_status(task_id):
try:
if task_id not in training_tasks:
return jsonify({"status": "error", "message": "任务不存在"}), 404
task_info = training_tasks[task_id].copy()
# 如果任务已完成,添加模型详情链接
if task_info['status'] == 'completed':
task_info['model_details_url'] = f"/api/models?product_id={task_info['product_id']}&model_type={task_info['model_type']}"
return jsonify({
"status": "success",
"data": task_info
})
except Exception as e:
return jsonify({"status": "error", "message": str(e)}), 500
# 3. 模型预测API
@app.route('/api/prediction', methods=['POST'])
@swag_from({
'tags': ['模型预测'],
'summary': '使用模型进行预测',
'description': '使用指定模型预测未来销售数据',
'parameters': [
{
'name': 'body',
'in': 'body',
'required': True,
'schema': {
'type': 'object',
'properties': {
'product_id': {'type': 'string'},
'model_type': {'type': 'string', 'enum': ['mlstm', 'transformer', 'kan']},
'version': {'type': 'string'},
'future_days': {'type': 'integer'},
'include_visualization': {'type': 'boolean'},
'start_date': {'type': 'string', 'description': '预测起始日期格式为YYYY-MM-DD'}
},
'required': ['product_id', 'model_type']
}
}
],
'responses': {
200: {
'description': '预测成功',
'schema': {
'type': 'object',
'properties': {
'status': {'type': 'string'},
'data': {
'type': 'object',
'properties': {
'product_id': {'type': 'string'},
'product_name': {'type': 'string'},
'model_type': {'type': 'string'},
'predictions': {'type': 'array'},
'visualization': {'type': 'string'}
}
}
}
}
},
400: {
'description': '请求错误,缺少必要参数或参数格式错误'
},
404: {
'description': '产品或模型不存在'
},
500: {
'description': '服务器内部错误'
}
}
})
def predict():
"""
使用指定的模型进行预测
---
tags:
- 模型预测
parameters:
- in: body
name: body
schema:
type: object
required:
- product_id
- model_type
properties:
product_id:
type: string
description: 产品ID
model_type:
type: string
description: 模型类型 (mlstm, kan, transformer)
future_days:
type: integer
description: 预测未来天数
default: 7
start_date:
type: string
description: 预测起始日期格式为YYYY-MM-DD
default: ''
responses:
200:
description: 预测成功
400:
description: 请求参数错误
404:
description: 模型文件未找到
"""
data = request.json
product_id = data.get('product_id')
model_type = data.get('model_type')
future_days = data.get('future_days', 7)
start_date = data.get('start_date', '')
print(f"API接收到预测请求: product_id={product_id}, model_type={model_type}, future_days={future_days}, start_date={start_date}")
if not product_id or not model_type:
return jsonify({"status": "error", "error": "product_id 和 model_type 是必需的"}), 400
try:
result = load_model_and_predict(product_id, model_type, future_days, start_date=start_date)
if result is None:
return jsonify({"status": "error", "error": f"模型 'models/{model_type}/{product_id}_model.pt' 未找到或加载失败"}), 404
# 获取预测数据
predictions_df = result['predictions_df']
chart_path = result['chart_path']
csv_path = result['csv_path']
# 将DataFrame转换为JSON
predictions_json = predictions_df.to_dict(orient='records')
# 获取当前服务器主机和端口
host_url = request.host_url.rstrip('/') # 移除末尾的斜杠
response_data = {
"status": "success",
"data": predictions_json, # 现在包含历史和预测数据
"history_data": predictions_df[predictions_df['data_type'] == '历史销量'].to_dict(orient='records'),
"prediction_data": predictions_df[predictions_df['data_type'] == '预测销量'].to_dict(orient='records')
}
# 只有在文件确实存在时才添加URL
timestamp = datetime.now().timestamp()
# 获取文件名中的日期和预测天数部分
start_date = data.get('start_date', '')
future_days = data.get('future_days', 7)
if start_date:
try:
start_date_obj = datetime.strptime(start_date, '%Y-%m-%d')
start_date_str = start_date_obj.strftime('%Y%m%d')
except:
start_date_str = datetime.now().strftime('%Y%m%d')
else:
# 如果未提供日期,使用当前日期
start_date_str = datetime.now().strftime('%Y%m%d')
# 构建文件名
png_filename = f"forecast_{start_date_str}_days{future_days}.png"
csv_filename = f"forecast_{start_date_str}_days{future_days}.csv"
# 构建图片路径
png_path = os.path.join('predictions', model_type, product_id, png_filename)
csv_path = os.path.join('predictions', model_type, product_id, csv_filename)
history_png_filename = f"history_{start_date_str}.png"
history_png_path = os.path.join('predictions', model_type, product_id, history_png_filename)
# 检查预测图表
if os.path.exists(png_path):
# 构建完整的URL包含主机名和端口
response_data["image_url"] = f"{host_url}/api/predictions/{model_type}/{product_id}/{png_filename}?t={timestamp}"
print(f"图表URL: {response_data['image_url']}")
else:
response_data["image_url"] = None
print(f"警告: 预测图表文件未生成或不存在: {png_path}")
# 检查历史图表
if os.path.exists(history_png_path):
response_data["history_image_url"] = f"{host_url}/api/predictions/{model_type}/{product_id}/{history_png_filename}?t={timestamp}"
print(f"历史图表URL: {response_data['history_image_url']}")
else:
response_data["history_image_url"] = None
if result and 'history_chart_path' in result and result['history_chart_path'] is None:
print(f"警告: 历史图表生成过程中出现错误,请检查服务器日志")
else:
print(f"警告: 历史图表文件未生成或不存在: {history_png_path}")
# 检查CSV文件
if os.path.exists(csv_path):
response_data["csv_url"] = f"{host_url}/api/predictions/{model_type}/{product_id}/{csv_filename}"
else:
response_data["csv_url"] = None
return jsonify(response_data)
except Exception as e:
traceback.print_exc()
return jsonify({"status": "error", "error": f"预测过程中发生错误: {str(e)}"}), 500
@app.route('/api/prediction/compare', methods=['POST'])
@swag_from({
'tags': ['模型预测'],
'summary': '比较不同模型预测结果',
'description': '比较不同模型对同一产品的预测结果',
'parameters': [
{
'name': 'body',
'in': 'body',
'required': True,
'schema': {
'type': 'object',
'properties': {
'product_id': {'type': 'string'},
'model_types': {'type': 'array', 'items': {'type': 'string'}},
'versions': {'type': 'array', 'items': {'type': 'string'}},
'include_visualization': {'type': 'boolean'}
},
'required': ['product_id', 'model_types']
}
}
],
'responses': {
200: {
'description': '比较成功',
'schema': {
'type': 'object',
'properties': {
'status': {'type': 'string'},
'data': {
'type': 'object',
'properties': {
'product_id': {'type': 'string'},
'product_name': {'type': 'string'},
'model_types': {'type': 'array'},
'comparison': {'type': 'array'},
'visualization': {'type': 'string'}
}
}
}
}
},
400: {
'description': '请求错误,缺少必要参数或参数格式错误'
},
404: {
'description': '产品或模型不存在'
},
500: {
'description': '服务器内部错误'
}
}
})
def compare_predictions():
try:
data = request.json
product_id = data.get('product_id')
model_types = data.get('model_types')
if not product_id or not model_types:
return jsonify({"status": "error", "error": "product_id 和 model_types 是必需的"}), 400
all_predictions = {}
plt.figure(figsize=(12, 8))
# 加载历史数据用于绘图
df = pd.read_excel('pharmacy_sales.xlsx')
product_df = df[df['product_id'] == product_id].sort_values('date')
product_name = product_df['product_name'].iloc[0]
history_days = 30
history_dates = product_df['date'].iloc[-history_days:].values
history_sales = product_df['sales'].iloc[-history_days:].values
plt.plot(history_dates, history_sales, 'b-', label='历史销量')
comparison_data = []
future_dates = None
# 创建比较结果目录
compare_dir = f'predictions/compare'
os.makedirs(compare_dir, exist_ok=True)
for model_type in model_types:
result = load_model_and_predict(product_id, model_type)
if result is not None:
predictions_df = result['predictions_df']
if future_dates is None:
future_dates = predictions_df['date']
comparison_data = [{'date': d.strftime('%Y-%m-%d')} for d in future_dates]
plt.plot(predictions_df['date'], predictions_df['predicted_sales'], '--', label=f'{model_type.upper()} 预测')
preds = predictions_df['predicted_sales'].tolist()
for i in range(len(comparison_data)):
comparison_data[i][f'{model_type}_prediction'] = preds[i] if i < len(preds) else None
plt.title(f'{product_name} - 多模型预测比较')
plt.xlabel('日期')
plt.ylabel('销量')
plt.legend()
plt.grid(True)
plt.xticks(rotation=45)
plt.tight_layout()
chart_path = f'{compare_dir}/{product_id}_model_comparison.png'
plt.savefig(chart_path)
plt.close() # 关闭图表,释放内存
# 获取当前服务器主机和端口
host_url = request.host_url.rstrip('/') # 移除末尾的斜杠
# 生成带时间戳的URL以避免缓存
timestamp = datetime.now().timestamp()
image_url = f"{host_url}/api/predictions/compare/{product_id}_model_comparison.png?t={timestamp}"
return jsonify({
"status": "success",
"data": {
"product_id": product_id,
"product_name": product_name,
"model_types": model_types,
"comparison": comparison_data,
"visualization_url": image_url
}
})
except Exception as e:
traceback.print_exc()
return jsonify({"status": "error", "error": f"比较预测时出错: {e}"}), 500
# 4. 模型管理API
@app.route('/api/models', methods=['GET'])
@swag_from({
'tags': ['模型管理'],
'summary': '获取模型列表',
'description': '获取系统中的模型列表可按产品ID和模型类型筛选',
'parameters': [
{
'name': 'product_id',
'in': 'query',
'type': 'string',
'required': False,
'description': '按产品ID筛选'
},
{
'name': 'model_type',
'in': 'query',
'type': 'string',
'required': False,
'description': '按模型类型筛选'
}
],
'responses': {
200: {
'description': '成功获取模型列表',
'schema': {
'type': 'object',
'properties': {
'status': {'type': 'string'},
'data': {
'type': 'array',
'items': {
'type': 'object',
'properties': {
'model_id': {'type': 'string'},
'product_id': {'type': 'string'},
'product_name': {'type': 'string'},
'model_type': {'type': 'string'},
'created_at': {'type': 'string'},
'metrics': {'type': 'object'}
}
}
}
}
}
},
500: {
'description': '服务器内部错误'
}
}
})
def list_models():
"""
列出所有可用的模型
---
tags:
- 模型管理
parameters:
- name: product_id
in: query
type: string
required: false
description: 按产品ID筛选模型
- name: model_type
in: query
type: string
required: false
description: 按模型类型筛选 (mlstm, kan, transformer)
responses:
200:
description: 模型列表
schema:
type: object
properties:
status:
type: string
example: success
data:
type: array
items:
type: object
properties:
model_id:
type: string
product_id:
type: string
product_name:
type: string
model_type:
type: string
created_at:
type: string
metrics:
type: object
"""
models_dir = 'models'
model_types = ['mlstm', 'kan', 'transformer']
available_models = []
product_id_filter = request.args.get('product_id')
model_type_filter = request.args.get('model_type')
if not os.path.exists(models_dir):
return jsonify({"status": "success", "data": []})
for model_type in model_types:
if model_type_filter and model_type_filter != model_type:
continue
type_dir = os.path.join(models_dir, model_type)
if not os.path.exists(type_dir):
continue
for file_name in os.listdir(type_dir):
if file_name.endswith('_log.json'):
product_id = file_name.replace('_log.json', '')
if product_id_filter and product_id_filter != product_id:
continue
log_path = os.path.join(type_dir, file_name)
try:
with open(log_path, 'r', encoding='utf-8') as f:
log_data = json.load(f)
model_info = {
"model_id": f"{model_type}_{product_id}",
"product_id": log_data.get('product_id'),
"product_name": log_data.get('product_name'),
"model_type": log_data.get('model_type'),
"created_at": log_data.get('training_completed_at'),
"metrics": log_data.get('metrics'),
"file_path": log_data.get('file_path')
}
available_models.append(model_info)
except Exception as e:
print(f"读取日志文件 {log_path} 失败: {e}")
# 按创建时间降序排序
available_models.sort(key=lambda x: x.get('created_at', ''), reverse=True)
return jsonify({"status": "success", "data": available_models})
@app.route('/api/models/<model_id>', methods=['GET'])
@swag_from({
'tags': ['模型管理'],
'summary': '获取模型详情',
'description': '获取特定模型的详细信息',
'parameters': [
{
'name': 'model_id',
'in': 'path',
'type': 'string',
'required': True,
'description': '模型ID格式为{product_id}_{model_type}_v{version}'
}
],
'responses': {
200: {
'description': '成功获取模型详情',
'schema': {
'type': 'object',
'properties': {
'status': {'type': 'string'},
'data': {
'type': 'object',
'properties': {
'product_id': {'type': 'string'},
'model_type': {'type': 'string'},
'version': {'type': 'string'},
'created_at': {'type': 'string'},
'file_path': {'type': 'string'},
'file_size': {'type': 'string'},
'features': {'type': 'array'},
'look_back': {'type': 'integer'},
'T': {'type': 'integer'},
'metrics': {'type': 'object'}
}
}
}
}
},
400: {
'description': '无效的模型ID格式'
},
404: {
'description': '模型不存在'
},
500: {
'description': '服务器内部错误'
}
}
})
def get_model_details(model_id):
"""
获取单个模型的详细信息
---
tags:
- 模型管理
parameters:
- name: model_id
in: path
type: string
required: true
description: 模型的唯一标识符 (格式: model_type_product_id)
responses:
200:
description: 模型的详细信息
404:
description: 未找到模型
"""
try:
model_type, product_id = model_id.split('_', 1)
log_path = os.path.join('models', model_type, f'{product_id}_log.json')
if not os.path.exists(log_path):
return jsonify({"status": "error", "error": "模型未找到"}), 404
with open(log_path, 'r', encoding='utf-8') as f:
log_data = json.load(f)
# 还可以添加从 .pt 文件读取更多信息的逻辑
# checkpoint = torch.load(log_data['file_path'], map_location='cpu')
# log_data['details_from_pt'] = ...
return jsonify({"status": "success", "data": log_data})
except ValueError:
return jsonify({"status": "error", "error": "无效的model_id格式"}), 400
except Exception as e:
return jsonify({"status": "error", "error": f"获取模型详情失败: {e}"}), 500
@app.route('/api/models/<model_id>', methods=['DELETE'])
@swag_from({
'tags': ['模型管理'],
'summary': '删除模型',
'description': '删除特定模型',
'parameters': [
{
'name': 'model_id',
'in': 'path',
'type': 'string',
'required': True,
'description': '模型ID格式为{product_id}_{model_type}_v{version}'
}
],
'responses': {
200: {
'description': '模型删除成功',
'schema': {
'type': 'object',
'properties': {
'status': {'type': 'string'},
'message': {'type': 'string'}
}
}
},
400: {
'description': '无效的模型ID格式'
},
500: {
'description': '服务器内部错误'
}
}
})
def delete_model(model_id):
"""
删除一个模型及其关联文件
---
tags:
- 模型管理
parameters:
- name: model_id
in: path
type: string
required: true
description: 要删除的模型的ID (格式: model_type_product_id)
responses:
200:
description: 模型删除成功
404:
description: 模型未找到
"""
try:
model_type, product_id = model_id.split('_', 1)
model_dir = os.path.join('models', model_type)
model_path = os.path.join(model_dir, f'{product_id}_model.pt')
log_path = os.path.join(model_dir, f'{product_id}_log.json')
if not os.path.exists(model_path) and not os.path.exists(log_path):
return jsonify({"status": "error", "error": "模型未找到"}), 404
if os.path.exists(model_path):
os.remove(model_path)
if os.path.exists(log_path):
os.remove(log_path)
return jsonify({"status": "success", "message": f"模型 {model_id} 已删除"})
except ValueError:
return jsonify({"status": "error", "error": "无效的model_id格式"}), 400
except Exception as e:
return jsonify({"status": "error", "error": f"删除模型失败: {e}"}), 500
@app.route('/api/models/<model_id>/export', methods=['GET'])
@swag_from({
'tags': ['模型管理'],
'summary': '导出模型',
'description': '导出特定模型文件',
'parameters': [
{
'name': 'model_id',
'in': 'path',
'type': 'string',
'required': True,
'description': '模型ID格式为{product_id}_{model_type}_v{version}'
}
],
'responses': {
200: {
'description': '模型文件下载',
'content': {
'application/octet-stream': {}
}
},
400: {
'description': '无效的模型ID格式'
},
500: {
'description': '服务器内部错误'
}
}
})
def export_model(model_id):
try:
model_type, product_id = model_id.split('_', 1)
model_path = os.path.join('models', model_type, f'{product_id}_model.pt')
if not os.path.exists(model_path):
return jsonify({"status": "error", "error": "模型文件未找到"}), 404
return send_file(
model_path,
as_attachment=True,
download_name=f'{model_id}.pt',
mimetype='application/octet-stream'
)
except Exception as e:
return jsonify({"status": "error", "error": f"导出模型失败: {e}"}), 500
@app.route('/api/models/import', methods=['POST'])
@swag_from({
'tags': ['模型管理'],
'summary': '导入模型',
'description': '导入模型文件',
'consumes': ['multipart/form-data'],
'parameters': [
{
'name': 'file',
'in': 'formData',
'type': 'file',
'required': True,
'description': 'PyTorch模型文件(.pt)'
}
],
'responses': {
200: {
'description': '模型导入成功',
'schema': {
'type': 'object',
'properties': {
'status': {'type': 'string'},
'message': {'type': 'string'},
'data': {
'type': 'object',
'properties': {
'model_path': {'type': 'string'}
}
}
}
}
},
400: {
'description': '请求错误,文件格式不正确或缺少文件'
},
500: {
'description': '服务器内部错误'
}
}
})
def import_model():
try:
if 'file' not in request.files:
return jsonify({"status": "error", "message": "没有上传文件"}), 400
file = request.files['file']
if file.filename == '' or not file.filename.endswith('.pt'):
return jsonify({"status": "error", "message": "请上传有效的.pt模型文件"}), 400
# 从文件名解析 product_id, model_type
# 假设文件名格式为 `transformer_P001.pt`
try:
model_type, product_id_ext = file.filename.split('_', 1)
product_id = os.path.splitext(product_id_ext)[0]
model_types = ['mlstm', 'kan', 'transformer']
if model_type not in model_types:
raise ValueError("无效的模型类型")
except ValueError:
return jsonify({
"status": "error",
"message": "文件名格式不正确,应为 'model_type_product_id.pt' (例如: mlstm_P001.pt)"
}), 400
# 创建目标目录并保存文件
model_dir = os.path.join('models', model_type)
os.makedirs(model_dir, exist_ok=True)
model_path = os.path.join(model_dir, f'{product_id}_model.pt')
# 检查是否需要创建关联的日志文件
log_path = os.path.join(model_dir, f'{product_id}_log.json')
if not os.path.exists(log_path):
# 尝试从.pt文件加载信息如果可能
try:
checkpoint = torch.load(file, map_location='cpu')
# 重新定位文件指针到开头,以便 `file.save` 正常工作
file.seek(0)
except Exception:
checkpoint = {} # 如果加载失败创建一个空的checkpoint
# 创建一个基础的log文件
log_data = {
'product_id': product_id,
'product_name': f"导入的产品 {product_id}", # 可能需要用户后续编辑
'model_type': model_type,
'training_completed_at': datetime.now().isoformat(),
'epochs': checkpoint.get('epochs', 'N/A'),
'metrics': checkpoint.get('metrics', {'info': '导入的模型,无详细指标'}),
'file_path': model_path
}
with open(log_path, 'w', encoding='utf-8') as f:
json.dump(log_data, f, indent=4, ensure_ascii=False)
# 保存模型文件
file.save(model_path)
return jsonify({
"status": "success",
"message": "模型已成功导入",
"data": {"model_path": model_path}
})
except Exception as e:
traceback.print_exc()
return jsonify({"status": "error", "error": f"导入模型失败: {e}"}), 500
@app.route('/api/plots/<filename>')
def get_plot(filename):
"""Serve a plot file from the root directory."""
try:
return send_from_directory(app.root_path, filename)
except FileNotFoundError:
return jsonify({"status": "error", "error": "Plot not found"}), 404
@app.route('/api/csv/<filename>')
def get_csv(filename):
"""Serve a CSV file from the root directory."""
try:
return send_from_directory(app.root_path, filename, as_attachment=True)
except FileNotFoundError:
return jsonify({"status": "error", "error": "CSV file not found"}), 404
# 添加静态UI路由将/ui路径映射到wwwroot目录
@app.route('/ui/', defaults={'path': 'index.html'})
@app.route('/ui/<path:path>')
def serve_ui(path):
"""提供UI静态文件服务将/ui路径映射到wwwroot目录"""
try:
# 从wwwroot目录提供静态文件
return send_from_directory('wwwroot', path)
except FileNotFoundError:
# 如果是子路径请求(例如/ui/about)但文件不存在尝试返回index.html以支持SPA路由
if path != 'index.html':
try:
return send_from_directory('wwwroot', 'index.html')
except FileNotFoundError:
return jsonify({"status": "error", "error": "UI files not found"}), 404
return jsonify({"status": "error", "error": f"UI file {path} not found"}), 404
@app.route('/api/predictions/<model_type>/<product_id>/<filename>')
def get_prediction_file(model_type, product_id, filename):
"""按模型类型和产品ID获取预测文件"""
try:
# 构建文件路径
file_path = os.path.join('predictions', model_type, product_id)
# 如果是图片文件,添加防缓存头
if filename.endswith('.png'):
response = send_from_directory(file_path, filename)
# 强制浏览器不缓存图片
response.headers['Cache-Control'] = 'no-store, no-cache, must-revalidate, max-age=0, post-check=0, pre-check=0'
response.headers['Pragma'] = 'no-cache'
response.headers['Expires'] = '0'
response.headers['Last-Modified'] = datetime.now().strftime("%a, %d %b %Y %H:%M:%S GMT")
response.headers['Vary'] = '*'
print(f"提供图片文件 {filename} 并添加防缓存头")
return response
else:
return send_from_directory(file_path, filename)
except FileNotFoundError:
return jsonify({"status": "error", "error": f"预测文件 {filename} 未找到"}), 404
@app.route('/api/predictions/compare/<filename>')
def get_compare_file(filename):
"""
提供比较结果文件的下载
"""
directory = os.path.join(current_dir, 'predictions', 'compare')
return send_from_directory(directory, filename)
if __name__ == '__main__':
# 检查可用的设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
# 运行Flask应用
# 在生产环境中应使用Gunicorn或uWSGI等WSGI服务器
# 例如: gunicorn --workers 4 --bind 0.0.0.0:5000 api:app
# 使用--host=0.0.0.0可以使服务在局域网内可访问
app.run(host='0.0.0.0', port=5000, debug=True)