ShopTRAINING/server/utils/visualization.py
2025-07-02 11:05:23 +08:00

184 lines
6.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
药店销售预测系统 - 可视化工具函数
"""
import matplotlib.pyplot as plt
import numpy as np
import os
# 修复导入路径问题
try:
from core.config import DEFAULT_MODEL_DIR
except ImportError:
try:
from ..core.config import DEFAULT_MODEL_DIR
except ImportError:
# 后备方案:使用默认值
DEFAULT_MODEL_DIR = "saved_models"
def plot_loss_curve(train_losses, val_losses, product_name, model_type, save_path=None, model_dir=DEFAULT_MODEL_DIR):
"""
绘制训练和验证损失曲线
参数:
train_losses: 训练损失列表
val_losses: 验证损失列表
product_name: 产品名称
model_type: 模型类型
save_path: 保存路径如果为None则自动生成路径
model_dir: 模型保存目录默认使用配置中的DEFAULT_MODEL_DIR
"""
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='训练损失')
plt.plot(val_losses, label='验证损失')
plt.title(f'{product_name} - {model_type}模型训练和验证损失')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
if save_path:
# 如果提供了完整路径,直接使用
full_path = save_path
else:
# 否则生成默认路径
# 确保模型目录存在
os.makedirs(model_dir, exist_ok=True)
# 构建文件名模型类型_产品名_loss_curve.png
filename = f"{model_type}_product_{product_name.replace(' ', '_')}_loss_curve.png"
full_path = os.path.join(model_dir, filename)
plt.savefig(full_path)
plt.close()
return full_path
def plot_prediction_results(y_true, y_pred, product_name, model_type, dates=None, save_path=None, model_dir=DEFAULT_MODEL_DIR):
"""
绘制预测结果与真实值的对比图
参数:
y_true: 真实值
y_pred: 预测值
product_name: 产品名称
model_type: 模型类型
dates: 日期列表如果提供则用作x轴
save_path: 保存路径如果为None则自动生成路径
model_dir: 模型保存目录默认使用配置中的DEFAULT_MODEL_DIR
"""
plt.figure(figsize=(12, 6))
if dates is not None:
plt.plot(dates, y_true, 'b-', label='真实销量')
plt.plot(dates, y_pred, 'r--', label='预测销量')
plt.xticks(rotation=45)
else:
plt.plot(y_true, 'b-', label='真实销量')
plt.plot(y_pred, 'r--', label='预测销量')
plt.title(f'{product_name} - {model_type}模型预测结果')
plt.xlabel('日期')
plt.ylabel('销量')
plt.legend()
plt.grid(True)
plt.tight_layout()
if save_path:
# 如果提供了完整路径,直接使用
full_path = save_path
else:
# 否则生成默认路径
# 确保模型目录存在
os.makedirs(model_dir, exist_ok=True)
# 构建文件名模型类型_产品名_prediction.png
filename = f"{model_type}_product_{product_name.replace(' ', '_')}_prediction.png"
full_path = os.path.join(model_dir, filename)
plt.savefig(full_path)
plt.close()
return full_path
def plot_multiple_predictions(predictions_dict, product_name, dates=None, save_path=None, model_dir=DEFAULT_MODEL_DIR):
"""
绘制多个模型的预测结果对比图
参数:
predictions_dict: 字典,键为模型名称,值为预测结果
product_name: 产品名称
dates: 日期列表如果提供则用作x轴
save_path: 保存路径如果为None则自动生成路径
model_dir: 模型保存目录默认使用配置中的DEFAULT_MODEL_DIR
"""
plt.figure(figsize=(12, 6))
colors = ['b-', 'r--', 'g-.', 'm:', 'c-', 'y--']
for i, (model_name, values) in enumerate(predictions_dict.items()):
if dates is not None:
plt.plot(dates, values, colors[i % len(colors)], label=model_name)
else:
plt.plot(values, colors[i % len(colors)], label=model_name)
plt.title(f'{product_name} - 多模型预测结果对比')
plt.xlabel('日期')
plt.ylabel('销量')
plt.legend()
plt.grid(True)
plt.tight_layout()
if save_path:
# 如果提供了完整路径,直接使用
full_path = save_path
else:
# 否则生成默认路径
# 确保模型目录存在
os.makedirs(model_dir, exist_ok=True)
# 构建文件名产品名_model_comparison.png
filename = f"{product_name.replace(' ', '_')}_model_comparison.png"
full_path = os.path.join(model_dir, filename)
plt.savefig(full_path)
plt.close()
return full_path
def plot_feature_importance(feature_names, importance_scores, product_name, model_type, save_path=None, model_dir=DEFAULT_MODEL_DIR):
"""
绘制特征重要性条形图
参数:
feature_names: 特征名称列表
importance_scores: 重要性分数列表
product_name: 产品名称
model_type: 模型类型
save_path: 保存路径如果为None则自动生成路径
model_dir: 模型保存目录默认使用配置中的DEFAULT_MODEL_DIR
"""
# 按重要性排序
sorted_idx = np.argsort(importance_scores)
sorted_names = [feature_names[i] for i in sorted_idx]
sorted_scores = [importance_scores[i] for i in sorted_idx]
plt.figure(figsize=(10, 6))
plt.barh(sorted_names, sorted_scores)
plt.title(f'{product_name} - {model_type}模型特征重要性')
plt.xlabel('重要性分数')
plt.ylabel('特征')
plt.tight_layout()
if save_path:
# 如果提供了完整路径,直接使用
full_path = save_path
else:
# 否则生成默认路径
# 确保模型目录存在
os.makedirs(model_dir, exist_ok=True)
# 构建文件名模型类型_产品名_feature_importance.png
filename = f"{model_type}_product_{product_name.replace(' ', '_')}_feature_importance.png"
full_path = os.path.join(model_dir, filename)
plt.savefig(full_path)
plt.close()
return full_path