ShopTRAINING/server/core/predictor.py
xz2000 e999ed4af2 ### 2025-07-15 (续): 训练器与核心调用层重构
**核心目标**: 将新的 `ModelManager` 统一应用到项目中所有剩余的模型训练器,并重构核心调用逻辑,确保整个训练链路的架构一致性。

**1. 修改 `server/trainers/kan_trainer.py`**
*   **内容**: 完全重写了 `kan_trainer.py`。
    *   **适配接口**: 函数签名与 `mlstm_trainer` 对齐,增加了 `socketio`, `task_id`, `patience` 等参数。
    *   **集成 `ModelManager`**: 移除了所有旧的、手动的保存逻辑,改为在训练开始时调用 `model_manager` 获取版本号和路径。
    *   **标准化产物保存**: 所有产物(模型、元数据、检查点、损失曲线)均通过 `model_manager.save_model_artifact()` 保存。
    *   **增加健壮性**: 引入了早停(Early Stopping)和保存最佳检查点(Best Checkpoint)的逻辑。

**2. 修改 `server/trainers/tcn_trainer.py`**
*   **内容**: 完全重写了 `tcn_trainer.py`,应用了与 `kan_trainer` 完全相同的重构模式。
    *   移除了旧的 `save_checkpoint` 辅助函数和基于 `core.config` 的版本管理。
    *   全面转向使用 `model_manager` 进行版本控制和文件保存。
    *   统一了函数签名和进度反馈逻辑。

**3. 修改 `server/trainers/transformer_trainer.py`**
*   **内容**: 完全重写了 `transformer_trainer.py`,完成了对所有训练器的统一重构。
    *   移除了所有遗留的、基于文件名的路径拼接和保存逻辑。
    *   实现了与其它训练器一致的、基于 `ModelManager` 的标准化训练流程。

**4. 修改 `server/core/predictor.py`**
*   **内容**: 对核心预测器类 `PharmacyPredictor` 进行了彻底重构。
    *   **统一调用接口**: `train_model` 方法现在以完全一致的方式调用所有(`mlstm`, `kan`, `tcn`, `transformer`)训练器。
    *   **移除旧逻辑**: 删除了 `_parse_model_filename` 等所有基于文件名解析的旧方法。
    *   **适配 `ModelManager`**: `list_models` 和 `delete_model` 等方法现在直接调用 `model_manager` 的相应功能,不再自己实现逻辑。
    *   **简化 `predict`**: 预测方法现在直接接收标准化的模型版本路径 (`model_version_path`) 作为输入,逻辑更清晰。
2025-07-15 20:09:09 +08:00

222 lines
8.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
药店销售预测系统 - 核心预测器类 (已重构)
支持多店铺销售预测功能并完全集成新的ModelManager
"""
import os
import pandas as pd
import time
from datetime import datetime
from trainers import (
train_product_model_with_mlstm,
train_product_model_with_kan,
train_product_model_with_tcn,
train_product_model_with_transformer
)
from predictors.model_predictor import load_model_and_predict
from utils.multi_store_data_utils import (
load_multi_store_data,
get_store_product_sales_data,
aggregate_multi_store_data
)
from core.config import DEVICE, DEFAULT_MODEL_DIR, DEFAULT_DATA_PATH
from utils.model_manager import model_manager
class PharmacyPredictor:
"""
药店销售预测系统核心类,用于训练模型和进行预测
"""
def __init__(self, data_path=None, model_dir=DEFAULT_MODEL_DIR):
"""
初始化预测器
"""
self.data_path = data_path if data_path else DEFAULT_DATA_PATH
self.model_dir = model_dir
self.device = DEVICE
if not os.path.exists(model_dir):
os.makedirs(model_dir)
print(f"使用设备: {self.device}")
try:
self.data = load_multi_store_data(self.data_path)
print(f"已加载多店铺数据,来源: {self.data_path}")
except Exception as e:
print(f"加载数据失败: {e}")
self.data = None
def train_model(self, product_id, model_type='transformer', epochs=100,
learning_rate=0.001, use_optimized=False,
store_id=None, training_mode='product', aggregation_method='sum',
socketio=None, task_id=None, progress_callback=None, patience=10):
"""
训练预测模型 - 完全适配新的训练器接口
"""
def log_message(message, log_type='info'):
print(f"[{log_type.upper()}] {message}", flush=True)
if progress_callback:
try:
progress_callback({'log_type': log_type, 'message': message})
except Exception as e:
print(f"[ERROR] 进度回调失败: {e}", flush=True)
if self.data is None:
log_message("没有可用的数据,请先加载或生成数据", 'error')
return None
# --- 数据准备 ---
try:
if training_mode == 'store':
product_data = get_store_product_sales_data(store_id, product_id, self.data_path)
log_message(f"按店铺-产品训练: 店铺 {store_id}, 产品 {product_id}, 数据量: {len(product_data)}")
elif training_mode == 'global':
product_data = aggregate_multi_store_data(product_id, aggregation_method, self.data_path)
log_message(f"全局训练模式: 产品 {product_id}, 聚合方法 {aggregation_method}, 数据<E695B0><E68DAE><EFBFBD>: {len(product_data)}")
else: # 'product'
product_data = self.data[self.data['product_id'] == product_id].copy()
log_message(f"按产品训练模式: 产品 {product_id}, 数据量: {len(product_data)}")
except Exception as e:
log_message(f"数据准备失败: {e}", 'error')
return None
if product_data.empty:
log_message(f"找不到产品 {product_id} 的数据", 'error')
return None
# --- 训练器选择与参数准备 ---
trainers = {
'transformer': train_product_model_with_transformer,
'mlstm': train_product_model_with_mlstm,
'tcn': train_product_model_with_tcn,
'kan': train_product_model_with_kan,
'optimized_kan': train_product_model_with_kan,
}
if model_type not in trainers:
log_message(f"不支持的模型类型: {model_type}", 'error')
return None
trainer_func = trainers[model_type]
# 统一所有训练器的参数
trainer_args = {
"product_id": product_id,
"product_df": product_data,
"store_id": store_id,
"training_mode": training_mode,
"aggregation_method": aggregation_method,
"epochs": epochs,
"socketio": socketio,
"task_id": task_id,
"progress_callback": progress_callback,
"patience": patience,
"learning_rate": learning_rate
}
# 为 KAN 模型添加特殊参数
if 'kan' in model_type:
trainer_args['use_optimized'] = (model_type == 'optimized_kan')
# --- 调用训练器 ---
try:
log_message(f"🤖 开始调用 {model_type} 训练器")
model, metrics, version, model_version_path = trainer_func(**trainer_args)
log_message(f"{model_type} 训练器成功返回", 'success')
if metrics:
metrics.update({
'model_type': model_type,
'version': version,
'model_path': model_version_path,
'training_mode': training_mode,
'store_id': store_id,
'product_id': product_id,
'aggregation_method': aggregation_method if training_mode == 'global' else None
})
log_message(f"📈 最终返回的metrics: {metrics}", 'success')
return metrics
else:
log_message("⚠️ 训练器返回的metrics为空", 'warning')
return None
except Exception as e:
import traceback
log_message(f"模型训练过程中发生严重错误: {e}\n{traceback.format_exc()}", 'error')
return None
def predict(self, model_version_path, future_days=7, start_date=None, analyze_result=False):
"""
使用已训练的模型进行预测 - 直接使用模型版本路径
"""
if not os.path.exists(model_version_path):
raise FileNotFoundError(f"指定的模型路径不存在: {model_version_path}")
return load_model_and_predict(
model_version_path=model_version_path,
future_days=future_days,
start_date=start_date,
analyze_result=analyze_result
)
def list_models(self, **kwargs):
"""
列出所有可用的模型版本。
直接调用 ModelManager 的 list_models 方法。
支持的过滤参数: model_type, training_mode, scope, version
"""
return model_manager.list_models(**kwargs)
def delete_model(self, model_version_path):
"""
删除一个指定的模型版本目录。
"""
return model_manager.delete_model_version(model_version_path)
def compare_models(self, product_id, epochs=50, **kwargs):
"""
在相同数据上训练并比较多个模型的性能。
"""
results = {}
model_types_to_compare = ['tcn', 'mlstm', 'transformer', 'kan', 'optimized_kan']
for model_type in model_types_to_compare:
print(f"\n{'='*20} 训练模型: {model_type.upper()} {'='*20}")
try:
metrics = self.train_model(
product_id=product_id,
model_type=model_type,
epochs=epochs,
**kwargs
)
results[model_type] = metrics if metrics else {}
except Exception as e:
print(f"训练 {model_type} 模型失败: {e}")
results[model_type] = {'error': str(e)}
# 打印比较结果
print(f"\n{'='*25} 模型性能比较 {'='*25}")
# 准备数据
df_data = []
for model, metrics in results.items():
if metrics and 'rmse' in metrics:
df_data.append({
'Model': model.upper(),
'RMSE': metrics.get('rmse'),
'': metrics.get('r2'),
'MAPE (%)': metrics.get('mape'),
'Time (s)': metrics.get('training_time')
})
if not df_data:
print("没有可供比较的模型结果。")
return results
comparison_df = pd.DataFrame(df_data).set_index('Model')
print(comparison_df.to_string(float_format="%.4f"))
return results