""" 药店销售预测系统 - 可视化工具函数 """ import matplotlib.pyplot as plt import numpy as np import os from core.config import DEFAULT_MODEL_DIR def plot_loss_curve(train_losses, val_losses, product_name, model_type, save_path=None, model_dir=DEFAULT_MODEL_DIR): """ 绘制训练和验证损失曲线 参数: train_losses: 训练损失列表 val_losses: 验证损失列表 product_name: 产品名称 model_type: 模型类型 save_path: 保存路径,如果为None则自动生成路径 model_dir: 模型保存目录,默认使用配置中的DEFAULT_MODEL_DIR """ plt.figure(figsize=(10, 5)) plt.plot(train_losses, label='训练损失') plt.plot(val_losses, label='验证损失') plt.title(f'{product_name} - {model_type}模型训练和验证损失') plt.xlabel('Epoch') plt.ylabel('Loss') plt.legend() plt.grid(True) if save_path: # 如果提供了完整路径,直接使用 full_path = save_path else: # 否则生成默认路径 # 确保模型目录存在 os.makedirs(model_dir, exist_ok=True) # 构建文件名:模型类型_产品名_loss_curve.png filename = f"{model_type}_product_{product_name.replace(' ', '_')}_loss_curve.png" full_path = os.path.join(model_dir, filename) plt.savefig(full_path) plt.close() return full_path def plot_prediction_results(y_true, y_pred, product_name, model_type, dates=None, save_path=None, model_dir=DEFAULT_MODEL_DIR): """ 绘制预测结果与真实值的对比图 参数: y_true: 真实值 y_pred: 预测值 product_name: 产品名称 model_type: 模型类型 dates: 日期列表,如果提供则用作x轴 save_path: 保存路径,如果为None则自动生成路径 model_dir: 模型保存目录,默认使用配置中的DEFAULT_MODEL_DIR """ plt.figure(figsize=(12, 6)) if dates is not None: plt.plot(dates, y_true, 'b-', label='真实销量') plt.plot(dates, y_pred, 'r--', label='预测销量') plt.xticks(rotation=45) else: plt.plot(y_true, 'b-', label='真实销量') plt.plot(y_pred, 'r--', label='预测销量') plt.title(f'{product_name} - {model_type}模型预测结果') plt.xlabel('日期') plt.ylabel('销量') plt.legend() plt.grid(True) plt.tight_layout() if save_path: # 如果提供了完整路径,直接使用 full_path = save_path else: # 否则生成默认路径 # 确保模型目录存在 os.makedirs(model_dir, exist_ok=True) # 构建文件名:模型类型_产品名_prediction.png filename = f"{model_type}_product_{product_name.replace(' ', '_')}_prediction.png" full_path = os.path.join(model_dir, filename) plt.savefig(full_path) plt.close() return full_path def plot_multiple_predictions(predictions_dict, product_name, dates=None, save_path=None, model_dir=DEFAULT_MODEL_DIR): """ 绘制多个模型的预测结果对比图 参数: predictions_dict: 字典,键为模型名称,值为预测结果 product_name: 产品名称 dates: 日期列表,如果提供则用作x轴 save_path: 保存路径,如果为None则自动生成路径 model_dir: 模型保存目录,默认使用配置中的DEFAULT_MODEL_DIR """ plt.figure(figsize=(12, 6)) colors = ['b-', 'r--', 'g-.', 'm:', 'c-', 'y--'] for i, (model_name, values) in enumerate(predictions_dict.items()): if dates is not None: plt.plot(dates, values, colors[i % len(colors)], label=model_name) else: plt.plot(values, colors[i % len(colors)], label=model_name) plt.title(f'{product_name} - 多模型预测结果对比') plt.xlabel('日期') plt.ylabel('销量') plt.legend() plt.grid(True) plt.tight_layout() if save_path: # 如果提供了完整路径,直接使用 full_path = save_path else: # 否则生成默认路径 # 确保模型目录存在 os.makedirs(model_dir, exist_ok=True) # 构建文件名:产品名_model_comparison.png filename = f"{product_name.replace(' ', '_')}_model_comparison.png" full_path = os.path.join(model_dir, filename) plt.savefig(full_path) plt.close() return full_path def plot_feature_importance(feature_names, importance_scores, product_name, model_type, save_path=None, model_dir=DEFAULT_MODEL_DIR): """ 绘制特征重要性条形图 参数: feature_names: 特征名称列表 importance_scores: 重要性分数列表 product_name: 产品名称 model_type: 模型类型 save_path: 保存路径,如果为None则自动生成路径 model_dir: 模型保存目录,默认使用配置中的DEFAULT_MODEL_DIR """ # 按重要性排序 sorted_idx = np.argsort(importance_scores) sorted_names = [feature_names[i] for i in sorted_idx] sorted_scores = [importance_scores[i] for i in sorted_idx] plt.figure(figsize=(10, 6)) plt.barh(sorted_names, sorted_scores) plt.title(f'{product_name} - {model_type}模型特征重要性') plt.xlabel('重要性分数') plt.ylabel('特征') plt.tight_layout() if save_path: # 如果提供了完整路径,直接使用 full_path = save_path else: # 否则生成默认路径 # 确保模型目录存在 os.makedirs(model_dir, exist_ok=True) # 构建文件名:模型类型_产品名_feature_importance.png filename = f"{model_type}_product_{product_name.replace(' ', '_')}_feature_importance.png" full_path = os.path.join(model_dir, filename) plt.savefig(full_path) plt.close() return full_path