82 lines
2.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
药店销售预测系统 - 评估指标计算函数
"""
import numpy as np
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
def evaluate_model(y_true, y_pred):
"""
计算模型评估指标
参数:
y_true: 真实值
y_pred: 预测值
返回:
包含各种评估指标的字典
"""
# 确保输入是一维数组
y_true = np.array(y_true).flatten()
y_pred = np.array(y_pred).flatten()
# 计算均方误差 (MSE)
mse = mean_squared_error(y_true, y_pred)
# 计算均方根误差 (RMSE)
rmse = np.sqrt(mse)
# 计算平均绝对误差 (MAE)
mae = mean_absolute_error(y_true, y_pred)
# 计算决定系数 (R^2)
r2 = r2_score(y_true, y_pred)
# 计算平均绝对百分比误差 (MAPE)
# 避免除以零
mask = y_true != 0
if np.any(mask):
mape = np.mean(np.abs((y_true[mask] - y_pred[mask]) / y_true[mask])) * 100
else:
# 如果所有真实值都为0无法计算MAPE返回0
mape = 0.0
return {
'mse': mse,
'rmse': rmse,
'mae': mae,
'r2': r2,
'mape': mape
}
def compare_models(results_dict):
"""
比较多个模型的性能
参数:
results_dict: 字典,键为模型名称,值为评估指标字典
返回:
比较结果的字典
"""
comparison = {}
# 获取所有指标名称
metrics = list(next(iter(results_dict.values())).keys())
# 对每个指标,找出最佳模型
for metric in metrics:
if metric in ['mse', 'rmse', 'mae', 'mape']: # 这些指标越小越好
best_model = min(results_dict.items(), key=lambda x: x[1][metric])[0]
best_value = min(model[metric] for model in results_dict.values())
else: # r2 越大越好
best_model = max(results_dict.items(), key=lambda x: x[1][metric])[0]
best_value = max(model[metric] for model in results_dict.values())
comparison[metric] = {
'best_model': best_model,
'best_value': best_value,
'all_values': {model_name: results[metric] for model_name, results in results_dict.items()}
}
return comparison