ShopTRAINING/server/models/data_utils.py

127 lines
3.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
# 定义数据集类
class PharmacyDataset(Dataset):
def __init__(self, data_X, data_Y):
self.data_X = data_X
self.data_Y = data_Y
def __getitem__(self, index):
return self.data_X[index], self.data_Y[index]
def __len__(self):
return len(self.data_X)
# 定义用于时间序列预测的数据处理函数
def create_dataset(datasetX, datasetY, look_back=1, predict_steps=1):
dataX, dataY = [], []
for i in range(0, len(datasetX) - look_back - predict_steps + 1):
x = datasetX[i:(i + look_back), :]
if predict_steps == 1:
y = datasetY[i + look_back]
else:
y = datasetY[i + look_back:i + look_back + predict_steps, 0] # 仅取销量列
dataX.append(x)
dataY.append(y)
return np.array(dataX), np.array(dataY)
# 评估函数
def evaluate_model(y_true, y_pred):
mse = mean_squared_error(y_true, y_pred)
rmse = np.sqrt(mse)
mae = mean_absolute_error(y_true, y_pred)
r2 = r2_score(y_true, y_pred)
# MAPE计算时避免除以0
mask = y_true != 0
y_true_masked = y_true[mask]
y_pred_masked = y_pred[mask]
mape = np.mean(np.abs((y_true_masked - y_pred_masked) / y_true_masked)) * 100
return {
'MSE': mse,
'RMSE': rmse,
'MAE': mae,
'': r2,
'MAPE(%)': mape
}
# 为优化版KAN模型添加的数据准备函数
def prepare_data(product_data, sequence_length=30, forecast_horizon=7, test_size=0.2, random_state=42):
"""
准备时间序列数据,用于训练和评估模型
参数:
product_data: 单个产品的数据
sequence_length: 输入序列长度
forecast_horizon: 预测天数
test_size: 测试集比例
random_state: 随机种子
返回:
X, y, X_train, X_val, y_train, y_val, scaler_X, scaler_y
"""
# 提取特征
features = ['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
# 确保数据按日期排序
product_data = product_data.sort_values('date')
# 提取特征和目标变量
X_data = product_data[features].values
y_data = product_data[['sales']].values
# 标准化数据
scaler_X = StandardScaler()
scaler_y = StandardScaler()
X_scaled = scaler_X.fit_transform(X_data)
y_scaled = scaler_y.fit_transform(y_data)
# 创建序列数据
X, y = [], []
for i in range(len(X_scaled) - sequence_length - forecast_horizon + 1):
X.append(X_scaled[i:i+sequence_length])
y.append(y_scaled[i+sequence_length:i+sequence_length+forecast_horizon])
X = np.array(X)
y = np.array(y)
# 如果y是3D的压缩为2D (batch_size, forecast_horizon)
if y.ndim == 3:
y = y.reshape(y.shape[0], y.shape[1])
# 分割训练集和验证集
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=test_size, random_state=random_state)
return X, y, X_train, X_val, y_train, y_val, scaler_X, scaler_y
def prepare_sequences(X, y, batch_size=32):
"""
将数据转换为DataLoader对象用于批量训练
参数:
X: 输入特征
y: 目标变量
batch_size: 批次大小
返回:
DataLoader对象
"""
# 转换为PyTorch张量
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.float32)
# 创建数据集
dataset = PharmacyDataset(X_tensor, y_tensor)
# 创建数据加载器
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
return data_loader