## 1. 核心 Bug 修复 ### 文件: `server/core/predictor.py` - **问题**: 在 `train_model` 方法中调用内部辅助函数 `_prepare_training_params` 时,没有正确传递 `product_ids` 和 `store_ids` 参数,导致在 `_prepare_training_params` 内部发生 `NameError`。 - **修复**: - 修正了 `train_model` 方法内部对 `_prepare_training_params` 的调用,确保 `product_ids` 和 `store_ids` 被显式传递。 - 此前已修复 `train_model` 的函数签名,使其能正确接收 `store_ids`。 - **结果**: 彻底解决了训练流程中的参数传递问题,根除了由此引发的 `NameError`。 ## 2. 代码清理与重构 ### 文件: `server/api.py` - **内容**: 移除了在 `start_training` API 端点中遗留的旧版、基于线程(`threading.Thread`)的训练逻辑。 - **原因**: 该代码块已被新的、基于多进程(`multiprocessing`)的 `TrainingProcessManager` 完全取代。旧代码中包含了大量用于调试的 `thread_safe_print` 日志,已无用处。 - **结果**: `start_training` 端点的逻辑变得更加清晰,只负责参数校验和向 `TrainingProcessManager` 提交任务。 ### 文件: `server/utils/training_process_manager.py` - **内容**: 在 `TrainingWorker` 的 `run_training_task` 方法中,移除了一个用于模拟训练进度的 `for` 循环。 - **原因**: 该循环包含 `time.sleep(1)`,仅用于在没有实际训练逻辑时模拟进度更新,现在实际的训练器会通过回调函数报告真实进度,因此该模拟代码不再需要。 - **结果**: `TrainingWorker` 现在直接调用实际的训练器,不再有模拟延迟,代码更贴近生产环境。 ## 3. 启动依赖 - **Python**: 3.x - **主要库**: - Flask - Flask-SocketIO - Flasgger - pandas - numpy - torch - scikit-learn - matplotlib - **启动命令**: `python server/api.py`
248 lines
10 KiB
Python
248 lines
10 KiB
Python
"""
|
||
统一模型管理工具
|
||
处理模型文件的统一命名、存储和检索
|
||
遵循层级式目录结构和文件版本管理规则
|
||
"""
|
||
|
||
import os
|
||
import json
|
||
import torch
|
||
import glob
|
||
from datetime import datetime
|
||
from typing import List, Dict, Optional, Any
|
||
from threading import Lock
|
||
from core.config import DEFAULT_MODEL_DIR
|
||
|
||
class ModelManager:
|
||
"""
|
||
统一模型管理器,采用结构化目录和版本文件进行管理。
|
||
"""
|
||
VERSION_FILE = 'versions.json'
|
||
|
||
def __init__(self, model_dir: str = DEFAULT_MODEL_DIR):
|
||
self.model_dir = os.path.abspath(model_dir)
|
||
self.versions_path = os.path.join(self.model_dir, self.VERSION_FILE)
|
||
self._lock = Lock()
|
||
self.ensure_model_dir()
|
||
|
||
def ensure_model_dir(self):
|
||
"""确保模型根目录存在"""
|
||
os.makedirs(self.model_dir, exist_ok=True)
|
||
|
||
def _read_versions(self) -> Dict[str, int]:
|
||
"""线程安全地读取版本文件"""
|
||
with self._lock:
|
||
if not os.path.exists(self.versions_path):
|
||
return {}
|
||
try:
|
||
with open(self.versions_path, 'r', encoding='utf-8') as f:
|
||
return json.load(f)
|
||
except (json.JSONDecodeError, IOError):
|
||
return {}
|
||
|
||
def _write_versions(self, versions: Dict[str, int]):
|
||
"""线程安全地写入版本文件"""
|
||
with self._lock:
|
||
with open(self.versions_path, 'w', encoding='utf-8') as f:
|
||
json.dump(versions, f, indent=2)
|
||
|
||
def get_model_identifier(self,
|
||
model_type: str,
|
||
training_mode: str,
|
||
scope: str,
|
||
aggregation_method: Optional[str] = None) -> str:
|
||
"""
|
||
生成模型的唯一标识符,用于版本文件中的key。
|
||
"""
|
||
if training_mode == 'global':
|
||
return f"{training_mode}_{scope}_{aggregation_method}_{model_type}"
|
||
return f"{training_mode}_{scope}_{model_type}"
|
||
|
||
def get_next_version_number(self, model_identifier: str) -> int:
|
||
"""
|
||
获取指定模型的下一个版本号(整数)。
|
||
"""
|
||
versions = self._read_versions()
|
||
current_version = versions.get(model_identifier, 0)
|
||
return current_version + 1
|
||
|
||
def update_version(self, model_identifier: str, new_version: int):
|
||
"""
|
||
更新模型的最新版本号。
|
||
"""
|
||
versions = self._read_versions()
|
||
versions[model_identifier] = new_version
|
||
self._write_versions(versions)
|
||
|
||
def get_model_version_path(self,
|
||
model_type: str,
|
||
version: int,
|
||
training_mode: str,
|
||
aggregation_method: Optional[str] = None,
|
||
store_id: Optional[str] = None,
|
||
product_id: Optional[str] = None,
|
||
scope: Optional[str] = None) -> str: # scope为了兼容旧调用
|
||
"""
|
||
根据 `xz训练模型保存规则.md` 中定义的新规则生成模型版本目录的完整路径。
|
||
"""
|
||
# 基础路径始终是 self.model_dir
|
||
base_path = self.model_dir
|
||
|
||
# 确定第一级目录,根据规则,所有模式都在 'global' 下
|
||
path_parts = [base_path, 'global']
|
||
|
||
if training_mode == 'global':
|
||
# global/all/{aggregation_method}/{model_type}/v{N}/
|
||
path_parts.extend(['all', str(aggregation_method)])
|
||
|
||
elif training_mode == 'stores':
|
||
# global/stores/{store_id}/{aggregation_method}/{model_type}/v{N}/
|
||
if not store_id: raise ValueError("store_id is required for 'stores' training mode.")
|
||
path_parts.extend(['stores', store_id, str(aggregation_method)])
|
||
|
||
elif training_mode == 'products':
|
||
# global/products/{product_id}/{aggregation_method}/{model_type}/v{N}/
|
||
if not product_id: raise ValueError("product_id is required for 'products' training mode.")
|
||
path_parts.extend(['products', product_id, str(aggregation_method)])
|
||
|
||
elif training_mode == 'custom':
|
||
# global/custom/{store_id}/{product_id}/{aggregation_method}/{model_type}/v{N}/
|
||
if not store_id or not product_id:
|
||
raise ValueError("store_id and product_id are required for 'custom' training mode.")
|
||
path_parts.extend(['custom', store_id, product_id, str(aggregation_method)])
|
||
|
||
else:
|
||
raise ValueError(f"不支持的 training_mode: {training_mode}")
|
||
|
||
path_parts.extend([model_type, f'v{version}'])
|
||
|
||
return os.path.join(*path_parts)
|
||
|
||
def save_model_artifact(self,
|
||
artifact_data: Any,
|
||
artifact_name: str,
|
||
model_version_path: str):
|
||
"""
|
||
在指定的模型版本目录下保存一个产物。
|
||
|
||
Args:
|
||
artifact_data: 要保存的数据 (e.g., model state dict, figure object).
|
||
artifact_name: 标准化的产物文件名 (e.g., 'model.pth', 'loss_curve.png').
|
||
model_version_path: 模型版本目录的路径.
|
||
"""
|
||
os.makedirs(model_version_path, exist_ok=True)
|
||
full_path = os.path.join(model_version_path, artifact_name)
|
||
|
||
if artifact_name.endswith('.pth'):
|
||
torch.save(artifact_data, full_path)
|
||
elif artifact_name.endswith('.png') and hasattr(artifact_data, 'savefig'):
|
||
artifact_data.savefig(full_path, dpi=300, bbox_inches='tight')
|
||
elif artifact_name.endswith('.json'):
|
||
with open(full_path, 'w', encoding='utf-8') as f:
|
||
json.dump(artifact_data, f, indent=2, ensure_ascii=False)
|
||
else:
|
||
raise ValueError(f"不支持的产物类型: {artifact_name}")
|
||
|
||
print(f"产物已保存: {full_path}")
|
||
|
||
def list_models(self,
|
||
page: Optional[int] = None,
|
||
page_size: Optional[int] = None) -> Dict:
|
||
"""
|
||
通过扫描目录结构来列出所有模型 (适配新结构)。
|
||
"""
|
||
all_models = []
|
||
# 使用glob查找所有版本目录
|
||
search_pattern = os.path.join(self.model_dir, '**', 'v*')
|
||
|
||
for version_path in glob.glob(search_pattern, recursive=True):
|
||
# 确保它是一个目录并且包含 metadata.json
|
||
metadata_path = os.path.join(version_path, 'metadata.json')
|
||
if os.path.isdir(version_path) and os.path.exists(metadata_path):
|
||
model_info = self._parse_info_from_path(version_path)
|
||
if model_info:
|
||
all_models.append(model_info)
|
||
|
||
# 按时间戳降序排序
|
||
all_models.sort(key=lambda x: x.get('timestamp', ''), reverse=True)
|
||
|
||
total_count = len(all_models)
|
||
if page and page_size:
|
||
start_index = (page - 1) * page_size
|
||
end_index = start_index + page_size
|
||
paginated_models = all_models[start_index:end_index]
|
||
else:
|
||
paginated_models = all_models
|
||
|
||
return {
|
||
'models': paginated_models,
|
||
'pagination': {
|
||
'total': total_count,
|
||
'page': page or 1,
|
||
'page_size': page_size or total_count,
|
||
'total_pages': (total_count + page_size - 1) // page_size if page_size and page_size > 0 else 1,
|
||
}
|
||
}
|
||
|
||
def _parse_info_from_path(self, version_path: str) -> Optional[Dict]:
|
||
"""根据新的目录结构从版本目录路径解析模型信息"""
|
||
try:
|
||
norm_path = os.path.normpath(version_path)
|
||
norm_model_dir = os.path.normpath(self.model_dir)
|
||
|
||
relative_path = os.path.relpath(norm_path, norm_model_dir)
|
||
parts = relative_path.split(os.sep)
|
||
|
||
# 期望路径: global/{scope_type}/{id...}/{agg_method}/{model_type}/v{N}
|
||
if parts[0] != 'global' or len(parts) < 5:
|
||
return None # 不是规范的新路径
|
||
|
||
info = {
|
||
'model_path': version_path,
|
||
'version': parts[-1],
|
||
'model_type': parts[-2],
|
||
'store_id': None,
|
||
'product_id': None,
|
||
}
|
||
|
||
scope_type = parts[1] # all, stores, products, custom
|
||
|
||
if scope_type == 'all':
|
||
# global/all/sum/mlstm/v1
|
||
info['training_mode'] = 'global'
|
||
info['aggregation_method'] = parts[2]
|
||
elif scope_type == 'stores':
|
||
# global/stores/S001/sum/mlstm/v1
|
||
info['training_mode'] = 'stores'
|
||
info['store_id'] = parts[2]
|
||
info['aggregation_method'] = parts[3]
|
||
elif scope_type == 'products':
|
||
# global/products/P001/sum/mlstm/v1
|
||
info['training_mode'] = 'products'
|
||
info['product_id'] = parts[2]
|
||
info['aggregation_method'] = parts[3]
|
||
elif scope_type == 'custom':
|
||
# global/custom/S001/P001/sum/mlstm/v1
|
||
info['training_mode'] = 'custom'
|
||
info['store_id'] = parts[2]
|
||
info['product_id'] = parts[3]
|
||
info['aggregation_method'] = parts[4]
|
||
else:
|
||
return None
|
||
|
||
metadata_path = os.path.join(version_path, 'metadata.json')
|
||
if os.path.exists(metadata_path):
|
||
with open(metadata_path, 'r', encoding='utf-8') as f:
|
||
metadata = json.load(f)
|
||
# 确保从路径解析出的ID覆盖元数据中的,因为路径是权威来源
|
||
info.update(metadata)
|
||
info['version'] = parts[-1] # 重新覆盖,确保正确
|
||
info['model_type'] = parts[-2]
|
||
|
||
return info
|
||
except (IndexError, IOError) as e:
|
||
print(f"解析路径失败 {version_path}: {e}")
|
||
return None
|
||
|
||
# 全局模型管理器实例
|
||
model_manager = ModelManager() |