2025-06-18 06:39:41 +08:00
|
|
|
|
"""
|
|
|
|
|
药店销售预测系统 - 可视化工具函数
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
import numpy as np
|
|
|
|
|
import os
|
2025-07-02 11:05:23 +08:00
|
|
|
|
|
|
|
|
|
# 修复导入路径问题
|
|
|
|
|
try:
|
|
|
|
|
from core.config import DEFAULT_MODEL_DIR
|
|
|
|
|
except ImportError:
|
|
|
|
|
try:
|
|
|
|
|
from ..core.config import DEFAULT_MODEL_DIR
|
|
|
|
|
except ImportError:
|
|
|
|
|
# 后备方案:使用默认值
|
|
|
|
|
DEFAULT_MODEL_DIR = "saved_models"
|
2025-06-18 06:39:41 +08:00
|
|
|
|
|
2025-07-24 17:55:10 +08:00
|
|
|
|
def plot_loss_curve(train_losses, val_losses, model_type: str, scope: str, identifier: str, version: str = None, save_path=None, model_dir=DEFAULT_MODEL_DIR):
|
2025-06-18 06:39:41 +08:00
|
|
|
|
"""
|
2025-07-24 17:55:10 +08:00
|
|
|
|
绘制训练和验证损失曲线,并根据scope和identifier生成标准化的文件名。
|
|
|
|
|
|
2025-06-18 06:39:41 +08:00
|
|
|
|
参数:
|
|
|
|
|
train_losses: 训练损失列表
|
|
|
|
|
val_losses: 验证损失列表
|
2025-07-24 17:55:10 +08:00
|
|
|
|
model_type: 模型类型 (e.g., 'xgboost')
|
|
|
|
|
scope: 训练范围 ('product', 'store', 'global')
|
|
|
|
|
identifier: 范围对应的标识符 (产品名, 店铺ID, 或聚合方法)
|
|
|
|
|
version: (可选) 模型版本号,用于生成唯一的文件名
|
2025-06-18 06:39:41 +08:00
|
|
|
|
save_path: 保存路径,如果为None则自动生成路径
|
2025-07-24 17:55:10 +08:00
|
|
|
|
model_dir: 模型保存目录
|
2025-06-18 06:39:41 +08:00
|
|
|
|
"""
|
|
|
|
|
plt.figure(figsize=(10, 5))
|
|
|
|
|
plt.plot(train_losses, label='训练损失')
|
|
|
|
|
plt.plot(val_losses, label='验证损失')
|
2025-07-24 17:55:10 +08:00
|
|
|
|
|
|
|
|
|
# 动态生成标题
|
|
|
|
|
title_identifier = identifier.replace('_', ' ')
|
|
|
|
|
title = f'{title_identifier} - {model_type} ({scope}) 模型训练和验证损失'
|
|
|
|
|
plt.title(title)
|
|
|
|
|
|
2025-06-18 06:39:41 +08:00
|
|
|
|
plt.xlabel('Epoch')
|
|
|
|
|
plt.ylabel('Loss')
|
|
|
|
|
plt.legend()
|
|
|
|
|
plt.grid(True)
|
|
|
|
|
|
|
|
|
|
if save_path:
|
|
|
|
|
full_path = save_path
|
|
|
|
|
else:
|
|
|
|
|
os.makedirs(model_dir, exist_ok=True)
|
|
|
|
|
|
2025-07-24 17:55:10 +08:00
|
|
|
|
# 标准化文件名生成逻辑,与 ModelManager 对齐
|
|
|
|
|
version_str = f"_{version}" if version else ""
|
|
|
|
|
# 清理标识符中的非法字符
|
|
|
|
|
safe_identifier = identifier.replace(' ', '_').replace('/', '_').replace('\\', '_')
|
|
|
|
|
|
|
|
|
|
filename = f"{model_type}_{scope}_{safe_identifier}{version_str}_loss_curve.png"
|
2025-06-18 06:39:41 +08:00
|
|
|
|
full_path = os.path.join(model_dir, filename)
|
|
|
|
|
|
|
|
|
|
plt.savefig(full_path)
|
|
|
|
|
plt.close()
|
|
|
|
|
return full_path
|
|
|
|
|
|
|
|
|
|
def plot_prediction_results(y_true, y_pred, product_name, model_type, dates=None, save_path=None, model_dir=DEFAULT_MODEL_DIR):
|
|
|
|
|
"""
|
|
|
|
|
绘制预测结果与真实值的对比图
|
|
|
|
|
|
|
|
|
|
参数:
|
|
|
|
|
y_true: 真实值
|
|
|
|
|
y_pred: 预测值
|
|
|
|
|
product_name: 产品名称
|
|
|
|
|
model_type: 模型类型
|
|
|
|
|
dates: 日期列表,如果提供则用作x轴
|
|
|
|
|
save_path: 保存路径,如果为None则自动生成路径
|
|
|
|
|
model_dir: 模型保存目录,默认使用配置中的DEFAULT_MODEL_DIR
|
|
|
|
|
"""
|
|
|
|
|
plt.figure(figsize=(12, 6))
|
|
|
|
|
|
|
|
|
|
if dates is not None:
|
|
|
|
|
plt.plot(dates, y_true, 'b-', label='真实销量')
|
|
|
|
|
plt.plot(dates, y_pred, 'r--', label='预测销量')
|
|
|
|
|
plt.xticks(rotation=45)
|
|
|
|
|
else:
|
|
|
|
|
plt.plot(y_true, 'b-', label='真实销量')
|
|
|
|
|
plt.plot(y_pred, 'r--', label='预测销量')
|
|
|
|
|
|
|
|
|
|
plt.title(f'{product_name} - {model_type}模型预测结果')
|
|
|
|
|
plt.xlabel('日期')
|
|
|
|
|
plt.ylabel('销量')
|
|
|
|
|
plt.legend()
|
|
|
|
|
plt.grid(True)
|
|
|
|
|
plt.tight_layout()
|
|
|
|
|
|
|
|
|
|
if save_path:
|
|
|
|
|
# 如果提供了完整路径,直接使用
|
|
|
|
|
full_path = save_path
|
|
|
|
|
else:
|
|
|
|
|
# 否则生成默认路径
|
|
|
|
|
# 确保模型目录存在
|
|
|
|
|
os.makedirs(model_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
# 构建文件名:模型类型_产品名_prediction.png
|
|
|
|
|
filename = f"{model_type}_product_{product_name.replace(' ', '_')}_prediction.png"
|
|
|
|
|
full_path = os.path.join(model_dir, filename)
|
|
|
|
|
|
|
|
|
|
plt.savefig(full_path)
|
|
|
|
|
plt.close()
|
|
|
|
|
return full_path
|
|
|
|
|
|
|
|
|
|
def plot_multiple_predictions(predictions_dict, product_name, dates=None, save_path=None, model_dir=DEFAULT_MODEL_DIR):
|
|
|
|
|
"""
|
|
|
|
|
绘制多个模型的预测结果对比图
|
|
|
|
|
|
|
|
|
|
参数:
|
|
|
|
|
predictions_dict: 字典,键为模型名称,值为预测结果
|
|
|
|
|
product_name: 产品名称
|
|
|
|
|
dates: 日期列表,如果提供则用作x轴
|
|
|
|
|
save_path: 保存路径,如果为None则自动生成路径
|
|
|
|
|
model_dir: 模型保存目录,默认使用配置中的DEFAULT_MODEL_DIR
|
|
|
|
|
"""
|
|
|
|
|
plt.figure(figsize=(12, 6))
|
|
|
|
|
|
|
|
|
|
colors = ['b-', 'r--', 'g-.', 'm:', 'c-', 'y--']
|
|
|
|
|
|
|
|
|
|
for i, (model_name, values) in enumerate(predictions_dict.items()):
|
|
|
|
|
if dates is not None:
|
|
|
|
|
plt.plot(dates, values, colors[i % len(colors)], label=model_name)
|
|
|
|
|
else:
|
|
|
|
|
plt.plot(values, colors[i % len(colors)], label=model_name)
|
|
|
|
|
|
|
|
|
|
plt.title(f'{product_name} - 多模型预测结果对比')
|
|
|
|
|
plt.xlabel('日期')
|
|
|
|
|
plt.ylabel('销量')
|
|
|
|
|
plt.legend()
|
|
|
|
|
plt.grid(True)
|
|
|
|
|
plt.tight_layout()
|
|
|
|
|
|
|
|
|
|
if save_path:
|
|
|
|
|
# 如果提供了完整路径,直接使用
|
|
|
|
|
full_path = save_path
|
|
|
|
|
else:
|
|
|
|
|
# 否则生成默认路径
|
|
|
|
|
# 确保模型目录存在
|
|
|
|
|
os.makedirs(model_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
# 构建文件名:产品名_model_comparison.png
|
|
|
|
|
filename = f"{product_name.replace(' ', '_')}_model_comparison.png"
|
|
|
|
|
full_path = os.path.join(model_dir, filename)
|
|
|
|
|
|
|
|
|
|
plt.savefig(full_path)
|
|
|
|
|
plt.close()
|
|
|
|
|
return full_path
|
|
|
|
|
|
|
|
|
|
def plot_feature_importance(feature_names, importance_scores, product_name, model_type, save_path=None, model_dir=DEFAULT_MODEL_DIR):
|
|
|
|
|
"""
|
|
|
|
|
绘制特征重要性条形图
|
|
|
|
|
|
|
|
|
|
参数:
|
|
|
|
|
feature_names: 特征名称列表
|
|
|
|
|
importance_scores: 重要性分数列表
|
|
|
|
|
product_name: 产品名称
|
|
|
|
|
model_type: 模型类型
|
|
|
|
|
save_path: 保存路径,如果为None则自动生成路径
|
|
|
|
|
model_dir: 模型保存目录,默认使用配置中的DEFAULT_MODEL_DIR
|
|
|
|
|
"""
|
|
|
|
|
# 按重要性排序
|
|
|
|
|
sorted_idx = np.argsort(importance_scores)
|
|
|
|
|
sorted_names = [feature_names[i] for i in sorted_idx]
|
|
|
|
|
sorted_scores = [importance_scores[i] for i in sorted_idx]
|
|
|
|
|
|
|
|
|
|
plt.figure(figsize=(10, 6))
|
|
|
|
|
plt.barh(sorted_names, sorted_scores)
|
|
|
|
|
plt.title(f'{product_name} - {model_type}模型特征重要性')
|
|
|
|
|
plt.xlabel('重要性分数')
|
|
|
|
|
plt.ylabel('特征')
|
|
|
|
|
plt.tight_layout()
|
|
|
|
|
|
|
|
|
|
if save_path:
|
|
|
|
|
# 如果提供了完整路径,直接使用
|
|
|
|
|
full_path = save_path
|
|
|
|
|
else:
|
|
|
|
|
# 否则生成默认路径
|
|
|
|
|
# 确保模型目录存在
|
|
|
|
|
os.makedirs(model_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
# 构建文件名:模型类型_产品名_feature_importance.png
|
|
|
|
|
filename = f"{model_type}_product_{product_name.replace(' ', '_')}_feature_importance.png"
|
|
|
|
|
full_path = os.path.join(model_dir, filename)
|
|
|
|
|
|
|
|
|
|
plt.savefig(full_path)
|
|
|
|
|
plt.close()
|
|
|
|
|
return full_path
|