ShopTRAINING/server/utils/visualization.py

184 lines
6.0 KiB
Python
Raw Normal View History

"""
药店销售预测系统 - 可视化工具函数
"""
import matplotlib.pyplot as plt
import numpy as np
import os
2025-07-02 11:05:23 +08:00
# 修复导入路径问题
try:
from core.config import DEFAULT_MODEL_DIR
except ImportError:
try:
from ..core.config import DEFAULT_MODEL_DIR
except ImportError:
# 后备方案:使用默认值
DEFAULT_MODEL_DIR = "saved_models"
def plot_loss_curve(train_losses, val_losses, product_name, model_type, save_path=None, model_dir=DEFAULT_MODEL_DIR):
"""
绘制训练和验证损失曲线
参数:
train_losses: 训练损失列表
val_losses: 验证损失列表
product_name: 产品名称
model_type: 模型类型
save_path: 保存路径如果为None则自动生成路径
model_dir: 模型保存目录默认使用配置中的DEFAULT_MODEL_DIR
"""
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='训练损失')
plt.plot(val_losses, label='验证损失')
plt.title(f'{product_name} - {model_type}模型训练和验证损失')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
if save_path:
# 如果提供了完整路径,直接使用
full_path = save_path
else:
# 否则生成默认路径
# 确保模型目录存在
os.makedirs(model_dir, exist_ok=True)
# 构建文件名模型类型_产品名_loss_curve.png
filename = f"{model_type}_product_{product_name.replace(' ', '_')}_loss_curve.png"
full_path = os.path.join(model_dir, filename)
plt.savefig(full_path)
plt.close()
return full_path
def plot_prediction_results(y_true, y_pred, product_name, model_type, dates=None, save_path=None, model_dir=DEFAULT_MODEL_DIR):
"""
绘制预测结果与真实值的对比图
参数:
y_true: 真实值
y_pred: 预测值
product_name: 产品名称
model_type: 模型类型
dates: 日期列表如果提供则用作x轴
save_path: 保存路径如果为None则自动生成路径
model_dir: 模型保存目录默认使用配置中的DEFAULT_MODEL_DIR
"""
plt.figure(figsize=(12, 6))
if dates is not None:
plt.plot(dates, y_true, 'b-', label='真实销量')
plt.plot(dates, y_pred, 'r--', label='预测销量')
plt.xticks(rotation=45)
else:
plt.plot(y_true, 'b-', label='真实销量')
plt.plot(y_pred, 'r--', label='预测销量')
plt.title(f'{product_name} - {model_type}模型预测结果')
plt.xlabel('日期')
plt.ylabel('销量')
plt.legend()
plt.grid(True)
plt.tight_layout()
if save_path:
# 如果提供了完整路径,直接使用
full_path = save_path
else:
# 否则生成默认路径
# 确保模型目录存在
os.makedirs(model_dir, exist_ok=True)
# 构建文件名模型类型_产品名_prediction.png
filename = f"{model_type}_product_{product_name.replace(' ', '_')}_prediction.png"
full_path = os.path.join(model_dir, filename)
plt.savefig(full_path)
plt.close()
return full_path
def plot_multiple_predictions(predictions_dict, product_name, dates=None, save_path=None, model_dir=DEFAULT_MODEL_DIR):
"""
绘制多个模型的预测结果对比图
参数:
predictions_dict: 字典键为模型名称值为预测结果
product_name: 产品名称
dates: 日期列表如果提供则用作x轴
save_path: 保存路径如果为None则自动生成路径
model_dir: 模型保存目录默认使用配置中的DEFAULT_MODEL_DIR
"""
plt.figure(figsize=(12, 6))
colors = ['b-', 'r--', 'g-.', 'm:', 'c-', 'y--']
for i, (model_name, values) in enumerate(predictions_dict.items()):
if dates is not None:
plt.plot(dates, values, colors[i % len(colors)], label=model_name)
else:
plt.plot(values, colors[i % len(colors)], label=model_name)
plt.title(f'{product_name} - 多模型预测结果对比')
plt.xlabel('日期')
plt.ylabel('销量')
plt.legend()
plt.grid(True)
plt.tight_layout()
if save_path:
# 如果提供了完整路径,直接使用
full_path = save_path
else:
# 否则生成默认路径
# 确保模型目录存在
os.makedirs(model_dir, exist_ok=True)
# 构建文件名产品名_model_comparison.png
filename = f"{product_name.replace(' ', '_')}_model_comparison.png"
full_path = os.path.join(model_dir, filename)
plt.savefig(full_path)
plt.close()
return full_path
def plot_feature_importance(feature_names, importance_scores, product_name, model_type, save_path=None, model_dir=DEFAULT_MODEL_DIR):
"""
绘制特征重要性条形图
参数:
feature_names: 特征名称列表
importance_scores: 重要性分数列表
product_name: 产品名称
model_type: 模型类型
save_path: 保存路径如果为None则自动生成路径
model_dir: 模型保存目录默认使用配置中的DEFAULT_MODEL_DIR
"""
# 按重要性排序
sorted_idx = np.argsort(importance_scores)
sorted_names = [feature_names[i] for i in sorted_idx]
sorted_scores = [importance_scores[i] for i in sorted_idx]
plt.figure(figsize=(10, 6))
plt.barh(sorted_names, sorted_scores)
plt.title(f'{product_name} - {model_type}模型特征重要性')
plt.xlabel('重要性分数')
plt.ylabel('特征')
plt.tight_layout()
if save_path:
# 如果提供了完整路径,直接使用
full_path = save_path
else:
# 否则生成默认路径
# 确保模型目录存在
os.makedirs(model_dir, exist_ok=True)
# 构建文件名模型类型_产品名_feature_importance.png
filename = f"{model_type}_product_{product_name.replace(' ', '_')}_feature_importance.png"
full_path = os.path.join(model_dir, filename)
plt.savefig(full_path)
plt.close()
return full_path