184 lines
6.0 KiB
Python
184 lines
6.0 KiB
Python
"""
|
||
药店销售预测系统 - 可视化工具函数
|
||
"""
|
||
|
||
import matplotlib.pyplot as plt
|
||
import numpy as np
|
||
import os
|
||
|
||
# 修复导入路径问题
|
||
try:
|
||
from core.config import DEFAULT_MODEL_DIR
|
||
except ImportError:
|
||
try:
|
||
from ..core.config import DEFAULT_MODEL_DIR
|
||
except ImportError:
|
||
# 后备方案:使用默认值
|
||
DEFAULT_MODEL_DIR = "saved_models"
|
||
|
||
def plot_loss_curve(train_losses, val_losses, product_name, model_type, save_path=None, model_dir=DEFAULT_MODEL_DIR):
|
||
"""
|
||
绘制训练和验证损失曲线
|
||
|
||
参数:
|
||
train_losses: 训练损失列表
|
||
val_losses: 验证损失列表
|
||
product_name: 产品名称
|
||
model_type: 模型类型
|
||
save_path: 保存路径,如果为None则自动生成路径
|
||
model_dir: 模型保存目录,默认使用配置中的DEFAULT_MODEL_DIR
|
||
"""
|
||
plt.figure(figsize=(10, 5))
|
||
plt.plot(train_losses, label='训练损失')
|
||
plt.plot(val_losses, label='验证损失')
|
||
plt.title(f'{product_name} - {model_type}模型训练和验证损失')
|
||
plt.xlabel('Epoch')
|
||
plt.ylabel('Loss')
|
||
plt.legend()
|
||
plt.grid(True)
|
||
|
||
if save_path:
|
||
# 如果提供了完整路径,直接使用
|
||
full_path = save_path
|
||
else:
|
||
# 否则生成默认路径
|
||
# 确保模型目录存在
|
||
os.makedirs(model_dir, exist_ok=True)
|
||
|
||
# 构建文件名:模型类型_产品名_loss_curve.png
|
||
filename = f"{model_type}_product_{product_name.replace(' ', '_')}_loss_curve.png"
|
||
full_path = os.path.join(model_dir, filename)
|
||
|
||
plt.savefig(full_path)
|
||
plt.close()
|
||
return full_path
|
||
|
||
def plot_prediction_results(y_true, y_pred, product_name, model_type, dates=None, save_path=None, model_dir=DEFAULT_MODEL_DIR):
|
||
"""
|
||
绘制预测结果与真实值的对比图
|
||
|
||
参数:
|
||
y_true: 真实值
|
||
y_pred: 预测值
|
||
product_name: 产品名称
|
||
model_type: 模型类型
|
||
dates: 日期列表,如果提供则用作x轴
|
||
save_path: 保存路径,如果为None则自动生成路径
|
||
model_dir: 模型保存目录,默认使用配置中的DEFAULT_MODEL_DIR
|
||
"""
|
||
plt.figure(figsize=(12, 6))
|
||
|
||
if dates is not None:
|
||
plt.plot(dates, y_true, 'b-', label='真实销量')
|
||
plt.plot(dates, y_pred, 'r--', label='预测销量')
|
||
plt.xticks(rotation=45)
|
||
else:
|
||
plt.plot(y_true, 'b-', label='真实销量')
|
||
plt.plot(y_pred, 'r--', label='预测销量')
|
||
|
||
plt.title(f'{product_name} - {model_type}模型预测结果')
|
||
plt.xlabel('日期')
|
||
plt.ylabel('销量')
|
||
plt.legend()
|
||
plt.grid(True)
|
||
plt.tight_layout()
|
||
|
||
if save_path:
|
||
# 如果提供了完整路径,直接使用
|
||
full_path = save_path
|
||
else:
|
||
# 否则生成默认路径
|
||
# 确保模型目录存在
|
||
os.makedirs(model_dir, exist_ok=True)
|
||
|
||
# 构建文件名:模型类型_产品名_prediction.png
|
||
filename = f"{model_type}_product_{product_name.replace(' ', '_')}_prediction.png"
|
||
full_path = os.path.join(model_dir, filename)
|
||
|
||
plt.savefig(full_path)
|
||
plt.close()
|
||
return full_path
|
||
|
||
def plot_multiple_predictions(predictions_dict, product_name, dates=None, save_path=None, model_dir=DEFAULT_MODEL_DIR):
|
||
"""
|
||
绘制多个模型的预测结果对比图
|
||
|
||
参数:
|
||
predictions_dict: 字典,键为模型名称,值为预测结果
|
||
product_name: 产品名称
|
||
dates: 日期列表,如果提供则用作x轴
|
||
save_path: 保存路径,如果为None则自动生成路径
|
||
model_dir: 模型保存目录,默认使用配置中的DEFAULT_MODEL_DIR
|
||
"""
|
||
plt.figure(figsize=(12, 6))
|
||
|
||
colors = ['b-', 'r--', 'g-.', 'm:', 'c-', 'y--']
|
||
|
||
for i, (model_name, values) in enumerate(predictions_dict.items()):
|
||
if dates is not None:
|
||
plt.plot(dates, values, colors[i % len(colors)], label=model_name)
|
||
else:
|
||
plt.plot(values, colors[i % len(colors)], label=model_name)
|
||
|
||
plt.title(f'{product_name} - 多模型预测结果对比')
|
||
plt.xlabel('日期')
|
||
plt.ylabel('销量')
|
||
plt.legend()
|
||
plt.grid(True)
|
||
plt.tight_layout()
|
||
|
||
if save_path:
|
||
# 如果提供了完整路径,直接使用
|
||
full_path = save_path
|
||
else:
|
||
# 否则生成默认路径
|
||
# 确保模型目录存在
|
||
os.makedirs(model_dir, exist_ok=True)
|
||
|
||
# 构建文件名:产品名_model_comparison.png
|
||
filename = f"{product_name.replace(' ', '_')}_model_comparison.png"
|
||
full_path = os.path.join(model_dir, filename)
|
||
|
||
plt.savefig(full_path)
|
||
plt.close()
|
||
return full_path
|
||
|
||
def plot_feature_importance(feature_names, importance_scores, product_name, model_type, save_path=None, model_dir=DEFAULT_MODEL_DIR):
|
||
"""
|
||
绘制特征重要性条形图
|
||
|
||
参数:
|
||
feature_names: 特征名称列表
|
||
importance_scores: 重要性分数列表
|
||
product_name: 产品名称
|
||
model_type: 模型类型
|
||
save_path: 保存路径,如果为None则自动生成路径
|
||
model_dir: 模型保存目录,默认使用配置中的DEFAULT_MODEL_DIR
|
||
"""
|
||
# 按重要性排序
|
||
sorted_idx = np.argsort(importance_scores)
|
||
sorted_names = [feature_names[i] for i in sorted_idx]
|
||
sorted_scores = [importance_scores[i] for i in sorted_idx]
|
||
|
||
plt.figure(figsize=(10, 6))
|
||
plt.barh(sorted_names, sorted_scores)
|
||
plt.title(f'{product_name} - {model_type}模型特征重要性')
|
||
plt.xlabel('重要性分数')
|
||
plt.ylabel('特征')
|
||
plt.tight_layout()
|
||
|
||
if save_path:
|
||
# 如果提供了完整路径,直接使用
|
||
full_path = save_path
|
||
else:
|
||
# 否则生成默认路径
|
||
# 确保模型目录存在
|
||
os.makedirs(model_dir, exist_ok=True)
|
||
|
||
# 构建文件名:模型类型_产品名_feature_importance.png
|
||
filename = f"{model_type}_product_{product_name.replace(' ', '_')}_feature_importance.png"
|
||
full_path = os.path.join(model_dir, filename)
|
||
|
||
plt.savefig(full_path)
|
||
plt.close()
|
||
return full_path |