import os import torch import glob import pandas as pd import matplotlib.pyplot as plt from datetime import datetime import json import shutil from .utils import get_device, to_device from .mlstm_model import MLSTMTransformer from .transformer_model import TimeSeriesTransformer from .kan_model import KANForecaster class ModelManager: """ 模型管理类:负责模型的保存、加载、列出和删除等操作 """ def __init__(self, models_dir='models'): """ 初始化模型管理器 参数: models_dir: 模型存储目录 """ self.models_dir = models_dir self._ensure_model_dir() # 模型类型映射 self.model_types = { 'mlstm': MLSTMTransformer, 'transformer': TimeSeriesTransformer, 'kan': KANForecaster } def _ensure_model_dir(self): """确保模型目录存在""" if not os.path.exists(self.models_dir): try: os.makedirs(self.models_dir, exist_ok=True) print(f"创建模型目录: {os.path.abspath(self.models_dir)}") except Exception as e: print(f"创建模型目录失败: {str(e)}") raise def save_model(self, model, model_type, product_id, optimizer=None, train_loss=None, test_loss=None, scaler_X=None, scaler_y=None, features=None, look_back=None, T=None, metrics=None, version=None): """ 保存模型及其相关信息 参数: model: 训练好的模型 model_type: 模型类型 ('mlstm', 'transformer', 'kan') product_id: 产品ID optimizer: 优化器 train_loss: 训练损失历史 test_loss: 测试损失历史 scaler_X: 特征缩放器 scaler_y: 目标缩放器 features: 使用的特征列表 look_back: 回看天数 T: 预测天数 metrics: 模型评估指标 version: 模型版本(可选),如果不提供则使用时间戳 """ self._ensure_model_dir() # 设置版本 if version is None: version = datetime.now().strftime("%Y%m%d_%H%M%S") # 设置文件名 model_filename = f"{product_id}_{model_type}_model_v{version}.pt" model_path = os.path.join(self.models_dir, model_filename) # 准备要保存的数据 save_dict = { 'model_state_dict': model.state_dict(), 'model_type': model_type, 'product_id': product_id, 'version': version, 'created_at': datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 'features': features, 'look_back': look_back, 'T': T } # 添加可选数据 if optimizer is not None: save_dict['optimizer_state_dict'] = optimizer.state_dict() if train_loss is not None: save_dict['train_loss'] = train_loss if test_loss is not None: save_dict['test_loss'] = test_loss if scaler_X is not None: save_dict['scaler_X'] = scaler_X if scaler_y is not None: save_dict['scaler_y'] = scaler_y if metrics is not None: save_dict['metrics'] = metrics try: # 保存模型 torch.save(save_dict, model_path) print(f"模型已成功保存到 {os.path.abspath(model_path)}") # 保存模型的元数据到JSON文件,便于查询 meta_path = os.path.join(self.models_dir, f"{product_id}_{model_type}_meta_v{version}.json") meta_dict = {k: str(v) if not isinstance(v, (int, float, bool, list, dict, type(None))) else v for k, v in save_dict.items() if k != 'model_state_dict' and k != 'optimizer_state_dict' and k != 'scaler_X' and k != 'scaler_y'} # 如果有评估指标,添加到元数据 if metrics is not None: meta_dict['metrics'] = metrics with open(meta_path, 'w') as f: json.dump(meta_dict, f, indent=4) return model_path except Exception as e: print(f"保存模型时出错: {str(e)}") raise def load_model(self, product_id, model_type='mlstm', version=None, device=None): """ 加载指定的模型 参数: product_id: 产品ID model_type: 模型类型 ('mlstm', 'transformer', 'kan') version: 模型版本,如果不指定则加载最新版本 device: 设备 (cuda/cpu) 返回: model: 加载的模型 checkpoint: 包含模型信息的字典 """ if device is None: device = get_device() # 查找匹配的模型文件 if version is None: # 查找最新版本 pattern = os.path.join(self.models_dir, f"{product_id}_{model_type}_model_v*.pt") model_files = glob.glob(pattern) if not model_files: print(f"错误: 未找到产品 {product_id} 的 {model_type} 模型文件") return None, None # 按照文件修改时间排序,获取最新的 model_path = max(model_files, key=os.path.getmtime) else: # 指定版本 model_path = os.path.join(self.models_dir, f"{product_id}_{model_type}_model_v{version}.pt") if not os.path.exists(model_path): print(f"错误: 未找到产品 {product_id} 的 {model_type} 模型版本 {version}") return None, None try: # 加载模型 checkpoint = torch.load(model_path, map_location=device) # 创建模型实例 if model_type == 'mlstm': model = MLSTMTransformer( num_features=len(checkpoint['features']), hidden_size=128, mlstm_layers=1, embed_dim=32, dense_dim=32, num_heads=4, dropout_rate=0.1, num_blocks=3, output_sequence_length=checkpoint['T'] ) elif model_type == 'transformer': model = TimeSeriesTransformer( num_features=len(checkpoint['features']), d_model=32, nhead=4, num_encoder_layers=3, dim_feedforward=32, dropout=0.1, output_sequence_length=checkpoint['T'] ) elif model_type == 'kan': model = KANForecaster( input_features=len(checkpoint['features']), hidden_sizes=[64, 128, 64], output_size=1, grid_size=5, spline_order=3, dropout_rate=0.1, output_sequence_length=checkpoint['T'] ) else: raise ValueError(f"不支持的模型类型: {model_type}") # 加载模型参数 model.load_state_dict(checkpoint['model_state_dict']) model = model.to(device) model.eval() print(f"模型已从 {os.path.abspath(model_path)} 成功加载") return model, checkpoint except Exception as e: print(f"加载模型时出错: {str(e)}") raise def list_models(self, product_id=None, model_type=None): """ 列出所有保存的模型 参数: product_id: 按产品ID筛选 (可选) model_type: 按模型类型筛选 (可选) 返回: models_list: 模型信息列表 """ self._ensure_model_dir() # 构建搜索模式 if product_id and model_type: pattern = os.path.join(self.models_dir, f"{product_id}_{model_type}_model_v*.pt") elif product_id: pattern = os.path.join(self.models_dir, f"{product_id}_*_model_v*.pt") elif model_type: pattern = os.path.join(self.models_dir, f"*_{model_type}_model_v*.pt") else: pattern = os.path.join(self.models_dir, "*_model_v*.pt") model_files = glob.glob(pattern) if not model_files: print("未找到匹配的模型文件") return [] # 收集模型信息 models_list = [] for model_path in model_files: try: # 从文件名解析信息 filename = os.path.basename(model_path) parts = filename.split('_') if len(parts) < 4: continue product_id = parts[0] model_type = parts[1] version = parts[-1].replace('model_v', '').replace('.pt', '') # 查找对应的元数据文件 meta_path = os.path.join(self.models_dir, f"{product_id}_{model_type}_meta_v{version}.json") model_info = { 'product_id': product_id, 'model_type': model_type, 'version': version, 'file_path': model_path, 'created_at': datetime.fromtimestamp(os.path.getctime(model_path)).strftime("%Y-%m-%d %H:%M:%S"), 'file_size': f"{os.path.getsize(model_path) / (1024 * 1024):.2f} MB" } # 如果有元数据文件,添加更多信息 if os.path.exists(meta_path): with open(meta_path, 'r') as f: meta = json.load(f) model_info.update(meta) models_list.append(model_info) except Exception as e: print(f"解析模型文件 {model_path} 时出错: {str(e)}") # 按创建时间排序 models_list.sort(key=lambda x: x['created_at'], reverse=True) return models_list def delete_model(self, product_id, model_type, version=None): """ 删除指定的模型 参数: product_id: 产品ID model_type: 模型类型 version: 模型版本,如果不指定则删除所有版本 返回: success: 是否成功删除 """ self._ensure_model_dir() if version: # 删除特定版本 model_path = os.path.join(self.models_dir, f"{product_id}_{model_type}_model_v{version}.pt") meta_path = os.path.join(self.models_dir, f"{product_id}_{model_type}_meta_v{version}.json") if not os.path.exists(model_path): print(f"错误: 未找到产品 {product_id} 的 {model_type} 模型版本 {version}") return False try: os.remove(model_path) if os.path.exists(meta_path): os.remove(meta_path) print(f"已删除产品 {product_id} 的 {model_type} 模型版本 {version}") return True except Exception as e: print(f"删除模型时出错: {str(e)}") return False else: # 删除所有版本 pattern = os.path.join(self.models_dir, f"{product_id}_{model_type}_model_v*.pt") meta_pattern = os.path.join(self.models_dir, f"{product_id}_{model_type}_meta_v*.json") model_files = glob.glob(pattern) meta_files = glob.glob(meta_pattern) if not model_files: print(f"错误: 未找到产品 {product_id} 的 {model_type} 模型文件") return False try: for file_path in model_files: os.remove(file_path) for file_path in meta_files: os.remove(file_path) print(f"已删除产品 {product_id} 的所有 {model_type} 模型") return True except Exception as e: print(f"删除模型时出错: {str(e)}") return False def get_model_details(self, product_id, model_type, version=None): """ 获取模型的详细信息 参数: product_id: 产品ID model_type: 模型类型 version: 模型版本,如果不指定则获取最新版本 返回: details: 模型详细信息字典 """ # 查找匹配的模型文件 if version is None: # 查找最新版本 pattern = os.path.join(self.models_dir, f"{product_id}_{model_type}_model_v*.pt") model_files = glob.glob(pattern) if not model_files: print(f"错误: 未找到产品 {product_id} 的 {model_type} 模型文件") return None # 按照文件修改时间排序,获取最新的 model_path = max(model_files, key=os.path.getmtime) # 从文件名解析版本 filename = os.path.basename(model_path) version = filename.split('_')[-1].replace('model_v', '').replace('.pt', '') # 查找元数据文件 meta_path = os.path.join(self.models_dir, f"{product_id}_{model_type}_meta_v{version}.json") if not os.path.exists(meta_path): print(f"错误: 未找到产品 {product_id} 的 {model_type} 模型版本 {version} 的元数据") return None try: with open(meta_path, 'r') as f: details = json.load(f) # 添加文件路径 model_path = os.path.join(self.models_dir, f"{product_id}_{model_type}_model_v{version}.pt") details['file_path'] = model_path details['file_size'] = f"{os.path.getsize(model_path) / (1024 * 1024):.2f} MB" return details except Exception as e: print(f"获取模型详情时出错: {str(e)}") return None def predict_with_model(self, product_id, model_type='mlstm', version=None, future_days=7, product_df=None, features=None, visualize=True, save_results=True): """ 使用指定的模型进行预测 参数: product_id: 产品ID model_type: 模型类型 ('mlstm', 'transformer', 'kan') version: 模型版本,如果不指定则使用最新版本 future_days: 要预测的未来天数 product_df: 产品数据DataFrame features: 特征列表 visualize: 是否可视化结果 save_results: 是否保存结果 返回: predictions_df: 预测结果DataFrame """ # 获取设备 device = get_device() print(f"使用设备: {device} 进行预测") # 加载模型 model, checkpoint = self.load_model(product_id, model_type, version, device) if model is None or checkpoint is None: return None # 如果没有提供产品数据,则从Excel文件加载 if product_df is None: try: df = pd.read_excel('pharmacy_sales.xlsx') product_df = df[df['product_id'] == product_id].sort_values('date') except Exception as e: print(f"加载产品数据时出错: {str(e)}") return None product_name = product_df['product_name'].iloc[0] # 获取模型参数 features = checkpoint['features'] look_back = checkpoint['look_back'] T = checkpoint['T'] scaler_X = checkpoint['scaler_X'] scaler_y = checkpoint['scaler_y'] # 获取最近的look_back天数据 last_data = product_df[features].values[-look_back:] last_data_scaled = scaler_X.transform(last_data) # 准备输入数据 X_input = torch.Tensor(last_data_scaled).unsqueeze(0) # 添加批次维度 X_input = X_input.to(device) # 移动到设备上 # 进行预测 with torch.no_grad(): y_pred_scaled = model(X_input).squeeze(0).cpu().numpy() # 返回到CPU并转换为numpy # 反归一化预测结果 y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).flatten() # 创建预测日期范围 last_date = product_df['date'].iloc[-1] future_dates = pd.date_range(start=last_date + pd.Timedelta(days=1), periods=T, freq='D') # 创建预测结果DataFrame predictions_df = pd.DataFrame({ 'date': future_dates, 'product_id': product_id, 'product_name': product_name, 'predicted_sales': y_pred }) print(f"\n{product_name} 未来 {T} 天销售预测 (使用{model_type.upper()}模型):") print(predictions_df[['date', 'predicted_sales']]) # 可视化预测结果 if visualize: plt.figure(figsize=(12, 6)) # 显示历史数据和预测数据 history_days = 30 # 显示最近30天的历史数据 history_dates = product_df['date'].iloc[-history_days:].values history_sales = product_df['sales'].iloc[-history_days:].values plt.plot(history_dates, history_sales, 'b-', label='历史销量') plt.plot(future_dates, y_pred, 'r--', label=f'{model_type.upper()}预测销量') plt.title(f'{product_name} - {model_type.upper()}销量预测 (未来{T}天)') plt.xlabel('日期') plt.ylabel('销量') plt.legend() plt.grid(True) plt.xticks(rotation=45) plt.tight_layout() # 保存和显示图表 forecast_chart = f'{product_id}_{model_type}_forecast.png' plt.savefig(forecast_chart) print(f"预测图表已保存为: {forecast_chart}") # 保存预测结果到CSV if save_results: forecast_csv = f'{product_id}_{model_type}_forecast.csv' predictions_df.to_csv(forecast_csv, index=False) print(f"预测结果已保存到: {forecast_csv}") return predictions_df def compare_models(self, product_id, model_types=None, versions=None, product_df=None, visualize=True): """ 比较不同模型的预测结果 参数: product_id: 产品ID model_types: 要比较的模型类型列表 versions: 对应的模型版本列表,如果不指定则使用最新版本 product_df: 产品数据DataFrame visualize: 是否可视化结果 返回: 比较结果DataFrame """ if model_types is None: model_types = ['mlstm', 'transformer', 'kan'] if versions is None: versions = [None] * len(model_types) if len(versions) != len(model_types): print("错误: 模型类型和版本列表长度不匹配") return None # 如果没有提供产品数据,则从Excel文件加载 if product_df is None: try: df = pd.read_excel('pharmacy_sales.xlsx') product_df = df[df['product_id'] == product_id].sort_values('date') except Exception as e: print(f"加载产品数据时出错: {str(e)}") return None product_name = product_df['product_name'].iloc[0] # 存储所有模型的预测结果 predictions = {} # 对每个模型进行预测 for i, model_type in enumerate(model_types): version = versions[i] try: pred_df = self.predict_with_model( product_id, model_type=model_type, version=version, product_df=product_df, visualize=False, save_results=False ) if pred_df is not None: predictions[model_type] = pred_df except Exception as e: print(f"{model_type} 模型预测出错: {str(e)}") if not predictions: print("没有成功的预测结果") return None # 合并预测结果 result_df = predictions[list(predictions.keys())[0]][['date', 'product_id', 'product_name']].copy() for model_type, pred_df in predictions.items(): result_df[f'{model_type}_prediction'] = pred_df['predicted_sales'].values # 可视化比较结果 if visualize and len(predictions) > 0: plt.figure(figsize=(12, 6)) # 显示历史数据 history_days = 30 # 显示最近30天的历史数据 history_dates = product_df['date'].iloc[-history_days:].values history_sales = product_df['sales'].iloc[-history_days:].values plt.plot(history_dates, history_sales, 'k-', label='历史销量') # 显示预测数据 colors = ['r', 'g', 'b', 'c', 'm', 'y'] future_dates = result_df['date'].values for i, (model_type, pred_df) in enumerate(predictions.items()): color = colors[i % len(colors)] plt.plot(future_dates, pred_df['predicted_sales'].values, f'{color}--', label=f'{model_type.upper()}预测') plt.title(f'{product_name} - 不同模型预测结果比较') plt.xlabel('日期') plt.ylabel('销量') plt.legend() plt.grid(True) plt.xticks(rotation=45) plt.tight_layout() # 保存和显示图表 compare_chart = f'{product_id}_model_comparison.png' plt.savefig(compare_chart) print(f"比较图表已保存为: {compare_chart}") # 保存比较结果到CSV compare_csv = f'{product_id}_model_comparison.csv' result_df.to_csv(compare_csv, index=False) print(f"比较结果已保存到: {compare_csv}") return result_df def export_model(self, product_id, model_type, version=None, export_dir='exported_models'): """ 导出模型到指定目录 参数: product_id: 产品ID model_type: 模型类型 version: 模型版本,如果不指定则导出最新版本 export_dir: 导出目录 返回: export_path: 导出的文件路径 """ # 确保导出目录存在 if not os.path.exists(export_dir): os.makedirs(export_dir, exist_ok=True) # 查找匹配的模型文件 if version is None: # 查找最新版本 pattern = os.path.join(self.models_dir, f"{product_id}_{model_type}_model_v*.pt") model_files = glob.glob(pattern) if not model_files: print(f"错误: 未找到产品 {product_id} 的 {model_type} 模型文件") return None # 按照文件修改时间排序,获取最新的 model_path = max(model_files, key=os.path.getmtime) # 从文件名解析版本 filename = os.path.basename(model_path) version = filename.split('_')[-1].replace('model_v', '').replace('.pt', '') else: model_path = os.path.join(self.models_dir, f"{product_id}_{model_type}_model_v{version}.pt") if not os.path.exists(model_path): print(f"错误: 未找到产品 {product_id} 的 {model_type} 模型版本 {version}") return None # 元数据文件 meta_path = os.path.join(self.models_dir, f"{product_id}_{model_type}_meta_v{version}.json") # 导出路径 export_model_path = os.path.join(export_dir, f"{product_id}_{model_type}_model_v{version}.pt") export_meta_path = os.path.join(export_dir, f"{product_id}_{model_type}_meta_v{version}.json") try: # 复制文件 shutil.copy2(model_path, export_model_path) if os.path.exists(meta_path): shutil.copy2(meta_path, export_meta_path) print(f"模型已导出到 {os.path.abspath(export_model_path)}") return export_model_path except Exception as e: print(f"导出模型时出错: {str(e)}") return None def import_model(self, import_file, overwrite=False): """ 导入模型文件 参数: import_file: 要导入的模型文件路径 overwrite: 如果存在同名文件是否覆盖 返回: import_path: 导入后的文件路径 """ self._ensure_model_dir() if not os.path.exists(import_file): print(f"错误: 导入文件 {import_file} 不存在") return None # 获取文件名 filename = os.path.basename(import_file) # 目标路径 target_path = os.path.join(self.models_dir, filename) # 检查是否存在同名文件 if os.path.exists(target_path) and not overwrite: print(f"错误: 目标文件 {target_path} 已存在,如需覆盖请设置overwrite=True") return None try: # 复制文件 shutil.copy2(import_file, target_path) # 如果有对应的元数据文件,也一并导入 meta_filename = filename.replace('_model_v', '_meta_v') meta_import_file = import_file.replace('_model_v', '_meta_v').replace('.pt', '.json') meta_target_path = os.path.join(self.models_dir, meta_filename.replace('.pt', '.json')) if os.path.exists(meta_import_file): shutil.copy2(meta_import_file, meta_target_path) print(f"模型已导入到 {os.path.abspath(target_path)}") return target_path except Exception as e: print(f"导入模型时出错: {str(e)}") return None