ShopTRAINING/server/utils/model_manager.py
xz2000 18f505a090 # 修改记录日志 (日期: 2025-07-16) ---未改全
## 1. 训练流程与模型保存逻辑修复 (重大)

- **背景**: 用户报告在“按店铺”和“按药品”模式下,如果选择了特定的子集(如为某个店铺选择特定药品),生成的模型范围 (`scope`) 不正确,始终为 `_all`。此外,所有模型都被错误地保存到 `global` 目录下,且在某些模式下训练会失败。
- **根本原因**:
    1.  `server/core/predictor.py` 中负责准备训练参数的内部函数 (`_prepare_product_params`, `_prepare_store_params`) 逻辑有误,未能正确处理传入的 `product_ids` 和 `store_ids` 列表来构建详细的 `scope`。
    2.  各个训练器 (`server/trainers/*.py`) 内部的日志记录和元数据生成逻辑不统一,且过于依赖 `product_id`,导致在全局或店铺模式下信息展示不清晰。

- **修复方案**:
    - **`server/core/predictor.py`**:
        - **重构 `_prepare_product_params` 和 `_prepare_store_params`**: 修改了这两个函数,使其能够正确使用 `product_ids` 和 `store_ids` 列表。现在,当选择特定范围时,会生成更具描述性的 `scope`,例如 `S001_specific_P001_P002`。
        - **结果**: 确保了传递给模型管理器的 `scope` 是准确且详细的,从而使模型能够根据训练范围被保存到正确的、独立的文件夹中。

    - **`server/trainers/*.py` (mlstm, kan, tcn, transformer)**:
        - **标准化日志与元数据**: 对所有四个训练器文件进行了统一修改。引入了一个通用的 `training_description` 变量,该变量整合了 `training_mode`、`scope` 和 `aggregation_method`。
        - **更新输出**: 修改了所有训练器中的日志消息、图表标题和 `metadata.json` 的生成逻辑,使其全部使用这个标准的 `training_description`。
        - **结果**: 确保了无论在哪种训练模式下,前端收到的日志、保存的图表和元数据都具有一致、清晰的格式,便于调试和结果追溯。

- **总体影响**: 此次修复从根本上解决了模型训练范围处理和模型保存路径的错误问题,使整个训练系统在所有模式下都能可靠、一致地运行。

---

## 2. 核心 Bug 修复

### 文件: `server/core/predictor.py`

- **问题**: 在 `train_model` 方法中调用内部辅助函数 `_prepare_training_params` 时,没有正确传递 `product_ids` 和 `store_ids` 参数,导致在 `_prepare_training_params` 内部发生 `NameError`。
- **修复**:
    - 修正了 `train_model` 方法内部对 `_prepare_training_params` 的调用,确保 `product_ids` 和 `store_ids` 被显式传递。
    - 此前已修复 `train_model` 的函数签名,使其能正确接收 `store_ids`。
- **结果**: 彻底解决了训练流程中的参数传递问题,根除了由此引发的 `NameError`。

## 3. 代码清理与重构

### 文件: `server/api.py`

- **内容**: 移除了在 `start_training` API 端点中遗留的旧版、基于线程(`threading.Thread`)的训练逻辑。
- **原因**: 该代码块已被新的、基于多进程(`multiprocessing`)的 `TrainingProcessManager` 完全取代。旧代码中包含了大量用于调试的 `thread_safe_print` 日志,已无用处。
- **结果**: `start_training` 端点的逻辑变得更加清晰,只负责参数校验和向 `TrainingProcessManager` 提交任务。

### 文件: `server/utils/training_process_manager.py`

- **内容**: 在 `TrainingWorker` 的 `run_training_task` 方法中,移除了一个用于模拟训练进度的 `for` 循环。
- **原因**: 该循环包含 `time.sleep(1)`,仅用于在没有实际训练逻辑时模拟进度更新,现在实际的训练器会通过回调函数报告真实进度,因此该模拟代码不再需要。
- **结果**: `TrainingWorker` 现在直接调用实际的训练器,不再有模拟延迟,代码更贴近生产环境。
2025-07-16 16:51:38 +08:00

231 lines
9.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
统一模型管理工具
处理模型文件的统一命名、存储和检索
遵循层级式目录结构和文件版本管理规则
"""
import os
import json
import torch
import glob
from datetime import datetime
from typing import List, Dict, Optional, Any
from threading import Lock
from core.config import DEFAULT_MODEL_DIR
class ModelManager:
"""
统一模型管理器,采用结构化目录和版本文件进行管理。
"""
VERSION_FILE = 'versions.json'
def __init__(self, model_dir: str = DEFAULT_MODEL_DIR):
self.model_dir = os.path.abspath(model_dir)
self.versions_path = os.path.join(self.model_dir, self.VERSION_FILE)
self._lock = Lock()
self.ensure_model_dir()
def ensure_model_dir(self):
"""确保模型根目录存在"""
os.makedirs(self.model_dir, exist_ok=True)
def _read_versions(self) -> Dict[str, int]:
"""线程安全地读取版本文件"""
with self._lock:
if not os.path.exists(self.versions_path):
return {}
try:
with open(self.versions_path, 'r', encoding='utf-8') as f:
return json.load(f)
except (json.JSONDecodeError, IOError):
return {}
def _write_versions(self, versions: Dict[str, int]):
"""线程安全地写入版本文件"""
with self._lock:
with open(self.versions_path, 'w', encoding='utf-8') as f:
json.dump(versions, f, indent=2)
def get_model_identifier(self,
model_type: str,
training_mode: str,
scope: str,
aggregation_method: Optional[str] = None) -> str:
"""
生成模型的唯一标识符用于版本文件中的key。
"""
if training_mode == 'global':
return f"{training_mode}_{scope}_{aggregation_method}_{model_type}"
return f"{training_mode}_{scope}_{model_type}"
def get_next_version_number(self, model_identifier: str) -> int:
"""
获取指定模型的下一个版本号(整数)。
"""
versions = self._read_versions()
current_version = versions.get(model_identifier, 0)
return current_version + 1
def update_version(self, model_identifier: str, new_version: int):
"""
更新模型的最新版本号。
"""
versions = self._read_versions()
versions[model_identifier] = new_version
self._write_versions(versions)
def get_model_version_path(self,
model_type: str,
version: int,
training_mode: str,
scope: str,
aggregation_method: Optional[str] = None) -> str:
"""
根据 `xz训练模型保存规则.md` 中定义的新规则生成模型版本目录的完整路径。
"""
base_path = self.model_dir
path_parts = [base_path]
if training_mode == 'product':
# saved_models/product/{scope}/{model_type}/v{N}/
if not scope: raise ValueError("scope is required for 'product' training mode.")
path_parts.extend(['product', scope, model_type, f'v{version}'])
elif training_mode == 'store':
# saved_models/store/{scope}/{model_type}/v{N}/
if not scope: raise ValueError("scope is required for 'store' training mode.")
path_parts.extend(['store', scope, model_type, f'v{version}'])
elif training_mode == 'global':
# saved_models/global/{scope_path}/{aggregation_method}/{model_type}/v{N}/
if not scope: raise ValueError("scope is required for 'global' training mode.")
if not aggregation_method: raise ValueError("aggregation_method is required for 'global' training mode.")
scope_parts = scope.split('/')
path_parts.extend(['global', *scope_parts, str(aggregation_method), model_type, f'v{version}'])
else:
raise ValueError(f"不支持的 training_mode: {training_mode}")
return os.path.join(*path_parts)
def save_model_artifact(self,
artifact_data: Any,
artifact_name: str,
model_version_path: str):
"""
在指定的模型版本目录下保存一个产物。
Args:
artifact_data: 要保存的数据 (e.g., model state dict, figure object).
artifact_name: 标准化的产物文件名 (e.g., 'model.pth', 'loss_curve.png').
model_version_path: 模型版本目录的路径.
"""
os.makedirs(model_version_path, exist_ok=True)
full_path = os.path.join(model_version_path, artifact_name)
if artifact_name.endswith('.pth'):
torch.save(artifact_data, full_path)
elif artifact_name.endswith('.png') and hasattr(artifact_data, 'savefig'):
artifact_data.savefig(full_path, dpi=300, bbox_inches='tight')
elif artifact_name.endswith('.json'):
with open(full_path, 'w', encoding='utf-8') as f:
json.dump(artifact_data, f, indent=2, ensure_ascii=False)
else:
raise ValueError(f"不支持的产物类型: {artifact_name}")
print(f"产物已保存: {full_path}")
def list_models(self,
page: Optional[int] = None,
page_size: Optional[int] = None) -> Dict:
"""
通过扫描目录结构来列出所有模型 (适配新结构)。
"""
all_models = []
# 使用glob查找所有版本目录
search_pattern = os.path.join(self.model_dir, '**', 'v*')
for version_path in glob.glob(search_pattern, recursive=True):
# 确保它是一个目录并且包含 metadata.json
metadata_path = os.path.join(version_path, 'metadata.json')
if os.path.isdir(version_path) and os.path.exists(metadata_path):
model_info = self._parse_info_from_path(version_path)
if model_info:
all_models.append(model_info)
# 按时间戳降序排序
all_models.sort(key=lambda x: x.get('timestamp', ''), reverse=True)
total_count = len(all_models)
if page and page_size:
start_index = (page - 1) * page_size
end_index = start_index + page_size
paginated_models = all_models[start_index:end_index]
else:
paginated_models = all_models
return {
'models': paginated_models,
'pagination': {
'total': total_count,
'page': page or 1,
'page_size': page_size or total_count,
'total_pages': (total_count + page_size - 1) // page_size if page_size and page_size > 0 else 1,
}
}
def _parse_info_from_path(self, version_path: str) -> Optional[Dict]:
"""根据新的目录结构从版本目录路径解析模型信息"""
try:
norm_path = os.path.normpath(version_path)
norm_model_dir = os.path.normpath(self.model_dir)
relative_path = os.path.relpath(norm_path, norm_model_dir)
parts = relative_path.split(os.sep)
if len(parts) < 4:
return None
info = {
'model_path': version_path,
'version': parts[-1],
'model_type': parts[-2],
'training_mode': parts[0],
'store_id': None,
'product_id': None,
'aggregation_method': None,
'scope': None
}
mode = parts[0]
if mode == 'product':
# product/{scope}/mlstm/v1
info['scope'] = parts[1]
elif mode == 'store':
# store/{scope}/mlstm/v1
info['scope'] = parts[1]
elif mode == 'global':
# global/{scope...}/sum/mlstm/v1
info['aggregation_method'] = parts[-3]
info['scope'] = '/'.join(parts[1:-3])
else:
return None # 未知模式
metadata_path = os.path.join(version_path, 'metadata.json')
if os.path.exists(metadata_path):
with open(metadata_path, 'r', encoding='utf-8') as f:
metadata = json.load(f)
info.update(metadata)
# 确保从路径解析出的关键信息覆盖元数据中的,因为路径是权威来源
info['version'] = parts[-1]
info['model_type'] = parts[-2]
info['training_mode'] = parts[0]
return info
except (IndexError, IOError) as e:
print(f"解析路径失败 {version_path}: {e}")
return None
# 全局模型管理器实例
model_manager = ModelManager()