1469 lines
50 KiB
Python
1469 lines
50 KiB
Python
import os
|
||
import sys
|
||
|
||
# 获取当前脚本所在目录的绝对路径
|
||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||
|
||
# 将当前目录添加到系统路径
|
||
sys.path.append(current_dir)
|
||
|
||
# 或者添加父目录
|
||
#parent_dir = os.path.dirname(current_dir)
|
||
#sys.path.append(parent_dir)
|
||
|
||
|
||
from flask import Flask, request, jsonify, send_file, redirect, send_from_directory
|
||
from flask_cors import CORS
|
||
import os
|
||
import uuid
|
||
import pandas as pd
|
||
from pharmacy_predictor import (
|
||
train_product_model_with_mlstm,
|
||
train_product_model_with_kan,
|
||
train_product_model_with_transformer,
|
||
load_model_and_predict
|
||
)
|
||
import threading
|
||
import base64
|
||
import matplotlib.pyplot as plt
|
||
from io import BytesIO
|
||
import json
|
||
from flasgger import Swagger, swag_from
|
||
import argparse
|
||
from datetime import datetime
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
from threading import Lock
|
||
import traceback
|
||
import torch
|
||
|
||
|
||
|
||
app = Flask(__name__)
|
||
CORS(app) # 启用CORS支持
|
||
|
||
# Swagger配置
|
||
swagger_config = {
|
||
"headers": [],
|
||
"specs": [
|
||
{
|
||
"endpoint": "apispec",
|
||
"route": "/apispec.json",
|
||
"rule_filter": lambda rule: True, # 包含所有路由
|
||
"model_filter": lambda tag: True, # 包含所有模型
|
||
}
|
||
],
|
||
"static_url_path": "/flasgger_static",
|
||
"swagger_ui": True,
|
||
"specs_route": "/swagger/"
|
||
}
|
||
|
||
swagger_template = {
|
||
"swagger": "2.0",
|
||
"info": {
|
||
"title": "药店销售预测系统API",
|
||
"description": "用于药店销售预测的RESTful API",
|
||
"version": "1.0.0",
|
||
"contact": {
|
||
"name": "API开发团队",
|
||
"email": "support@example.com"
|
||
}
|
||
},
|
||
"tags": [
|
||
{
|
||
"name": "数据管理",
|
||
"description": "数据上传和查询相关接口"
|
||
},
|
||
{
|
||
"name": "模型训练",
|
||
"description": "模型训练相关接口"
|
||
},
|
||
{
|
||
"name": "模型预测",
|
||
"description": "预测销售数据相关接口"
|
||
},
|
||
{
|
||
"name": "模型管理",
|
||
"description": "模型查询、导出和删除接口"
|
||
}
|
||
]
|
||
}
|
||
|
||
swagger = Swagger(app, config=swagger_config, template=swagger_template)
|
||
|
||
# 存储训练任务状态
|
||
training_tasks = {}
|
||
tasks_lock = Lock()
|
||
|
||
# 线程池用于后台训练
|
||
executor = ThreadPoolExecutor(max_workers=2)
|
||
|
||
# 辅助函数:将图像转换为Base64
|
||
def fig_to_base64(fig):
|
||
buf = BytesIO()
|
||
fig.savefig(buf, format='png')
|
||
buf.seek(0)
|
||
img_str = base64.b64encode(buf.read()).decode('utf-8')
|
||
return img_str
|
||
|
||
# 根路由 - 重定向到UI界面
|
||
@app.route('/')
|
||
def index():
|
||
"""重定向到UI界面"""
|
||
return redirect('/ui/')
|
||
|
||
# Swagger UI路由
|
||
@app.route('/swagger')
|
||
def swagger_ui():
|
||
"""重定向到Swagger UI文档页面"""
|
||
return redirect('/swagger/')
|
||
|
||
# 1. 数据管理API
|
||
@app.route('/api/products', methods=['GET'])
|
||
@swag_from({
|
||
'tags': ['数据管理'],
|
||
'summary': '获取所有产品列表',
|
||
'description': '返回系统中所有产品的ID和名称',
|
||
'responses': {
|
||
200: {
|
||
'description': '成功获取产品列表',
|
||
'schema': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'status': {'type': 'string'},
|
||
'data': {
|
||
'type': 'array',
|
||
'items': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'product_id': {'type': 'string'},
|
||
'product_name': {'type': 'string'}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
},
|
||
500: {
|
||
'description': '服务器内部错误'
|
||
}
|
||
}
|
||
})
|
||
def get_products():
|
||
try:
|
||
df = pd.read_excel('pharmacy_sales.xlsx')
|
||
products = df[['product_id', 'product_name']].drop_duplicates().to_dict('records')
|
||
return jsonify({"status": "success", "data": products})
|
||
except Exception as e:
|
||
return jsonify({"status": "error", "message": str(e)}), 500
|
||
|
||
@app.route('/api/products/<product_id>', methods=['GET'])
|
||
@swag_from({
|
||
'tags': ['数据管理'],
|
||
'summary': '获取单个产品详情',
|
||
'description': '返回指定产品ID的详细信息',
|
||
'parameters': [
|
||
{
|
||
'name': 'product_id',
|
||
'in': 'path',
|
||
'type': 'string',
|
||
'required': True,
|
||
'description': '产品ID,例如P001'
|
||
}
|
||
],
|
||
'responses': {
|
||
200: {
|
||
'description': '成功获取产品详情',
|
||
'schema': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'status': {'type': 'string'},
|
||
'data': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'product_id': {'type': 'string'},
|
||
'product_name': {'type': 'string'},
|
||
'data_points': {'type': 'integer'},
|
||
'date_range': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'start': {'type': 'string'},
|
||
'end': {'type': 'string'}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
},
|
||
404: {
|
||
'description': '产品不存在'
|
||
},
|
||
500: {
|
||
'description': '服务器内部错误'
|
||
}
|
||
}
|
||
})
|
||
def get_product(product_id):
|
||
try:
|
||
df = pd.read_excel('pharmacy_sales.xlsx')
|
||
product_df = df[df['product_id'] == product_id]
|
||
|
||
if product_df.empty:
|
||
return jsonify({"status": "error", "message": "产品不存在"}), 404
|
||
|
||
product_info = {
|
||
"product_id": product_id,
|
||
"product_name": product_df['product_name'].iloc[0],
|
||
"data_points": len(product_df),
|
||
"date_range": {
|
||
"start": product_df['date'].min().strftime('%Y-%m-%d'),
|
||
"end": product_df['date'].max().strftime('%Y-%m-%d')
|
||
}
|
||
}
|
||
|
||
return jsonify({"status": "success", "data": product_info})
|
||
except Exception as e:
|
||
return jsonify({"status": "error", "message": str(e)}), 500
|
||
|
||
@app.route('/api/products/<product_id>/sales', methods=['GET'])
|
||
@swag_from({
|
||
'tags': ['数据管理'],
|
||
'summary': '获取产品销售数据',
|
||
'description': '返回指定产品在特定日期范围内的销售数据',
|
||
'parameters': [
|
||
{
|
||
'name': 'product_id',
|
||
'in': 'path',
|
||
'type': 'string',
|
||
'required': True,
|
||
'description': '产品ID,例如P001'
|
||
},
|
||
{
|
||
'name': 'start_date',
|
||
'in': 'query',
|
||
'type': 'string',
|
||
'required': False,
|
||
'description': '开始日期,格式为YYYY-MM-DD'
|
||
},
|
||
{
|
||
'name': 'end_date',
|
||
'in': 'query',
|
||
'type': 'string',
|
||
'required': False,
|
||
'description': '结束日期,格式为YYYY-MM-DD'
|
||
}
|
||
],
|
||
'responses': {
|
||
200: {
|
||
'description': '成功获取销售数据',
|
||
'schema': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'status': {'type': 'string'},
|
||
'data': {
|
||
'type': 'array',
|
||
'items': {
|
||
'type': 'object'
|
||
}
|
||
}
|
||
}
|
||
}
|
||
},
|
||
404: {
|
||
'description': '产品不存在'
|
||
},
|
||
500: {
|
||
'description': '服务器内部错误'
|
||
}
|
||
}
|
||
})
|
||
def get_product_sales(product_id):
|
||
try:
|
||
start_date = request.args.get('start_date')
|
||
end_date = request.args.get('end_date')
|
||
|
||
df = pd.read_excel('pharmacy_sales.xlsx')
|
||
product_df = df[df['product_id'] == product_id].sort_values('date')
|
||
|
||
if product_df.empty:
|
||
return jsonify({"status": "error", "message": "产品不存在"}), 404
|
||
|
||
# 如果提供了日期范围,进行过滤
|
||
if start_date:
|
||
product_df = product_df[product_df['date'] >= pd.to_datetime(start_date)]
|
||
if end_date:
|
||
product_df = product_df[product_df['date'] <= pd.to_datetime(end_date)]
|
||
|
||
# 转换日期为字符串以便JSON序列化
|
||
product_df['date'] = product_df['date'].dt.strftime('%Y-%m-%d')
|
||
|
||
sales_data = product_df.to_dict('records')
|
||
return jsonify({"status": "success", "data": sales_data})
|
||
except Exception as e:
|
||
return jsonify({"status": "error", "message": str(e)}), 500
|
||
|
||
@app.route('/api/data/upload', methods=['POST'])
|
||
@swag_from({
|
||
'tags': ['数据管理'],
|
||
'summary': '上传销售数据',
|
||
'description': '上传新的销售数据文件(Excel格式)',
|
||
'consumes': ['multipart/form-data'],
|
||
'parameters': [
|
||
{
|
||
'name': 'file',
|
||
'in': 'formData',
|
||
'type': 'file',
|
||
'required': True,
|
||
'description': 'Excel文件(.xlsx),包含销售数据'
|
||
}
|
||
],
|
||
'responses': {
|
||
200: {
|
||
'description': '数据上传成功',
|
||
'schema': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'status': {'type': 'string'},
|
||
'message': {'type': 'string'},
|
||
'data': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'products': {'type': 'integer'},
|
||
'rows': {'type': 'integer'}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
},
|
||
400: {
|
||
'description': '请求错误,可能是文件格式不正确或缺少必要字段'
|
||
},
|
||
500: {
|
||
'description': '服务器内部错误'
|
||
}
|
||
}
|
||
})
|
||
def upload_data():
|
||
try:
|
||
if 'file' not in request.files:
|
||
return jsonify({"status": "error", "message": "没有上传文件"}), 400
|
||
|
||
file = request.files['file']
|
||
if file.filename == '':
|
||
return jsonify({"status": "error", "message": "没有选择文件"}), 400
|
||
|
||
if file and file.filename.endswith('.xlsx'):
|
||
file_path = 'uploaded_data.xlsx'
|
||
file.save(file_path)
|
||
|
||
# 验证数据格式
|
||
try:
|
||
df = pd.read_excel(file_path)
|
||
required_columns = ['date', 'product_id', 'product_name', 'sales', 'price']
|
||
missing_columns = [col for col in required_columns if col not in df.columns]
|
||
|
||
if missing_columns:
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": f"上传的数据缺少必要的列: {', '.join(missing_columns)}"
|
||
}), 400
|
||
|
||
# 合并到现有数据或替换现有数据
|
||
existing_df = pd.read_excel('pharmacy_sales.xlsx')
|
||
# 这里可以实现数据合并逻辑,例如按日期和产品ID去重后合并
|
||
|
||
# 简单示例:保存上传的数据
|
||
df.to_excel('pharmacy_sales.xlsx', index=False)
|
||
|
||
return jsonify({
|
||
"status": "success",
|
||
"message": "数据上传成功",
|
||
"data": {
|
||
"products": len(df['product_id'].unique()),
|
||
"rows": len(df)
|
||
}
|
||
})
|
||
except Exception as e:
|
||
return jsonify({"status": "error", "message": f"数据验证失败: {str(e)}"}), 400
|
||
else:
|
||
return jsonify({"status": "error", "message": "只支持Excel文件(.xlsx)"}), 400
|
||
except Exception as e:
|
||
return jsonify({"status": "error", "message": str(e)}), 500
|
||
|
||
# 2. 模型训练API
|
||
@app.route('/api/training', methods=['GET'])
|
||
@swag_from({
|
||
'tags': ['模型训练'],
|
||
'summary': '获取所有训练任务列表',
|
||
'description': '返回所有正在进行、已完成或失败的训练任务',
|
||
'responses': {
|
||
200: {
|
||
'description': '成功获取任务列表',
|
||
'schema': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'status': {'type': 'string'},
|
||
'data': {
|
||
'type': 'array',
|
||
'items': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'task_id': {'type': 'string'},
|
||
'product_id': {'type': 'string'},
|
||
'model_type': {'type': 'string'},
|
||
'status': {'type': 'string'},
|
||
'start_time': {'type': 'string'},
|
||
'metrics': {'type': 'object'},
|
||
'error': {'type': 'string'},
|
||
'model_path': {'type': 'string'}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
})
|
||
def get_all_training_tasks():
|
||
"""获取所有训练任务的状态"""
|
||
try:
|
||
# 为了方便前端使用,我们将任务ID也包含在每个任务信息中
|
||
tasks_with_id = []
|
||
for task_id, task_info in training_tasks.items():
|
||
task_copy = task_info.copy()
|
||
task_copy['task_id'] = task_id
|
||
tasks_with_id.append(task_copy)
|
||
|
||
# 按开始时间降序排序,最新的任务在前面
|
||
sorted_tasks = sorted(tasks_with_id, key=lambda x: x['start_time'], reverse=True)
|
||
|
||
return jsonify({"status": "success", "data": sorted_tasks})
|
||
except Exception as e:
|
||
return jsonify({"status": "error", "message": str(e)}), 500
|
||
|
||
@app.route('/api/training', methods=['POST'])
|
||
@swag_from({
|
||
'tags': ['模型训练'],
|
||
'summary': '启动模型训练任务',
|
||
'description': '为指定产品启动一个新的模型训练任务',
|
||
'parameters': [
|
||
{
|
||
'name': 'body',
|
||
'in': 'body',
|
||
'required': True,
|
||
'schema': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'product_id': {'type': 'string', 'description': '例如 P001'},
|
||
'model_type': {'type': 'string', 'enum': ['mlstm', 'transformer', 'kan'], 'description': '要训练的模型类型'},
|
||
'epochs': {'type': 'integer', 'default': 50, 'description': '训练轮次'}
|
||
},
|
||
'required': ['product_id', 'model_type']
|
||
}
|
||
}
|
||
],
|
||
'responses': {
|
||
200: {
|
||
'description': '训练任务已启动',
|
||
'schema': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'message': {'type': 'string'},
|
||
'task_id': {'type': 'string'}
|
||
}
|
||
}
|
||
},
|
||
400: {
|
||
'description': '请求错误'
|
||
}
|
||
}
|
||
})
|
||
def start_training():
|
||
"""
|
||
启动模型训练
|
||
---
|
||
post:
|
||
...
|
||
"""
|
||
data = request.get_json()
|
||
product_id = data.get('product_id')
|
||
model_type = data.get('model_type')
|
||
epochs = data.get('epochs', 50) # 默认为50轮
|
||
|
||
if not product_id or not model_type:
|
||
return jsonify({'error': '缺少product_id或model_type'}), 400
|
||
|
||
global training_tasks
|
||
task_id = str(uuid.uuid4())
|
||
|
||
model_train_functions = {
|
||
'mlstm': train_product_model_with_mlstm,
|
||
'kan': train_product_model_with_kan,
|
||
'transformer': train_product_model_with_transformer
|
||
}
|
||
|
||
if model_type not in model_train_functions:
|
||
return jsonify({'error': '无效的模型类型'}), 400
|
||
|
||
train_function = model_train_functions[model_type]
|
||
|
||
def train_task(product_id, epochs, model_type):
|
||
global training_tasks
|
||
try:
|
||
print(f"任务 {task_id}: 开始为产品 {product_id} 训练 {model_type} 模型,共 {epochs} 个轮次。")
|
||
# 这里的 train_function 会返回 (model, metrics)
|
||
_, metrics = train_function(product_id, epochs)
|
||
training_tasks[task_id]['status'] = 'completed'
|
||
training_tasks[task_id]['metrics'] = metrics
|
||
# 保存模型路径
|
||
model_path = f'models/{model_type}/{product_id}_model.pt'
|
||
training_tasks[task_id]['model_path'] = model_path
|
||
print(f"任务 {task_id}: 训练完成。评估指标: {metrics}")
|
||
except Exception as e:
|
||
import traceback
|
||
traceback.print_exc()
|
||
print(f"任务 {task_id}: 训练失败。错误: {e}")
|
||
training_tasks[task_id]['status'] = 'failed'
|
||
training_tasks[task_id]['error'] = str(e)
|
||
|
||
thread = threading.Thread(target=train_task, args=(product_id, epochs, model_type))
|
||
thread.start()
|
||
|
||
training_tasks[task_id] = {
|
||
'status': 'running',
|
||
'product_id': product_id,
|
||
'model_type': model_type,
|
||
'start_time': datetime.now().isoformat(),
|
||
'metrics': None,
|
||
'error': None,
|
||
'model_path': None
|
||
}
|
||
|
||
return jsonify({'message': '模型训练已开始', 'task_id': task_id})
|
||
|
||
@app.route('/api/training/<task_id>', methods=['GET'])
|
||
@swag_from({
|
||
'tags': ['模型训练'],
|
||
'summary': '查询训练任务状态',
|
||
'description': '获取特定训练任务的当前状态和详情',
|
||
'parameters': [
|
||
{
|
||
'name': 'task_id',
|
||
'in': 'path',
|
||
'type': 'string',
|
||
'required': True,
|
||
'description': '训练任务ID'
|
||
}
|
||
],
|
||
'responses': {
|
||
200: {
|
||
'description': '成功获取任务状态',
|
||
'schema': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'status': {'type': 'string'},
|
||
'data': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'product_id': {'type': 'string'},
|
||
'model_type': {'type': 'string'},
|
||
'parameters': {'type': 'object'},
|
||
'status': {'type': 'string', 'enum': ['pending', 'running', 'completed', 'failed']},
|
||
'created_at': {'type': 'string'},
|
||
'model_path': {'type': 'string'},
|
||
'metrics': {'type': 'object'},
|
||
'model_details_url': {'type': 'string'}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
},
|
||
404: {
|
||
'description': '任务不存在'
|
||
},
|
||
500: {
|
||
'description': '服务器内部错误'
|
||
}
|
||
}
|
||
})
|
||
def get_training_status(task_id):
|
||
try:
|
||
if task_id not in training_tasks:
|
||
return jsonify({"status": "error", "message": "任务不存在"}), 404
|
||
|
||
task_info = training_tasks[task_id].copy()
|
||
|
||
# 如果任务已完成,添加模型详情链接
|
||
if task_info['status'] == 'completed':
|
||
task_info['model_details_url'] = f"/api/models?product_id={task_info['product_id']}&model_type={task_info['model_type']}"
|
||
|
||
return jsonify({
|
||
"status": "success",
|
||
"data": task_info
|
||
})
|
||
except Exception as e:
|
||
return jsonify({"status": "error", "message": str(e)}), 500
|
||
|
||
# 3. 模型预测API
|
||
@app.route('/api/prediction', methods=['POST'])
|
||
@swag_from({
|
||
'tags': ['模型预测'],
|
||
'summary': '使用模型进行预测',
|
||
'description': '使用指定模型预测未来销售数据',
|
||
'parameters': [
|
||
{
|
||
'name': 'body',
|
||
'in': 'body',
|
||
'required': True,
|
||
'schema': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'product_id': {'type': 'string'},
|
||
'model_type': {'type': 'string', 'enum': ['mlstm', 'transformer', 'kan']},
|
||
'version': {'type': 'string'},
|
||
'future_days': {'type': 'integer'},
|
||
'include_visualization': {'type': 'boolean'},
|
||
'start_date': {'type': 'string', 'description': '预测起始日期,格式为YYYY-MM-DD'}
|
||
},
|
||
'required': ['product_id', 'model_type']
|
||
}
|
||
}
|
||
],
|
||
'responses': {
|
||
200: {
|
||
'description': '预测成功',
|
||
'schema': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'status': {'type': 'string'},
|
||
'data': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'product_id': {'type': 'string'},
|
||
'product_name': {'type': 'string'},
|
||
'model_type': {'type': 'string'},
|
||
'predictions': {'type': 'array'},
|
||
'visualization': {'type': 'string'}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
},
|
||
400: {
|
||
'description': '请求错误,缺少必要参数或参数格式错误'
|
||
},
|
||
404: {
|
||
'description': '产品或模型不存在'
|
||
},
|
||
500: {
|
||
'description': '服务器内部错误'
|
||
}
|
||
}
|
||
})
|
||
def predict():
|
||
"""
|
||
使用指定的模型进行预测
|
||
---
|
||
tags:
|
||
- 模型预测
|
||
parameters:
|
||
- in: body
|
||
name: body
|
||
schema:
|
||
type: object
|
||
required:
|
||
- product_id
|
||
- model_type
|
||
properties:
|
||
product_id:
|
||
type: string
|
||
description: 产品ID
|
||
model_type:
|
||
type: string
|
||
description: 模型类型 (mlstm, kan, transformer)
|
||
future_days:
|
||
type: integer
|
||
description: 预测未来天数
|
||
default: 7
|
||
start_date:
|
||
type: string
|
||
description: 预测起始日期,格式为YYYY-MM-DD
|
||
default: ''
|
||
responses:
|
||
200:
|
||
description: 预测成功
|
||
400:
|
||
description: 请求参数错误
|
||
404:
|
||
description: 模型文件未找到
|
||
"""
|
||
data = request.json
|
||
product_id = data.get('product_id')
|
||
model_type = data.get('model_type')
|
||
future_days = data.get('future_days', 7)
|
||
start_date = data.get('start_date', '')
|
||
|
||
print(f"API接收到预测请求: product_id={product_id}, model_type={model_type}, future_days={future_days}, start_date={start_date}")
|
||
|
||
if not product_id or not model_type:
|
||
return jsonify({"status": "error", "error": "product_id 和 model_type 是必需的"}), 400
|
||
|
||
try:
|
||
result = load_model_and_predict(product_id, model_type, future_days, start_date=start_date)
|
||
if result is None:
|
||
return jsonify({"status": "error", "error": f"模型 'models/{model_type}/{product_id}_model.pt' 未找到或加载失败"}), 404
|
||
|
||
# 获取预测数据
|
||
predictions_df = result['predictions_df']
|
||
chart_path = result['chart_path']
|
||
csv_path = result['csv_path']
|
||
|
||
# 将DataFrame转换为JSON
|
||
predictions_json = predictions_df.to_dict(orient='records')
|
||
|
||
# 获取当前服务器主机和端口
|
||
host_url = request.host_url.rstrip('/') # 移除末尾的斜杠
|
||
|
||
response_data = {
|
||
"status": "success",
|
||
"data": predictions_json, # 现在包含历史和预测数据
|
||
"history_data": predictions_df[predictions_df['data_type'] == '历史销量'].to_dict(orient='records'),
|
||
"prediction_data": predictions_df[predictions_df['data_type'] == '预测销量'].to_dict(orient='records')
|
||
}
|
||
|
||
# 只有在文件确实存在时才添加URL
|
||
timestamp = datetime.now().timestamp()
|
||
|
||
# 获取文件名中的日期和预测天数部分
|
||
start_date = data.get('start_date', '')
|
||
future_days = data.get('future_days', 7)
|
||
|
||
if start_date:
|
||
try:
|
||
start_date_obj = datetime.strptime(start_date, '%Y-%m-%d')
|
||
start_date_str = start_date_obj.strftime('%Y%m%d')
|
||
except:
|
||
start_date_str = datetime.now().strftime('%Y%m%d')
|
||
else:
|
||
# 如果未提供日期,使用当前日期
|
||
start_date_str = datetime.now().strftime('%Y%m%d')
|
||
|
||
# 构建文件名
|
||
png_filename = f"forecast_{start_date_str}_days{future_days}.png"
|
||
csv_filename = f"forecast_{start_date_str}_days{future_days}.csv"
|
||
|
||
# 构建图片路径
|
||
png_path = os.path.join('predictions', model_type, product_id, png_filename)
|
||
csv_path = os.path.join('predictions', model_type, product_id, csv_filename)
|
||
history_png_filename = f"history_{start_date_str}.png"
|
||
history_png_path = os.path.join('predictions', model_type, product_id, history_png_filename)
|
||
|
||
# 检查预测图表
|
||
if os.path.exists(png_path):
|
||
# 构建完整的URL,包含主机名和端口
|
||
response_data["image_url"] = f"{host_url}/api/predictions/{model_type}/{product_id}/{png_filename}?t={timestamp}"
|
||
print(f"图表URL: {response_data['image_url']}")
|
||
else:
|
||
response_data["image_url"] = None
|
||
print(f"警告: 预测图表文件未生成或不存在: {png_path}")
|
||
|
||
# 检查历史图表
|
||
if os.path.exists(history_png_path):
|
||
response_data["history_image_url"] = f"{host_url}/api/predictions/{model_type}/{product_id}/{history_png_filename}?t={timestamp}"
|
||
print(f"历史图表URL: {response_data['history_image_url']}")
|
||
else:
|
||
response_data["history_image_url"] = None
|
||
if result and 'history_chart_path' in result and result['history_chart_path'] is None:
|
||
print(f"警告: 历史图表生成过程中出现错误,请检查服务器日志")
|
||
else:
|
||
print(f"警告: 历史图表文件未生成或不存在: {history_png_path}")
|
||
|
||
# 检查CSV文件
|
||
if os.path.exists(csv_path):
|
||
response_data["csv_url"] = f"{host_url}/api/predictions/{model_type}/{product_id}/{csv_filename}"
|
||
else:
|
||
response_data["csv_url"] = None
|
||
|
||
return jsonify(response_data)
|
||
|
||
except Exception as e:
|
||
traceback.print_exc()
|
||
return jsonify({"status": "error", "error": f"预测过程中发生错误: {str(e)}"}), 500
|
||
|
||
@app.route('/api/prediction/compare', methods=['POST'])
|
||
@swag_from({
|
||
'tags': ['模型预测'],
|
||
'summary': '比较不同模型预测结果',
|
||
'description': '比较不同模型对同一产品的预测结果',
|
||
'parameters': [
|
||
{
|
||
'name': 'body',
|
||
'in': 'body',
|
||
'required': True,
|
||
'schema': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'product_id': {'type': 'string'},
|
||
'model_types': {'type': 'array', 'items': {'type': 'string'}},
|
||
'versions': {'type': 'array', 'items': {'type': 'string'}},
|
||
'include_visualization': {'type': 'boolean'}
|
||
},
|
||
'required': ['product_id', 'model_types']
|
||
}
|
||
}
|
||
],
|
||
'responses': {
|
||
200: {
|
||
'description': '比较成功',
|
||
'schema': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'status': {'type': 'string'},
|
||
'data': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'product_id': {'type': 'string'},
|
||
'product_name': {'type': 'string'},
|
||
'model_types': {'type': 'array'},
|
||
'comparison': {'type': 'array'},
|
||
'visualization': {'type': 'string'}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
},
|
||
400: {
|
||
'description': '请求错误,缺少必要参数或参数格式错误'
|
||
},
|
||
404: {
|
||
'description': '产品或模型不存在'
|
||
},
|
||
500: {
|
||
'description': '服务器内部错误'
|
||
}
|
||
}
|
||
})
|
||
def compare_predictions():
|
||
try:
|
||
data = request.json
|
||
|
||
product_id = data.get('product_id')
|
||
model_types = data.get('model_types')
|
||
|
||
if not product_id or not model_types:
|
||
return jsonify({"status": "error", "error": "product_id 和 model_types 是必需的"}), 400
|
||
|
||
all_predictions = {}
|
||
plt.figure(figsize=(12, 8))
|
||
|
||
# 加载历史数据用于绘图
|
||
df = pd.read_excel('pharmacy_sales.xlsx')
|
||
product_df = df[df['product_id'] == product_id].sort_values('date')
|
||
product_name = product_df['product_name'].iloc[0]
|
||
history_days = 30
|
||
history_dates = product_df['date'].iloc[-history_days:].values
|
||
history_sales = product_df['sales'].iloc[-history_days:].values
|
||
plt.plot(history_dates, history_sales, 'b-', label='历史销量')
|
||
|
||
comparison_data = []
|
||
future_dates = None
|
||
|
||
# 创建比较结果目录
|
||
compare_dir = f'predictions/compare'
|
||
os.makedirs(compare_dir, exist_ok=True)
|
||
|
||
for model_type in model_types:
|
||
result = load_model_and_predict(product_id, model_type)
|
||
if result is not None:
|
||
predictions_df = result['predictions_df']
|
||
|
||
if future_dates is None:
|
||
future_dates = predictions_df['date']
|
||
comparison_data = [{'date': d.strftime('%Y-%m-%d')} for d in future_dates]
|
||
|
||
plt.plot(predictions_df['date'], predictions_df['predicted_sales'], '--', label=f'{model_type.upper()} 预测')
|
||
|
||
preds = predictions_df['predicted_sales'].tolist()
|
||
for i in range(len(comparison_data)):
|
||
comparison_data[i][f'{model_type}_prediction'] = preds[i] if i < len(preds) else None
|
||
|
||
plt.title(f'{product_name} - 多模型预测比较')
|
||
plt.xlabel('日期')
|
||
plt.ylabel('销量')
|
||
plt.legend()
|
||
plt.grid(True)
|
||
plt.xticks(rotation=45)
|
||
plt.tight_layout()
|
||
|
||
chart_path = f'{compare_dir}/{product_id}_model_comparison.png'
|
||
plt.savefig(chart_path)
|
||
plt.close() # 关闭图表,释放内存
|
||
|
||
# 获取当前服务器主机和端口
|
||
host_url = request.host_url.rstrip('/') # 移除末尾的斜杠
|
||
|
||
# 生成带时间戳的URL以避免缓存
|
||
timestamp = datetime.now().timestamp()
|
||
image_url = f"{host_url}/api/predictions/compare/{product_id}_model_comparison.png?t={timestamp}"
|
||
|
||
return jsonify({
|
||
"status": "success",
|
||
"data": {
|
||
"product_id": product_id,
|
||
"product_name": product_name,
|
||
"model_types": model_types,
|
||
"comparison": comparison_data,
|
||
"visualization_url": image_url
|
||
}
|
||
})
|
||
|
||
except Exception as e:
|
||
traceback.print_exc()
|
||
return jsonify({"status": "error", "error": f"比较预测时出错: {e}"}), 500
|
||
|
||
# 4. 模型管理API
|
||
@app.route('/api/models', methods=['GET'])
|
||
@swag_from({
|
||
'tags': ['模型管理'],
|
||
'summary': '获取模型列表',
|
||
'description': '获取系统中的模型列表,可按产品ID和模型类型筛选',
|
||
'parameters': [
|
||
{
|
||
'name': 'product_id',
|
||
'in': 'query',
|
||
'type': 'string',
|
||
'required': False,
|
||
'description': '按产品ID筛选'
|
||
},
|
||
{
|
||
'name': 'model_type',
|
||
'in': 'query',
|
||
'type': 'string',
|
||
'required': False,
|
||
'description': '按模型类型筛选'
|
||
}
|
||
],
|
||
'responses': {
|
||
200: {
|
||
'description': '成功获取模型列表',
|
||
'schema': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'status': {'type': 'string'},
|
||
'data': {
|
||
'type': 'array',
|
||
'items': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'model_id': {'type': 'string'},
|
||
'product_id': {'type': 'string'},
|
||
'product_name': {'type': 'string'},
|
||
'model_type': {'type': 'string'},
|
||
'created_at': {'type': 'string'},
|
||
'metrics': {'type': 'object'}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
},
|
||
500: {
|
||
'description': '服务器内部错误'
|
||
}
|
||
}
|
||
})
|
||
def list_models():
|
||
"""
|
||
列出所有可用的模型
|
||
---
|
||
tags:
|
||
- 模型管理
|
||
parameters:
|
||
- name: product_id
|
||
in: query
|
||
type: string
|
||
required: false
|
||
description: 按产品ID筛选模型
|
||
- name: model_type
|
||
in: query
|
||
type: string
|
||
required: false
|
||
description: 按模型类型筛选 (mlstm, kan, transformer)
|
||
responses:
|
||
200:
|
||
description: 模型列表
|
||
schema:
|
||
type: object
|
||
properties:
|
||
status:
|
||
type: string
|
||
example: success
|
||
data:
|
||
type: array
|
||
items:
|
||
type: object
|
||
properties:
|
||
model_id:
|
||
type: string
|
||
product_id:
|
||
type: string
|
||
product_name:
|
||
type: string
|
||
model_type:
|
||
type: string
|
||
created_at:
|
||
type: string
|
||
metrics:
|
||
type: object
|
||
"""
|
||
models_dir = 'models'
|
||
model_types = ['mlstm', 'kan', 'transformer']
|
||
available_models = []
|
||
|
||
product_id_filter = request.args.get('product_id')
|
||
model_type_filter = request.args.get('model_type')
|
||
|
||
if not os.path.exists(models_dir):
|
||
return jsonify({"status": "success", "data": []})
|
||
|
||
for model_type in model_types:
|
||
if model_type_filter and model_type_filter != model_type:
|
||
continue
|
||
|
||
type_dir = os.path.join(models_dir, model_type)
|
||
if not os.path.exists(type_dir):
|
||
continue
|
||
|
||
for file_name in os.listdir(type_dir):
|
||
if file_name.endswith('_log.json'):
|
||
product_id = file_name.replace('_log.json', '')
|
||
|
||
if product_id_filter and product_id_filter != product_id:
|
||
continue
|
||
|
||
log_path = os.path.join(type_dir, file_name)
|
||
try:
|
||
with open(log_path, 'r', encoding='utf-8') as f:
|
||
log_data = json.load(f)
|
||
model_info = {
|
||
"model_id": f"{model_type}_{product_id}",
|
||
"product_id": log_data.get('product_id'),
|
||
"product_name": log_data.get('product_name'),
|
||
"model_type": log_data.get('model_type'),
|
||
"created_at": log_data.get('training_completed_at'),
|
||
"metrics": log_data.get('metrics'),
|
||
"file_path": log_data.get('file_path')
|
||
}
|
||
available_models.append(model_info)
|
||
except Exception as e:
|
||
print(f"读取日志文件 {log_path} 失败: {e}")
|
||
|
||
# 按创建时间降序排序
|
||
available_models.sort(key=lambda x: x.get('created_at', ''), reverse=True)
|
||
|
||
return jsonify({"status": "success", "data": available_models})
|
||
|
||
@app.route('/api/models/<model_id>', methods=['GET'])
|
||
@swag_from({
|
||
'tags': ['模型管理'],
|
||
'summary': '获取模型详情',
|
||
'description': '获取特定模型的详细信息',
|
||
'parameters': [
|
||
{
|
||
'name': 'model_id',
|
||
'in': 'path',
|
||
'type': 'string',
|
||
'required': True,
|
||
'description': '模型ID,格式为{product_id}_{model_type}_v{version}'
|
||
}
|
||
],
|
||
'responses': {
|
||
200: {
|
||
'description': '成功获取模型详情',
|
||
'schema': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'status': {'type': 'string'},
|
||
'data': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'product_id': {'type': 'string'},
|
||
'model_type': {'type': 'string'},
|
||
'version': {'type': 'string'},
|
||
'created_at': {'type': 'string'},
|
||
'file_path': {'type': 'string'},
|
||
'file_size': {'type': 'string'},
|
||
'features': {'type': 'array'},
|
||
'look_back': {'type': 'integer'},
|
||
'T': {'type': 'integer'},
|
||
'metrics': {'type': 'object'}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
},
|
||
400: {
|
||
'description': '无效的模型ID格式'
|
||
},
|
||
404: {
|
||
'description': '模型不存在'
|
||
},
|
||
500: {
|
||
'description': '服务器内部错误'
|
||
}
|
||
}
|
||
})
|
||
def get_model_details(model_id):
|
||
"""
|
||
获取单个模型的详细信息
|
||
---
|
||
tags:
|
||
- 模型管理
|
||
parameters:
|
||
- name: model_id
|
||
in: path
|
||
type: string
|
||
required: true
|
||
description: 模型的唯一标识符 (格式: model_type_product_id)
|
||
responses:
|
||
200:
|
||
description: 模型的详细信息
|
||
404:
|
||
description: 未找到模型
|
||
"""
|
||
try:
|
||
model_type, product_id = model_id.split('_', 1)
|
||
log_path = os.path.join('models', model_type, f'{product_id}_log.json')
|
||
|
||
if not os.path.exists(log_path):
|
||
return jsonify({"status": "error", "error": "模型未找到"}), 404
|
||
|
||
with open(log_path, 'r', encoding='utf-8') as f:
|
||
log_data = json.load(f)
|
||
|
||
# 还可以添加从 .pt 文件读取更多信息的逻辑
|
||
# checkpoint = torch.load(log_data['file_path'], map_location='cpu')
|
||
# log_data['details_from_pt'] = ...
|
||
|
||
return jsonify({"status": "success", "data": log_data})
|
||
except ValueError:
|
||
return jsonify({"status": "error", "error": "无效的model_id格式"}), 400
|
||
except Exception as e:
|
||
return jsonify({"status": "error", "error": f"获取模型详情失败: {e}"}), 500
|
||
|
||
@app.route('/api/models/<model_id>', methods=['DELETE'])
|
||
@swag_from({
|
||
'tags': ['模型管理'],
|
||
'summary': '删除模型',
|
||
'description': '删除特定模型',
|
||
'parameters': [
|
||
{
|
||
'name': 'model_id',
|
||
'in': 'path',
|
||
'type': 'string',
|
||
'required': True,
|
||
'description': '模型ID,格式为{product_id}_{model_type}_v{version}'
|
||
}
|
||
],
|
||
'responses': {
|
||
200: {
|
||
'description': '模型删除成功',
|
||
'schema': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'status': {'type': 'string'},
|
||
'message': {'type': 'string'}
|
||
}
|
||
}
|
||
},
|
||
400: {
|
||
'description': '无效的模型ID格式'
|
||
},
|
||
500: {
|
||
'description': '服务器内部错误'
|
||
}
|
||
}
|
||
})
|
||
def delete_model(model_id):
|
||
"""
|
||
删除一个模型及其关联文件
|
||
---
|
||
tags:
|
||
- 模型管理
|
||
parameters:
|
||
- name: model_id
|
||
in: path
|
||
type: string
|
||
required: true
|
||
description: 要删除的模型的ID (格式: model_type_product_id)
|
||
responses:
|
||
200:
|
||
description: 模型删除成功
|
||
404:
|
||
description: 模型未找到
|
||
"""
|
||
try:
|
||
model_type, product_id = model_id.split('_', 1)
|
||
model_dir = os.path.join('models', model_type)
|
||
model_path = os.path.join(model_dir, f'{product_id}_model.pt')
|
||
log_path = os.path.join(model_dir, f'{product_id}_log.json')
|
||
|
||
if not os.path.exists(model_path) and not os.path.exists(log_path):
|
||
return jsonify({"status": "error", "error": "模型未找到"}), 404
|
||
|
||
if os.path.exists(model_path):
|
||
os.remove(model_path)
|
||
if os.path.exists(log_path):
|
||
os.remove(log_path)
|
||
|
||
return jsonify({"status": "success", "message": f"模型 {model_id} 已删除"})
|
||
except ValueError:
|
||
return jsonify({"status": "error", "error": "无效的model_id格式"}), 400
|
||
except Exception as e:
|
||
return jsonify({"status": "error", "error": f"删除模型失败: {e}"}), 500
|
||
|
||
@app.route('/api/models/<model_id>/export', methods=['GET'])
|
||
@swag_from({
|
||
'tags': ['模型管理'],
|
||
'summary': '导出模型',
|
||
'description': '导出特定模型文件',
|
||
'parameters': [
|
||
{
|
||
'name': 'model_id',
|
||
'in': 'path',
|
||
'type': 'string',
|
||
'required': True,
|
||
'description': '模型ID,格式为{product_id}_{model_type}_v{version}'
|
||
}
|
||
],
|
||
'responses': {
|
||
200: {
|
||
'description': '模型文件下载',
|
||
'content': {
|
||
'application/octet-stream': {}
|
||
}
|
||
},
|
||
400: {
|
||
'description': '无效的模型ID格式'
|
||
},
|
||
500: {
|
||
'description': '服务器内部错误'
|
||
}
|
||
}
|
||
})
|
||
def export_model(model_id):
|
||
try:
|
||
model_type, product_id = model_id.split('_', 1)
|
||
model_path = os.path.join('models', model_type, f'{product_id}_model.pt')
|
||
|
||
if not os.path.exists(model_path):
|
||
return jsonify({"status": "error", "error": "模型文件未找到"}), 404
|
||
|
||
return send_file(
|
||
model_path,
|
||
as_attachment=True,
|
||
download_name=f'{model_id}.pt',
|
||
mimetype='application/octet-stream'
|
||
)
|
||
except Exception as e:
|
||
return jsonify({"status": "error", "error": f"导出模型失败: {e}"}), 500
|
||
|
||
@app.route('/api/models/import', methods=['POST'])
|
||
@swag_from({
|
||
'tags': ['模型管理'],
|
||
'summary': '导入模型',
|
||
'description': '导入模型文件',
|
||
'consumes': ['multipart/form-data'],
|
||
'parameters': [
|
||
{
|
||
'name': 'file',
|
||
'in': 'formData',
|
||
'type': 'file',
|
||
'required': True,
|
||
'description': 'PyTorch模型文件(.pt)'
|
||
}
|
||
],
|
||
'responses': {
|
||
200: {
|
||
'description': '模型导入成功',
|
||
'schema': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'status': {'type': 'string'},
|
||
'message': {'type': 'string'},
|
||
'data': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'model_path': {'type': 'string'}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
},
|
||
400: {
|
||
'description': '请求错误,文件格式不正确或缺少文件'
|
||
},
|
||
500: {
|
||
'description': '服务器内部错误'
|
||
}
|
||
}
|
||
})
|
||
def import_model():
|
||
try:
|
||
if 'file' not in request.files:
|
||
return jsonify({"status": "error", "message": "没有上传文件"}), 400
|
||
|
||
file = request.files['file']
|
||
if file.filename == '' or not file.filename.endswith('.pt'):
|
||
return jsonify({"status": "error", "message": "请上传有效的.pt模型文件"}), 400
|
||
|
||
# 从文件名解析 product_id, model_type
|
||
# 假设文件名格式为 `transformer_P001.pt`
|
||
try:
|
||
model_type, product_id_ext = file.filename.split('_', 1)
|
||
product_id = os.path.splitext(product_id_ext)[0]
|
||
model_types = ['mlstm', 'kan', 'transformer']
|
||
if model_type not in model_types:
|
||
raise ValueError("无效的模型类型")
|
||
except ValueError:
|
||
return jsonify({
|
||
"status": "error",
|
||
"message": "文件名格式不正确,应为 'model_type_product_id.pt' (例如: mlstm_P001.pt)"
|
||
}), 400
|
||
|
||
# 创建目标目录并保存文件
|
||
model_dir = os.path.join('models', model_type)
|
||
os.makedirs(model_dir, exist_ok=True)
|
||
|
||
model_path = os.path.join(model_dir, f'{product_id}_model.pt')
|
||
|
||
# 检查是否需要创建关联的日志文件
|
||
log_path = os.path.join(model_dir, f'{product_id}_log.json')
|
||
if not os.path.exists(log_path):
|
||
# 尝试从.pt文件加载信息(如果可能)
|
||
try:
|
||
checkpoint = torch.load(file, map_location='cpu')
|
||
# 重新定位文件指针到开头,以便 `file.save` 正常工作
|
||
file.seek(0)
|
||
except Exception:
|
||
checkpoint = {} # 如果加载失败,创建一个空的checkpoint
|
||
|
||
# 创建一个基础的log文件
|
||
log_data = {
|
||
'product_id': product_id,
|
||
'product_name': f"导入的产品 {product_id}", # 可能需要用户后续编辑
|
||
'model_type': model_type,
|
||
'training_completed_at': datetime.now().isoformat(),
|
||
'epochs': checkpoint.get('epochs', 'N/A'),
|
||
'metrics': checkpoint.get('metrics', {'info': '导入的模型,无详细指标'}),
|
||
'file_path': model_path
|
||
}
|
||
with open(log_path, 'w', encoding='utf-8') as f:
|
||
json.dump(log_data, f, indent=4, ensure_ascii=False)
|
||
|
||
# 保存模型文件
|
||
file.save(model_path)
|
||
|
||
return jsonify({
|
||
"status": "success",
|
||
"message": "模型已成功导入",
|
||
"data": {"model_path": model_path}
|
||
})
|
||
|
||
except Exception as e:
|
||
traceback.print_exc()
|
||
return jsonify({"status": "error", "error": f"导入模型失败: {e}"}), 500
|
||
|
||
@app.route('/api/plots/<filename>')
|
||
def get_plot(filename):
|
||
"""Serve a plot file from the root directory."""
|
||
try:
|
||
return send_from_directory(app.root_path, filename)
|
||
except FileNotFoundError:
|
||
return jsonify({"status": "error", "error": "Plot not found"}), 404
|
||
|
||
@app.route('/api/csv/<filename>')
|
||
def get_csv(filename):
|
||
"""Serve a CSV file from the root directory."""
|
||
try:
|
||
return send_from_directory(app.root_path, filename, as_attachment=True)
|
||
except FileNotFoundError:
|
||
return jsonify({"status": "error", "error": "CSV file not found"}), 404
|
||
|
||
# 添加静态UI路由,将/ui路径映射到wwwroot目录
|
||
@app.route('/ui/', defaults={'path': 'index.html'})
|
||
@app.route('/ui/<path:path>')
|
||
def serve_ui(path):
|
||
"""提供UI静态文件服务,将/ui路径映射到wwwroot目录"""
|
||
try:
|
||
# 从wwwroot目录提供静态文件
|
||
return send_from_directory('wwwroot', path)
|
||
except FileNotFoundError:
|
||
# 如果是子路径请求(例如/ui/about)但文件不存在,尝试返回index.html以支持SPA路由
|
||
if path != 'index.html':
|
||
try:
|
||
return send_from_directory('wwwroot', 'index.html')
|
||
except FileNotFoundError:
|
||
return jsonify({"status": "error", "error": "UI files not found"}), 404
|
||
return jsonify({"status": "error", "error": f"UI file {path} not found"}), 404
|
||
|
||
@app.route('/api/predictions/<model_type>/<product_id>/<filename>')
|
||
def get_prediction_file(model_type, product_id, filename):
|
||
"""按模型类型和产品ID获取预测文件"""
|
||
try:
|
||
# 构建文件路径
|
||
file_path = os.path.join('predictions', model_type, product_id)
|
||
|
||
# 如果是图片文件,添加防缓存头
|
||
if filename.endswith('.png'):
|
||
response = send_from_directory(file_path, filename)
|
||
|
||
# 强制浏览器不缓存图片
|
||
response.headers['Cache-Control'] = 'no-store, no-cache, must-revalidate, max-age=0, post-check=0, pre-check=0'
|
||
response.headers['Pragma'] = 'no-cache'
|
||
response.headers['Expires'] = '0'
|
||
response.headers['Last-Modified'] = datetime.now().strftime("%a, %d %b %Y %H:%M:%S GMT")
|
||
response.headers['Vary'] = '*'
|
||
|
||
print(f"提供图片文件 {filename} 并添加防缓存头")
|
||
return response
|
||
else:
|
||
return send_from_directory(file_path, filename)
|
||
except FileNotFoundError:
|
||
return jsonify({"status": "error", "error": f"预测文件 {filename} 未找到"}), 404
|
||
|
||
@app.route('/api/predictions/compare/<filename>')
|
||
def get_compare_file(filename):
|
||
"""获取模型比较结果文件"""
|
||
try:
|
||
file_path = os.path.join('predictions', 'compare')
|
||
return send_from_directory(file_path, filename)
|
||
except FileNotFoundError:
|
||
return jsonify({"status": "error", "error": f"比较结果文件 {filename} 未找到"}), 404
|
||
|
||
if __name__ == '__main__':
|
||
# 命令行参数解析
|
||
parser = argparse.ArgumentParser(description='药店销售预测系统API服务')
|
||
parser.add_argument('--host', type=str, default='0.0.0.0', help='API服务监听的主机地址')
|
||
parser.add_argument('--port', type=int, default=5000, help='API服务监听的端口')
|
||
parser.add_argument('--swagger', action='store_true', default=True, help='是否启用Swagger UI')
|
||
parser.add_argument('--debug', action='store_true', default=True, help='是否启用调试模式')
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 如果不启用Swagger,则关闭Swagger UI
|
||
if not args.swagger:
|
||
app.config['SWAGGER'] = {'enabled': False}
|
||
print("Swagger UI已禁用")
|
||
else:
|
||
print(f"Swagger UI已启用,访问: http://{args.host}:{args.port}/swagger/")
|
||
|
||
# 确保 models 目录存在
|
||
if not os.path.exists('models'):
|
||
os.makedirs('models')
|
||
|
||
# 确保预测结果目录存在
|
||
if not os.path.exists('predictions'):
|
||
os.makedirs('predictions')
|
||
|
||
app.run(debug=args.debug, host=args.host, port=args.port) |