ShopTRAINING/server/utils/model_manager.py
xz2000 a9a0e51769 # 修改记录日志 (日期: 2025-07-16)
## 1. 核心 Bug 修复

### 文件: `server/core/predictor.py`

- **问题**: 在 `train_model` 方法中调用内部辅助函数 `_prepare_training_params` 时,没有正确传递 `product_ids` 和 `store_ids` 参数,导致在 `_prepare_training_params` 内部发生 `NameError`。
- **修复**:
    - 修正了 `train_model` 方法内部对 `_prepare_training_params` 的调用,确保 `product_ids` 和 `store_ids` 被显式传递。
    - 此前已修复 `train_model` 的函数签名,使其能正确接收 `store_ids`。
- **结果**: 彻底解决了训练流程中的参数传递问题,根除了由此引发的 `NameError`。

## 2. 代码清理与重构

### 文件: `server/api.py`

- **内容**: 移除了在 `start_training` API 端点中遗留的旧版、基于线程(`threading.Thread`)的训练逻辑。
- **原因**: 该代码块已被新的、基于多进程(`multiprocessing`)的 `TrainingProcessManager` 完全取代。旧代码中包含了大量用于调试的 `thread_safe_print` 日志,已无用处。
- **结果**: `start_training` 端点的逻辑变得更加清晰,只负责参数校验和向 `TrainingProcessManager` 提交任务。

### 文件: `server/utils/training_process_manager.py`

- **内容**: 在 `TrainingWorker` 的 `run_training_task` 方法中,移除了一个用于模拟训练进度的 `for` 循环。
- **原因**: 该循环包含 `time.sleep(1)`,仅用于在没有实际训练逻辑时模拟进度更新,现在实际的训练器会通过回调函数报告真实进度,因此该模拟代码不再需要。
- **结果**: `TrainingWorker` 现在直接调用实际的训练器,不再有模拟延迟,代码更贴近生产环境。

## 3. 启动依赖

- **Python**: 3.x
- **主要库**:
    - Flask
    - Flask-SocketIO
    - Flasgger
    - pandas
    - numpy
    - torch
    - scikit-learn
    - matplotlib
- **启动命令**: `python server/api.py`
2025-07-16 15:34:57 +08:00

248 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
统一模型管理工具
处理模型文件的统一命名、存储和检索
遵循层级式目录结构和文件版本管理规则
"""
import os
import json
import torch
import glob
from datetime import datetime
from typing import List, Dict, Optional, Any
from threading import Lock
from core.config import DEFAULT_MODEL_DIR
class ModelManager:
"""
统一模型管理器,采用结构化目录和版本文件进行管理。
"""
VERSION_FILE = 'versions.json'
def __init__(self, model_dir: str = DEFAULT_MODEL_DIR):
self.model_dir = os.path.abspath(model_dir)
self.versions_path = os.path.join(self.model_dir, self.VERSION_FILE)
self._lock = Lock()
self.ensure_model_dir()
def ensure_model_dir(self):
"""确保模型根目录存在"""
os.makedirs(self.model_dir, exist_ok=True)
def _read_versions(self) -> Dict[str, int]:
"""线程安全地读取版本文件"""
with self._lock:
if not os.path.exists(self.versions_path):
return {}
try:
with open(self.versions_path, 'r', encoding='utf-8') as f:
return json.load(f)
except (json.JSONDecodeError, IOError):
return {}
def _write_versions(self, versions: Dict[str, int]):
"""线程安全地写入版本文件"""
with self._lock:
with open(self.versions_path, 'w', encoding='utf-8') as f:
json.dump(versions, f, indent=2)
def get_model_identifier(self,
model_type: str,
training_mode: str,
scope: str,
aggregation_method: Optional[str] = None) -> str:
"""
生成模型的唯一标识符用于版本文件中的key。
"""
if training_mode == 'global':
return f"{training_mode}_{scope}_{aggregation_method}_{model_type}"
return f"{training_mode}_{scope}_{model_type}"
def get_next_version_number(self, model_identifier: str) -> int:
"""
获取指定模型的下一个版本号(整数)。
"""
versions = self._read_versions()
current_version = versions.get(model_identifier, 0)
return current_version + 1
def update_version(self, model_identifier: str, new_version: int):
"""
更新模型的最新版本号。
"""
versions = self._read_versions()
versions[model_identifier] = new_version
self._write_versions(versions)
def get_model_version_path(self,
model_type: str,
version: int,
training_mode: str,
aggregation_method: Optional[str] = None,
store_id: Optional[str] = None,
product_id: Optional[str] = None,
scope: Optional[str] = None) -> str: # scope为了兼容旧调用
"""
根据 `xz训练模型保存规则.md` 中定义的新规则生成模型版本目录的完整路径。
"""
# 基础路径始终是 self.model_dir
base_path = self.model_dir
# 确定第一级目录,根据规则,所有模式都在 'global' 下
path_parts = [base_path, 'global']
if training_mode == 'global':
# global/all/{aggregation_method}/{model_type}/v{N}/
path_parts.extend(['all', str(aggregation_method)])
elif training_mode == 'stores':
# global/stores/{store_id}/{aggregation_method}/{model_type}/v{N}/
if not store_id: raise ValueError("store_id is required for 'stores' training mode.")
path_parts.extend(['stores', store_id, str(aggregation_method)])
elif training_mode == 'products':
# global/products/{product_id}/{aggregation_method}/{model_type}/v{N}/
if not product_id: raise ValueError("product_id is required for 'products' training mode.")
path_parts.extend(['products', product_id, str(aggregation_method)])
elif training_mode == 'custom':
# global/custom/{store_id}/{product_id}/{aggregation_method}/{model_type}/v{N}/
if not store_id or not product_id:
raise ValueError("store_id and product_id are required for 'custom' training mode.")
path_parts.extend(['custom', store_id, product_id, str(aggregation_method)])
else:
raise ValueError(f"不支持的 training_mode: {training_mode}")
path_parts.extend([model_type, f'v{version}'])
return os.path.join(*path_parts)
def save_model_artifact(self,
artifact_data: Any,
artifact_name: str,
model_version_path: str):
"""
在指定的模型版本目录下保存一个产物。
Args:
artifact_data: 要保存的数据 (e.g., model state dict, figure object).
artifact_name: 标准化的产物文件名 (e.g., 'model.pth', 'loss_curve.png').
model_version_path: 模型版本目录的路径.
"""
os.makedirs(model_version_path, exist_ok=True)
full_path = os.path.join(model_version_path, artifact_name)
if artifact_name.endswith('.pth'):
torch.save(artifact_data, full_path)
elif artifact_name.endswith('.png') and hasattr(artifact_data, 'savefig'):
artifact_data.savefig(full_path, dpi=300, bbox_inches='tight')
elif artifact_name.endswith('.json'):
with open(full_path, 'w', encoding='utf-8') as f:
json.dump(artifact_data, f, indent=2, ensure_ascii=False)
else:
raise ValueError(f"不支持的产物类型: {artifact_name}")
print(f"产物已保存: {full_path}")
def list_models(self,
page: Optional[int] = None,
page_size: Optional[int] = None) -> Dict:
"""
通过扫描目录结构来列出所有模型 (适配新结构)。
"""
all_models = []
# 使用glob查找所有版本目录
search_pattern = os.path.join(self.model_dir, '**', 'v*')
for version_path in glob.glob(search_pattern, recursive=True):
# 确保它是一个目录并且包含 metadata.json
metadata_path = os.path.join(version_path, 'metadata.json')
if os.path.isdir(version_path) and os.path.exists(metadata_path):
model_info = self._parse_info_from_path(version_path)
if model_info:
all_models.append(model_info)
# 按时间戳降序排序
all_models.sort(key=lambda x: x.get('timestamp', ''), reverse=True)
total_count = len(all_models)
if page and page_size:
start_index = (page - 1) * page_size
end_index = start_index + page_size
paginated_models = all_models[start_index:end_index]
else:
paginated_models = all_models
return {
'models': paginated_models,
'pagination': {
'total': total_count,
'page': page or 1,
'page_size': page_size or total_count,
'total_pages': (total_count + page_size - 1) // page_size if page_size and page_size > 0 else 1,
}
}
def _parse_info_from_path(self, version_path: str) -> Optional[Dict]:
"""根据新的目录结构从版本目录路径解析模型信息"""
try:
norm_path = os.path.normpath(version_path)
norm_model_dir = os.path.normpath(self.model_dir)
relative_path = os.path.relpath(norm_path, norm_model_dir)
parts = relative_path.split(os.sep)
# 期望路径: global/{scope_type}/{id...}/{agg_method}/{model_type}/v{N}
if parts[0] != 'global' or len(parts) < 5:
return None # 不是规范的新路径
info = {
'model_path': version_path,
'version': parts[-1],
'model_type': parts[-2],
'store_id': None,
'product_id': None,
}
scope_type = parts[1] # all, stores, products, custom
if scope_type == 'all':
# global/all/sum/mlstm/v1
info['training_mode'] = 'global'
info['aggregation_method'] = parts[2]
elif scope_type == 'stores':
# global/stores/S001/sum/mlstm/v1
info['training_mode'] = 'stores'
info['store_id'] = parts[2]
info['aggregation_method'] = parts[3]
elif scope_type == 'products':
# global/products/P001/sum/mlstm/v1
info['training_mode'] = 'products'
info['product_id'] = parts[2]
info['aggregation_method'] = parts[3]
elif scope_type == 'custom':
# global/custom/S001/P001/sum/mlstm/v1
info['training_mode'] = 'custom'
info['store_id'] = parts[2]
info['product_id'] = parts[3]
info['aggregation_method'] = parts[4]
else:
return None
metadata_path = os.path.join(version_path, 'metadata.json')
if os.path.exists(metadata_path):
with open(metadata_path, 'r', encoding='utf-8') as f:
metadata = json.load(f)
# 确保从路径解析出的ID覆盖元数据中的因为路径是权威来源
info.update(metadata)
info['version'] = parts[-1] # 重新覆盖,确保正确
info['model_type'] = parts[-2]
return info
except (IndexError, IOError) as e:
print(f"解析路径失败 {version_path}: {e}")
return None
# 全局模型管理器实例
model_manager = ModelManager()