""" 药店销售预测系统 - 可视化工具函数 """ import matplotlib.pyplot as plt import numpy as np import os # 修复导入路径问题 try: from core.config import DEFAULT_MODEL_DIR except ImportError: try: from ..core.config import DEFAULT_MODEL_DIR except ImportError: # 后备方案:使用默认值 DEFAULT_MODEL_DIR = "saved_models" def plot_loss_curve(train_losses, val_losses, model_type: str, scope: str, identifier: str, version: str = None, save_path=None, model_dir=DEFAULT_MODEL_DIR): """ 绘制训练和验证损失曲线,并根据scope和identifier生成标准化的文件名。 参数: train_losses: 训练损失列表 val_losses: 验证损失列表 model_type: 模型类型 (e.g., 'xgboost') scope: 训练范围 ('product', 'store', 'global') identifier: 范围对应的标识符 (产品名, 店铺ID, 或聚合方法) version: (可选) 模型版本号,用于生成唯一的文件名 save_path: 保存路径,如果为None则自动生成路径 model_dir: 模型保存目录 """ plt.figure(figsize=(10, 5)) plt.plot(train_losses, label='训练损失') plt.plot(val_losses, label='验证损失') # 动态生成标题 title_identifier = identifier.replace('_', ' ') title = f'{title_identifier} - {model_type} ({scope}) 模型训练和验证损失' plt.title(title) plt.xlabel('Epoch') plt.ylabel('Loss') plt.legend() plt.grid(True) if save_path: full_path = save_path else: os.makedirs(model_dir, exist_ok=True) # 标准化文件名生成逻辑,与 ModelManager 对齐 version_str = f"_{version}" if version else "" # 清理标识符中的非法字符 safe_identifier = identifier.replace(' ', '_').replace('/', '_').replace('\\', '_') filename = f"{model_type}_{scope}_{safe_identifier}{version_str}_loss_curve.png" full_path = os.path.join(model_dir, filename) plt.savefig(full_path) plt.close() return full_path def plot_prediction_results(y_true, y_pred, product_name, model_type, dates=None, save_path=None, model_dir=DEFAULT_MODEL_DIR): """ 绘制预测结果与真实值的对比图 参数: y_true: 真实值 y_pred: 预测值 product_name: 产品名称 model_type: 模型类型 dates: 日期列表,如果提供则用作x轴 save_path: 保存路径,如果为None则自动生成路径 model_dir: 模型保存目录,默认使用配置中的DEFAULT_MODEL_DIR """ plt.figure(figsize=(12, 6)) if dates is not None: plt.plot(dates, y_true, 'b-', label='真实销量') plt.plot(dates, y_pred, 'r--', label='预测销量') plt.xticks(rotation=45) else: plt.plot(y_true, 'b-', label='真实销量') plt.plot(y_pred, 'r--', label='预测销量') plt.title(f'{product_name} - {model_type}模型预测结果') plt.xlabel('日期') plt.ylabel('销量') plt.legend() plt.grid(True) plt.tight_layout() if save_path: # 如果提供了完整路径,直接使用 full_path = save_path else: # 否则生成默认路径 # 确保模型目录存在 os.makedirs(model_dir, exist_ok=True) # 构建文件名:模型类型_产品名_prediction.png filename = f"{model_type}_product_{product_name.replace(' ', '_')}_prediction.png" full_path = os.path.join(model_dir, filename) plt.savefig(full_path) plt.close() return full_path def plot_multiple_predictions(predictions_dict, product_name, dates=None, save_path=None, model_dir=DEFAULT_MODEL_DIR): """ 绘制多个模型的预测结果对比图 参数: predictions_dict: 字典,键为模型名称,值为预测结果 product_name: 产品名称 dates: 日期列表,如果提供则用作x轴 save_path: 保存路径,如果为None则自动生成路径 model_dir: 模型保存目录,默认使用配置中的DEFAULT_MODEL_DIR """ plt.figure(figsize=(12, 6)) colors = ['b-', 'r--', 'g-.', 'm:', 'c-', 'y--'] for i, (model_name, values) in enumerate(predictions_dict.items()): if dates is not None: plt.plot(dates, values, colors[i % len(colors)], label=model_name) else: plt.plot(values, colors[i % len(colors)], label=model_name) plt.title(f'{product_name} - 多模型预测结果对比') plt.xlabel('日期') plt.ylabel('销量') plt.legend() plt.grid(True) plt.tight_layout() if save_path: # 如果提供了完整路径,直接使用 full_path = save_path else: # 否则生成默认路径 # 确保模型目录存在 os.makedirs(model_dir, exist_ok=True) # 构建文件名:产品名_model_comparison.png filename = f"{product_name.replace(' ', '_')}_model_comparison.png" full_path = os.path.join(model_dir, filename) plt.savefig(full_path) plt.close() return full_path def plot_feature_importance(feature_names, importance_scores, product_name, model_type, save_path=None, model_dir=DEFAULT_MODEL_DIR): """ 绘制特征重要性条形图 参数: feature_names: 特征名称列表 importance_scores: 重要性分数列表 product_name: 产品名称 model_type: 模型类型 save_path: 保存路径,如果为None则自动生成路径 model_dir: 模型保存目录,默认使用配置中的DEFAULT_MODEL_DIR """ # 按重要性排序 sorted_idx = np.argsort(importance_scores) sorted_names = [feature_names[i] for i in sorted_idx] sorted_scores = [importance_scores[i] for i in sorted_idx] plt.figure(figsize=(10, 6)) plt.barh(sorted_names, sorted_scores) plt.title(f'{product_name} - {model_type}模型特征重要性') plt.xlabel('重要性分数') plt.ylabel('特征') plt.tight_layout() if save_path: # 如果提供了完整路径,直接使用 full_path = save_path else: # 否则生成默认路径 # 确保模型目录存在 os.makedirs(model_dir, exist_ok=True) # 构建文件名:模型类型_产品名_feature_importance.png filename = f"{model_type}_product_{product_name.replace(' ', '_')}_feature_importance.png" full_path = os.path.join(model_dir, filename) plt.savefig(full_path) plt.close() return full_path