ShopTRAINING/server/utils/visualization.py

192 lines
6.5 KiB
Python
Raw Normal View History

"""
药店销售预测系统 - 可视化工具函数
"""
import matplotlib.pyplot as plt
import numpy as np
import os
2025-07-02 11:05:23 +08:00
# 修复导入路径问题
try:
from core.config import DEFAULT_MODEL_DIR
except ImportError:
try:
from ..core.config import DEFAULT_MODEL_DIR
except ImportError:
# 后备方案:使用默认值
DEFAULT_MODEL_DIR = "saved_models"
def plot_loss_curve(train_losses, val_losses, model_type: str, scope: str, identifier: str, version: str = None, save_path=None, model_dir=DEFAULT_MODEL_DIR):
"""
绘制训练和验证损失曲线并根据scope和identifier生成标准化的文件名
参数:
train_losses: 训练损失列表
val_losses: 验证损失列表
model_type: 模型类型 (e.g., 'xgboost')
scope: 训练范围 ('product', 'store', 'global')
identifier: 范围对应的标识符 (产品名, 店铺ID, 或聚合方法)
version: (可选) 模型版本号用于生成唯一的文件名
save_path: 保存路径如果为None则自动生成路径
model_dir: 模型保存目录
"""
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='训练损失')
plt.plot(val_losses, label='验证损失')
# 动态生成标题
title_identifier = identifier.replace('_', ' ')
title = f'{title_identifier} - {model_type} ({scope}) 模型训练和验证损失'
plt.title(title)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
if save_path:
full_path = save_path
else:
os.makedirs(model_dir, exist_ok=True)
# 标准化文件名生成逻辑,与 ModelManager 对齐
version_str = f"_{version}" if version else ""
# 清理标识符中的非法字符
safe_identifier = identifier.replace(' ', '_').replace('/', '_').replace('\\', '_')
filename = f"{model_type}_{scope}_{safe_identifier}{version_str}_loss_curve.png"
full_path = os.path.join(model_dir, filename)
plt.savefig(full_path)
plt.close()
return full_path
def plot_prediction_results(y_true, y_pred, product_name, model_type, dates=None, save_path=None, model_dir=DEFAULT_MODEL_DIR):
"""
绘制预测结果与真实值的对比图
参数:
y_true: 真实值
y_pred: 预测值
product_name: 产品名称
model_type: 模型类型
dates: 日期列表如果提供则用作x轴
save_path: 保存路径如果为None则自动生成路径
model_dir: 模型保存目录默认使用配置中的DEFAULT_MODEL_DIR
"""
plt.figure(figsize=(12, 6))
if dates is not None:
plt.plot(dates, y_true, 'b-', label='真实销量')
plt.plot(dates, y_pred, 'r--', label='预测销量')
plt.xticks(rotation=45)
else:
plt.plot(y_true, 'b-', label='真实销量')
plt.plot(y_pred, 'r--', label='预测销量')
plt.title(f'{product_name} - {model_type}模型预测结果')
plt.xlabel('日期')
plt.ylabel('销量')
plt.legend()
plt.grid(True)
plt.tight_layout()
if save_path:
# 如果提供了完整路径,直接使用
full_path = save_path
else:
# 否则生成默认路径
# 确保模型目录存在
os.makedirs(model_dir, exist_ok=True)
# 构建文件名模型类型_产品名_prediction.png
filename = f"{model_type}_product_{product_name.replace(' ', '_')}_prediction.png"
full_path = os.path.join(model_dir, filename)
plt.savefig(full_path)
plt.close()
return full_path
def plot_multiple_predictions(predictions_dict, product_name, dates=None, save_path=None, model_dir=DEFAULT_MODEL_DIR):
"""
绘制多个模型的预测结果对比图
参数:
predictions_dict: 字典键为模型名称值为预测结果
product_name: 产品名称
dates: 日期列表如果提供则用作x轴
save_path: 保存路径如果为None则自动生成路径
model_dir: 模型保存目录默认使用配置中的DEFAULT_MODEL_DIR
"""
plt.figure(figsize=(12, 6))
colors = ['b-', 'r--', 'g-.', 'm:', 'c-', 'y--']
for i, (model_name, values) in enumerate(predictions_dict.items()):
if dates is not None:
plt.plot(dates, values, colors[i % len(colors)], label=model_name)
else:
plt.plot(values, colors[i % len(colors)], label=model_name)
plt.title(f'{product_name} - 多模型预测结果对比')
plt.xlabel('日期')
plt.ylabel('销量')
plt.legend()
plt.grid(True)
plt.tight_layout()
if save_path:
# 如果提供了完整路径,直接使用
full_path = save_path
else:
# 否则生成默认路径
# 确保模型目录存在
os.makedirs(model_dir, exist_ok=True)
# 构建文件名产品名_model_comparison.png
filename = f"{product_name.replace(' ', '_')}_model_comparison.png"
full_path = os.path.join(model_dir, filename)
plt.savefig(full_path)
plt.close()
return full_path
def plot_feature_importance(feature_names, importance_scores, product_name, model_type, save_path=None, model_dir=DEFAULT_MODEL_DIR):
"""
绘制特征重要性条形图
参数:
feature_names: 特征名称列表
importance_scores: 重要性分数列表
product_name: 产品名称
model_type: 模型类型
save_path: 保存路径如果为None则自动生成路径
model_dir: 模型保存目录默认使用配置中的DEFAULT_MODEL_DIR
"""
# 按重要性排序
sorted_idx = np.argsort(importance_scores)
sorted_names = [feature_names[i] for i in sorted_idx]
sorted_scores = [importance_scores[i] for i in sorted_idx]
plt.figure(figsize=(10, 6))
plt.barh(sorted_names, sorted_scores)
plt.title(f'{product_name} - {model_type}模型特征重要性')
plt.xlabel('重要性分数')
plt.ylabel('特征')
plt.tight_layout()
if save_path:
# 如果提供了完整路径,直接使用
full_path = save_path
else:
# 否则生成默认路径
# 确保模型目录存在
os.makedirs(model_dir, exist_ok=True)
# 构建文件名模型类型_产品名_feature_importance.png
filename = f"{model_type}_product_{product_name.replace(' ', '_')}_feature_importance.png"
full_path = os.path.join(model_dir, filename)
plt.savefig(full_path)
plt.close()
return full_path