数据文件保存机构改为### 1.2. 文件存储位置 - **最终产物**: 所有最终模型、元数据文件、损失图等,统一存放在 `saved_models/` 根目录下。 - **过程文件**: 所有训练过程中的检查点文件,统一存放在 `saved_models/checkpoints/` 目录下。 ### 1.3. 文件名生成规则 1. **构建逻辑路径**: 根据训练参数(模式、范围、类型、版本)确定逻辑路径。 - *示例*: `product/P001_all/mlstm/v2` 2. **生成文件名前缀**: 将逻辑路径中的所有 `/` 替换为 `_`。 - *示例*: `product_P001_all_mlstm_v2` 3. **拼接文件后缀**: 在前缀后加上描述文件类型的后缀。 - `_model.pth` - `_loss_curve.png` - `_checkpoint_best.pth` - `_checkpoint_epoch_{N}.pth` #### **完整示例:** - **最终模型**: `saved_models/product_P001_all_mlstm_v2_model.pth` - **最佳检查点**: `saved_models/checkpoints/product_P001_all_mlstm_v2_checkpoint_best.pth` - **Epoch 50 检查点**: `saved_models/checkpoints/product_P001_all_mlstm_v2_checkpoint_epoch_50.pth`
559 lines
25 KiB
Python
559 lines
25 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
独立训练进程管理器
|
||
使用multiprocessing实现真正的并行训练,避免GIL限制
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import uuid
|
||
import time
|
||
import json
|
||
import queue
|
||
import multiprocessing as mp
|
||
from multiprocessing import Process, Queue, Manager
|
||
from dataclasses import dataclass, asdict
|
||
from typing import Dict, Any, Optional, Callable
|
||
from threading import Thread, Lock
|
||
from pathlib import Path
|
||
|
||
# 添加当前目录到路径
|
||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||
server_dir = os.path.dirname(current_dir)
|
||
sys.path.append(server_dir)
|
||
|
||
from utils.logging_config import setup_api_logging, get_training_logger, log_training_progress
|
||
from utils.file_save import ModelPathManager
|
||
import numpy as np
|
||
|
||
def convert_numpy_types(obj):
|
||
"""递归地将字典/列表中的NumPy类型转换为Python原生类型"""
|
||
if isinstance(obj, dict):
|
||
return {k: convert_numpy_types(v) for k, v in obj.items()}
|
||
elif isinstance(obj, list):
|
||
return [convert_numpy_types(i) for i in obj]
|
||
elif isinstance(obj, np.generic):
|
||
return obj.item()
|
||
return obj
|
||
|
||
@dataclass
|
||
class TrainingTask:
|
||
"""训练任务数据结构"""
|
||
task_id: str
|
||
product_id: str
|
||
model_type: str
|
||
training_mode: str
|
||
store_id: Optional[str] = None
|
||
aggregation_method: Optional[str] = None # 新增:聚合方式
|
||
product_scope: str = 'all'
|
||
product_ids: Optional[list] = None
|
||
epochs: int = 100
|
||
status: str = "pending" # pending, running, completed, failed
|
||
start_time: Optional[str] = None
|
||
end_time: Optional[str] = None
|
||
progress: float = 0.0
|
||
message: str = ""
|
||
error: Optional[str] = None
|
||
metrics: Optional[Dict[str, Any]] = None
|
||
process_id: Optional[int] = None
|
||
path_info: Optional[Dict[str, Any]] = None # 新增字段
|
||
version: Optional[int] = None # 新增版本字段
|
||
|
||
class TrainingWorker:
|
||
"""训练工作进程"""
|
||
|
||
def __init__(self, task_queue: Queue, result_queue: Queue, progress_queue: Queue):
|
||
self.task_queue = task_queue
|
||
self.result_queue = result_queue
|
||
self.progress_queue = progress_queue
|
||
|
||
def run_training_task(self, task: TrainingTask):
|
||
"""执行训练任务"""
|
||
try:
|
||
# 设置进程级别的日志
|
||
logger = setup_api_logging(log_level="INFO")
|
||
training_logger = get_training_logger(task.task_id, task.model_type, task.product_id)
|
||
|
||
# 发送日志到主控制台
|
||
self.progress_queue.put({
|
||
'task_id': task.task_id,
|
||
'log_type': 'info',
|
||
'message': f"🚀 训练进程启动 - PID: {os.getpid()}"
|
||
})
|
||
self.progress_queue.put({
|
||
'task_id': task.task_id,
|
||
'log_type': 'info',
|
||
'message': f"📋 任务参数: {task.model_type} | {task.product_id} | {task.epochs}轮次"
|
||
})
|
||
|
||
training_logger.info(f"🚀 训练进程启动 - PID: {os.getpid()}")
|
||
training_logger.info(f"📋 任务参数: {task.model_type} | {task.product_id} | {task.epochs}轮次")
|
||
|
||
# 更新任务状态
|
||
task.status = "running"
|
||
task.start_time = time.strftime('%Y-%m-%d %H:%M:%S')
|
||
task.process_id = os.getpid()
|
||
self.result_queue.put(('update', asdict(task)))
|
||
|
||
# 模拟训练进度更新
|
||
for epoch in range(1, task.epochs + 1):
|
||
progress = (epoch / task.epochs) * 100
|
||
|
||
# 发送进度更新
|
||
self.progress_queue.put({
|
||
'task_id': task.task_id,
|
||
'progress': progress,
|
||
'epoch': epoch,
|
||
'total_epochs': task.epochs,
|
||
'message': f"Epoch {epoch}/{task.epochs}"
|
||
})
|
||
|
||
training_logger.info(f"🔄 训练进度: Epoch {epoch}/{task.epochs} ({progress:.1f}%)")
|
||
|
||
# 模拟训练时间
|
||
time.sleep(1) # 实际训练中这里会是真正的训练代码
|
||
|
||
# 导入真正的训练函数
|
||
try:
|
||
# 添加服务器目录到路径,确保能找到core模块
|
||
server_dir = os.path.dirname(os.path.dirname(__file__))
|
||
if server_dir not in sys.path:
|
||
sys.path.append(server_dir)
|
||
|
||
from core.predictor import PharmacyPredictor
|
||
|
||
predictor = PharmacyPredictor()
|
||
training_logger.info("🤖 开始调用实际训练器")
|
||
|
||
# 发送训练开始日志到主控制台
|
||
self.progress_queue.put({
|
||
'task_id': task.task_id,
|
||
'log_type': 'info',
|
||
'message': f"🤖 开始执行 {task.model_type} 模型训练..."
|
||
})
|
||
|
||
# 创建子进程内的进度回调函数
|
||
def progress_callback(progress_data):
|
||
"""子进程内的进度回调,通过队列发送到主进程"""
|
||
try:
|
||
# 添加任务ID到进度数据
|
||
progress_data['task_id'] = task.task_id
|
||
self.progress_queue.put(progress_data)
|
||
except Exception as e:
|
||
training_logger.error(f"进度回调失败: {e}")
|
||
|
||
# 执行真正的训练,传递进度回调和路径信息
|
||
metrics = predictor.train_model(
|
||
product_id=task.product_id,
|
||
model_type=task.model_type,
|
||
epochs=task.epochs,
|
||
store_id=task.store_id,
|
||
training_mode=task.training_mode,
|
||
aggregation_method=task.aggregation_method, # 传递聚合方式
|
||
product_scope=task.product_scope, # 传递药品范围
|
||
product_ids=task.product_ids, # 传递药品ID列表
|
||
socketio=None, # 子进程中不能直接使用socketio
|
||
task_id=task.task_id,
|
||
progress_callback=progress_callback, # 传递进度回调函数
|
||
path_info=task.path_info # 传递路径信息
|
||
)
|
||
|
||
# 发送训练完成日志到主控制台
|
||
self.progress_queue.put({
|
||
'task_id': task.task_id,
|
||
'log_type': 'success',
|
||
'message': f"✅ {task.model_type} 模型训练完成!"
|
||
})
|
||
|
||
if metrics:
|
||
if 'error' in metrics:
|
||
self.progress_queue.put({
|
||
'task_id': task.task_id,
|
||
'log_type': 'error',
|
||
'message': f"❌ 训练返回错误: {metrics['error']}"
|
||
})
|
||
else:
|
||
# 只有在没有错误时才格式化指标
|
||
mse_val = metrics.get('mse', 'N/A')
|
||
rmse_val = metrics.get('rmse', 'N/A')
|
||
|
||
mse_str = f"{mse_val:.4f}" if isinstance(mse_val, (int, float)) else mse_val
|
||
rmse_str = f"{rmse_val:.4f}" if isinstance(rmse_val, (int, float)) else rmse_val
|
||
|
||
self.progress_queue.put({
|
||
'task_id': task.task_id,
|
||
'log_type': 'info',
|
||
'message': f"📊 训练指标: MSE={mse_str}, RMSE={rmse_str}"
|
||
})
|
||
except ImportError as e:
|
||
training_logger.error(f"❌ 导入训练器失败: {e}")
|
||
# 返回模拟的训练结果用于测试
|
||
metrics = {
|
||
"mse": 0.001,
|
||
"rmse": 0.032,
|
||
"mae": 0.025,
|
||
"r2": 0.95,
|
||
"mape": 2.5,
|
||
"training_time": task.epochs * 2,
|
||
"note": "模拟训练结果(导入失败时的备用方案)"
|
||
}
|
||
training_logger.warning("⚠️ 使用模拟训练结果")
|
||
|
||
# 检查训练是否成功
|
||
if metrics:
|
||
# 训练成功
|
||
task.status = "completed"
|
||
task.end_time = time.strftime('%Y-%m-%d %H:%M:%S')
|
||
task.progress = 100.0
|
||
task.metrics = metrics
|
||
task.message = "训练完成"
|
||
|
||
training_logger.success(f"✅ 训练任务完成 - 耗时: {task.end_time}")
|
||
training_logger.info(f"📊 训练指标: {metrics}")
|
||
|
||
self.result_queue.put(('complete', asdict(task)))
|
||
else:
|
||
# 训练失败(性能不佳)
|
||
# 即使性能不佳,也标记为完成,让用户决定是否使用
|
||
task.status = "completed"
|
||
task.end_time = time.strftime('%Y-%m-%d %H:%M:%S')
|
||
task.metrics = metrics if metrics else {}
|
||
task.message = "训练完成(性能可能不佳)"
|
||
|
||
training_logger.warning(f"⚠️ 训练完成,但性能可能不佳 (metrics: {metrics})")
|
||
self.result_queue.put(('complete', asdict(task)))
|
||
|
||
except Exception as e:
|
||
error_msg = str(e)
|
||
task.status = "failed"
|
||
task.end_time = time.strftime('%Y-%m-%d %H:%M:%S')
|
||
task.error = error_msg
|
||
task.message = f"训练失败: {error_msg}"
|
||
|
||
training_logger.error(f"❌ 训练任务失败: {error_msg}")
|
||
self.result_queue.put(('error', asdict(task)))
|
||
|
||
def start(self):
|
||
"""启动工作进程"""
|
||
while True:
|
||
try:
|
||
# 从队列获取任务(超时5秒)
|
||
task_data = self.task_queue.get(timeout=5)
|
||
if task_data is None: # 毒丸,退出信号
|
||
break
|
||
|
||
task = TrainingTask(**task_data)
|
||
self.run_training_task(task)
|
||
|
||
except queue.Empty:
|
||
continue
|
||
except Exception as e:
|
||
print(f"工作进程错误: {e}")
|
||
continue
|
||
|
||
class TrainingProcessManager:
|
||
"""训练进程管理器"""
|
||
|
||
def __init__(self, max_workers: int = 2):
|
||
self.max_workers = max_workers
|
||
self.tasks: Dict[str, TrainingTask] = {}
|
||
self.processes: Dict[str, Process] = {}
|
||
self.task_queue = Queue()
|
||
self.result_queue = Queue()
|
||
self.progress_queue = Queue()
|
||
self.running = False
|
||
self.lock = Lock()
|
||
|
||
# WebSocket回调
|
||
self.websocket_callback: Optional[Callable] = None
|
||
|
||
# 设置日志
|
||
self.logger = setup_api_logging()
|
||
self.path_manager = ModelPathManager() # 实例化
|
||
|
||
def start(self):
|
||
"""启动进程管理器"""
|
||
if self.running:
|
||
return
|
||
|
||
self.running = True
|
||
self.logger.info(f"🚀 训练进程管理器启动 - 最大工作进程数: {self.max_workers}")
|
||
|
||
# 启动工作进程
|
||
for i in range(self.max_workers):
|
||
worker = TrainingWorker(self.task_queue, self.result_queue, self.progress_queue)
|
||
process = Process(target=worker.start, name=f"TrainingWorker-{i}")
|
||
process.start()
|
||
self.processes[f"worker-{i}"] = process
|
||
self.logger.info(f"🔧 工作进程 {i} 启动 - PID: {process.pid}")
|
||
|
||
# 启动结果监听线程
|
||
self.result_thread = Thread(target=self._monitor_results, daemon=True)
|
||
self.result_thread.start()
|
||
|
||
# 启动进度监听线程
|
||
self.progress_thread = Thread(target=self._monitor_progress, daemon=True)
|
||
self.progress_thread.start()
|
||
|
||
def stop(self):
|
||
"""停止进程管理器"""
|
||
if not self.running:
|
||
return
|
||
|
||
self.logger.info("🛑 正在停止训练进程管理器...")
|
||
self.running = False
|
||
|
||
# 发送停止信号给所有工作进程
|
||
for _ in range(self.max_workers):
|
||
self.task_queue.put(None)
|
||
|
||
# 等待所有进程结束
|
||
for name, process in self.processes.items():
|
||
process.join(timeout=10)
|
||
if process.is_alive():
|
||
self.logger.warning(f"⚠️ 强制终止进程: {name}")
|
||
process.terminate()
|
||
|
||
self.logger.info("✅ 训练进程管理器已停止")
|
||
|
||
def submit_task(self, training_params: Dict[str, Any], path_info: Dict[str, Any]) -> str:
|
||
"""
|
||
提交训练任务
|
||
Args:
|
||
training_params (Dict[str, Any]): 来自API请求的原始参数
|
||
path_info (Dict[str, Any]): 由ModelPathManager生成的路径和版本信息
|
||
"""
|
||
task_id = str(uuid.uuid4())
|
||
|
||
task = TrainingTask(
|
||
task_id=task_id,
|
||
product_id=training_params.get('product_id'),
|
||
model_type=training_params.get('model_type'),
|
||
training_mode=training_params.get('training_mode', 'product'),
|
||
store_id=training_params.get('store_id'),
|
||
epochs=training_params.get('epochs', 100),
|
||
aggregation_method=training_params.get('aggregation_method'), # 新增
|
||
product_scope=training_params.get('product_scope', 'all'),
|
||
product_ids=training_params.get('product_ids'),
|
||
path_info=path_info # 存储路径信息
|
||
)
|
||
|
||
with self.lock:
|
||
self.tasks[task_id] = task
|
||
|
||
# 将任务放入队列
|
||
self.task_queue.put(asdict(task))
|
||
|
||
self.logger.info(f"📋 训练任务已提交: {task_id[:8]} | {task.model_type} | {task.product_id}")
|
||
return task_id
|
||
|
||
def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
|
||
"""获取任务状态"""
|
||
with self.lock:
|
||
task = self.tasks.get(task_id)
|
||
if task:
|
||
return asdict(task)
|
||
return None
|
||
|
||
def get_all_tasks(self) -> Dict[str, Dict[str, Any]]:
|
||
"""获取所有任务状态"""
|
||
with self.lock:
|
||
return {task_id: asdict(task) for task_id, task in self.tasks.items()}
|
||
|
||
def cancel_task(self, task_id: str) -> bool:
|
||
"""取消任务(仅对未开始的任务有效)"""
|
||
with self.lock:
|
||
task = self.tasks.get(task_id)
|
||
if task and task.status == "pending":
|
||
task.status = "cancelled"
|
||
task.message = "任务已取消"
|
||
return True
|
||
return False
|
||
|
||
def _monitor_results(self):
|
||
"""监听训练结果"""
|
||
while self.running:
|
||
try:
|
||
result = self.result_queue.get(timeout=1)
|
||
action, task_data = result
|
||
|
||
task_id = task_data['task_id']
|
||
|
||
# 立即对从队列中取出的数据进行类型转换
|
||
serializable_task_data = convert_numpy_types(task_data)
|
||
|
||
with self.lock:
|
||
if task_id in self.tasks:
|
||
task = self.tasks[task_id]
|
||
# 使用转换后的数据更新任务状态
|
||
for key, value in serializable_task_data.items():
|
||
if hasattr(task, key):
|
||
setattr(task, key, value)
|
||
|
||
# 如果任务成功完成,则更新版本文件和任务对象中的版本号
|
||
if action == 'complete':
|
||
# 只有在训练成功(metrics有效)时才保存版本信息
|
||
if task.metrics and task.metrics.get('r2', -1) >= 0:
|
||
if task.path_info:
|
||
# 确保使用正确的、经过规范化处理的标识符
|
||
version_control_identifier = task.path_info.get('identifier')
|
||
version = task.path_info.get('version')
|
||
if version_control_identifier and version:
|
||
try:
|
||
self.path_manager.save_version_info(version_control_identifier, version)
|
||
self.logger.info(f"✅ 版本信息已更新: identifier={version_control_identifier}, version={version}")
|
||
task.version = version # 关键修复:将版本号保存到任务对象中
|
||
except Exception as e:
|
||
self.logger.error(f"❌ 更新版本文件失败: {e}")
|
||
else:
|
||
self.logger.warning(f"⚠️ 任务 {task_id} 训练性能不佳或失败,不保存版本信息。")
|
||
|
||
# WebSocket通知 - 使用已转换的数据
|
||
if self.websocket_callback:
|
||
try:
|
||
if action == 'complete':
|
||
# 从任务对象中获取权威的版本号
|
||
version = None
|
||
with self.lock:
|
||
task = self.tasks.get(task_id)
|
||
if task:
|
||
version = task.version
|
||
|
||
# 训练完成 - 发送完成状态
|
||
self.websocket_callback('training_update', {
|
||
'task_id': task_id,
|
||
'action': 'completed',
|
||
'status': 'completed',
|
||
'progress': 100,
|
||
'message': serializable_task_data.get('message', '训练完成'),
|
||
'metrics': serializable_task_data.get('metrics'),
|
||
'end_time': serializable_task_data.get('end_time'),
|
||
'product_id': serializable_task_data.get('product_id'),
|
||
'model_type': serializable_task_data.get('model_type'),
|
||
'version': version, # 添加版本号
|
||
'product_scope': serializable_task_data.get('product_scope'),
|
||
'product_ids': serializable_task_data.get('product_ids')
|
||
})
|
||
# 额外发送一个完成事件,确保前端能收到
|
||
self.websocket_callback('training_completed', {
|
||
'task_id': task_id,
|
||
'status': 'completed',
|
||
'progress': 100,
|
||
'message': serializable_task_data.get('message', '训练完成'),
|
||
'metrics': serializable_task_data.get('metrics'),
|
||
'product_id': serializable_task_data.get('product_id'),
|
||
'model_type': serializable_task_data.get('model_type'),
|
||
'version': version, # 添加版本号
|
||
'product_scope': serializable_task_data.get('product_scope'),
|
||
'product_ids': serializable_task_data.get('product_ids')
|
||
})
|
||
elif action == 'error':
|
||
# 训练失败
|
||
self.websocket_callback('training_update', {
|
||
'task_id': task_id,
|
||
'action': 'failed',
|
||
'status': 'failed',
|
||
'progress': 0,
|
||
'message': serializable_task_data.get('message', '训练失败'),
|
||
'error': serializable_task_data.get('error'),
|
||
'product_id': serializable_task_data.get('product_id'),
|
||
'model_type': serializable_task_data.get('model_type'),
|
||
'product_scope': serializable_task_data.get('product_scope'),
|
||
'product_ids': serializable_task_data.get('product_ids')
|
||
})
|
||
else:
|
||
# 状态更新
|
||
self.websocket_callback('training_update', {
|
||
'task_id': task_id,
|
||
'action': action,
|
||
'status': serializable_task_data.get('status'),
|
||
'progress': serializable_task_data.get('progress', 0),
|
||
'message': serializable_task_data.get('message', ''),
|
||
'metrics': serializable_task_data.get('metrics'),
|
||
'product_id': serializable_task_data.get('product_id'),
|
||
'model_type': serializable_task_data.get('model_type'),
|
||
'product_scope': serializable_task_data.get('product_scope'),
|
||
'product_ids': serializable_task_data.get('product_ids')
|
||
})
|
||
except Exception as e:
|
||
self.logger.error(f"WebSocket通知失败: {e}")
|
||
|
||
except queue.Empty:
|
||
continue
|
||
except Exception as e:
|
||
self.logger.error(f"结果监听错误: {e}")
|
||
|
||
def _monitor_progress(self):
|
||
"""监听训练进度"""
|
||
while self.running:
|
||
try:
|
||
progress_data = self.progress_queue.get(timeout=1)
|
||
|
||
task_id = progress_data['task_id']
|
||
|
||
# 处理日志消息,显示到主控制台
|
||
if 'log_type' in progress_data:
|
||
log_type = progress_data['log_type']
|
||
message = progress_data['message']
|
||
task_short_id = task_id[:8]
|
||
|
||
if log_type == 'info':
|
||
print(f"[{task_short_id}] {message}", flush=True)
|
||
self.logger.info(f"[{task_short_id}] {message}")
|
||
elif log_type == 'success':
|
||
print(f"[{task_short_id}] {message}", flush=True)
|
||
self.logger.success(f"[{task_short_id}] {message}")
|
||
|
||
# 如果是训练完成的成功消息,发送WebSocket通知
|
||
if "训练完成" in message:
|
||
if self.websocket_callback:
|
||
try:
|
||
self.websocket_callback('training_progress', {
|
||
'task_id': task_id,
|
||
'progress': 100,
|
||
'message': message,
|
||
'log_type': 'success',
|
||
'timestamp': time.time()
|
||
})
|
||
except Exception as e:
|
||
self.logger.error(f"成功消息WebSocket通知失败: {e}")
|
||
|
||
elif log_type == 'error':
|
||
print(f"[{task_short_id}] {message}", flush=True)
|
||
self.logger.error(f"[{task_short_id}] {message}")
|
||
elif log_type == 'warning':
|
||
print(f"[{task_short_id}] {message}", flush=True)
|
||
self.logger.warning(f"[{task_short_id}] {message}")
|
||
|
||
# 更新任务进度(只处理包含progress的消息)
|
||
if 'progress' in progress_data:
|
||
with self.lock:
|
||
if task_id in self.tasks:
|
||
self.tasks[task_id].progress = progress_data['progress']
|
||
self.tasks[task_id].message = progress_data.get('message', '')
|
||
|
||
# WebSocket通知进度更新
|
||
if self.websocket_callback and 'progress' in progress_data:
|
||
try:
|
||
# 在发送前确保所有数据类型都是JSON可序列化的
|
||
serializable_data = convert_numpy_types(progress_data)
|
||
self.websocket_callback('training_progress', serializable_data)
|
||
except Exception as e:
|
||
self.logger.error(f"进度WebSocket通知失败: {e}")
|
||
|
||
except queue.Empty:
|
||
continue
|
||
except Exception as e:
|
||
self.logger.error(f"进度监听错误: {e}")
|
||
|
||
def set_websocket_callback(self, callback: Callable):
|
||
"""设置WebSocket回调函数"""
|
||
self.websocket_callback = callback
|
||
|
||
# 全局进程管理器实例
|
||
training_manager = TrainingProcessManager()
|
||
|
||
def get_training_manager() -> TrainingProcessManager:
|
||
"""获取训练进程管理器实例"""
|
||
return training_manager |