#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 现代化API服务器 - 使用loguru日志和独立训练进程 """ import sys import os import json from datetime import datetime from flask import Flask, jsonify, request from flask_cors import CORS from flask_socketio import SocketIO import argparse # 获取当前脚本所在目录的绝对路径 current_dir = os.path.dirname(os.path.abspath(__file__)) sys.path.append(current_dir) # 使用新的现代化日志系统 try: from utils.logging_config import setup_api_logging, get_logger # 初始化现代化日志系统 logger = setup_api_logging(log_dir=".", log_level="INFO") logger.info("✅ 现代化日志系统导入成功") except Exception as e: print(f"❌ 日志系统导入失败: {e}") import logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) try: from utils.training_process_manager import get_training_manager # 获取训练进程管理器 training_manager = get_training_manager() logger.info("✅ 训练进程管理器导入成功") except Exception as e: logger.error(f"❌ 训练进程管理器导入失败: {e}") training_manager = None # 初始化数据库 def init_db(): """初始化数据库""" import sqlite3 conn = sqlite3.connect('prediction_history.db') cursor = conn.cursor() # 创建预测历史表 cursor.execute(''' CREATE TABLE IF NOT EXISTS prediction_history ( id INTEGER PRIMARY KEY AUTOINCREMENT, prediction_id TEXT UNIQUE NOT NULL, product_id TEXT NOT NULL, product_name TEXT NOT NULL, model_type TEXT NOT NULL, model_id TEXT NOT NULL, start_date TEXT, future_days INTEGER, created_at TEXT NOT NULL, predictions_data TEXT, metrics TEXT, chart_data TEXT, analysis TEXT, file_path TEXT ) ''') # 创建模型版本表 cursor.execute(''' CREATE TABLE IF NOT EXISTS model_versions ( id INTEGER PRIMARY KEY AUTOINCREMENT, product_id TEXT NOT NULL, model_type TEXT NOT NULL, version TEXT NOT NULL, file_path TEXT NOT NULL, created_at TEXT NOT NULL, metrics TEXT, is_active INTEGER DEFAULT 1, UNIQUE(product_id, model_type, version) ) ''') conn.commit() conn.close() logger.info("数据库初始化完成,包含模型版本管理表") # 创建 Flask 应用 app = Flask(__name__) app.config['SECRET_KEY'] = 'your-secret-key-here' # 启用CORS CORS(app, origins="*") # 初始化 SocketIO socketio = SocketIO(app, cors_allowed_origins="*", namespace='/training') @app.route('/api/products', methods=['GET']) def get_products(): """获取产品列表""" return jsonify({ "status": "success", "data": [ {"id": "P001", "name": "感冒灵颗粒"}, {"id": "P002", "name": "布洛芬片"}, {"id": "P003", "name": "维生素C片"}, {"id": "P004", "name": "阿莫西林胶囊"}, {"id": "P005", "name": "板蓝根颗粒"} ] }) @app.route('/api/training', methods=['POST']) def start_training(): """启动模型训练 - 使用现代化进程管理器""" data = request.get_json() # 参数验证 model_type = data.get('model_type') product_id = data.get('product_id', 'P001') epochs = data.get('epochs', 3) training_mode = data.get('training_mode', 'product') store_id = data.get('store_id') if not model_type: return jsonify({'error': '缺少model_type参数'}), 400 # 检查模型类型是否有效 valid_model_types = ['mlstm', 'kan', 'optimized_kan', 'transformer', 'tcn'] if model_type not in valid_model_types: return jsonify({'error': '无效的模型类型'}), 400 # 检查训练进程管理器是否可用 if training_manager is None: logger.error("❌ 训练进程管理器不可用") return jsonify({'error': '训练进程管理器初始化失败,请检查系统配置'}), 500 # 使用新的训练进程管理器提交任务 try: task_id = training_manager.submit_task( product_id=product_id, model_type=model_type, training_mode=training_mode, store_id=store_id, epochs=epochs ) logger.info(f"🚀 训练任务已提交到进程管理器: {task_id[:8]} | {model_type} | {product_id}") return jsonify({ 'message': '模型训练已开始(使用独立进程)', 'task_id': task_id, 'training_mode': training_mode, 'model_type': model_type, 'product_id': product_id, 'epochs': epochs }) except Exception as e: logger.error(f"❌ 提交训练任务失败: {str(e)}") return jsonify({'error': f'启动训练任务失败: {str(e)}'}), 500 @app.route('/api/training', methods=['GET']) def get_all_training_tasks(): """获取所有训练任务的状态""" if training_manager is None: return jsonify({"status": "error", "message": "训练进程管理器不可用"}), 500 try: all_tasks = training_manager.get_all_tasks() # 为了方便前端使用,我们将任务ID也包含在每个任务信息中 tasks_with_id = [] for task_id, task_info in all_tasks.items(): task_copy = task_info.copy() task_copy['task_id'] = task_id tasks_with_id.append(task_copy) # 按开始时间降序排序,最新的任务在前面 sorted_tasks = sorted(tasks_with_id, key=lambda x: x.get('start_time', ''), reverse=True) return jsonify({"status": "success", "data": sorted_tasks}) except Exception as e: logger.error(f"获取训练任务列表失败: {str(e)}") return jsonify({"status": "error", "message": str(e)}), 500 @app.route('/api/training/', methods=['GET']) def get_training_status(task_id): """查询特定训练任务状态""" if training_manager is None: return jsonify({"status": "error", "message": "训练进程管理器不可用"}), 500 try: task_info = training_manager.get_task_status(task_id) if not task_info: return jsonify({"status": "error", "message": "任务不存在"}), 404 # 如果任务已完成,添加模型详情链接 if task_info['status'] == 'completed': task_info['model_details_url'] = f"/api/models?product_id={task_info['product_id']}&model_type={task_info['model_type']}" return jsonify({ "status": "success", "data": task_info }) except Exception as e: logger.error(f"查询训练任务状态失败: {str(e)}") return jsonify({"status": "error", "message": str(e)}), 500 @app.route('/api/version', methods=['GET']) def api_version(): """检查API版本和状态""" return jsonify({ "status": "success", "data": { "version": "2.0.0-modern", "description": "现代化药店销售预测系统API", "features": [ "loguru现代化日志系统", "独立训练进程管理", "完美中文和emoji支持", "实时WebSocket进度推送" ], "timestamp": datetime.now().strftime('%Y-%m-%d %H:%M:%S') } }) # WebSocket 事件处理 @socketio.on('connect', namespace='/training') def on_connect(): logger.info("WebSocket客户端已连接") @socketio.on('disconnect', namespace='/training') def on_disconnect(): logger.info("WebSocket客户端已断开") if __name__ == '__main__': # 初始化数据库 init_db() # 解析命令行参数 parser = argparse.ArgumentParser(description='现代化药店销售预测系统API服务') parser.add_argument('--host', default='0.0.0.0', help='服务器主机地址') parser.add_argument('--port', type=int, default=5000, help='服务器端口') parser.add_argument('--debug', action='store_true', help='是否启用调试模式') args = parser.parse_args() # 确保目录存在 os.makedirs('static/plots', exist_ok=True) os.makedirs('static/csv', exist_ok=True) os.makedirs('saved_models', exist_ok=True) # 启动信息输出 logger.info("=" * 60) logger.info("现代化药店销售预测系统API服务启动") logger.info("=" * 60) logger.info(f"服务器地址: {args.host}:{args.port}") logger.info(f"调试模式: {args.debug}") logger.info(f"WebSocket: ws://{args.host}:{args.port}/training") logger.info(f"模型目录: saved_models") logger.info("特性: loguru日志 + 独立训练进程 + 中文支持") logger.info("=" * 60) # 启动训练进程管理器 if training_manager is not None: logger.info("🚀 启动训练进程管理器...") try: training_manager.start() # 设置WebSocket回调 def websocket_callback(event, data): try: socketio.emit(event, data, namespace='/training') except Exception as e: logger.error(f"WebSocket回调失败: {e}") training_manager.set_websocket_callback(websocket_callback) logger.info("✅ 训练进程管理器已启动") except Exception as e: logger.error(f"❌ 训练进程管理器启动失败: {e}") else: logger.warning("⚠️ 训练进程管理器不可用,将以有限功能模式运行") try: # 使用 SocketIO 启动应用 socketio.run(app, host=args.host, port=args.port, debug=args.debug, allow_unsafe_werkzeug=True) finally: # 确保在退出时停止训练进程管理器 if training_manager is not None: logger.info("🛑 正在停止训练进程管理器...") try: training_manager.stop() except Exception as e: logger.error(f"停止训练进程管理器时出错: {e}") logger.info("👋 API服务器已关闭")