ShopTRAINING/server/modern_api.py
2025-07-02 11:05:23 +08:00

297 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
现代化API服务器 - 使用loguru日志和独立训练进程
"""
import sys
import os
import json
from datetime import datetime
from flask import Flask, jsonify, request
from flask_cors import CORS
from flask_socketio import SocketIO
import argparse
# 获取当前脚本所在目录的绝对路径
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(current_dir)
# 使用新的现代化日志系统
try:
from utils.logging_config import setup_api_logging, get_logger
# 初始化现代化日志系统
logger = setup_api_logging(log_dir=".", log_level="INFO")
logger.info("✅ 现代化日志系统导入成功")
except Exception as e:
print(f"❌ 日志系统导入失败: {e}")
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
try:
from utils.training_process_manager import get_training_manager
# 获取训练进程管理器
training_manager = get_training_manager()
logger.info("✅ 训练进程管理器导入成功")
except Exception as e:
logger.error(f"❌ 训练进程管理器导入失败: {e}")
training_manager = None
# 初始化数据库
def init_db():
"""初始化数据库"""
import sqlite3
conn = sqlite3.connect('prediction_history.db')
cursor = conn.cursor()
# 创建预测历史表
cursor.execute('''
CREATE TABLE IF NOT EXISTS prediction_history (
id INTEGER PRIMARY KEY AUTOINCREMENT,
prediction_id TEXT UNIQUE NOT NULL,
product_id TEXT NOT NULL,
product_name TEXT NOT NULL,
model_type TEXT NOT NULL,
model_id TEXT NOT NULL,
start_date TEXT,
future_days INTEGER,
created_at TEXT NOT NULL,
predictions_data TEXT,
metrics TEXT,
chart_data TEXT,
analysis TEXT,
file_path TEXT
)
''')
# 创建模型版本表
cursor.execute('''
CREATE TABLE IF NOT EXISTS model_versions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
product_id TEXT NOT NULL,
model_type TEXT NOT NULL,
version TEXT NOT NULL,
file_path TEXT NOT NULL,
created_at TEXT NOT NULL,
metrics TEXT,
is_active INTEGER DEFAULT 1,
UNIQUE(product_id, model_type, version)
)
''')
conn.commit()
conn.close()
logger.info("数据库初始化完成,包含模型版本管理表")
# 创建 Flask 应用
app = Flask(__name__)
app.config['SECRET_KEY'] = 'your-secret-key-here'
# 启用CORS
CORS(app, origins="*")
# 初始化 SocketIO
socketio = SocketIO(app, cors_allowed_origins="*", namespace='/training')
@app.route('/api/products', methods=['GET'])
def get_products():
"""获取产品列表"""
return jsonify({
"status": "success",
"data": [
{"id": "P001", "name": "感冒灵颗粒"},
{"id": "P002", "name": "布洛芬片"},
{"id": "P003", "name": "维生素C片"},
{"id": "P004", "name": "阿莫西林胶囊"},
{"id": "P005", "name": "板蓝根颗粒"}
]
})
@app.route('/api/training', methods=['POST'])
def start_training():
"""启动模型训练 - 使用现代化进程管理器"""
data = request.get_json()
# 参数验证
model_type = data.get('model_type')
product_id = data.get('product_id', 'P001')
epochs = data.get('epochs', 3)
training_mode = data.get('training_mode', 'product')
store_id = data.get('store_id')
if not model_type:
return jsonify({'error': '缺少model_type参数'}), 400
# 检查模型类型是否有效
valid_model_types = ['mlstm', 'kan', 'optimized_kan', 'transformer', 'tcn']
if model_type not in valid_model_types:
return jsonify({'error': '无效的模型类型'}), 400
# 检查训练进程管理器是否可用
if training_manager is None:
logger.error("❌ 训练进程管理器不可用")
return jsonify({'error': '训练进程管理器初始化失败,请检查系统配置'}), 500
# 使用新的训练进程管理器提交任务
try:
task_id = training_manager.submit_task(
product_id=product_id,
model_type=model_type,
training_mode=training_mode,
store_id=store_id,
epochs=epochs
)
logger.info(f"🚀 训练任务已提交到进程管理器: {task_id[:8]} | {model_type} | {product_id}")
return jsonify({
'message': '模型训练已开始(使用独立进程)',
'task_id': task_id,
'training_mode': training_mode,
'model_type': model_type,
'product_id': product_id,
'epochs': epochs
})
except Exception as e:
logger.error(f"❌ 提交训练任务失败: {str(e)}")
return jsonify({'error': f'启动训练任务失败: {str(e)}'}), 500
@app.route('/api/training', methods=['GET'])
def get_all_training_tasks():
"""获取所有训练任务的状态"""
if training_manager is None:
return jsonify({"status": "error", "message": "训练进程管理器不可用"}), 500
try:
all_tasks = training_manager.get_all_tasks()
# 为了方便前端使用我们将任务ID也包含在每个任务信息中
tasks_with_id = []
for task_id, task_info in all_tasks.items():
task_copy = task_info.copy()
task_copy['task_id'] = task_id
tasks_with_id.append(task_copy)
# 按开始时间降序排序,最新的任务在前面
sorted_tasks = sorted(tasks_with_id,
key=lambda x: x.get('start_time', ''),
reverse=True)
return jsonify({"status": "success", "data": sorted_tasks})
except Exception as e:
logger.error(f"获取训练任务列表失败: {str(e)}")
return jsonify({"status": "error", "message": str(e)}), 500
@app.route('/api/training/<task_id>', methods=['GET'])
def get_training_status(task_id):
"""查询特定训练任务状态"""
if training_manager is None:
return jsonify({"status": "error", "message": "训练进程管理器不可用"}), 500
try:
task_info = training_manager.get_task_status(task_id)
if not task_info:
return jsonify({"status": "error", "message": "任务不存在"}), 404
# 如果任务已完成,添加模型详情链接
if task_info['status'] == 'completed':
task_info['model_details_url'] = f"/api/models?product_id={task_info['product_id']}&model_type={task_info['model_type']}"
return jsonify({
"status": "success",
"data": task_info
})
except Exception as e:
logger.error(f"查询训练任务状态失败: {str(e)}")
return jsonify({"status": "error", "message": str(e)}), 500
@app.route('/api/version', methods=['GET'])
def api_version():
"""检查API版本和状态"""
return jsonify({
"status": "success",
"data": {
"version": "2.0.0-modern",
"description": "现代化药店销售预测系统API",
"features": [
"loguru现代化日志系统",
"独立训练进程管理",
"完美中文和emoji支持",
"实时WebSocket进度推送"
],
"timestamp": datetime.now().strftime('%Y-%m-%d %H:%M:%S')
}
})
# WebSocket 事件处理
@socketio.on('connect', namespace='/training')
def on_connect():
logger.info("WebSocket客户端已连接")
@socketio.on('disconnect', namespace='/training')
def on_disconnect():
logger.info("WebSocket客户端已断开")
if __name__ == '__main__':
# 初始化数据库
init_db()
# 解析命令行参数
parser = argparse.ArgumentParser(description='现代化药店销售预测系统API服务')
parser.add_argument('--host', default='0.0.0.0', help='服务器主机地址')
parser.add_argument('--port', type=int, default=5000, help='服务器端口')
parser.add_argument('--debug', action='store_true', help='是否启用调试模式')
args = parser.parse_args()
# 确保目录存在
os.makedirs('static/plots', exist_ok=True)
os.makedirs('static/csv', exist_ok=True)
os.makedirs('saved_models', exist_ok=True)
# 启动信息输出
logger.info("=" * 60)
logger.info("现代化药店销售预测系统API服务启动")
logger.info("=" * 60)
logger.info(f"服务器地址: {args.host}:{args.port}")
logger.info(f"调试模式: {args.debug}")
logger.info(f"WebSocket: ws://{args.host}:{args.port}/training")
logger.info(f"模型目录: saved_models")
logger.info("特性: loguru日志 + 独立训练进程 + 中文支持")
logger.info("=" * 60)
# 启动训练进程管理器
if training_manager is not None:
logger.info("🚀 启动训练进程管理器...")
try:
training_manager.start()
# 设置WebSocket回调
def websocket_callback(event, data):
try:
socketio.emit(event, data, namespace='/training')
except Exception as e:
logger.error(f"WebSocket回调失败: {e}")
training_manager.set_websocket_callback(websocket_callback)
logger.info("✅ 训练进程管理器已启动")
except Exception as e:
logger.error(f"❌ 训练进程管理器启动失败: {e}")
else:
logger.warning("⚠️ 训练进程管理器不可用,将以有限功能模式运行")
try:
# 使用 SocketIO 启动应用
socketio.run(app, host=args.host, port=args.port, debug=args.debug, allow_unsafe_werkzeug=True)
finally:
# 确保在退出时停止训练进程管理器
if training_manager is not None:
logger.info("🛑 正在停止训练进程管理器...")
try:
training_manager.stop()
except Exception as e:
logger.error(f"停止训练进程管理器时出错: {e}")
logger.info("👋 API服务器已关闭")