ShopTRAINING/server/utils/visualization.py

192 lines
6.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
药店销售预测系统 - 可视化工具函数
"""
import matplotlib.pyplot as plt
import numpy as np
import os
# 修复导入路径问题
try:
from core.config import DEFAULT_MODEL_DIR
except ImportError:
try:
from ..core.config import DEFAULT_MODEL_DIR
except ImportError:
# 后备方案:使用默认值
DEFAULT_MODEL_DIR = "saved_models"
def plot_loss_curve(train_losses, val_losses, model_type: str, scope: str, identifier: str, version: str = None, save_path=None, model_dir=DEFAULT_MODEL_DIR):
"""
绘制训练和验证损失曲线并根据scope和identifier生成标准化的文件名。
参数:
train_losses: 训练损失列表
val_losses: 验证损失列表
model_type: 模型类型 (e.g., 'xgboost')
scope: 训练范围 ('product', 'store', 'global')
identifier: 范围对应的标识符 (产品名, 店铺ID, 或聚合方法)
version: (可选) 模型版本号,用于生成唯一的文件名
save_path: 保存路径如果为None则自动生成路径
model_dir: 模型保存目录
"""
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='训练损失')
plt.plot(val_losses, label='验证损失')
# 动态生成标题
title_identifier = identifier.replace('_', ' ')
title = f'{title_identifier} - {model_type} ({scope}) 模型训练和验证损失'
plt.title(title)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
if save_path:
full_path = save_path
else:
os.makedirs(model_dir, exist_ok=True)
# 标准化文件名生成逻辑,与 ModelManager 对齐
version_str = f"_{version}" if version else ""
# 清理标识符中的非法字符
safe_identifier = identifier.replace(' ', '_').replace('/', '_').replace('\\', '_')
filename = f"{model_type}_{scope}_{safe_identifier}{version_str}_loss_curve.png"
full_path = os.path.join(model_dir, filename)
plt.savefig(full_path)
plt.close()
return full_path
def plot_prediction_results(y_true, y_pred, product_name, model_type, dates=None, save_path=None, model_dir=DEFAULT_MODEL_DIR):
"""
绘制预测结果与真实值的对比图
参数:
y_true: 真实值
y_pred: 预测值
product_name: 产品名称
model_type: 模型类型
dates: 日期列表如果提供则用作x轴
save_path: 保存路径如果为None则自动生成路径
model_dir: 模型保存目录默认使用配置中的DEFAULT_MODEL_DIR
"""
plt.figure(figsize=(12, 6))
if dates is not None:
plt.plot(dates, y_true, 'b-', label='真实销量')
plt.plot(dates, y_pred, 'r--', label='预测销量')
plt.xticks(rotation=45)
else:
plt.plot(y_true, 'b-', label='真实销量')
plt.plot(y_pred, 'r--', label='预测销量')
plt.title(f'{product_name} - {model_type}模型预测结果')
plt.xlabel('日期')
plt.ylabel('销量')
plt.legend()
plt.grid(True)
plt.tight_layout()
if save_path:
# 如果提供了完整路径,直接使用
full_path = save_path
else:
# 否则生成默认路径
# 确保模型目录存在
os.makedirs(model_dir, exist_ok=True)
# 构建文件名模型类型_产品名_prediction.png
filename = f"{model_type}_product_{product_name.replace(' ', '_')}_prediction.png"
full_path = os.path.join(model_dir, filename)
plt.savefig(full_path)
plt.close()
return full_path
def plot_multiple_predictions(predictions_dict, product_name, dates=None, save_path=None, model_dir=DEFAULT_MODEL_DIR):
"""
绘制多个模型的预测结果对比图
参数:
predictions_dict: 字典,键为模型名称,值为预测结果
product_name: 产品名称
dates: 日期列表如果提供则用作x轴
save_path: 保存路径如果为None则自动生成路径
model_dir: 模型保存目录默认使用配置中的DEFAULT_MODEL_DIR
"""
plt.figure(figsize=(12, 6))
colors = ['b-', 'r--', 'g-.', 'm:', 'c-', 'y--']
for i, (model_name, values) in enumerate(predictions_dict.items()):
if dates is not None:
plt.plot(dates, values, colors[i % len(colors)], label=model_name)
else:
plt.plot(values, colors[i % len(colors)], label=model_name)
plt.title(f'{product_name} - 多模型预测结果对比')
plt.xlabel('日期')
plt.ylabel('销量')
plt.legend()
plt.grid(True)
plt.tight_layout()
if save_path:
# 如果提供了完整路径,直接使用
full_path = save_path
else:
# 否则生成默认路径
# 确保模型目录存在
os.makedirs(model_dir, exist_ok=True)
# 构建文件名产品名_model_comparison.png
filename = f"{product_name.replace(' ', '_')}_model_comparison.png"
full_path = os.path.join(model_dir, filename)
plt.savefig(full_path)
plt.close()
return full_path
def plot_feature_importance(feature_names, importance_scores, product_name, model_type, save_path=None, model_dir=DEFAULT_MODEL_DIR):
"""
绘制特征重要性条形图
参数:
feature_names: 特征名称列表
importance_scores: 重要性分数列表
product_name: 产品名称
model_type: 模型类型
save_path: 保存路径如果为None则自动生成路径
model_dir: 模型保存目录默认使用配置中的DEFAULT_MODEL_DIR
"""
# 按重要性排序
sorted_idx = np.argsort(importance_scores)
sorted_names = [feature_names[i] for i in sorted_idx]
sorted_scores = [importance_scores[i] for i in sorted_idx]
plt.figure(figsize=(10, 6))
plt.barh(sorted_names, sorted_scores)
plt.title(f'{product_name} - {model_type}模型特征重要性')
plt.xlabel('重要性分数')
plt.ylabel('特征')
plt.tight_layout()
if save_path:
# 如果提供了完整路径,直接使用
full_path = save_path
else:
# 否则生成默认路径
# 确保模型目录存在
os.makedirs(model_dir, exist_ok=True)
# 构建文件名模型类型_产品名_feature_importance.png
filename = f"{model_type}_product_{product_name.replace(' ', '_')}_feature_importance.png"
full_path = os.path.join(model_dir, filename)
plt.savefig(full_path)
plt.close()
return full_path