297 lines
10 KiB
Python
297 lines
10 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
现代化API服务器 - 使用loguru日志和独立训练进程
|
||
"""
|
||
|
||
import sys
|
||
import os
|
||
import json
|
||
from datetime import datetime
|
||
from flask import Flask, jsonify, request
|
||
from flask_cors import CORS
|
||
from flask_socketio import SocketIO
|
||
import argparse
|
||
|
||
# 获取当前脚本所在目录的绝对路径
|
||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||
sys.path.append(current_dir)
|
||
|
||
# 使用新的现代化日志系统
|
||
try:
|
||
from utils.logging_config import setup_api_logging, get_logger
|
||
# 初始化现代化日志系统
|
||
logger = setup_api_logging(log_dir=".", log_level="INFO")
|
||
logger.info("✅ 现代化日志系统导入成功")
|
||
except Exception as e:
|
||
print(f"❌ 日志系统导入失败: {e}")
|
||
import logging
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
try:
|
||
from utils.training_process_manager import get_training_manager
|
||
# 获取训练进程管理器
|
||
training_manager = get_training_manager()
|
||
logger.info("✅ 训练进程管理器导入成功")
|
||
except Exception as e:
|
||
logger.error(f"❌ 训练进程管理器导入失败: {e}")
|
||
training_manager = None
|
||
|
||
# 初始化数据库
|
||
def init_db():
|
||
"""初始化数据库"""
|
||
import sqlite3
|
||
conn = sqlite3.connect('prediction_history.db')
|
||
cursor = conn.cursor()
|
||
|
||
# 创建预测历史表
|
||
cursor.execute('''
|
||
CREATE TABLE IF NOT EXISTS prediction_history (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
prediction_id TEXT UNIQUE NOT NULL,
|
||
product_id TEXT NOT NULL,
|
||
product_name TEXT NOT NULL,
|
||
model_type TEXT NOT NULL,
|
||
model_id TEXT NOT NULL,
|
||
start_date TEXT,
|
||
future_days INTEGER,
|
||
created_at TEXT NOT NULL,
|
||
predictions_data TEXT,
|
||
metrics TEXT,
|
||
chart_data TEXT,
|
||
analysis TEXT,
|
||
file_path TEXT
|
||
)
|
||
''')
|
||
|
||
# 创建模型版本表
|
||
cursor.execute('''
|
||
CREATE TABLE IF NOT EXISTS model_versions (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
product_id TEXT NOT NULL,
|
||
model_type TEXT NOT NULL,
|
||
version TEXT NOT NULL,
|
||
file_path TEXT NOT NULL,
|
||
created_at TEXT NOT NULL,
|
||
metrics TEXT,
|
||
is_active INTEGER DEFAULT 1,
|
||
UNIQUE(product_id, model_type, version)
|
||
)
|
||
''')
|
||
|
||
conn.commit()
|
||
conn.close()
|
||
logger.info("数据库初始化完成,包含模型版本管理表")
|
||
|
||
# 创建 Flask 应用
|
||
app = Flask(__name__)
|
||
app.config['SECRET_KEY'] = 'your-secret-key-here'
|
||
|
||
# 启用CORS
|
||
CORS(app, origins="*")
|
||
|
||
# 初始化 SocketIO
|
||
socketio = SocketIO(app, cors_allowed_origins="*", namespace='/training')
|
||
|
||
@app.route('/api/products', methods=['GET'])
|
||
def get_products():
|
||
"""获取产品列表"""
|
||
return jsonify({
|
||
"status": "success",
|
||
"data": [
|
||
{"id": "P001", "name": "感冒灵颗粒"},
|
||
{"id": "P002", "name": "布洛芬片"},
|
||
{"id": "P003", "name": "维生素C片"},
|
||
{"id": "P004", "name": "阿莫西林胶囊"},
|
||
{"id": "P005", "name": "板蓝根颗粒"}
|
||
]
|
||
})
|
||
|
||
@app.route('/api/training', methods=['POST'])
|
||
def start_training():
|
||
"""启动模型训练 - 使用现代化进程管理器"""
|
||
data = request.get_json()
|
||
|
||
# 参数验证
|
||
model_type = data.get('model_type')
|
||
product_id = data.get('product_id', 'P001')
|
||
epochs = data.get('epochs', 3)
|
||
training_mode = data.get('training_mode', 'product')
|
||
store_id = data.get('store_id')
|
||
|
||
if not model_type:
|
||
return jsonify({'error': '缺少model_type参数'}), 400
|
||
|
||
# 检查模型类型是否有效
|
||
valid_model_types = ['mlstm', 'kan', 'optimized_kan', 'transformer', 'tcn']
|
||
if model_type not in valid_model_types:
|
||
return jsonify({'error': '无效的模型类型'}), 400
|
||
|
||
# 检查训练进程管理器是否可用
|
||
if training_manager is None:
|
||
logger.error("❌ 训练进程管理器不可用")
|
||
return jsonify({'error': '训练进程管理器初始化失败,请检查系统配置'}), 500
|
||
|
||
# 使用新的训练进程管理器提交任务
|
||
try:
|
||
task_id = training_manager.submit_task(
|
||
product_id=product_id,
|
||
model_type=model_type,
|
||
training_mode=training_mode,
|
||
store_id=store_id,
|
||
epochs=epochs
|
||
)
|
||
|
||
logger.info(f"🚀 训练任务已提交到进程管理器: {task_id[:8]} | {model_type} | {product_id}")
|
||
|
||
return jsonify({
|
||
'message': '模型训练已开始(使用独立进程)',
|
||
'task_id': task_id,
|
||
'training_mode': training_mode,
|
||
'model_type': model_type,
|
||
'product_id': product_id,
|
||
'epochs': epochs
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 提交训练任务失败: {str(e)}")
|
||
return jsonify({'error': f'启动训练任务失败: {str(e)}'}), 500
|
||
|
||
@app.route('/api/training', methods=['GET'])
|
||
def get_all_training_tasks():
|
||
"""获取所有训练任务的状态"""
|
||
if training_manager is None:
|
||
return jsonify({"status": "error", "message": "训练进程管理器不可用"}), 500
|
||
|
||
try:
|
||
all_tasks = training_manager.get_all_tasks()
|
||
|
||
# 为了方便前端使用,我们将任务ID也包含在每个任务信息中
|
||
tasks_with_id = []
|
||
for task_id, task_info in all_tasks.items():
|
||
task_copy = task_info.copy()
|
||
task_copy['task_id'] = task_id
|
||
tasks_with_id.append(task_copy)
|
||
|
||
# 按开始时间降序排序,最新的任务在前面
|
||
sorted_tasks = sorted(tasks_with_id,
|
||
key=lambda x: x.get('start_time', ''),
|
||
reverse=True)
|
||
|
||
return jsonify({"status": "success", "data": sorted_tasks})
|
||
except Exception as e:
|
||
logger.error(f"获取训练任务列表失败: {str(e)}")
|
||
return jsonify({"status": "error", "message": str(e)}), 500
|
||
|
||
@app.route('/api/training/<task_id>', methods=['GET'])
|
||
def get_training_status(task_id):
|
||
"""查询特定训练任务状态"""
|
||
if training_manager is None:
|
||
return jsonify({"status": "error", "message": "训练进程管理器不可用"}), 500
|
||
|
||
try:
|
||
task_info = training_manager.get_task_status(task_id)
|
||
|
||
if not task_info:
|
||
return jsonify({"status": "error", "message": "任务不存在"}), 404
|
||
|
||
# 如果任务已完成,添加模型详情链接
|
||
if task_info['status'] == 'completed':
|
||
task_info['model_details_url'] = f"/api/models?product_id={task_info['product_id']}&model_type={task_info['model_type']}"
|
||
|
||
return jsonify({
|
||
"status": "success",
|
||
"data": task_info
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"查询训练任务状态失败: {str(e)}")
|
||
return jsonify({"status": "error", "message": str(e)}), 500
|
||
|
||
@app.route('/api/version', methods=['GET'])
|
||
def api_version():
|
||
"""检查API版本和状态"""
|
||
return jsonify({
|
||
"status": "success",
|
||
"data": {
|
||
"version": "2.0.0-modern",
|
||
"description": "现代化药店销售预测系统API",
|
||
"features": [
|
||
"loguru现代化日志系统",
|
||
"独立训练进程管理",
|
||
"完美中文和emoji支持",
|
||
"实时WebSocket进度推送"
|
||
],
|
||
"timestamp": datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||
}
|
||
})
|
||
|
||
# WebSocket 事件处理
|
||
@socketio.on('connect', namespace='/training')
|
||
def on_connect():
|
||
logger.info("WebSocket客户端已连接")
|
||
|
||
@socketio.on('disconnect', namespace='/training')
|
||
def on_disconnect():
|
||
logger.info("WebSocket客户端已断开")
|
||
|
||
if __name__ == '__main__':
|
||
# 初始化数据库
|
||
init_db()
|
||
|
||
# 解析命令行参数
|
||
parser = argparse.ArgumentParser(description='现代化药店销售预测系统API服务')
|
||
parser.add_argument('--host', default='0.0.0.0', help='服务器主机地址')
|
||
parser.add_argument('--port', type=int, default=5000, help='服务器端口')
|
||
parser.add_argument('--debug', action='store_true', help='是否启用调试模式')
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 确保目录存在
|
||
os.makedirs('static/plots', exist_ok=True)
|
||
os.makedirs('static/csv', exist_ok=True)
|
||
os.makedirs('saved_models', exist_ok=True)
|
||
|
||
# 启动信息输出
|
||
logger.info("=" * 60)
|
||
logger.info("现代化药店销售预测系统API服务启动")
|
||
logger.info("=" * 60)
|
||
logger.info(f"服务器地址: {args.host}:{args.port}")
|
||
logger.info(f"调试模式: {args.debug}")
|
||
logger.info(f"WebSocket: ws://{args.host}:{args.port}/training")
|
||
logger.info(f"模型目录: saved_models")
|
||
logger.info("特性: loguru日志 + 独立训练进程 + 中文支持")
|
||
logger.info("=" * 60)
|
||
|
||
# 启动训练进程管理器
|
||
if training_manager is not None:
|
||
logger.info("🚀 启动训练进程管理器...")
|
||
try:
|
||
training_manager.start()
|
||
|
||
# 设置WebSocket回调
|
||
def websocket_callback(event, data):
|
||
try:
|
||
socketio.emit(event, data, namespace='/training')
|
||
except Exception as e:
|
||
logger.error(f"WebSocket回调失败: {e}")
|
||
|
||
training_manager.set_websocket_callback(websocket_callback)
|
||
logger.info("✅ 训练进程管理器已启动")
|
||
except Exception as e:
|
||
logger.error(f"❌ 训练进程管理器启动失败: {e}")
|
||
else:
|
||
logger.warning("⚠️ 训练进程管理器不可用,将以有限功能模式运行")
|
||
|
||
try:
|
||
# 使用 SocketIO 启动应用
|
||
socketio.run(app, host=args.host, port=args.port, debug=args.debug, allow_unsafe_werkzeug=True)
|
||
finally:
|
||
# 确保在退出时停止训练进程管理器
|
||
if training_manager is not None:
|
||
logger.info("🛑 正在停止训练进程管理器...")
|
||
try:
|
||
training_manager.stop()
|
||
except Exception as e:
|
||
logger.error(f"停止训练进程管理器时出错: {e}")
|
||
logger.info("👋 API服务器已关闭") |