## 1. 核心 Bug 修复 ### 文件: `server/core/predictor.py` - **问题**: 在 `train_model` 方法中调用内部辅助函数 `_prepare_training_params` 时,没有正确传递 `product_ids` 和 `store_ids` 参数,导致在 `_prepare_training_params` 内部发生 `NameError`。 - **修复**: - 修正了 `train_model` 方法内部对 `_prepare_training_params` 的调用,确保 `product_ids` 和 `store_ids` 被显式传递。 - 此前已修复 `train_model` 的函数签名,使其能正确接收 `store_ids`。 - **结果**: 彻底解决了训练流程中的参数传递问题,根除了由此引发的 `NameError`。 ## 2. 代码清理与重构 ### 文件: `server/api.py` - **内容**: 移除了在 `start_training` API 端点中遗留的旧版、基于线程(`threading.Thread`)的训练逻辑。 - **原因**: 该代码块已被新的、基于多进程(`multiprocessing`)的 `TrainingProcessManager` 完全取代。旧代码中包含了大量用于调试的 `thread_safe_print` 日志,已无用处。 - **结果**: `start_training` 端点的逻辑变得更加清晰,只负责参数校验和向 `TrainingProcessManager` 提交任务。 ### 文件: `server/utils/training_process_manager.py` - **内容**: 在 `TrainingWorker` 的 `run_training_task` 方法中,移除了一个用于模拟训练进度的 `for` 循环。 - **原因**: 该循环包含 `time.sleep(1)`,仅用于在没有实际训练逻辑时模拟进度更新,现在实际的训练器会通过回调函数报告真实进度,因此该模拟代码不再需要。 - **结果**: `TrainingWorker` 现在直接调用实际的训练器,不再有模拟延迟,代码更贴近生产环境。 ## 3. 启动依赖 - **Python**: 3.x - **主要库**: - Flask - Flask-SocketIO - Flasgger - pandas - numpy - torch - scikit-learn - matplotlib - **启动命令**: `python server/api.py`
313 lines
12 KiB
Python
313 lines
12 KiB
Python
"""
|
||
药店销售预测系统 - 核心预测器类 (已重构)
|
||
支持多店铺销售预测功能,并完全集成新的ModelManager
|
||
"""
|
||
|
||
import os
|
||
import pandas as pd
|
||
import time
|
||
from datetime import datetime
|
||
|
||
from trainers import (
|
||
train_product_model_with_mlstm,
|
||
train_product_model_with_kan,
|
||
train_product_model_with_tcn,
|
||
train_product_model_with_transformer
|
||
)
|
||
from predictors.model_predictor import load_model_and_predict
|
||
from utils.multi_store_data_utils import (
|
||
load_multi_store_data,
|
||
get_store_product_sales_data,
|
||
aggregate_multi_store_data
|
||
)
|
||
from core.config import DEVICE, DEFAULT_MODEL_DIR, DEFAULT_DATA_PATH
|
||
from utils.model_manager import model_manager
|
||
|
||
class PharmacyPredictor:
|
||
"""
|
||
药店销售预测系统核心类,用于训练模型和进行预测
|
||
"""
|
||
def __init__(self, data_path=None, model_dir=DEFAULT_MODEL_DIR):
|
||
"""
|
||
初始化预测器
|
||
"""
|
||
self.data_path = data_path if data_path else DEFAULT_DATA_PATH
|
||
self.model_dir = model_dir
|
||
self.device = DEVICE
|
||
|
||
if not os.path.exists(model_dir):
|
||
os.makedirs(model_dir)
|
||
|
||
print(f"使用设备: {self.device}")
|
||
|
||
try:
|
||
self.data = load_multi_store_data(self.data_path)
|
||
print(f"已加载多店铺数据,来源: {self.data_path}")
|
||
except Exception as e:
|
||
print(f"加载数据失败: {e}")
|
||
self.data = None
|
||
|
||
def _prepare_global_params(self, **kwargs):
|
||
"""为 'global' (all_stores_all_products) 模式准备参数"""
|
||
return {
|
||
'final_training_mode': 'global',
|
||
'agg_store_id': None,
|
||
'agg_product_id': None,
|
||
'path_store_id': 'all',
|
||
'path_product_id': 'all',
|
||
}
|
||
|
||
def _prepare_stores_params(self, **kwargs):
|
||
"""为 'stores' (selected_stores) 模式准备参数并校验"""
|
||
store_ids_list = kwargs.get('store_ids')
|
||
if not store_ids_list:
|
||
raise ValueError("进行 'selected_stores' 范围训练时,必须提供 store_ids 列表。")
|
||
return {
|
||
'final_training_mode': 'stores',
|
||
'agg_store_id': store_ids_list,
|
||
'agg_product_id': None,
|
||
'path_store_id': store_ids_list[0],
|
||
'path_product_id': 'all',
|
||
}
|
||
|
||
def _prepare_products_params(self, **kwargs):
|
||
"""为 'products' (selected_products) 模式准备参数并校验"""
|
||
product_ids_list = kwargs.get('product_ids')
|
||
if not product_ids_list:
|
||
raise ValueError("进行 'selected_products' 范围训练时,必须提供 product_ids 列表。")
|
||
return {
|
||
'final_training_mode': 'products',
|
||
'agg_store_id': None,
|
||
'agg_product_id': product_ids_list,
|
||
'path_store_id': 'all',
|
||
'path_product_id': product_ids_list[0],
|
||
}
|
||
|
||
def _prepare_custom_params(self, **kwargs):
|
||
"""为 'custom' 模式准备参数并校验"""
|
||
store_ids_list = kwargs.get('store_ids')
|
||
product_ids_list = kwargs.get('product_ids')
|
||
if not store_ids_list or not product_ids_list:
|
||
raise ValueError("进行 'custom' 范围训练时,必须同时提供 store_ids 和 product_ids 列表。")
|
||
return {
|
||
'final_training_mode': 'custom',
|
||
'agg_store_id': store_ids_list,
|
||
'agg_product_id': product_ids_list,
|
||
'path_store_id': store_ids_list[0],
|
||
'path_product_id': product_ids_list[0],
|
||
}
|
||
|
||
def _prepare_training_params(self, training_scope, product_id, store_id, **kwargs):
|
||
"""
|
||
参数准备分发器:根据 training_scope 调用相应的处理函数。
|
||
"""
|
||
scope_handlers = {
|
||
'all_stores_all_products': self._prepare_global_params,
|
||
'selected_stores': self._prepare_stores_params,
|
||
'selected_products': self._prepare_products_params,
|
||
'custom': self._prepare_custom_params,
|
||
}
|
||
handler = scope_handlers.get(training_scope)
|
||
if not handler:
|
||
raise ValueError(f"不支持的训练范围: '{training_scope}'")
|
||
|
||
# 将所有相关参数合并到一个字典中,然后传递给处理函数
|
||
all_params = kwargs.copy()
|
||
all_params['training_scope'] = training_scope
|
||
all_params['product_id'] = product_id
|
||
all_params['store_id'] = store_id
|
||
|
||
return handler(**all_params)
|
||
|
||
def train_model(self, product_id, model_type='transformer', epochs=100,
|
||
learning_rate=0.001, use_optimized=False,
|
||
store_id=None, training_mode='product', aggregation_method='sum',
|
||
product_scope='all', product_ids=None, store_ids=None,
|
||
socketio=None, task_id=None, progress_callback=None, patience=10, **kwargs):
|
||
"""
|
||
训练预测模型 - 完全适配新的训练器接口
|
||
"""
|
||
def log_message(message, log_type='info'):
|
||
print(f"[{log_type.upper()}] {message}", flush=True)
|
||
if progress_callback:
|
||
try:
|
||
progress_callback({'log_type': log_type, 'message': message})
|
||
except Exception as e:
|
||
print(f"[ERROR] 进度回调失败: {e}", flush=True)
|
||
|
||
if self.data is None:
|
||
log_message("没有可用的数据,请先加载或生成数据", 'error')
|
||
return None
|
||
|
||
try:
|
||
# 从kwargs复制一份,避免修改原始字典
|
||
call_kwargs = kwargs.copy()
|
||
training_scope = call_kwargs.pop('training_scope', None)
|
||
|
||
# The dispatcher will pop the legacy store_id and product_id
|
||
params = self._prepare_training_params(
|
||
training_scope=training_scope,
|
||
store_id=store_id,
|
||
product_id=product_id,
|
||
product_ids=product_ids,
|
||
store_ids=store_ids,
|
||
**call_kwargs
|
||
)
|
||
|
||
product_data = aggregate_multi_store_data(
|
||
store_id=params['agg_store_id'],
|
||
product_id=params['agg_product_id'],
|
||
aggregation_method=aggregation_method,
|
||
file_path=self.data_path
|
||
)
|
||
|
||
if product_data is None or product_data.empty:
|
||
raise ValueError(f"聚合后数据为空,无法继续训练。模式: {params['final_training_mode']}")
|
||
|
||
except ValueError as e:
|
||
log_message(f"参数校验或数据准备失败: {e}", 'error')
|
||
return None
|
||
except Exception as e:
|
||
import traceback
|
||
log_message(f"数据准备过程中发生未知错误: {e}", 'error')
|
||
log_message(traceback.format_exc(), 'error')
|
||
return None
|
||
|
||
if product_data.empty:
|
||
log_message(f"找不到产品 {product_id} 的数据", 'error')
|
||
return None
|
||
|
||
trainers = {
|
||
'transformer': train_product_model_with_transformer,
|
||
'mlstm': train_product_model_with_mlstm,
|
||
'tcn': train_product_model_with_tcn,
|
||
'kan': train_product_model_with_kan,
|
||
'optimized_kan': train_product_model_with_kan,
|
||
}
|
||
|
||
if model_type not in trainers:
|
||
log_message(f"不支持的模型类型: {model_type}", 'error')
|
||
return None
|
||
|
||
trainer_func = trainers[model_type]
|
||
|
||
trainer_args = {
|
||
"product_id": params['path_product_id'],
|
||
"product_df": product_data,
|
||
"store_id": params['path_store_id'],
|
||
"training_mode": params['final_training_mode'],
|
||
"aggregation_method": aggregation_method,
|
||
"scope": kwargs.get('training_scope'),
|
||
"epochs": epochs,
|
||
"socketio": socketio,
|
||
"task_id": task_id,
|
||
"progress_callback": progress_callback,
|
||
"patience": patience,
|
||
"learning_rate": learning_rate
|
||
}
|
||
|
||
if 'kan' in model_type:
|
||
trainer_args['use_optimized'] = (model_type == 'optimized_kan')
|
||
|
||
try:
|
||
log_message(f"🤖 开始调用 {model_type} 训练器")
|
||
|
||
model, metrics, version, model_version_path = trainer_func(**trainer_args)
|
||
|
||
log_message(f"✅ {model_type} 训练器成功返回", 'success')
|
||
|
||
if metrics:
|
||
relative_model_path = os.path.relpath(model_version_path, os.getcwd())
|
||
|
||
metrics.update({
|
||
'model_type': model_type,
|
||
'version': version,
|
||
'model_path': relative_model_path.replace('\\', '/'),
|
||
'training_mode': params['final_training_mode'],
|
||
'store_id': params['path_store_id'],
|
||
'product_id': params['path_product_id'],
|
||
'aggregation_method': aggregation_method if params['final_training_mode'] == 'global' else None
|
||
})
|
||
log_message(f"📈 最终返回的metrics: {metrics}", 'success')
|
||
return metrics
|
||
else:
|
||
log_message("⚠️ 训练器返回的metrics为空", 'warning')
|
||
return None
|
||
|
||
except Exception as e:
|
||
import traceback
|
||
log_message(f"模型训练过程中发生严重错误: {e}\n{traceback.format_exc()}", 'error')
|
||
return None
|
||
|
||
def predict(self, model_version_path, future_days=7, start_date=None, analyze_result=False):
|
||
"""
|
||
使用已训练的模型进行预测 - 直接使用模型版本路径
|
||
"""
|
||
if not os.path.exists(model_version_path):
|
||
raise FileNotFoundError(f"指定的模型路径不存在: {model_version_path}")
|
||
|
||
return load_model_and_predict(
|
||
model_version_path=model_version_path,
|
||
future_days=future_days,
|
||
start_date=start_date,
|
||
analyze_result=analyze_result
|
||
)
|
||
|
||
def list_models(self, **kwargs):
|
||
"""
|
||
列出所有可用的模型版本。
|
||
直接调用 ModelManager 的 list_models 方法。
|
||
支持的过滤参数: model_type, training_mode, scope, version
|
||
"""
|
||
return model_manager.list_models(**kwargs)
|
||
|
||
def delete_model(self, model_version_path):
|
||
"""
|
||
删除一个指定的模型版本目录。
|
||
"""
|
||
return model_manager.delete_model_version(model_version_path)
|
||
|
||
def compare_models(self, product_id, epochs=50, **kwargs):
|
||
"""
|
||
在相同数据上训练并比较多个模型的性能。
|
||
"""
|
||
results = {}
|
||
model_types_to_compare = ['tcn', 'mlstm', 'transformer', 'kan', 'optimized_kan']
|
||
|
||
for model_type in model_types_to_compare:
|
||
print(f"\n{'='*20} 训练模型: {model_type.upper()} {'='*20}")
|
||
try:
|
||
metrics = self.train_model(
|
||
product_id=product_id,
|
||
model_type=model_type,
|
||
epochs=epochs,
|
||
**kwargs
|
||
)
|
||
results[model_type] = metrics if metrics else {}
|
||
except Exception as e:
|
||
print(f"训练 {model_type} 模型失败: {e}")
|
||
results[model_type] = {'error': str(e)}
|
||
|
||
# 打印比较结果
|
||
print(f"\n{'='*25} 模型性能比较 {'='*25}")
|
||
|
||
# 准备数据
|
||
df_data = []
|
||
for model, metrics in results.items():
|
||
if metrics and 'rmse' in metrics:
|
||
df_data.append({
|
||
'Model': model.upper(),
|
||
'RMSE': metrics.get('rmse'),
|
||
'R²': metrics.get('r2'),
|
||
'MAPE (%)': metrics.get('mape'),
|
||
'Time (s)': metrics.get('training_time')
|
||
})
|
||
|
||
if not df_data:
|
||
print("没有可供比较的模型结果。")
|
||
return results
|
||
|
||
comparison_df = pd.DataFrame(df_data).set_index('Model')
|
||
print(comparison_df.to_string(float_format="%.4f"))
|
||
|
||
return results |