""" 药店销售预测系统 - 核心预测器类 (已重构) 支持多店铺销售预测功能,并完全集成新的ModelManager """ import os import pandas as pd import time from datetime import datetime from trainers import ( train_product_model_with_mlstm, train_product_model_with_kan, train_product_model_with_tcn, train_product_model_with_transformer ) from predictors.model_predictor import load_model_and_predict from utils.multi_store_data_utils import ( load_multi_store_data, get_store_product_sales_data, aggregate_multi_store_data ) from core.config import DEVICE, DEFAULT_MODEL_DIR, DEFAULT_DATA_PATH from utils.model_manager import model_manager class PharmacyPredictor: """ 药店销售预测系统核心类,用于训练模型和进行预测 """ def __init__(self, data_path=None, model_dir=DEFAULT_MODEL_DIR): """ 初始化预测器 """ self.data_path = data_path if data_path else DEFAULT_DATA_PATH self.model_dir = model_dir self.device = DEVICE if not os.path.exists(model_dir): os.makedirs(model_dir) print(f"使用设备: {self.device}") try: self.data = load_multi_store_data(self.data_path) print(f"已加载多店铺数据,来源: {self.data_path}") except Exception as e: print(f"加载数据失败: {e}") self.data = None def _prepare_product_params(self, product_id, store_scope, **kwargs): """为 'product' 训练模式准备参数""" if not product_id: raise ValueError("进行 'product' 模式训练时,必须提供 product_id。") agg_store_id = None final_scope_suffix = store_scope if store_scope == 'specific': store_ids = kwargs.get('store_ids') if not store_ids: raise ValueError("当 store_scope 为 'specific' 时, 必须提供 store_ids 列表。") agg_store_id = store_ids final_scope_suffix = f"specific_{'_'.join(store_ids)}" elif store_scope != 'all': # 假设 store_scope 本身就是一个店铺ID agg_store_id = [store_scope] return { 'agg_store_id': agg_store_id, 'agg_product_id': [product_id], 'final_scope': f"{product_id}_{final_scope_suffix}", } def _prepare_store_params(self, store_id, product_scope, **kwargs): """为 'store' 训练模式准备参数""" if not store_id: raise ValueError("进行 'store' 模式训练时,必须提供 store_id。") agg_product_id = None final_scope_suffix = product_scope if product_scope == 'specific': product_ids = kwargs.get('product_ids') if not product_ids: raise ValueError("当 product_scope 为 'specific' 时, 必须提供 product_ids 列表。") agg_product_id = product_ids final_scope_suffix = f"specific_{'_'.join(product_ids)}" elif product_scope != 'all': # 假设 product_scope 本身就是一个药品ID agg_product_id = [product_scope] return { 'agg_store_id': [store_id], 'agg_product_id': agg_product_id, 'final_scope': f"{store_id}_{final_scope_suffix}", } def _prepare_global_params(self, global_scope, store_ids, product_ids, **kwargs): """为 'global' 训练模式准备参数""" agg_store_id, agg_product_id = None, None if global_scope == 'all': final_scope = 'all' elif global_scope == 'selected_stores': if not store_ids: raise ValueError("global_scope 为 'selected_stores' 时必须提供 store_ids。") final_scope = f"stores/{'_'.join(store_ids)}" agg_store_id = store_ids elif global_scope == 'selected_products': if not product_ids: raise ValueError("global_scope 为 'selected_products' 时必须提供 product_ids。") final_scope = f"products/{'_'.join(product_ids)}" agg_product_id = product_ids elif global_scope == 'custom': if not store_ids or not product_ids: raise ValueError("global_scope 为 'custom' 时必须提供 store_ids 和 product_ids。") final_scope = f"custom/{'_'.join(store_ids)}/{'_'.join(product_ids)}" agg_store_id = store_ids agg_product_id = product_ids else: raise ValueError(f"不支持的 global_scope: '{global_scope}'") return { 'agg_store_id': agg_store_id, 'agg_product_id': agg_product_id, 'final_scope': final_scope, } def _prepare_training_params(self, training_mode, **kwargs): """参数准备分发器""" if training_mode == 'product': return self._prepare_product_params(**kwargs) elif training_mode == 'store': return self._prepare_store_params(**kwargs) elif training_mode == 'global': return self._prepare_global_params(**kwargs) else: raise ValueError(f"不支持的 training_mode: '{training_mode}'") def train_model(self, **kwargs): """ 训练预测模型 - 完全适配新的训练器接口和模型保存规则 """ # 从 kwargs 中安全地提取参数 product_id = kwargs.get('product_id') model_type = kwargs.get('model_type', 'transformer') epochs = kwargs.get('epochs', 100) learning_rate = kwargs.get('learning_rate', 0.001) use_optimized = kwargs.get('use_optimized', False) store_id = kwargs.get('store_id') training_mode = kwargs.get('training_mode', 'product') aggregation_method = kwargs.get('aggregation_method', 'sum') product_scope = kwargs.get('product_scope', 'all') store_scope = kwargs.get('store_scope', 'all') global_scope = kwargs.get('global_scope', 'all') product_ids = kwargs.get('product_ids') store_ids = kwargs.get('store_ids') socketio = kwargs.get('socketio') task_id = kwargs.get('task_id') progress_callback = kwargs.get('progress_callback') patience = kwargs.get('patience', 10) def log_message(message, log_type='info'): print(f"[{log_type.upper()}] {message}", flush=True) if progress_callback: try: progress_callback({'log_type': log_type, 'message': message}) except Exception as e: print(f"[ERROR] 进度回调失败: {e}", flush=True) if self.data is None: log_message("没有可用的数据,请先加载或生成数据", 'error') return None try: # 将所有相关参数打包以便传递 prep_args = { 'training_mode': training_mode, 'product_id': product_id, 'store_id': store_id, 'product_scope': product_scope, 'store_scope': store_scope, 'global_scope': global_scope, 'product_ids': product_ids, 'store_ids': store_ids } params = self._prepare_training_params(**prep_args) product_data = aggregate_multi_store_data( store_id=params['agg_store_id'], product_id=params['agg_product_id'], aggregation_method=aggregation_method, file_path=self.data_path ) if product_data is None or product_data.empty: raise ValueError(f"聚合后数据为空,无法继续训练。模式: {training_mode}, Scope: {params['final_scope']}") except ValueError as e: log_message(f"参数校验或数据准备失败: {e}", 'error') return None except Exception as e: import traceback log_message(f"数据准备过程中发生未知错误: {e}", 'error') log_message(traceback.format_exc(), 'error') return None trainers = { 'transformer': train_product_model_with_transformer, 'mlstm': train_product_model_with_mlstm, 'tcn': train_product_model_with_tcn, 'kan': train_product_model_with_kan, 'optimized_kan': train_product_model_with_kan, } if model_type not in trainers: log_message(f"不支持的模型类型: {model_type}", 'error') return None trainer_func = trainers[model_type] trainer_args = { "product_df": product_data, "training_mode": training_mode, "aggregation_method": aggregation_method, "scope": params['final_scope'], "epochs": epochs, "socketio": socketio, "task_id": task_id, "progress_callback": progress_callback, "patience": patience, "learning_rate": learning_rate } if 'kan' in model_type: trainer_args['use_optimized'] = (model_type == 'optimized_kan') # 确保将 product_id 和 store_id 传递给训练器 if product_id: trainer_args['product_id'] = product_id if store_id: trainer_args['store_id'] = store_id try: log_message(f"🤖 开始调用 {model_type} 训练器 with scope: '{params['final_scope']}'") model, metrics, version, model_version_path = trainer_func(**trainer_args) log_message(f"✅ {model_type} 训练器成功返回", 'success') if metrics: relative_model_path = os.path.relpath(model_version_path, os.getcwd()) metrics.update({ 'model_type': model_type, 'version': version, 'model_path': relative_model_path.replace('\\', '/'), 'training_mode': training_mode, 'scope': params['final_scope'], 'aggregation_method': aggregation_method }) log_message(f"📈 最终返回的metrics: {metrics}", 'success') return metrics else: log_message("⚠️ 训练器返回的metrics为空", 'warning') return None except Exception as e: import traceback log_message(f"模型训练过程中发生严重错误: {e}\n{traceback.format_exc()}", 'error') return None def predict(self, model_version_path, future_days=7, start_date=None, analyze_result=False): """ 使用已训练的模型进行预测 - 直接使用模型版本路径 """ if not os.path.exists(model_version_path): raise FileNotFoundError(f"指定的模型路径不存在: {model_version_path}") return load_model_and_predict( model_version_path=model_version_path, future_days=future_days, start_date=start_date, analyze_result=analyze_result ) def list_models(self, **kwargs): """ 列出所有可用的模型版本。 直接调用 ModelManager 的 list_models 方法。 支持的过滤参数: model_type, training_mode, scope, version """ return model_manager.list_models(**kwargs) def delete_model(self, model_version_path): """ 删除一个指定的模型版本目录。 """ return model_manager.delete_model_version(model_version_path) def compare_models(self, product_id, epochs=50, **kwargs): """ 在相同数据上训练并比较多个模型的性能。 """ results = {} model_types_to_compare = ['tcn', 'mlstm', 'transformer', 'kan', 'optimized_kan'] for model_type in model_types_to_compare: print(f"\n{'='*20} 训练模型: {model_type.upper()} {'='*20}") try: metrics = self.train_model( product_id=product_id, model_type=model_type, epochs=epochs, **kwargs ) results[model_type] = metrics if metrics else {} except Exception as e: print(f"训练 {model_type} 模型失败: {e}") results[model_type] = {'error': str(e)} # 打印比较结果 print(f"\n{'='*25} 模型性能比较 {'='*25}") # 准备数据 df_data = [] for model, metrics in results.items(): if metrics and 'rmse' in metrics: df_data.append({ 'Model': model.upper(), 'RMSE': metrics.get('rmse'), 'R²': metrics.get('r2'), 'MAPE (%)': metrics.get('mape'), 'Time (s)': metrics.get('training_time') }) if not df_data: print("没有可供比较的模型结果。") return results comparison_df = pd.DataFrame(df_data).set_index('Model') print(comparison_df.to_string(float_format="%.4f")) return results