ShopTRAINING/server/utils/training_process_manager.py
xz2000 28bae35783 # 扁平化模型数据处理规范 (最终版)
**版本**: 4.0 (最终版)
**核心思想**: 逻辑路径被转换为文件名的一部分,实现极致扁平化的文件存储。

---

## 一、 文件保存规则

### 1.1. 核心原则

所有元数据都被编码到文件名中。一个逻辑上的层级路径(例如 `product/P001_all/mlstm/v2`)应该被转换为一个用下划线连接的文件名前缀(`product_P001_all_mlstm_v2`)。

### 1.2. 文件存储位置

-   **最终产物**: 所有最终模型、元数据文件、损失图等,统一存放在 `saved_models/` 根目录下。
-   **过程文件**: 所有训练过程中的检查点文件,统一存放在 `saved_models/checkpoints/` 目录下。

### 1.3. 文件名生成规则

1.  **构建逻辑路径**: 根据训练参数(模式、范围、类型、版本)确定逻辑路径。
    -   *示例*: `product/P001_all/mlstm/v2`

2.  **生成文件名前缀**: 将逻辑路径中的所有 `/` 替换为 `_`。
    -   *示例*: `product_P001_all_mlstm_v2`

3.  **拼接文件后缀**: 在前缀后加上描述文件类型的后缀。
    -   `_model.pth`
    -   `_metadata.json`
    -   `_loss_curve.png`
    -   `_checkpoint_best.pth`
    -   `_checkpoint_epoch_{N}.pth`

#### **完整示例:**

-   **最终模型**: `saved_models/product_P001_all_mlstm_v2_model.pth`
-   **元数据**: `saved_models/product_P001_all_mlstm_v2_metadata.json`
-   **最佳检查点**: `saved_models/checkpoints/product_P001_all_mlstm_v2_checkpoint_best.pth`
-   **Epoch 50 检查点**: `saved_models/checkpoints/product_P001_all_mlstm_v2_checkpoint_epoch_50.pth`

---

## 二、 文件读取规则

1.  **确定模型元数据**: 根据需求确定要加载的模型的训练模式、范围、类型和版本。
2.  **构建文件名前缀**: 按照与保存时相同的逻辑,将元数据拼接成文件名前缀(例如 `product_P001_all_mlstm_v2`)。
3.  **定位文件**:
    -   要加载最终模型,查找文件: `saved_models/{prefix}_model.pth`。
    -   要加载最佳检查点,查找文件: `saved_models/checkpoints/{prefix}_checkpoint_best.pth`。

---

## 三、 数据库存储规则

数据库用于索引,应存储足以重构文件名前缀的关键元数据。

#### **`models` 表结构建议:**

| 字段名 | 类型 | 描述 | 示例 |
| :--- | :--- | :--- | :--- |
| `id` | INTEGER | 主键 | 1 |
| `filename_prefix` | TEXT | **完整文件名前缀,可作为唯一标识** | `product_P001_all_mlstm_v2` |
| `model_identifier`| TEXT | 用于版本控制的标识符 (不含版本) | `product_P001_all_mlstm` |
| `version` | INTEGER | 版本号 | `2` |
| `status` | TEXT | 模型状态 | `completed`, `training`, `failed` |
| `created_at` | TEXT | 创建时间 | `2025-07-21 02:29:00` |
| `metrics_summary`| TEXT | 关键性能指标的JSON字符串 | `{"rmse": 10.5, "r2": 0.89}` |

#### **保存逻辑:**
-   训练完成后,向表中插入一条记录。`filename_prefix` 字段是查找与该次训练相关的所有文件的关键。

---

## 四、 版本记录规则

版本管理依赖于根目录下的 `versions.json` 文件,以实现原子化、线程安全的版本号递增。

-   **文件名**: `versions.json`
-   **位置**: `saved_models/versions.json`
-   **结构**: 一个JSON对象,`key` 是不包含版本号的标识符,`value` 是该标识符下最新的版本号(整数)。
    -   **Key**: `{prefix_core}_{model_type}` (例如: `product_P001_all_mlstm`)
    -   **Value**: `Integer`

#### **`versions.json` 示例:**
```json
{
  "product_P001_all_mlstm": 2,
  "store_S001_P002_transformer": 1
}
```

#### **版本管理流程:**

1.  **获取新版本**: 开始训练前,构建 `key`。读取 `versions.json`,找到对应 `key` 的 `value`。新版本号为 `value + 1` (若key不存在,则为 `1`)。
2.  **更新版本**: 训练成功后,将新的版本号写回到 `versions.json`。此过程**必须使用文件锁**以防止并发冲突。

调试完成药品预测和店铺预测
2025-07-21 16:39:52 +08:00

544 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
独立训练进程管理器
使用multiprocessing实现真正的并行训练避免GIL限制
"""
import os
import sys
import uuid
import time
import json
import queue
import multiprocessing as mp
from multiprocessing import Process, Queue, Manager
from dataclasses import dataclass, asdict
from typing import Dict, Any, Optional, Callable
from threading import Thread, Lock
from pathlib import Path
# 添加当前目录到路径
current_dir = os.path.dirname(os.path.abspath(__file__))
server_dir = os.path.dirname(current_dir)
sys.path.append(server_dir)
from utils.logging_config import setup_api_logging, get_training_logger, log_training_progress
from utils.file_save import ModelPathManager
import numpy as np
def convert_numpy_types(obj):
"""递归地将字典/列表中的NumPy类型转换为Python原生类型"""
if isinstance(obj, dict):
return {k: convert_numpy_types(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [convert_numpy_types(i) for i in obj]
elif isinstance(obj, np.generic):
return obj.item()
return obj
@dataclass
class TrainingTask:
"""训练任务数据结构"""
task_id: str
product_id: str
model_type: str
training_mode: str
store_id: Optional[str] = None
aggregation_method: Optional[str] = None # 新增:聚合方式
product_scope: str = 'all'
product_ids: Optional[list] = None
epochs: int = 100
status: str = "pending" # pending, running, completed, failed
start_time: Optional[str] = None
end_time: Optional[str] = None
progress: float = 0.0
message: str = ""
error: Optional[str] = None
metrics: Optional[Dict[str, Any]] = None
process_id: Optional[int] = None
path_info: Optional[Dict[str, Any]] = None # 新增字段
version: Optional[int] = None # 新增版本字段
class TrainingWorker:
"""训练工作进程"""
def __init__(self, task_queue: Queue, result_queue: Queue, progress_queue: Queue):
self.task_queue = task_queue
self.result_queue = result_queue
self.progress_queue = progress_queue
def run_training_task(self, task: TrainingTask):
"""执行训练任务"""
try:
# 设置进程级别的日志
logger = setup_api_logging(log_level="INFO")
training_logger = get_training_logger(task.task_id, task.model_type, task.product_id)
# 发送日志到主控制台
self.progress_queue.put({
'task_id': task.task_id,
'log_type': 'info',
'message': f"🚀 训练进程启动 - PID: {os.getpid()}"
})
self.progress_queue.put({
'task_id': task.task_id,
'log_type': 'info',
'message': f"📋 任务参数: {task.model_type} | {task.product_id} | {task.epochs}轮次"
})
training_logger.info(f"🚀 训练进程启动 - PID: {os.getpid()}")
training_logger.info(f"📋 任务参数: {task.model_type} | {task.product_id} | {task.epochs}轮次")
# 更新任务状态
task.status = "running"
task.start_time = time.strftime('%Y-%m-%d %H:%M:%S')
task.process_id = os.getpid()
self.result_queue.put(('update', asdict(task)))
# 模拟训练进度更新
for epoch in range(1, task.epochs + 1):
progress = (epoch / task.epochs) * 100
# 发送进度更新
self.progress_queue.put({
'task_id': task.task_id,
'progress': progress,
'epoch': epoch,
'total_epochs': task.epochs,
'message': f"Epoch {epoch}/{task.epochs}"
})
training_logger.info(f"🔄 训练进度: Epoch {epoch}/{task.epochs} ({progress:.1f}%)")
# 模拟训练时间
time.sleep(1) # 实际训练中这里会是真正的训练代码
# 导入真正的训练函数
try:
# 添加服务器目录到路径确保能找到core模块
server_dir = os.path.dirname(os.path.dirname(__file__))
if server_dir not in sys.path:
sys.path.append(server_dir)
from core.predictor import PharmacyPredictor
predictor = PharmacyPredictor()
training_logger.info("🤖 开始调用实际训练器")
# 发送训练开始日志到主控制台
self.progress_queue.put({
'task_id': task.task_id,
'log_type': 'info',
'message': f"🤖 开始执行 {task.model_type} 模型训练..."
})
# 创建子进程内的进度回调函数
def progress_callback(progress_data):
"""子进程内的进度回调,通过队列发送到主进程"""
try:
# 添加任务ID到进度数据
progress_data['task_id'] = task.task_id
self.progress_queue.put(progress_data)
except Exception as e:
training_logger.error(f"进度回调失败: {e}")
# 执行真正的训练,传递进度回调和路径信息
metrics = predictor.train_model(
product_id=task.product_id,
model_type=task.model_type,
epochs=task.epochs,
store_id=task.store_id,
training_mode=task.training_mode,
aggregation_method=task.aggregation_method, # 传递聚合方式
product_scope=task.product_scope, # 传递药品范围
product_ids=task.product_ids, # 传递药品ID列表
socketio=None, # 子进程中不能直接使用socketio
task_id=task.task_id,
progress_callback=progress_callback, # 传递进度回调函数
path_info=task.path_info # 传递路径信息
)
# 发送训练完成日志到主控制台
self.progress_queue.put({
'task_id': task.task_id,
'log_type': 'success',
'message': f"{task.model_type} 模型训练完成!"
})
if metrics:
self.progress_queue.put({
'task_id': task.task_id,
'log_type': 'info',
'message': f"📊 训练指标: MSE={metrics.get('mse', 'N/A'):.4f}, RMSE={metrics.get('rmse', 'N/A'):.4f}"
})
except ImportError as e:
training_logger.error(f"❌ 导入训练器失败: {e}")
# 返回模拟的训练结果用于测试
metrics = {
"mse": 0.001,
"rmse": 0.032,
"mae": 0.025,
"r2": 0.95,
"mape": 2.5,
"training_time": task.epochs * 2,
"note": "模拟训练结果(导入失败时的备用方案)"
}
training_logger.warning("⚠️ 使用模拟训练结果")
# 检查训练是否成功
if metrics:
# 训练成功
task.status = "completed"
task.end_time = time.strftime('%Y-%m-%d %H:%M:%S')
task.progress = 100.0
task.metrics = metrics
task.message = "训练完成"
training_logger.success(f"✅ 训练任务完成 - 耗时: {task.end_time}")
training_logger.info(f"📊 训练指标: {metrics}")
self.result_queue.put(('complete', asdict(task)))
else:
# 训练失败(性能不佳)
# 即使性能不佳,也标记为完成,让用户决定是否使用
task.status = "completed"
task.end_time = time.strftime('%Y-%m-%d %H:%M:%S')
task.metrics = metrics if metrics else {}
task.message = "训练完成(性能可能不佳)"
training_logger.warning(f"⚠️ 训练完成,但性能可能不佳 (metrics: {metrics})")
self.result_queue.put(('complete', asdict(task)))
except Exception as e:
error_msg = str(e)
task.status = "failed"
task.end_time = time.strftime('%Y-%m-%d %H:%M:%S')
task.error = error_msg
task.message = f"训练失败: {error_msg}"
training_logger.error(f"❌ 训练任务失败: {error_msg}")
self.result_queue.put(('error', asdict(task)))
def start(self):
"""启动工作进程"""
while True:
try:
# 从队列获取任务超时5秒
task_data = self.task_queue.get(timeout=5)
if task_data is None: # 毒丸,退出信号
break
task = TrainingTask(**task_data)
self.run_training_task(task)
except queue.Empty:
continue
except Exception as e:
print(f"工作进程错误: {e}")
continue
class TrainingProcessManager:
"""训练进程管理器"""
def __init__(self, max_workers: int = 2):
self.max_workers = max_workers
self.tasks: Dict[str, TrainingTask] = {}
self.processes: Dict[str, Process] = {}
self.task_queue = Queue()
self.result_queue = Queue()
self.progress_queue = Queue()
self.running = False
self.lock = Lock()
# WebSocket回调
self.websocket_callback: Optional[Callable] = None
# 设置日志
self.logger = setup_api_logging()
self.path_manager = ModelPathManager() # 实例化
def start(self):
"""启动进程管理器"""
if self.running:
return
self.running = True
self.logger.info(f"🚀 训练进程管理器启动 - 最大工作进程数: {self.max_workers}")
# 启动工作进程
for i in range(self.max_workers):
worker = TrainingWorker(self.task_queue, self.result_queue, self.progress_queue)
process = Process(target=worker.start, name=f"TrainingWorker-{i}")
process.start()
self.processes[f"worker-{i}"] = process
self.logger.info(f"🔧 工作进程 {i} 启动 - PID: {process.pid}")
# 启动结果监听线程
self.result_thread = Thread(target=self._monitor_results, daemon=True)
self.result_thread.start()
# 启动进度监听线程
self.progress_thread = Thread(target=self._monitor_progress, daemon=True)
self.progress_thread.start()
def stop(self):
"""停止进程管理器"""
if not self.running:
return
self.logger.info("🛑 正在停止训练进程管理器...")
self.running = False
# 发送停止信号给所有工作进程
for _ in range(self.max_workers):
self.task_queue.put(None)
# 等待所有进程结束
for name, process in self.processes.items():
process.join(timeout=10)
if process.is_alive():
self.logger.warning(f"⚠️ 强制终止进程: {name}")
process.terminate()
self.logger.info("✅ 训练进程管理器已停止")
def submit_task(self, training_params: Dict[str, Any], path_info: Dict[str, Any]) -> str:
"""
提交训练任务
Args:
training_params (Dict[str, Any]): 来自API请求的原始参数
path_info (Dict[str, Any]): 由ModelPathManager生成的路径和版本信息
"""
task_id = str(uuid.uuid4())
task = TrainingTask(
task_id=task_id,
product_id=training_params.get('product_id'),
model_type=training_params.get('model_type'),
training_mode=training_params.get('training_mode', 'product'),
store_id=training_params.get('store_id'),
epochs=training_params.get('epochs', 100),
aggregation_method=training_params.get('aggregation_method'), # 新增
product_scope=training_params.get('product_scope', 'all'),
product_ids=training_params.get('product_ids'),
path_info=path_info # 存储路径信息
)
with self.lock:
self.tasks[task_id] = task
# 将任务放入队列
self.task_queue.put(asdict(task))
self.logger.info(f"📋 训练任务已提交: {task_id[:8]} | {task.model_type} | {task.product_id}")
return task_id
def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
"""获取任务状态"""
with self.lock:
task = self.tasks.get(task_id)
if task:
return asdict(task)
return None
def get_all_tasks(self) -> Dict[str, Dict[str, Any]]:
"""获取所有任务状态"""
with self.lock:
return {task_id: asdict(task) for task_id, task in self.tasks.items()}
def cancel_task(self, task_id: str) -> bool:
"""取消任务(仅对未开始的任务有效)"""
with self.lock:
task = self.tasks.get(task_id)
if task and task.status == "pending":
task.status = "cancelled"
task.message = "任务已取消"
return True
return False
def _monitor_results(self):
"""监听训练结果"""
while self.running:
try:
result = self.result_queue.get(timeout=1)
action, task_data = result
task_id = task_data['task_id']
# 立即对从队列中取出的数据进行类型转换
serializable_task_data = convert_numpy_types(task_data)
with self.lock:
if task_id in self.tasks:
task = self.tasks[task_id]
# 使用转换后的数据更新任务状态
for key, value in serializable_task_data.items():
if hasattr(task, key):
setattr(task, key, value)
# 如果任务成功完成,则更新版本文件和任务对象中的版本号
if action == 'complete':
# 只有在训练成功metrics有效时才保存版本信息
if task.metrics and task.metrics.get('r2', -1) >= 0:
if task.path_info:
identifier = task.path_info.get('identifier')
version = task.path_info.get('version')
if identifier and version:
try:
self.path_manager.save_version_info(identifier, version)
self.logger.info(f"✅ 版本信息已更新: identifier={identifier}, version={version}")
task.version = version # 关键修复:将版本号保存到任务对象中
except Exception as e:
self.logger.error(f"❌ 更新版本文件失败: {e}")
else:
self.logger.warning(f"⚠️ 任务 {task_id} 训练性能不佳或失败,不保存版本信息。")
# WebSocket通知 - 使用已转换的数据
if self.websocket_callback:
try:
if action == 'complete':
# 从任务信息中获取版本号
version = None
with self.lock:
task = self.tasks.get(task_id)
if task and task.path_info:
version = task.path_info.get('version')
# 训练完成 - 发送完成状态
self.websocket_callback('training_update', {
'task_id': task_id,
'action': 'completed',
'status': 'completed',
'progress': 100,
'message': serializable_task_data.get('message', '训练完成'),
'metrics': serializable_task_data.get('metrics'),
'end_time': serializable_task_data.get('end_time'),
'product_id': serializable_task_data.get('product_id'),
'model_type': serializable_task_data.get('model_type'),
'version': version, # 添加版本号
'product_scope': serializable_task_data.get('product_scope'),
'product_ids': serializable_task_data.get('product_ids')
})
# 额外发送一个完成事件,确保前端能收到
self.websocket_callback('training_completed', {
'task_id': task_id,
'status': 'completed',
'progress': 100,
'message': serializable_task_data.get('message', '训练完成'),
'metrics': serializable_task_data.get('metrics'),
'product_id': serializable_task_data.get('product_id'),
'model_type': serializable_task_data.get('model_type'),
'version': version, # 添加版本号
'product_scope': serializable_task_data.get('product_scope'),
'product_ids': serializable_task_data.get('product_ids')
})
elif action == 'error':
# 训练失败
self.websocket_callback('training_update', {
'task_id': task_id,
'action': 'failed',
'status': 'failed',
'progress': 0,
'message': serializable_task_data.get('message', '训练失败'),
'error': serializable_task_data.get('error'),
'product_id': serializable_task_data.get('product_id'),
'model_type': serializable_task_data.get('model_type'),
'product_scope': serializable_task_data.get('product_scope'),
'product_ids': serializable_task_data.get('product_ids')
})
else:
# 状态更新
self.websocket_callback('training_update', {
'task_id': task_id,
'action': action,
'status': serializable_task_data.get('status'),
'progress': serializable_task_data.get('progress', 0),
'message': serializable_task_data.get('message', ''),
'metrics': serializable_task_data.get('metrics'),
'product_id': serializable_task_data.get('product_id'),
'model_type': serializable_task_data.get('model_type'),
'product_scope': serializable_task_data.get('product_scope'),
'product_ids': serializable_task_data.get('product_ids')
})
except Exception as e:
self.logger.error(f"WebSocket通知失败: {e}")
except queue.Empty:
continue
except Exception as e:
self.logger.error(f"结果监听错误: {e}")
def _monitor_progress(self):
"""监听训练进度"""
while self.running:
try:
progress_data = self.progress_queue.get(timeout=1)
task_id = progress_data['task_id']
# 处理日志消息,显示到主控制台
if 'log_type' in progress_data:
log_type = progress_data['log_type']
message = progress_data['message']
task_short_id = task_id[:8]
if log_type == 'info':
print(f"[{task_short_id}] {message}", flush=True)
self.logger.info(f"[{task_short_id}] {message}")
elif log_type == 'success':
print(f"[{task_short_id}] {message}", flush=True)
self.logger.success(f"[{task_short_id}] {message}")
# 如果是训练完成的成功消息发送WebSocket通知
if "训练完成" in message:
if self.websocket_callback:
try:
self.websocket_callback('training_progress', {
'task_id': task_id,
'progress': 100,
'message': message,
'log_type': 'success',
'timestamp': time.time()
})
except Exception as e:
self.logger.error(f"成功消息WebSocket通知失败: {e}")
elif log_type == 'error':
print(f"[{task_short_id}] {message}", flush=True)
self.logger.error(f"[{task_short_id}] {message}")
elif log_type == 'warning':
print(f"[{task_short_id}] {message}", flush=True)
self.logger.warning(f"[{task_short_id}] {message}")
# 更新任务进度只处理包含progress的消息
if 'progress' in progress_data:
with self.lock:
if task_id in self.tasks:
self.tasks[task_id].progress = progress_data['progress']
self.tasks[task_id].message = progress_data.get('message', '')
# WebSocket通知进度更新
if self.websocket_callback and 'progress' in progress_data:
try:
# 在发送前确保所有数据类型都是JSON可序列化的
serializable_data = convert_numpy_types(progress_data)
self.websocket_callback('training_progress', serializable_data)
except Exception as e:
self.logger.error(f"进度WebSocket通知失败: {e}")
except queue.Empty:
continue
except Exception as e:
self.logger.error(f"进度监听错误: {e}")
def set_websocket_callback(self, callback: Callable):
"""设置WebSocket回调函数"""
self.websocket_callback = callback
# 全局进程管理器实例
training_manager = TrainingProcessManager()
def get_training_manager() -> TrainingProcessManager:
"""获取训练进程管理器实例"""
return training_manager