ShopTRAINING/server/utils/training_process_manager.py
xz2000 5b2cdfa74a ---
**日期**: 2025-07-18
**主题**: 模型保存逻辑重构与集中化管理

### 目标
根据 `xz训练模型保存规则.md`,将系统中分散的模型文件保存逻辑统一重构,创建一个集中、健壮且可测试的路径管理系统。

### 核心成果
1.  **创建了 `server/utils/file_save.py` 模块**: 这个新模块现在是系统中处理模型文件保存路径的唯一权威来源。
2.  **实现了三种训练模式的路径生成**: 系统现在可以为“按店铺”、“按药品”和“全局”三种训练模式正确生成层级化的、可追溯的目录结构。
3.  **集成了智能ID处理**:
    *   对于包含**多个ID**的训练场景,系统会自动计算一个简短的哈希值作为目录名。
    *   对于全局训练中只包含**单个店铺或药品ID**的场景,系统会直接使用该ID作为目录名,增强了路径的可读性。
4.  **重构了整个训练流程**: 修改了API层、进程管理层以及所有模型训练器,使它们能够协同使用新的路径管理模块。
5.  **添加了自动化测试**: 创建了 `test/test_file_save_logic.py` 脚本,用于验证所有路径生成和版本管理逻辑的正确性。

### 详细文件修改记录

1.  **`server/utils/file_save.py`**
    *   **操作**: 创建
    *   **内容**: 实现了 `ModelPathManager` 类,包含以下核心方法:
        *   `_hash_ids`: 对ID列表进行排序和哈希。
        *   `_generate_identifier`: 根据训练模式和参数生成唯一的模型标识符。
        *   `get_next_version` / `save_version_info`: 线程安全地管理 `versions.json` 文件,实现版本号的获取和更新。
        *   `get_model_paths`: 作为主入口,协调以上方法,生成包含所有产物路径的字典。

2.  **`server/api.py`**
    *   **操作**: 修改
    *   **位置**: `start_training` 函数 (`/api/training` 端点)。
    *   **内容**:
        *   导入并实例化 `ModelPathManager`。
        *   在接收到训练请求后,调用 `path_manager.get_model_paths()` 来获取所有路径信息。
        *   将获取到的 `path_info` 字典和原始请求参数 `training_params` 一并传递给后台训练任务管理器。
        *   修复了因重复传递关键字参数 (`model_type`, `training_mode`) 导致的 `TypeError`。
        *   修复了 `except` 块中因未导入 `traceback` 模块导致的 `UnboundLocalError`。

3.  **`server/utils/training_process_manager.py`**
    *   **操作**: 修改
    *   **内容**:
        *   修改 `submit_task` 方法,使其能接收 `training_params` 和 `path_info` 字典。
        *   在 `TrainingTask` 数据类中增加了 `path_info` 字段来存储路径信息。
        *   在 `TrainingWorker` 中,将 `path_info` 传递给实际的训练函数。
        *   在 `_monitor_results` 方法中,当任务成功完成时,调用 `path_manager.save_version_info` 来更新 `versions.json`,完成版本管理的闭环。

4.  **所有训练器文件** (`mlstm_trainer.py`, `kan_trainer.py`, `tcn_trainer.py`, `transformer_trainer.py`)
    *   **操作**: 修改
    *   **内容**:
        *   统一修改了主训练函数的签名,增加了 `path_info=None` 参数。
        *   移除了所有内部手动构建文件路径的逻辑。
        *   所有保存操作(最终模型、检查点、损失曲线图)现在都直接从传入的 `path_info` 字典中获取预先生成好的路径。
        *   简化了 `save_checkpoint` 辅助函数,使其也依赖 `path_info`。

5.  **`test/test_file_save_logic.py`**
    *   **操作**: 创建
    *   **内容**:
        *   编写了一个独立的测试脚本,用于验证 `ModelPathManager` 的所有功能。
        *   覆盖了所有训练模式及其子场景(包括单ID和多ID哈希)。
        *   测试了版本号的正确递增和 `versions.json` 的写入。
        *   修复了测试脚本中因绝对/相对路径不匹配和重复关键字参数导致的多个 `AssertionError` 和 `TypeError`。

---
**日期**: 2025-07-18 (后续修复)
**主题**: 修复API层调用路径管理器时的 `TypeError`

### 问题描述
在完成所有重构和测试后,实际运行API时,`POST /api/training` 端点在调用 `path_manager.get_model_paths` 时崩溃,并抛出 `TypeError: get_model_paths() got multiple values for keyword argument 'training_mode'`。

### 根本原因
这是一个回归错误。在修复测试脚本 `test_file_save_logic.py` 中的类似问题时,我未能将相同的修复逻辑应用回 `server/api.py`。代码在调用 `get_model_paths` 时,既通过关键字参数 `training_mode=...` 明确传递了该参数,又通过 `**data` 将其再次传入,导致了冲突。

### 解决方案
1.  **文件**: `server/api.py`
2.  **位置**: `start_training` 函数。
3.  **操作**: 修改了对 `get_model_paths` 的调用逻辑。
4.  **内容**:
    ```python
    # 移除 model_type 和 training_mode 以避免重复关键字参数错误
    data_for_path = data.copy()
    data_for_path.pop('model_type', None)
    data_for_path.pop('training_mode', None)
    path_info = path_manager.get_model_paths(
        training_mode=training_mode,
        model_type=model_type,
        **data_for_path  # 传递剩余的payload
    )
    ```
5.  **原因**: 在通过 `**` 解包传递参数之前,先从字典副本中移除了所有会被明确指定的关键字参数,从而确保了函数调用签名的正确性。

---
**日期**: 2025-07-18 (最终修复)
**主题**: 修复因中间层函数签名未更新导致的 `TypeError`

### 问题描述
在完成所有重构后,实际运行API并触发训练任务时,程序在后台进程中因 `TypeError: train_model() got an unexpected keyword argument 'path_info'` 而崩溃。

### 根本原因
这是一个典型的“中间人”遗漏错误。我成功地修改了调用链的两端(`api.py` -> `training_process_manager.py` 和 `*_trainer.py`),但忘记了修改它们之间的中间层——`server/core/predictor.py` 中的 `train_model` 方法。`training_process_manager` 尝试将 `path_info` 传递给 `predictor.train_model`,但后者的函数签名中并未包含这个新参数,导致了 `TypeError`。

### 解决方案
1.  **文件**: `server/core/predictor.py`
2.  **位置**: `train_model` 函数的定义处。
3.  **操作**: 在函数签名中增加了 `path_info=None` 参数。
4.  **内容**:
    ```python
    def train_model(self, ..., progress_callback=None, path_info=None):
        # ...
    ```
5.  **位置**: `train_model` 函数内部,对所有具体训练器(`train_product_model_with_mlstm`, `_with_kan`, etc.)的调用处。
6.  **操作**: 在所有调用中,将接收到的 `path_info` 参数透传下去。
7.  **内容**:
    ```python
    # ...
    metrics = train_product_model_with_transformer(
        ...,
        path_info=path_info
    )
    # ...
    ```
8.  **原因**: 通过在中间层函数上“打通”`path_info` 参数的传递通道,确保了从API层到最终训练器层的完整数据流,解决了 `TypeError`。

---
**日期**: 2025-07-18 (最终修复)
**主题**: 修复“按药品训练-聚合所有店铺”模式下的路径生成错误

### 问题描述
在实际运行中发现,当进行“按药品训练”并选择“聚合所有店铺”时,生成的模型保存路径中包含了错误的后缀 `_None`,而不是预期的 `_all` (例如 `.../17002608_None/...`)。

### 根本原因
在 `server/utils/file_save.py` 的 `_generate_identifier` 和 `get_model_paths` 方法中,当 `store_id` 从前端传来为 `None` 时,代码 `scope = store_id if store_id else 'all'` 会因为 `store_id` 是 `None` 而正确地将 `scope` 设为 `'all'`。然而,在 `get_model_paths` 方法中,我错误地使用了 `kwargs.get('store_id', 'all')`,这在 `store_id` 键存在但值为 `None` 时,仍然会返回 `None`,导致了路径拼接错误。

### 解决方案
1.  **文件**: `server/utils/file_save.py`
2.  **位置**: `_generate_identifier` 和 `get_model_paths` 方法中处理 `product` 训练模式的部分。
3.  **操作**: 将逻辑从 `scope = kwargs.get('store_id', 'all')` 修改为更严谨的 `scope = store_id if store_id is not None else 'all'`。
4.  **内容**:
    ```python
    # in _generate_identifier
    scope = store_id if store_id is not None else 'all'

    # in get_model_paths
    store_id = kwargs.get('store_id')
    scope = store_id if store_id is not None else 'all'
    scope_folder = f"{product_id}_{scope}"
    ```
5.  **原因**: 这种写法能正确处理 `store_id` 键不存在、或键存在但值为 `None` 的两种情况,确保在这两种情况下 `scope` 都被正确地设置为 `'all'`,从而生成符合规范的路径。

---
**日期**: 2025-07-18 (最终修复)
**主题**: 修复 `KeyError: 'price'` 和单ID哈希错误

### 问题描述
在完成大规模重构后,实际运行时发现了两个隐藏的bug:
1.  在“按店铺训练”模式下,训练因 `KeyError: 'price'` 而失败。
2.  在“按店铺训练”模式下,当只选择一个“指定药品”时,系统仍然错误地对该药品的ID进行了哈希处理,而不是直接使用ID。

### 根本原因
1.  **`KeyError`**: `server/utils/multi_store_data_utils.py` 中的 `get_store_product_sales_data` 函数包含了一个硬编码的列校验,该校验要求 `price` 列必须存在,但这与当前的数据源不符。
2.  **哈希错误**: `server/utils/file_save.py` 中的 `get_model_paths` 方法在处理 `store` 训练模式时,没有复用 `_generate_identifier` 中已经写好的单ID判断逻辑,导致了逻辑不一致。

### 解决方案
1.  **修复 `KeyError`**:
    *   **文件**: `server/utils/multi_store_data_utils.py`
    *   **位置**: `get_store_product_sales_data` 函数。
    *   **操作**: 从 `required_columns` 列表中移除了 `'price'`,根除了这个硬性依赖。
2.  **修复哈希逻辑**:
    *   **文件**: `server/utils/file_save.py`
    *   **位置**: `_generate_identifier` 和 `get_model_paths` 方法中处理 `store` 训练模式的部分。
    *   **操作**: 统一了逻辑,确保在这两个地方都使用了 `scope = product_ids[0] if len(product_ids) == 1 else self._hash_ids(product_ids)` 的判断,从而在只选择一个药品时直接使用其ID。
3.  **更新测试**:
    *   **文件**: `test/test_file_save_logic.py`
    *   **操作**: 增加了新的测试用例,专门验证“按店铺训练-单个指定药品”场景下的路径生成是否正确。

---
**日期**: 2025-07-18 (最终修复)
**主题**: 修复全局训练范围值不匹配导致的 `ValueError`

### 问题描述
在完成所有重构后,实际运行API并触发“全局训练-所有店铺所有药品”时,程序因 `ValueError: 未知的全局训练范围: all_stores_all_products` 而崩溃。

### 根本原因
前端传递的 `training_scope` 值为 `all_stores_all_products`,而 `server/utils/file_save.py` 中的 `_generate_identifier` 和 `get_model_paths` 方法只处理了 `all` 这个值,未能兼容前端传递的具体字符串,导致逻辑判断失败。

### 解决方案
1.  **文件**: `server/utils/file_save.py`
2.  **位置**: `_generate_identifier` 和 `get_model_paths` 方法中处理 `global` 训练模式的部分。
3.  **操作**: 将逻辑判断从 `if training_scope == 'all':` 修改为 `if training_scope in ['all', 'all_stores_all_products']:`。
4.  **原因**: 使代码能够同时兼容两种表示“所有范围”的字符串,确保了前端请求的正确处理。
5.  **更新测试**:
    *   **文件**: `test/test_file_save_logic.py`
    *   **操作**: 增加了新的测试用例,专门验证 `training_scope` 为 `all_stores_all_products` 时的路径生成是否正确。

---
**日期**: 2025-07-18 (最终优化)
**主题**: 优化全局训练自定义模式下的单ID路径生成

### 问题描述
根据用户反馈,希望在全局训练的“自定义范围”模式下,如果只选择单个店铺和/或单个药品,路径中应直接使用ID而不是哈希值,以增强可读性。

### 解决方案
1.  **文件**: `server/utils/file_save.py`
2.  **位置**: `_generate_identifier` 和 `get_model_paths` 方法中处理 `global` 训练模式 `custom` 范围的部分。
3.  **操作**: 为 `store_ids` 和 `product_ids` 分别增加了单ID判断逻辑。
4.  **内容**:
    ```python
    # in _generate_identifier
    s_id = store_ids[0] if len(store_ids) == 1 else self._hash_ids(store_ids)
    p_id = product_ids[0] if len(product_ids) == 1 else self._hash_ids(product_ids)
    scope_part = f"custom_s_{s_id}_p_{p_id}"

    # in get_model_paths
    store_ids = kwargs.get('store_ids', [])
    product_ids = kwargs.get('product_ids', [])
    s_id = store_ids[0] if len(store_ids) == 1 else self._hash_ids(store_ids)
    p_id = product_ids[0] if len(product_ids) == 1 else self._hash_ids(product_ids)
    scope_parts.extend(['custom', s_id, p_id])
    ```
5.  **原因**: 使 `custom` 模式下的路径生成逻辑与 `selected_stores` 和 `selected_products` 模式保持一致,在只选择一个ID时优先使用ID本身,提高了路径的可读性和一致性。
6.  **更新测试**:
    *   **文件**: `test/test_file_save_logic.py`
    *   **操作**: 增加了新的测试用例,专门验证“全局训练-自定义范围-单ID”场景下的路径生成是否正确。
2025-07-18 16:45:29 +08:00

500 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
独立训练进程管理器
使用multiprocessing实现真正的并行训练避免GIL限制
"""
import os
import sys
import uuid
import time
import json
import queue
import multiprocessing as mp
from multiprocessing import Process, Queue, Manager
from dataclasses import dataclass, asdict
from typing import Dict, Any, Optional, Callable
from threading import Thread, Lock
from pathlib import Path
# 添加当前目录到路径
current_dir = os.path.dirname(os.path.abspath(__file__))
server_dir = os.path.dirname(current_dir)
sys.path.append(server_dir)
from utils.logging_config import setup_api_logging, get_training_logger, log_training_progress
from utils.file_save import ModelPathManager
import numpy as np
def convert_numpy_types(obj):
"""递归地将字典/列表中的NumPy类型转换为Python原生类型"""
if isinstance(obj, dict):
return {k: convert_numpy_types(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [convert_numpy_types(i) for i in obj]
elif isinstance(obj, np.generic):
return obj.item()
return obj
@dataclass
class TrainingTask:
"""训练任务数据结构"""
task_id: str
product_id: str
model_type: str
training_mode: str
store_id: Optional[str] = None
epochs: int = 100
status: str = "pending" # pending, running, completed, failed
start_time: Optional[str] = None
end_time: Optional[str] = None
progress: float = 0.0
message: str = ""
error: Optional[str] = None
metrics: Optional[Dict[str, Any]] = None
process_id: Optional[int] = None
path_info: Optional[Dict[str, Any]] = None # 新增字段
class TrainingWorker:
"""训练工作进程"""
def __init__(self, task_queue: Queue, result_queue: Queue, progress_queue: Queue):
self.task_queue = task_queue
self.result_queue = result_queue
self.progress_queue = progress_queue
def run_training_task(self, task: TrainingTask):
"""执行训练任务"""
try:
# 设置进程级别的日志
logger = setup_api_logging(log_level="INFO")
training_logger = get_training_logger(task.task_id, task.model_type, task.product_id)
# 发送日志到主控制台
self.progress_queue.put({
'task_id': task.task_id,
'log_type': 'info',
'message': f"🚀 训练进程启动 - PID: {os.getpid()}"
})
self.progress_queue.put({
'task_id': task.task_id,
'log_type': 'info',
'message': f"📋 任务参数: {task.model_type} | {task.product_id} | {task.epochs}轮次"
})
training_logger.info(f"🚀 训练进程启动 - PID: {os.getpid()}")
training_logger.info(f"📋 任务参数: {task.model_type} | {task.product_id} | {task.epochs}轮次")
# 更新任务状态
task.status = "running"
task.start_time = time.strftime('%Y-%m-%d %H:%M:%S')
task.process_id = os.getpid()
self.result_queue.put(('update', asdict(task)))
# 模拟训练进度更新
for epoch in range(1, task.epochs + 1):
progress = (epoch / task.epochs) * 100
# 发送进度更新
self.progress_queue.put({
'task_id': task.task_id,
'progress': progress,
'epoch': epoch,
'total_epochs': task.epochs,
'message': f"Epoch {epoch}/{task.epochs}"
})
training_logger.info(f"🔄 训练进度: Epoch {epoch}/{task.epochs} ({progress:.1f}%)")
# 模拟训练时间
time.sleep(1) # 实际训练中这里会是真正的训练代码
# 导入真正的训练函数
try:
# 添加服务器目录到路径确保能找到core模块
server_dir = os.path.dirname(os.path.dirname(__file__))
if server_dir not in sys.path:
sys.path.append(server_dir)
from core.predictor import PharmacyPredictor
predictor = PharmacyPredictor()
training_logger.info("🤖 开始调用实际训练器")
# 发送训练开始日志到主控制台
self.progress_queue.put({
'task_id': task.task_id,
'log_type': 'info',
'message': f"🤖 开始执行 {task.model_type} 模型训练..."
})
# 创建子进程内的进度回调函数
def progress_callback(progress_data):
"""子进程内的进度回调,通过队列发送到主进程"""
try:
# 添加任务ID到进度数据
progress_data['task_id'] = task.task_id
self.progress_queue.put(progress_data)
except Exception as e:
training_logger.error(f"进度回调失败: {e}")
# 执行真正的训练,传递进度回调和路径信息
metrics = predictor.train_model(
product_id=task.product_id,
model_type=task.model_type,
epochs=task.epochs,
store_id=task.store_id,
training_mode=task.training_mode,
socketio=None, # 子进程中不能直接使用socketio
task_id=task.task_id,
progress_callback=progress_callback, # 传递进度回调函数
path_info=task.path_info # 传递路径信息
)
# 发送训练完成日志到主控制台
self.progress_queue.put({
'task_id': task.task_id,
'log_type': 'success',
'message': f"{task.model_type} 模型训练完成!"
})
if metrics:
self.progress_queue.put({
'task_id': task.task_id,
'log_type': 'info',
'message': f"📊 训练指标: MSE={metrics.get('mse', 'N/A'):.4f}, RMSE={metrics.get('rmse', 'N/A'):.4f}"
})
except ImportError as e:
training_logger.error(f"❌ 导入训练器失败: {e}")
# 返回模拟的训练结果用于测试
metrics = {
"mse": 0.001,
"rmse": 0.032,
"mae": 0.025,
"r2": 0.95,
"mape": 2.5,
"training_time": task.epochs * 2,
"note": "模拟训练结果(导入失败时的备用方案)"
}
training_logger.warning("⚠️ 使用模拟训练结果")
# 训练完成
task.status = "completed"
task.end_time = time.strftime('%Y-%m-%d %H:%M:%S')
task.progress = 100.0
task.metrics = metrics
task.message = "训练完成"
training_logger.success(f"✅ 训练任务完成 - 耗时: {task.end_time}")
if metrics:
training_logger.info(f"📊 训练指标: {metrics}")
self.result_queue.put(('complete', asdict(task)))
except Exception as e:
error_msg = str(e)
task.status = "failed"
task.end_time = time.strftime('%Y-%m-%d %H:%M:%S')
task.error = error_msg
task.message = f"训练失败: {error_msg}"
training_logger.error(f"❌ 训练任务失败: {error_msg}")
self.result_queue.put(('error', asdict(task)))
def start(self):
"""启动工作进程"""
while True:
try:
# 从队列获取任务超时5秒
task_data = self.task_queue.get(timeout=5)
if task_data is None: # 毒丸,退出信号
break
task = TrainingTask(**task_data)
self.run_training_task(task)
except queue.Empty:
continue
except Exception as e:
print(f"工作进程错误: {e}")
continue
class TrainingProcessManager:
"""训练进程管理器"""
def __init__(self, max_workers: int = 2):
self.max_workers = max_workers
self.tasks: Dict[str, TrainingTask] = {}
self.processes: Dict[str, Process] = {}
self.task_queue = Queue()
self.result_queue = Queue()
self.progress_queue = Queue()
self.running = False
self.lock = Lock()
# WebSocket回调
self.websocket_callback: Optional[Callable] = None
# 设置日志
self.logger = setup_api_logging()
self.path_manager = ModelPathManager() # 实例化
def start(self):
"""启动进程管理器"""
if self.running:
return
self.running = True
self.logger.info(f"🚀 训练进程管理器启动 - 最大工作进程数: {self.max_workers}")
# 启动工作进程
for i in range(self.max_workers):
worker = TrainingWorker(self.task_queue, self.result_queue, self.progress_queue)
process = Process(target=worker.start, name=f"TrainingWorker-{i}")
process.start()
self.processes[f"worker-{i}"] = process
self.logger.info(f"🔧 工作进程 {i} 启动 - PID: {process.pid}")
# 启动结果监听线程
self.result_thread = Thread(target=self._monitor_results, daemon=True)
self.result_thread.start()
# 启动进度监听线程
self.progress_thread = Thread(target=self._monitor_progress, daemon=True)
self.progress_thread.start()
def stop(self):
"""停止进程管理器"""
if not self.running:
return
self.logger.info("🛑 正在停止训练进程管理器...")
self.running = False
# 发送停止信号给所有工作进程
for _ in range(self.max_workers):
self.task_queue.put(None)
# 等待所有进程结束
for name, process in self.processes.items():
process.join(timeout=10)
if process.is_alive():
self.logger.warning(f"⚠️ 强制终止进程: {name}")
process.terminate()
self.logger.info("✅ 训练进程管理器已停止")
def submit_task(self, training_params: Dict[str, Any], path_info: Dict[str, Any]) -> str:
"""
提交训练任务
Args:
training_params (Dict[str, Any]): 来自API请求的原始参数
path_info (Dict[str, Any]): 由ModelPathManager生成的路径和版本信息
"""
task_id = str(uuid.uuid4())
task = TrainingTask(
task_id=task_id,
product_id=training_params.get('product_id'),
model_type=training_params.get('model_type'),
training_mode=training_params.get('training_mode', 'product'),
store_id=training_params.get('store_id'),
epochs=training_params.get('epochs', 100),
path_info=path_info # 存储路径信息
)
with self.lock:
self.tasks[task_id] = task
# 将任务放入队列
self.task_queue.put(asdict(task))
self.logger.info(f"📋 训练任务已提交: {task_id[:8]} | {task.model_type} | {task.product_id}")
return task_id
def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
"""获取任务状态"""
with self.lock:
task = self.tasks.get(task_id)
if task:
return asdict(task)
return None
def get_all_tasks(self) -> Dict[str, Dict[str, Any]]:
"""获取所有任务状态"""
with self.lock:
return {task_id: asdict(task) for task_id, task in self.tasks.items()}
def cancel_task(self, task_id: str) -> bool:
"""取消任务(仅对未开始的任务有效)"""
with self.lock:
task = self.tasks.get(task_id)
if task and task.status == "pending":
task.status = "cancelled"
task.message = "任务已取消"
return True
return False
def _monitor_results(self):
"""监听训练结果"""
while self.running:
try:
result = self.result_queue.get(timeout=1)
action, task_data = result
task_id = task_data['task_id']
# 立即对从队列中取出的数据进行类型转换
serializable_task_data = convert_numpy_types(task_data)
with self.lock:
if task_id in self.tasks:
# 使用转换后的数据更新任务状态
for key, value in serializable_task_data.items():
setattr(self.tasks[task_id], key, value)
# 如果任务成功完成,则更新版本文件
if action == 'complete':
task = self.tasks[task_id]
if task.path_info:
identifier = task.path_info.get('identifier')
version = task.path_info.get('version')
if identifier and version:
try:
self.path_manager.save_version_info(identifier, version)
self.logger.info(f"✅ 版本信息已更新: identifier={identifier}, version={version}")
except Exception as e:
self.logger.error(f"❌ 更新版本文件失败: {e}")
# WebSocket通知 - 使用已转换的数据
if self.websocket_callback:
try:
if action == 'complete':
# 训练完成 - 发送完成状态
self.websocket_callback('training_update', {
'task_id': task_id,
'action': 'completed',
'status': 'completed',
'progress': 100,
'message': serializable_task_data.get('message', '训练完成'),
'metrics': serializable_task_data.get('metrics'),
'end_time': serializable_task_data.get('end_time'),
'product_id': serializable_task_data.get('product_id'),
'model_type': serializable_task_data.get('model_type')
})
# 额外发送一个完成事件,确保前端能收到
self.websocket_callback('training_completed', {
'task_id': task_id,
'status': 'completed',
'progress': 100,
'message': serializable_task_data.get('message', '训练完成'),
'metrics': serializable_task_data.get('metrics'),
'product_id': serializable_task_data.get('product_id'),
'model_type': serializable_task_data.get('model_type')
})
elif action == 'error':
# 训练失败
self.websocket_callback('training_update', {
'task_id': task_id,
'action': 'failed',
'status': 'failed',
'progress': 0,
'message': serializable_task_data.get('message', '训练失败'),
'error': serializable_task_data.get('error'),
'product_id': serializable_task_data.get('product_id'),
'model_type': serializable_task_data.get('model_type')
})
else:
# 状态更新
self.websocket_callback('training_update', {
'task_id': task_id,
'action': action,
'status': serializable_task_data.get('status'),
'progress': serializable_task_data.get('progress', 0),
'message': serializable_task_data.get('message', ''),
'metrics': serializable_task_data.get('metrics'),
'product_id': serializable_task_data.get('product_id'),
'model_type': serializable_task_data.get('model_type')
})
except Exception as e:
self.logger.error(f"WebSocket通知失败: {e}")
except queue.Empty:
continue
except Exception as e:
self.logger.error(f"结果监听错误: {e}")
def _monitor_progress(self):
"""监听训练进度"""
while self.running:
try:
progress_data = self.progress_queue.get(timeout=1)
task_id = progress_data['task_id']
# 处理日志消息,显示到主控制台
if 'log_type' in progress_data:
log_type = progress_data['log_type']
message = progress_data['message']
task_short_id = task_id[:8]
if log_type == 'info':
print(f"[{task_short_id}] {message}", flush=True)
self.logger.info(f"[{task_short_id}] {message}")
elif log_type == 'success':
print(f"[{task_short_id}] {message}", flush=True)
self.logger.success(f"[{task_short_id}] {message}")
# 如果是训练完成的成功消息发送WebSocket通知
if "训练完成" in message:
if self.websocket_callback:
try:
self.websocket_callback('training_progress', {
'task_id': task_id,
'progress': 100,
'message': message,
'log_type': 'success',
'timestamp': time.time()
})
except Exception as e:
self.logger.error(f"成功消息WebSocket通知失败: {e}")
elif log_type == 'error':
print(f"[{task_short_id}] {message}", flush=True)
self.logger.error(f"[{task_short_id}] {message}")
elif log_type == 'warning':
print(f"[{task_short_id}] {message}", flush=True)
self.logger.warning(f"[{task_short_id}] {message}")
# 更新任务进度只处理包含progress的消息
if 'progress' in progress_data:
with self.lock:
if task_id in self.tasks:
self.tasks[task_id].progress = progress_data['progress']
self.tasks[task_id].message = progress_data.get('message', '')
# WebSocket通知进度更新
if self.websocket_callback and 'progress' in progress_data:
try:
# 在发送前确保所有数据类型都是JSON可序列化的
serializable_data = convert_numpy_types(progress_data)
self.websocket_callback('training_progress', serializable_data)
except Exception as e:
self.logger.error(f"进度WebSocket通知失败: {e}")
except queue.Empty:
continue
except Exception as e:
self.logger.error(f"进度监听错误: {e}")
def set_websocket_callback(self, callback: Callable):
"""设置WebSocket回调函数"""
self.websocket_callback = callback
# 全局进程管理器实例
training_manager = TrainingProcessManager()
def get_training_manager() -> TrainingProcessManager:
"""获取训练进程管理器实例"""
return training_manager