import os import sys # 获取当前脚本所在目录的绝对路径 current_dir = os.path.dirname(os.path.abspath(__file__)) # 将当前目录添加到系统路径 sys.path.append(current_dir) # 或者添加父目录 #parent_dir = os.path.dirname(current_dir) #sys.path.append(parent_dir) from flask import Flask, request, jsonify, send_file, redirect, send_from_directory from flask_cors import CORS import os import uuid import pandas as pd from pharmacy_predictor import ( train_product_model_with_mlstm, train_product_model_with_kan, train_product_model_with_transformer, load_model_and_predict ) import threading import base64 import matplotlib.pyplot as plt from io import BytesIO import json from flasgger import Swagger, swag_from import argparse from datetime import datetime from concurrent.futures import ThreadPoolExecutor from threading import Lock import traceback import torch app = Flask(__name__) CORS(app) # 启用CORS支持 # Swagger配置 swagger_config = { "headers": [], "specs": [ { "endpoint": "apispec", "route": "/apispec.json", "rule_filter": lambda rule: True, # 包含所有路由 "model_filter": lambda tag: True, # 包含所有模型 } ], "static_url_path": "/flasgger_static", "swagger_ui": True, "specs_route": "/swagger/" } swagger_template = { "swagger": "2.0", "info": { "title": "药店销售预测系统API", "description": "用于药店销售预测的RESTful API", "version": "1.0.0", "contact": { "name": "API开发团队", "email": "support@example.com" } }, "tags": [ { "name": "数据管理", "description": "数据上传和查询相关接口" }, { "name": "模型训练", "description": "模型训练相关接口" }, { "name": "模型预测", "description": "预测销售数据相关接口" }, { "name": "模型管理", "description": "模型查询、导出和删除接口" } ] } swagger = Swagger(app, config=swagger_config, template=swagger_template) # 存储训练任务状态 training_tasks = {} tasks_lock = Lock() # 线程池用于后台训练 executor = ThreadPoolExecutor(max_workers=2) # 辅助函数:将图像转换为Base64 def fig_to_base64(fig): buf = BytesIO() fig.savefig(buf, format='png') buf.seek(0) img_str = base64.b64encode(buf.read()).decode('utf-8') return img_str # 根路由 - 重定向到UI界面 @app.route('/') def index(): """重定向到UI界面""" return redirect('/ui/') # Swagger UI路由 @app.route('/swagger') def swagger_ui(): """重定向到Swagger UI文档页面""" return redirect('/swagger/') # 1. 数据管理API @app.route('/api/products', methods=['GET']) @swag_from({ 'tags': ['数据管理'], 'summary': '获取所有产品列表', 'description': '返回系统中所有产品的ID和名称', 'responses': { 200: { 'description': '成功获取产品列表', 'schema': { 'type': 'object', 'properties': { 'status': {'type': 'string'}, 'data': { 'type': 'array', 'items': { 'type': 'object', 'properties': { 'product_id': {'type': 'string'}, 'product_name': {'type': 'string'} } } } } } }, 500: { 'description': '服务器内部错误' } } }) def get_products(): try: df = pd.read_excel('pharmacy_sales.xlsx') products = df[['product_id', 'product_name']].drop_duplicates().to_dict('records') return jsonify({"status": "success", "data": products}) except Exception as e: return jsonify({"status": "error", "message": str(e)}), 500 @app.route('/api/products/', methods=['GET']) @swag_from({ 'tags': ['数据管理'], 'summary': '获取单个产品详情', 'description': '返回指定产品ID的详细信息', 'parameters': [ { 'name': 'product_id', 'in': 'path', 'type': 'string', 'required': True, 'description': '产品ID,例如P001' } ], 'responses': { 200: { 'description': '成功获取产品详情', 'schema': { 'type': 'object', 'properties': { 'status': {'type': 'string'}, 'data': { 'type': 'object', 'properties': { 'product_id': {'type': 'string'}, 'product_name': {'type': 'string'}, 'data_points': {'type': 'integer'}, 'date_range': { 'type': 'object', 'properties': { 'start': {'type': 'string'}, 'end': {'type': 'string'} } } } } } } }, 404: { 'description': '产品不存在' }, 500: { 'description': '服务器内部错误' } } }) def get_product(product_id): try: df = pd.read_excel('pharmacy_sales.xlsx') product_df = df[df['product_id'] == product_id] if product_df.empty: return jsonify({"status": "error", "message": "产品不存在"}), 404 product_info = { "product_id": product_id, "product_name": product_df['product_name'].iloc[0], "data_points": len(product_df), "date_range": { "start": product_df['date'].min().strftime('%Y-%m-%d'), "end": product_df['date'].max().strftime('%Y-%m-%d') } } return jsonify({"status": "success", "data": product_info}) except Exception as e: return jsonify({"status": "error", "message": str(e)}), 500 @app.route('/api/products//sales', methods=['GET']) @swag_from({ 'tags': ['数据管理'], 'summary': '获取产品销售数据', 'description': '返回指定产品在特定日期范围内的销售数据', 'parameters': [ { 'name': 'product_id', 'in': 'path', 'type': 'string', 'required': True, 'description': '产品ID,例如P001' }, { 'name': 'start_date', 'in': 'query', 'type': 'string', 'required': False, 'description': '开始日期,格式为YYYY-MM-DD' }, { 'name': 'end_date', 'in': 'query', 'type': 'string', 'required': False, 'description': '结束日期,格式为YYYY-MM-DD' } ], 'responses': { 200: { 'description': '成功获取销售数据', 'schema': { 'type': 'object', 'properties': { 'status': {'type': 'string'}, 'data': { 'type': 'array', 'items': { 'type': 'object' } } } } }, 404: { 'description': '产品不存在' }, 500: { 'description': '服务器内部错误' } } }) def get_product_sales(product_id): try: start_date = request.args.get('start_date') end_date = request.args.get('end_date') df = pd.read_excel('pharmacy_sales.xlsx') product_df = df[df['product_id'] == product_id].sort_values('date') if product_df.empty: return jsonify({"status": "error", "message": "产品不存在"}), 404 # 如果提供了日期范围,进行过滤 if start_date: product_df = product_df[product_df['date'] >= pd.to_datetime(start_date)] if end_date: product_df = product_df[product_df['date'] <= pd.to_datetime(end_date)] # 转换日期为字符串以便JSON序列化 product_df['date'] = product_df['date'].dt.strftime('%Y-%m-%d') sales_data = product_df.to_dict('records') return jsonify({"status": "success", "data": sales_data}) except Exception as e: return jsonify({"status": "error", "message": str(e)}), 500 @app.route('/api/data/upload', methods=['POST']) @swag_from({ 'tags': ['数据管理'], 'summary': '上传销售数据', 'description': '上传新的销售数据文件(Excel格式)', 'consumes': ['multipart/form-data'], 'parameters': [ { 'name': 'file', 'in': 'formData', 'type': 'file', 'required': True, 'description': 'Excel文件(.xlsx),包含销售数据' } ], 'responses': { 200: { 'description': '数据上传成功', 'schema': { 'type': 'object', 'properties': { 'status': {'type': 'string'}, 'message': {'type': 'string'}, 'data': { 'type': 'object', 'properties': { 'products': {'type': 'integer'}, 'rows': {'type': 'integer'} } } } } }, 400: { 'description': '请求错误,可能是文件格式不正确或缺少必要字段' }, 500: { 'description': '服务器内部错误' } } }) def upload_data(): try: if 'file' not in request.files: return jsonify({"status": "error", "message": "没有上传文件"}), 400 file = request.files['file'] if file.filename == '': return jsonify({"status": "error", "message": "没有选择文件"}), 400 if file and file.filename.endswith('.xlsx'): file_path = 'uploaded_data.xlsx' file.save(file_path) # 验证数据格式 try: df = pd.read_excel(file_path) required_columns = ['date', 'product_id', 'product_name', 'sales', 'price'] missing_columns = [col for col in required_columns if col not in df.columns] if missing_columns: return jsonify({ "status": "error", "message": f"上传的数据缺少必要的列: {', '.join(missing_columns)}" }), 400 # 合并到现有数据或替换现有数据 existing_df = pd.read_excel('pharmacy_sales.xlsx') # 这里可以实现数据合并逻辑,例如按日期和产品ID去重后合并 # 简单示例:保存上传的数据 df.to_excel('pharmacy_sales.xlsx', index=False) return jsonify({ "status": "success", "message": "数据上传成功", "data": { "products": len(df['product_id'].unique()), "rows": len(df) } }) except Exception as e: return jsonify({"status": "error", "message": f"数据验证失败: {str(e)}"}), 400 else: return jsonify({"status": "error", "message": "只支持Excel文件(.xlsx)"}), 400 except Exception as e: return jsonify({"status": "error", "message": str(e)}), 500 # 2. 模型训练API @app.route('/api/training', methods=['GET']) @swag_from({ 'tags': ['模型训练'], 'summary': '获取所有训练任务列表', 'description': '返回所有正在进行、已完成或失败的训练任务', 'responses': { 200: { 'description': '成功获取任务列表', 'schema': { 'type': 'object', 'properties': { 'status': {'type': 'string'}, 'data': { 'type': 'array', 'items': { 'type': 'object', 'properties': { 'task_id': {'type': 'string'}, 'product_id': {'type': 'string'}, 'model_type': {'type': 'string'}, 'status': {'type': 'string'}, 'start_time': {'type': 'string'}, 'metrics': {'type': 'object'}, 'error': {'type': 'string'}, 'model_path': {'type': 'string'} } } } } } } } }) def get_all_training_tasks(): """获取所有训练任务的状态""" try: # 为了方便前端使用,我们将任务ID也包含在每个任务信息中 tasks_with_id = [] for task_id, task_info in training_tasks.items(): task_copy = task_info.copy() task_copy['task_id'] = task_id tasks_with_id.append(task_copy) # 按开始时间降序排序,最新的任务在前面 sorted_tasks = sorted(tasks_with_id, key=lambda x: x['start_time'], reverse=True) return jsonify({"status": "success", "data": sorted_tasks}) except Exception as e: return jsonify({"status": "error", "message": str(e)}), 500 @app.route('/api/training', methods=['POST']) @swag_from({ 'tags': ['模型训练'], 'summary': '启动模型训练任务', 'description': '为指定产品启动一个新的模型训练任务', 'parameters': [ { 'name': 'body', 'in': 'body', 'required': True, 'schema': { 'type': 'object', 'properties': { 'product_id': {'type': 'string', 'description': '例如 P001'}, 'model_type': {'type': 'string', 'enum': ['mlstm', 'transformer', 'kan'], 'description': '要训练的模型类型'}, 'epochs': {'type': 'integer', 'default': 50, 'description': '训练轮次'} }, 'required': ['product_id', 'model_type'] } } ], 'responses': { 200: { 'description': '训练任务已启动', 'schema': { 'type': 'object', 'properties': { 'message': {'type': 'string'}, 'task_id': {'type': 'string'} } } }, 400: { 'description': '请求错误' } } }) def start_training(): """ 启动模型训练 --- post: ... """ data = request.get_json() product_id = data.get('product_id') model_type = data.get('model_type') epochs = data.get('epochs', 50) # 默认为50轮 if not product_id or not model_type: return jsonify({'error': '缺少product_id或model_type'}), 400 global training_tasks task_id = str(uuid.uuid4()) model_train_functions = { 'mlstm': train_product_model_with_mlstm, 'kan': train_product_model_with_kan, 'transformer': train_product_model_with_transformer } if model_type not in model_train_functions: return jsonify({'error': '无效的模型类型'}), 400 train_function = model_train_functions[model_type] def train_task(product_id, epochs, model_type): global training_tasks try: print(f"任务 {task_id}: 开始为产品 {product_id} 训练 {model_type} 模型,共 {epochs} 个轮次。") # 这里的 train_function 会返回 (model, metrics) _, metrics = train_function(product_id, epochs) training_tasks[task_id]['status'] = 'completed' training_tasks[task_id]['metrics'] = metrics # 保存模型路径 model_path = f'models/{model_type}/{product_id}_model.pt' training_tasks[task_id]['model_path'] = model_path print(f"任务 {task_id}: 训练完成。评估指标: {metrics}") except Exception as e: import traceback traceback.print_exc() print(f"任务 {task_id}: 训练失败。错误: {e}") training_tasks[task_id]['status'] = 'failed' training_tasks[task_id]['error'] = str(e) thread = threading.Thread(target=train_task, args=(product_id, epochs, model_type)) thread.start() training_tasks[task_id] = { 'status': 'running', 'product_id': product_id, 'model_type': model_type, 'start_time': datetime.now().isoformat(), 'metrics': None, 'error': None, 'model_path': None } return jsonify({'message': '模型训练已开始', 'task_id': task_id}) @app.route('/api/training/', methods=['GET']) @swag_from({ 'tags': ['模型训练'], 'summary': '查询训练任务状态', 'description': '获取特定训练任务的当前状态和详情', 'parameters': [ { 'name': 'task_id', 'in': 'path', 'type': 'string', 'required': True, 'description': '训练任务ID' } ], 'responses': { 200: { 'description': '成功获取任务状态', 'schema': { 'type': 'object', 'properties': { 'status': {'type': 'string'}, 'data': { 'type': 'object', 'properties': { 'product_id': {'type': 'string'}, 'model_type': {'type': 'string'}, 'parameters': {'type': 'object'}, 'status': {'type': 'string', 'enum': ['pending', 'running', 'completed', 'failed']}, 'created_at': {'type': 'string'}, 'model_path': {'type': 'string'}, 'metrics': {'type': 'object'}, 'model_details_url': {'type': 'string'} } } } } }, 404: { 'description': '任务不存在' }, 500: { 'description': '服务器内部错误' } } }) def get_training_status(task_id): try: if task_id not in training_tasks: return jsonify({"status": "error", "message": "任务不存在"}), 404 task_info = training_tasks[task_id].copy() # 如果任务已完成,添加模型详情链接 if task_info['status'] == 'completed': task_info['model_details_url'] = f"/api/models?product_id={task_info['product_id']}&model_type={task_info['model_type']}" return jsonify({ "status": "success", "data": task_info }) except Exception as e: return jsonify({"status": "error", "message": str(e)}), 500 # 3. 模型预测API @app.route('/api/prediction', methods=['POST']) @swag_from({ 'tags': ['模型预测'], 'summary': '使用模型进行预测', 'description': '使用指定模型预测未来销售数据', 'parameters': [ { 'name': 'body', 'in': 'body', 'required': True, 'schema': { 'type': 'object', 'properties': { 'product_id': {'type': 'string'}, 'model_type': {'type': 'string', 'enum': ['mlstm', 'transformer', 'kan']}, 'version': {'type': 'string'}, 'future_days': {'type': 'integer'}, 'include_visualization': {'type': 'boolean'}, 'start_date': {'type': 'string', 'description': '预测起始日期,格式为YYYY-MM-DD'} }, 'required': ['product_id', 'model_type'] } } ], 'responses': { 200: { 'description': '预测成功', 'schema': { 'type': 'object', 'properties': { 'status': {'type': 'string'}, 'data': { 'type': 'object', 'properties': { 'product_id': {'type': 'string'}, 'product_name': {'type': 'string'}, 'model_type': {'type': 'string'}, 'predictions': {'type': 'array'}, 'visualization': {'type': 'string'} } } } } }, 400: { 'description': '请求错误,缺少必要参数或参数格式错误' }, 404: { 'description': '产品或模型不存在' }, 500: { 'description': '服务器内部错误' } } }) def predict(): """ 使用指定的模型进行预测 --- tags: - 模型预测 parameters: - in: body name: body schema: type: object required: - product_id - model_type properties: product_id: type: string description: 产品ID model_type: type: string description: 模型类型 (mlstm, kan, transformer) future_days: type: integer description: 预测未来天数 default: 7 start_date: type: string description: 预测起始日期,格式为YYYY-MM-DD default: '' responses: 200: description: 预测成功 400: description: 请求参数错误 404: description: 模型文件未找到 """ data = request.json product_id = data.get('product_id') model_type = data.get('model_type') future_days = data.get('future_days', 7) start_date = data.get('start_date', '') print(f"API接收到预测请求: product_id={product_id}, model_type={model_type}, future_days={future_days}, start_date={start_date}") if not product_id or not model_type: return jsonify({"status": "error", "error": "product_id 和 model_type 是必需的"}), 400 try: result = load_model_and_predict(product_id, model_type, future_days, start_date=start_date) if result is None: return jsonify({"status": "error", "error": f"模型 'models/{model_type}/{product_id}_model.pt' 未找到或加载失败"}), 404 # 获取预测数据 predictions_df = result['predictions_df'] chart_path = result['chart_path'] csv_path = result['csv_path'] # 将DataFrame转换为JSON predictions_json = predictions_df.to_dict(orient='records') # 获取当前服务器主机和端口 host_url = request.host_url.rstrip('/') # 移除末尾的斜杠 response_data = { "status": "success", "data": predictions_json, # 现在包含历史和预测数据 "history_data": predictions_df[predictions_df['data_type'] == '历史销量'].to_dict(orient='records'), "prediction_data": predictions_df[predictions_df['data_type'] == '预测销量'].to_dict(orient='records') } # 只有在文件确实存在时才添加URL timestamp = datetime.now().timestamp() # 获取文件名中的日期和预测天数部分 start_date = data.get('start_date', '') future_days = data.get('future_days', 7) if start_date: try: start_date_obj = datetime.strptime(start_date, '%Y-%m-%d') start_date_str = start_date_obj.strftime('%Y%m%d') except: start_date_str = datetime.now().strftime('%Y%m%d') else: # 如果未提供日期,使用当前日期 start_date_str = datetime.now().strftime('%Y%m%d') # 构建文件名 png_filename = f"forecast_{start_date_str}_days{future_days}.png" csv_filename = f"forecast_{start_date_str}_days{future_days}.csv" # 构建图片路径 png_path = os.path.join('predictions', model_type, product_id, png_filename) csv_path = os.path.join('predictions', model_type, product_id, csv_filename) history_png_filename = f"history_{start_date_str}.png" history_png_path = os.path.join('predictions', model_type, product_id, history_png_filename) # 检查预测图表 if os.path.exists(png_path): # 构建完整的URL,包含主机名和端口 response_data["image_url"] = f"{host_url}/api/predictions/{model_type}/{product_id}/{png_filename}?t={timestamp}" print(f"图表URL: {response_data['image_url']}") else: response_data["image_url"] = None print(f"警告: 预测图表文件未生成或不存在: {png_path}") # 检查历史图表 if os.path.exists(history_png_path): response_data["history_image_url"] = f"{host_url}/api/predictions/{model_type}/{product_id}/{history_png_filename}?t={timestamp}" print(f"历史图表URL: {response_data['history_image_url']}") else: response_data["history_image_url"] = None if result and 'history_chart_path' in result and result['history_chart_path'] is None: print(f"警告: 历史图表生成过程中出现错误,请检查服务器日志") else: print(f"警告: 历史图表文件未生成或不存在: {history_png_path}") # 检查CSV文件 if os.path.exists(csv_path): response_data["csv_url"] = f"{host_url}/api/predictions/{model_type}/{product_id}/{csv_filename}" else: response_data["csv_url"] = None return jsonify(response_data) except Exception as e: traceback.print_exc() return jsonify({"status": "error", "error": f"预测过程中发生错误: {str(e)}"}), 500 @app.route('/api/prediction/compare', methods=['POST']) @swag_from({ 'tags': ['模型预测'], 'summary': '比较不同模型预测结果', 'description': '比较不同模型对同一产品的预测结果', 'parameters': [ { 'name': 'body', 'in': 'body', 'required': True, 'schema': { 'type': 'object', 'properties': { 'product_id': {'type': 'string'}, 'model_types': {'type': 'array', 'items': {'type': 'string'}}, 'versions': {'type': 'array', 'items': {'type': 'string'}}, 'include_visualization': {'type': 'boolean'} }, 'required': ['product_id', 'model_types'] } } ], 'responses': { 200: { 'description': '比较成功', 'schema': { 'type': 'object', 'properties': { 'status': {'type': 'string'}, 'data': { 'type': 'object', 'properties': { 'product_id': {'type': 'string'}, 'product_name': {'type': 'string'}, 'model_types': {'type': 'array'}, 'comparison': {'type': 'array'}, 'visualization': {'type': 'string'} } } } } }, 400: { 'description': '请求错误,缺少必要参数或参数格式错误' }, 404: { 'description': '产品或模型不存在' }, 500: { 'description': '服务器内部错误' } } }) def compare_predictions(): try: data = request.json product_id = data.get('product_id') model_types = data.get('model_types') if not product_id or not model_types: return jsonify({"status": "error", "error": "product_id 和 model_types 是必需的"}), 400 all_predictions = {} plt.figure(figsize=(12, 8)) # 加载历史数据用于绘图 df = pd.read_excel('pharmacy_sales.xlsx') product_df = df[df['product_id'] == product_id].sort_values('date') product_name = product_df['product_name'].iloc[0] history_days = 30 history_dates = product_df['date'].iloc[-history_days:].values history_sales = product_df['sales'].iloc[-history_days:].values plt.plot(history_dates, history_sales, 'b-', label='历史销量') comparison_data = [] future_dates = None # 创建比较结果目录 compare_dir = f'predictions/compare' os.makedirs(compare_dir, exist_ok=True) for model_type in model_types: result = load_model_and_predict(product_id, model_type) if result is not None: predictions_df = result['predictions_df'] if future_dates is None: future_dates = predictions_df['date'] comparison_data = [{'date': d.strftime('%Y-%m-%d')} for d in future_dates] plt.plot(predictions_df['date'], predictions_df['predicted_sales'], '--', label=f'{model_type.upper()} 预测') preds = predictions_df['predicted_sales'].tolist() for i in range(len(comparison_data)): comparison_data[i][f'{model_type}_prediction'] = preds[i] if i < len(preds) else None plt.title(f'{product_name} - 多模型预测比较') plt.xlabel('日期') plt.ylabel('销量') plt.legend() plt.grid(True) plt.xticks(rotation=45) plt.tight_layout() chart_path = f'{compare_dir}/{product_id}_model_comparison.png' plt.savefig(chart_path) plt.close() # 关闭图表,释放内存 # 获取当前服务器主机和端口 host_url = request.host_url.rstrip('/') # 移除末尾的斜杠 # 生成带时间戳的URL以避免缓存 timestamp = datetime.now().timestamp() image_url = f"{host_url}/api/predictions/compare/{product_id}_model_comparison.png?t={timestamp}" return jsonify({ "status": "success", "data": { "product_id": product_id, "product_name": product_name, "model_types": model_types, "comparison": comparison_data, "visualization_url": image_url } }) except Exception as e: traceback.print_exc() return jsonify({"status": "error", "error": f"比较预测时出错: {e}"}), 500 # 4. 模型管理API @app.route('/api/models', methods=['GET']) @swag_from({ 'tags': ['模型管理'], 'summary': '获取模型列表', 'description': '获取系统中的模型列表,可按产品ID和模型类型筛选', 'parameters': [ { 'name': 'product_id', 'in': 'query', 'type': 'string', 'required': False, 'description': '按产品ID筛选' }, { 'name': 'model_type', 'in': 'query', 'type': 'string', 'required': False, 'description': '按模型类型筛选' } ], 'responses': { 200: { 'description': '成功获取模型列表', 'schema': { 'type': 'object', 'properties': { 'status': {'type': 'string'}, 'data': { 'type': 'array', 'items': { 'type': 'object', 'properties': { 'model_id': {'type': 'string'}, 'product_id': {'type': 'string'}, 'product_name': {'type': 'string'}, 'model_type': {'type': 'string'}, 'created_at': {'type': 'string'}, 'metrics': {'type': 'object'} } } } } } }, 500: { 'description': '服务器内部错误' } } }) def list_models(): """ 列出所有可用的模型 --- tags: - 模型管理 parameters: - name: product_id in: query type: string required: false description: 按产品ID筛选模型 - name: model_type in: query type: string required: false description: 按模型类型筛选 (mlstm, kan, transformer) responses: 200: description: 模型列表 schema: type: object properties: status: type: string example: success data: type: array items: type: object properties: model_id: type: string product_id: type: string product_name: type: string model_type: type: string created_at: type: string metrics: type: object """ models_dir = 'models' model_types = ['mlstm', 'kan', 'transformer'] available_models = [] product_id_filter = request.args.get('product_id') model_type_filter = request.args.get('model_type') if not os.path.exists(models_dir): return jsonify({"status": "success", "data": []}) for model_type in model_types: if model_type_filter and model_type_filter != model_type: continue type_dir = os.path.join(models_dir, model_type) if not os.path.exists(type_dir): continue for file_name in os.listdir(type_dir): if file_name.endswith('_log.json'): product_id = file_name.replace('_log.json', '') if product_id_filter and product_id_filter != product_id: continue log_path = os.path.join(type_dir, file_name) try: with open(log_path, 'r', encoding='utf-8') as f: log_data = json.load(f) model_info = { "model_id": f"{model_type}_{product_id}", "product_id": log_data.get('product_id'), "product_name": log_data.get('product_name'), "model_type": log_data.get('model_type'), "created_at": log_data.get('training_completed_at'), "metrics": log_data.get('metrics'), "file_path": log_data.get('file_path') } available_models.append(model_info) except Exception as e: print(f"读取日志文件 {log_path} 失败: {e}") # 按创建时间降序排序 available_models.sort(key=lambda x: x.get('created_at', ''), reverse=True) return jsonify({"status": "success", "data": available_models}) @app.route('/api/models/', methods=['GET']) @swag_from({ 'tags': ['模型管理'], 'summary': '获取模型详情', 'description': '获取特定模型的详细信息', 'parameters': [ { 'name': 'model_id', 'in': 'path', 'type': 'string', 'required': True, 'description': '模型ID,格式为{product_id}_{model_type}_v{version}' } ], 'responses': { 200: { 'description': '成功获取模型详情', 'schema': { 'type': 'object', 'properties': { 'status': {'type': 'string'}, 'data': { 'type': 'object', 'properties': { 'product_id': {'type': 'string'}, 'model_type': {'type': 'string'}, 'version': {'type': 'string'}, 'created_at': {'type': 'string'}, 'file_path': {'type': 'string'}, 'file_size': {'type': 'string'}, 'features': {'type': 'array'}, 'look_back': {'type': 'integer'}, 'T': {'type': 'integer'}, 'metrics': {'type': 'object'} } } } } }, 400: { 'description': '无效的模型ID格式' }, 404: { 'description': '模型不存在' }, 500: { 'description': '服务器内部错误' } } }) def get_model_details(model_id): """ 获取单个模型的详细信息 --- tags: - 模型管理 parameters: - name: model_id in: path type: string required: true description: 模型的唯一标识符 (格式: model_type_product_id) responses: 200: description: 模型的详细信息 404: description: 未找到模型 """ try: model_type, product_id = model_id.split('_', 1) log_path = os.path.join('models', model_type, f'{product_id}_log.json') if not os.path.exists(log_path): return jsonify({"status": "error", "error": "模型未找到"}), 404 with open(log_path, 'r', encoding='utf-8') as f: log_data = json.load(f) # 还可以添加从 .pt 文件读取更多信息的逻辑 # checkpoint = torch.load(log_data['file_path'], map_location='cpu') # log_data['details_from_pt'] = ... return jsonify({"status": "success", "data": log_data}) except ValueError: return jsonify({"status": "error", "error": "无效的model_id格式"}), 400 except Exception as e: return jsonify({"status": "error", "error": f"获取模型详情失败: {e}"}), 500 @app.route('/api/models/', methods=['DELETE']) @swag_from({ 'tags': ['模型管理'], 'summary': '删除模型', 'description': '删除特定模型', 'parameters': [ { 'name': 'model_id', 'in': 'path', 'type': 'string', 'required': True, 'description': '模型ID,格式为{product_id}_{model_type}_v{version}' } ], 'responses': { 200: { 'description': '模型删除成功', 'schema': { 'type': 'object', 'properties': { 'status': {'type': 'string'}, 'message': {'type': 'string'} } } }, 400: { 'description': '无效的模型ID格式' }, 500: { 'description': '服务器内部错误' } } }) def delete_model(model_id): """ 删除一个模型及其关联文件 --- tags: - 模型管理 parameters: - name: model_id in: path type: string required: true description: 要删除的模型的ID (格式: model_type_product_id) responses: 200: description: 模型删除成功 404: description: 模型未找到 """ try: model_type, product_id = model_id.split('_', 1) model_dir = os.path.join('models', model_type) model_path = os.path.join(model_dir, f'{product_id}_model.pt') log_path = os.path.join(model_dir, f'{product_id}_log.json') if not os.path.exists(model_path) and not os.path.exists(log_path): return jsonify({"status": "error", "error": "模型未找到"}), 404 if os.path.exists(model_path): os.remove(model_path) if os.path.exists(log_path): os.remove(log_path) return jsonify({"status": "success", "message": f"模型 {model_id} 已删除"}) except ValueError: return jsonify({"status": "error", "error": "无效的model_id格式"}), 400 except Exception as e: return jsonify({"status": "error", "error": f"删除模型失败: {e}"}), 500 @app.route('/api/models//export', methods=['GET']) @swag_from({ 'tags': ['模型管理'], 'summary': '导出模型', 'description': '导出特定模型文件', 'parameters': [ { 'name': 'model_id', 'in': 'path', 'type': 'string', 'required': True, 'description': '模型ID,格式为{product_id}_{model_type}_v{version}' } ], 'responses': { 200: { 'description': '模型文件下载', 'content': { 'application/octet-stream': {} } }, 400: { 'description': '无效的模型ID格式' }, 500: { 'description': '服务器内部错误' } } }) def export_model(model_id): try: model_type, product_id = model_id.split('_', 1) model_path = os.path.join('models', model_type, f'{product_id}_model.pt') if not os.path.exists(model_path): return jsonify({"status": "error", "error": "模型文件未找到"}), 404 return send_file( model_path, as_attachment=True, download_name=f'{model_id}.pt', mimetype='application/octet-stream' ) except Exception as e: return jsonify({"status": "error", "error": f"导出模型失败: {e}"}), 500 @app.route('/api/models/import', methods=['POST']) @swag_from({ 'tags': ['模型管理'], 'summary': '导入模型', 'description': '导入模型文件', 'consumes': ['multipart/form-data'], 'parameters': [ { 'name': 'file', 'in': 'formData', 'type': 'file', 'required': True, 'description': 'PyTorch模型文件(.pt)' } ], 'responses': { 200: { 'description': '模型导入成功', 'schema': { 'type': 'object', 'properties': { 'status': {'type': 'string'}, 'message': {'type': 'string'}, 'data': { 'type': 'object', 'properties': { 'model_path': {'type': 'string'} } } } } }, 400: { 'description': '请求错误,文件格式不正确或缺少文件' }, 500: { 'description': '服务器内部错误' } } }) def import_model(): try: if 'file' not in request.files: return jsonify({"status": "error", "message": "没有上传文件"}), 400 file = request.files['file'] if file.filename == '' or not file.filename.endswith('.pt'): return jsonify({"status": "error", "message": "请上传有效的.pt模型文件"}), 400 # 从文件名解析 product_id, model_type # 假设文件名格式为 `transformer_P001.pt` try: model_type, product_id_ext = file.filename.split('_', 1) product_id = os.path.splitext(product_id_ext)[0] model_types = ['mlstm', 'kan', 'transformer'] if model_type not in model_types: raise ValueError("无效的模型类型") except ValueError: return jsonify({ "status": "error", "message": "文件名格式不正确,应为 'model_type_product_id.pt' (例如: mlstm_P001.pt)" }), 400 # 创建目标目录并保存文件 model_dir = os.path.join('models', model_type) os.makedirs(model_dir, exist_ok=True) model_path = os.path.join(model_dir, f'{product_id}_model.pt') # 检查是否需要创建关联的日志文件 log_path = os.path.join(model_dir, f'{product_id}_log.json') if not os.path.exists(log_path): # 尝试从.pt文件加载信息(如果可能) try: checkpoint = torch.load(file, map_location='cpu') # 重新定位文件指针到开头,以便 `file.save` 正常工作 file.seek(0) except Exception: checkpoint = {} # 如果加载失败,创建一个空的checkpoint # 创建一个基础的log文件 log_data = { 'product_id': product_id, 'product_name': f"导入的产品 {product_id}", # 可能需要用户后续编辑 'model_type': model_type, 'training_completed_at': datetime.now().isoformat(), 'epochs': checkpoint.get('epochs', 'N/A'), 'metrics': checkpoint.get('metrics', {'info': '导入的模型,无详细指标'}), 'file_path': model_path } with open(log_path, 'w', encoding='utf-8') as f: json.dump(log_data, f, indent=4, ensure_ascii=False) # 保存模型文件 file.save(model_path) return jsonify({ "status": "success", "message": "模型已成功导入", "data": {"model_path": model_path} }) except Exception as e: traceback.print_exc() return jsonify({"status": "error", "error": f"导入模型失败: {e}"}), 500 @app.route('/api/plots/') def get_plot(filename): """Serve a plot file from the root directory.""" try: return send_from_directory(app.root_path, filename) except FileNotFoundError: return jsonify({"status": "error", "error": "Plot not found"}), 404 @app.route('/api/csv/') def get_csv(filename): """Serve a CSV file from the root directory.""" try: return send_from_directory(app.root_path, filename, as_attachment=True) except FileNotFoundError: return jsonify({"status": "error", "error": "CSV file not found"}), 404 # 添加静态UI路由,将/ui路径映射到wwwroot目录 @app.route('/ui/', defaults={'path': 'index.html'}) @app.route('/ui/') def serve_ui(path): """提供UI静态文件服务,将/ui路径映射到wwwroot目录""" try: # 从wwwroot目录提供静态文件 return send_from_directory('wwwroot', path) except FileNotFoundError: # 如果是子路径请求(例如/ui/about)但文件不存在,尝试返回index.html以支持SPA路由 if path != 'index.html': try: return send_from_directory('wwwroot', 'index.html') except FileNotFoundError: return jsonify({"status": "error", "error": "UI files not found"}), 404 return jsonify({"status": "error", "error": f"UI file {path} not found"}), 404 @app.route('/api/predictions///') def get_prediction_file(model_type, product_id, filename): """按模型类型和产品ID获取预测文件""" try: # 构建文件路径 file_path = os.path.join('predictions', model_type, product_id) # 如果是图片文件,添加防缓存头 if filename.endswith('.png'): response = send_from_directory(file_path, filename) # 强制浏览器不缓存图片 response.headers['Cache-Control'] = 'no-store, no-cache, must-revalidate, max-age=0, post-check=0, pre-check=0' response.headers['Pragma'] = 'no-cache' response.headers['Expires'] = '0' response.headers['Last-Modified'] = datetime.now().strftime("%a, %d %b %Y %H:%M:%S GMT") response.headers['Vary'] = '*' print(f"提供图片文件 {filename} 并添加防缓存头") return response else: return send_from_directory(file_path, filename) except FileNotFoundError: return jsonify({"status": "error", "error": f"预测文件 {filename} 未找到"}), 404 @app.route('/api/predictions/compare/') def get_compare_file(filename): """ 提供比较结果文件的下载 """ directory = os.path.join(current_dir, 'predictions', 'compare') return send_from_directory(directory, filename) if __name__ == '__main__': # 检查可用的设备 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"使用设备: {device}") # 运行Flask应用 # 在生产环境中,应使用Gunicorn或uWSGI等WSGI服务器 # 例如: gunicorn --workers 4 --bind 0.0.0.0:5000 api:app # 使用--host=0.0.0.0可以使服务在局域网内可访问 app.run(host='0.0.0.0', port=5000, debug=True)