ShopTRAINING/server/modern_api.py

297 lines
10 KiB
Python
Raw Permalink Normal View History

2025-07-02 11:05:23 +08:00
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
现代化API服务器 - 使用loguru日志和独立训练进程
"""
import sys
import os
import json
from datetime import datetime
from flask import Flask, jsonify, request
from flask_cors import CORS
from flask_socketio import SocketIO
import argparse
# 获取当前脚本所在目录的绝对路径
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(current_dir)
# 使用新的现代化日志系统
try:
from utils.logging_config import setup_api_logging, get_logger
# 初始化现代化日志系统
logger = setup_api_logging(log_dir=".", log_level="INFO")
logger.info("✅ 现代化日志系统导入成功")
except Exception as e:
print(f"❌ 日志系统导入失败: {e}")
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
try:
from utils.training_process_manager import get_training_manager
# 获取训练进程管理器
training_manager = get_training_manager()
logger.info("✅ 训练进程管理器导入成功")
except Exception as e:
logger.error(f"❌ 训练进程管理器导入失败: {e}")
training_manager = None
# 初始化数据库
def init_db():
"""初始化数据库"""
import sqlite3
conn = sqlite3.connect('prediction_history.db')
cursor = conn.cursor()
# 创建预测历史表
cursor.execute('''
CREATE TABLE IF NOT EXISTS prediction_history (
id INTEGER PRIMARY KEY AUTOINCREMENT,
prediction_id TEXT UNIQUE NOT NULL,
product_id TEXT NOT NULL,
product_name TEXT NOT NULL,
model_type TEXT NOT NULL,
model_id TEXT NOT NULL,
start_date TEXT,
future_days INTEGER,
created_at TEXT NOT NULL,
predictions_data TEXT,
metrics TEXT,
chart_data TEXT,
analysis TEXT,
file_path TEXT
)
''')
# 创建模型版本表
cursor.execute('''
CREATE TABLE IF NOT EXISTS model_versions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
product_id TEXT NOT NULL,
model_type TEXT NOT NULL,
version TEXT NOT NULL,
file_path TEXT NOT NULL,
created_at TEXT NOT NULL,
metrics TEXT,
is_active INTEGER DEFAULT 1,
UNIQUE(product_id, model_type, version)
)
''')
conn.commit()
conn.close()
logger.info("数据库初始化完成,包含模型版本管理表")
# 创建 Flask 应用
app = Flask(__name__)
app.config['SECRET_KEY'] = 'your-secret-key-here'
# 启用CORS
CORS(app, origins="*")
# 初始化 SocketIO
socketio = SocketIO(app, cors_allowed_origins="*", namespace='/training')
@app.route('/api/products', methods=['GET'])
def get_products():
"""获取产品列表"""
return jsonify({
"status": "success",
"data": [
{"id": "P001", "name": "感冒灵颗粒"},
{"id": "P002", "name": "布洛芬片"},
{"id": "P003", "name": "维生素C片"},
{"id": "P004", "name": "阿莫西林胶囊"},
{"id": "P005", "name": "板蓝根颗粒"}
]
})
@app.route('/api/training', methods=['POST'])
def start_training():
"""启动模型训练 - 使用现代化进程管理器"""
data = request.get_json()
# 参数验证
model_type = data.get('model_type')
product_id = data.get('product_id', 'P001')
epochs = data.get('epochs', 3)
training_mode = data.get('training_mode', 'product')
store_id = data.get('store_id')
if not model_type:
return jsonify({'error': '缺少model_type参数'}), 400
# 检查模型类型是否有效
valid_model_types = ['mlstm', 'kan', 'optimized_kan', 'transformer', 'tcn']
if model_type not in valid_model_types:
return jsonify({'error': '无效的模型类型'}), 400
# 检查训练进程管理器是否可用
if training_manager is None:
logger.error("❌ 训练进程管理器不可用")
return jsonify({'error': '训练进程管理器初始化失败,请检查系统配置'}), 500
# 使用新的训练进程管理器提交任务
try:
task_id = training_manager.submit_task(
product_id=product_id,
model_type=model_type,
training_mode=training_mode,
store_id=store_id,
epochs=epochs
)
logger.info(f"🚀 训练任务已提交到进程管理器: {task_id[:8]} | {model_type} | {product_id}")
return jsonify({
'message': '模型训练已开始(使用独立进程)',
'task_id': task_id,
'training_mode': training_mode,
'model_type': model_type,
'product_id': product_id,
'epochs': epochs
})
except Exception as e:
logger.error(f"❌ 提交训练任务失败: {str(e)}")
return jsonify({'error': f'启动训练任务失败: {str(e)}'}), 500
@app.route('/api/training', methods=['GET'])
def get_all_training_tasks():
"""获取所有训练任务的状态"""
if training_manager is None:
return jsonify({"status": "error", "message": "训练进程管理器不可用"}), 500
try:
all_tasks = training_manager.get_all_tasks()
# 为了方便前端使用我们将任务ID也包含在每个任务信息中
tasks_with_id = []
for task_id, task_info in all_tasks.items():
task_copy = task_info.copy()
task_copy['task_id'] = task_id
tasks_with_id.append(task_copy)
# 按开始时间降序排序,最新的任务在前面
sorted_tasks = sorted(tasks_with_id,
key=lambda x: x.get('start_time', ''),
reverse=True)
return jsonify({"status": "success", "data": sorted_tasks})
except Exception as e:
logger.error(f"获取训练任务列表失败: {str(e)}")
return jsonify({"status": "error", "message": str(e)}), 500
@app.route('/api/training/<task_id>', methods=['GET'])
def get_training_status(task_id):
"""查询特定训练任务状态"""
if training_manager is None:
return jsonify({"status": "error", "message": "训练进程管理器不可用"}), 500
try:
task_info = training_manager.get_task_status(task_id)
if not task_info:
return jsonify({"status": "error", "message": "任务不存在"}), 404
# 如果任务已完成,添加模型详情链接
if task_info['status'] == 'completed':
task_info['model_details_url'] = f"/api/models?product_id={task_info['product_id']}&model_type={task_info['model_type']}"
return jsonify({
"status": "success",
"data": task_info
})
except Exception as e:
logger.error(f"查询训练任务状态失败: {str(e)}")
return jsonify({"status": "error", "message": str(e)}), 500
@app.route('/api/version', methods=['GET'])
def api_version():
"""检查API版本和状态"""
return jsonify({
"status": "success",
"data": {
"version": "2.0.0-modern",
"description": "现代化药店销售预测系统API",
"features": [
"loguru现代化日志系统",
"独立训练进程管理",
"完美中文和emoji支持",
"实时WebSocket进度推送"
],
"timestamp": datetime.now().strftime('%Y-%m-%d %H:%M:%S')
}
})
# WebSocket 事件处理
@socketio.on('connect', namespace='/training')
def on_connect():
logger.info("WebSocket客户端已连接")
@socketio.on('disconnect', namespace='/training')
def on_disconnect():
logger.info("WebSocket客户端已断开")
if __name__ == '__main__':
# 初始化数据库
init_db()
# 解析命令行参数
parser = argparse.ArgumentParser(description='现代化药店销售预测系统API服务')
parser.add_argument('--host', default='0.0.0.0', help='服务器主机地址')
parser.add_argument('--port', type=int, default=5000, help='服务器端口')
parser.add_argument('--debug', action='store_true', help='是否启用调试模式')
args = parser.parse_args()
# 确保目录存在
os.makedirs('static/plots', exist_ok=True)
os.makedirs('static/csv', exist_ok=True)
os.makedirs('saved_models', exist_ok=True)
# 启动信息输出
logger.info("=" * 60)
logger.info("现代化药店销售预测系统API服务启动")
logger.info("=" * 60)
logger.info(f"服务器地址: {args.host}:{args.port}")
logger.info(f"调试模式: {args.debug}")
logger.info(f"WebSocket: ws://{args.host}:{args.port}/training")
logger.info(f"模型目录: saved_models")
logger.info("特性: loguru日志 + 独立训练进程 + 中文支持")
logger.info("=" * 60)
# 启动训练进程管理器
if training_manager is not None:
logger.info("🚀 启动训练进程管理器...")
try:
training_manager.start()
# 设置WebSocket回调
def websocket_callback(event, data):
try:
socketio.emit(event, data, namespace='/training')
except Exception as e:
logger.error(f"WebSocket回调失败: {e}")
training_manager.set_websocket_callback(websocket_callback)
logger.info("✅ 训练进程管理器已启动")
except Exception as e:
logger.error(f"❌ 训练进程管理器启动失败: {e}")
else:
logger.warning("⚠️ 训练进程管理器不可用,将以有限功能模式运行")
try:
# 使用 SocketIO 启动应用
socketio.run(app, host=args.host, port=args.port, debug=args.debug, allow_unsafe_werkzeug=True)
finally:
# 确保在退出时停止训练进程管理器
if training_manager is not None:
logger.info("🛑 正在停止训练进程管理器...")
try:
training_manager.stop()
except Exception as e:
logger.error(f"停止训练进程管理器时出错: {e}")
logger.info("👋 API服务器已关闭")