ShopTRAINING/server/core/predictor.py
xz2000 a9a0e51769 # 修改记录日志 (日期: 2025-07-16)
## 1. 核心 Bug 修复

### 文件: `server/core/predictor.py`

- **问题**: 在 `train_model` 方法中调用内部辅助函数 `_prepare_training_params` 时,没有正确传递 `product_ids` 和 `store_ids` 参数,导致在 `_prepare_training_params` 内部发生 `NameError`。
- **修复**:
    - 修正了 `train_model` 方法内部对 `_prepare_training_params` 的调用,确保 `product_ids` 和 `store_ids` 被显式传递。
    - 此前已修复 `train_model` 的函数签名,使其能正确接收 `store_ids`。
- **结果**: 彻底解决了训练流程中的参数传递问题,根除了由此引发的 `NameError`。

## 2. 代码清理与重构

### 文件: `server/api.py`

- **内容**: 移除了在 `start_training` API 端点中遗留的旧版、基于线程(`threading.Thread`)的训练逻辑。
- **原因**: 该代码块已被新的、基于多进程(`multiprocessing`)的 `TrainingProcessManager` 完全取代。旧代码中包含了大量用于调试的 `thread_safe_print` 日志,已无用处。
- **结果**: `start_training` 端点的逻辑变得更加清晰,只负责参数校验和向 `TrainingProcessManager` 提交任务。

### 文件: `server/utils/training_process_manager.py`

- **内容**: 在 `TrainingWorker` 的 `run_training_task` 方法中,移除了一个用于模拟训练进度的 `for` 循环。
- **原因**: 该循环包含 `time.sleep(1)`,仅用于在没有实际训练逻辑时模拟进度更新,现在实际的训练器会通过回调函数报告真实进度,因此该模拟代码不再需要。
- **结果**: `TrainingWorker` 现在直接调用实际的训练器,不再有模拟延迟,代码更贴近生产环境。

## 3. 启动依赖

- **Python**: 3.x
- **主要库**:
    - Flask
    - Flask-SocketIO
    - Flasgger
    - pandas
    - numpy
    - torch
    - scikit-learn
    - matplotlib
- **启动命令**: `python server/api.py`
2025-07-16 15:34:57 +08:00

313 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
药店销售预测系统 - 核心预测器类 (已重构)
支持多店铺销售预测功能并完全集成新的ModelManager
"""
import os
import pandas as pd
import time
from datetime import datetime
from trainers import (
train_product_model_with_mlstm,
train_product_model_with_kan,
train_product_model_with_tcn,
train_product_model_with_transformer
)
from predictors.model_predictor import load_model_and_predict
from utils.multi_store_data_utils import (
load_multi_store_data,
get_store_product_sales_data,
aggregate_multi_store_data
)
from core.config import DEVICE, DEFAULT_MODEL_DIR, DEFAULT_DATA_PATH
from utils.model_manager import model_manager
class PharmacyPredictor:
"""
药店销售预测系统核心类,用于训练模型和进行预测
"""
def __init__(self, data_path=None, model_dir=DEFAULT_MODEL_DIR):
"""
初始化预测器
"""
self.data_path = data_path if data_path else DEFAULT_DATA_PATH
self.model_dir = model_dir
self.device = DEVICE
if not os.path.exists(model_dir):
os.makedirs(model_dir)
print(f"使用设备: {self.device}")
try:
self.data = load_multi_store_data(self.data_path)
print(f"已加载多店铺数据,来源: {self.data_path}")
except Exception as e:
print(f"加载数据失败: {e}")
self.data = None
def _prepare_global_params(self, **kwargs):
"""'global' (all_stores_all_products) 模式准备参数"""
return {
'final_training_mode': 'global',
'agg_store_id': None,
'agg_product_id': None,
'path_store_id': 'all',
'path_product_id': 'all',
}
def _prepare_stores_params(self, **kwargs):
"""'stores' (selected_stores) 模式准备参数并校验"""
store_ids_list = kwargs.get('store_ids')
if not store_ids_list:
raise ValueError("进行 'selected_stores' 范围训练时,必须提供 store_ids 列表。")
return {
'final_training_mode': 'stores',
'agg_store_id': store_ids_list,
'agg_product_id': None,
'path_store_id': store_ids_list[0],
'path_product_id': 'all',
}
def _prepare_products_params(self, **kwargs):
"""'products' (selected_products) 模式准备参数并校验"""
product_ids_list = kwargs.get('product_ids')
if not product_ids_list:
raise ValueError("进行 'selected_products' 范围训练时,必须提供 product_ids 列表。")
return {
'final_training_mode': 'products',
'agg_store_id': None,
'agg_product_id': product_ids_list,
'path_store_id': 'all',
'path_product_id': product_ids_list[0],
}
def _prepare_custom_params(self, **kwargs):
"""'custom' 模式准备参数并校验"""
store_ids_list = kwargs.get('store_ids')
product_ids_list = kwargs.get('product_ids')
if not store_ids_list or not product_ids_list:
raise ValueError("进行 'custom' 范围训练时,必须同时提供 store_ids 和 product_ids 列表。")
return {
'final_training_mode': 'custom',
'agg_store_id': store_ids_list,
'agg_product_id': product_ids_list,
'path_store_id': store_ids_list[0],
'path_product_id': product_ids_list[0],
}
def _prepare_training_params(self, training_scope, product_id, store_id, **kwargs):
"""
参数准备分发器:根据 training_scope 调用相应的处理函数。
"""
scope_handlers = {
'all_stores_all_products': self._prepare_global_params,
'selected_stores': self._prepare_stores_params,
'selected_products': self._prepare_products_params,
'custom': self._prepare_custom_params,
}
handler = scope_handlers.get(training_scope)
if not handler:
raise ValueError(f"不支持的训练范围: '{training_scope}'")
# 将所有相关参数合并到一个字典中,然后传递给处理函数
all_params = kwargs.copy()
all_params['training_scope'] = training_scope
all_params['product_id'] = product_id
all_params['store_id'] = store_id
return handler(**all_params)
def train_model(self, product_id, model_type='transformer', epochs=100,
learning_rate=0.001, use_optimized=False,
store_id=None, training_mode='product', aggregation_method='sum',
product_scope='all', product_ids=None, store_ids=None,
socketio=None, task_id=None, progress_callback=None, patience=10, **kwargs):
"""
训练预测模型 - 完全适配新的训练器接口
"""
def log_message(message, log_type='info'):
print(f"[{log_type.upper()}] {message}", flush=True)
if progress_callback:
try:
progress_callback({'log_type': log_type, 'message': message})
except Exception as e:
print(f"[ERROR] 进度回调失败: {e}", flush=True)
if self.data is None:
log_message("没有可用的数据,请先加载或生成数据", 'error')
return None
try:
# 从kwargs复制一份避免修改原始字典
call_kwargs = kwargs.copy()
training_scope = call_kwargs.pop('training_scope', None)
# The dispatcher will pop the legacy store_id and product_id
params = self._prepare_training_params(
training_scope=training_scope,
store_id=store_id,
product_id=product_id,
product_ids=product_ids,
store_ids=store_ids,
**call_kwargs
)
product_data = aggregate_multi_store_data(
store_id=params['agg_store_id'],
product_id=params['agg_product_id'],
aggregation_method=aggregation_method,
file_path=self.data_path
)
if product_data is None or product_data.empty:
raise ValueError(f"聚合后数据为空,无法继续训练。模式: {params['final_training_mode']}")
except ValueError as e:
log_message(f"参数校验或数据准备失败: {e}", 'error')
return None
except Exception as e:
import traceback
log_message(f"数据准备过程中发生未知错误: {e}", 'error')
log_message(traceback.format_exc(), 'error')
return None
if product_data.empty:
log_message(f"找不到产品 {product_id} 的数据", 'error')
return None
trainers = {
'transformer': train_product_model_with_transformer,
'mlstm': train_product_model_with_mlstm,
'tcn': train_product_model_with_tcn,
'kan': train_product_model_with_kan,
'optimized_kan': train_product_model_with_kan,
}
if model_type not in trainers:
log_message(f"不支持的模型类型: {model_type}", 'error')
return None
trainer_func = trainers[model_type]
trainer_args = {
"product_id": params['path_product_id'],
"product_df": product_data,
"store_id": params['path_store_id'],
"training_mode": params['final_training_mode'],
"aggregation_method": aggregation_method,
"scope": kwargs.get('training_scope'),
"epochs": epochs,
"socketio": socketio,
"task_id": task_id,
"progress_callback": progress_callback,
"patience": patience,
"learning_rate": learning_rate
}
if 'kan' in model_type:
trainer_args['use_optimized'] = (model_type == 'optimized_kan')
try:
log_message(f"🤖 开始调用 {model_type} 训练器")
model, metrics, version, model_version_path = trainer_func(**trainer_args)
log_message(f"{model_type} 训练器成功返回", 'success')
if metrics:
relative_model_path = os.path.relpath(model_version_path, os.getcwd())
metrics.update({
'model_type': model_type,
'version': version,
'model_path': relative_model_path.replace('\\', '/'),
'training_mode': params['final_training_mode'],
'store_id': params['path_store_id'],
'product_id': params['path_product_id'],
'aggregation_method': aggregation_method if params['final_training_mode'] == 'global' else None
})
log_message(f"📈 最终返回的metrics: {metrics}", 'success')
return metrics
else:
log_message("⚠️ 训练器返回的metrics为空", 'warning')
return None
except Exception as e:
import traceback
log_message(f"模型训练过程中发生严重错误: {e}\n{traceback.format_exc()}", 'error')
return None
def predict(self, model_version_path, future_days=7, start_date=None, analyze_result=False):
"""
使用已训练的模型进行预测 - 直接使用模型版本路径
"""
if not os.path.exists(model_version_path):
raise FileNotFoundError(f"指定的模型路径不存在: {model_version_path}")
return load_model_and_predict(
model_version_path=model_version_path,
future_days=future_days,
start_date=start_date,
analyze_result=analyze_result
)
def list_models(self, **kwargs):
"""
列出所有可用的模型版本。
直接调用 ModelManager 的 list_models 方法。
支持的过滤参数: model_type, training_mode, scope, version
"""
return model_manager.list_models(**kwargs)
def delete_model(self, model_version_path):
"""
删除一个指定的模型版本目录。
"""
return model_manager.delete_model_version(model_version_path)
def compare_models(self, product_id, epochs=50, **kwargs):
"""
在相同数据上训练并比较多个模型的性能。
"""
results = {}
model_types_to_compare = ['tcn', 'mlstm', 'transformer', 'kan', 'optimized_kan']
for model_type in model_types_to_compare:
print(f"\n{'='*20} 训练模型: {model_type.upper()} {'='*20}")
try:
metrics = self.train_model(
product_id=product_id,
model_type=model_type,
epochs=epochs,
**kwargs
)
results[model_type] = metrics if metrics else {}
except Exception as e:
print(f"训练 {model_type} 模型失败: {e}")
results[model_type] = {'error': str(e)}
# 打印比较结果
print(f"\n{'='*25} 模型性能比较 {'='*25}")
# 准备数据
df_data = []
for model, metrics in results.items():
if metrics and 'rmse' in metrics:
df_data.append({
'Model': model.upper(),
'RMSE': metrics.get('rmse'),
'': metrics.get('r2'),
'MAPE (%)': metrics.get('mape'),
'Time (s)': metrics.get('training_time')
})
if not df_data:
print("没有可供比较的模型结果。")
return results
comparison_df = pd.DataFrame(df_data).set_index('Model')
print(comparison_df.to_string(float_format="%.4f"))
return results