ShopTRAINING/server/utils/data_utils.py

106 lines
3.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
药店销售预测系统 - 数据处理工具函数
"""
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
class PharmacyDataset(Dataset):
"""
药店销售数据集类用于PyTorch数据加载
"""
def __init__(self, data_X, data_Y):
self.data_X = data_X
self.data_Y = data_Y
def __getitem__(self, index):
return self.data_X[index], self.data_Y[index]
def __len__(self):
return len(self.data_X)
def create_dataset(datasetX, datasetY, look_back=1, predict_steps=1):
"""
将时间序列数据转换为监督学习问题的格式
参数:
datasetX: 输入特征数据
datasetY: 目标变量数据
look_back: 使用过去多少天的数据作为输入
predict_steps: 预测未来多少天的数据
返回:
dataX: 输入特征,形状为 (样本数, 时间步, 特征数)
dataY: 目标变量,形状为 (样本数, 预测步数)
"""
dataX, dataY = [], []
for i in range(len(datasetX) - look_back - predict_steps + 1):
x = datasetX[i:(i + look_back)]
dataX.append(x)
y = datasetY[(i + look_back):(i + look_back + predict_steps)]
dataY.append(y)
return np.array(dataX), np.array(dataY)
def prepare_data(product_data, sequence_length=30, forecast_horizon=7):
"""
准备训练和验证数据
参数:
product_data: 产品销售数据DataFrame
sequence_length: 输入序列长度
forecast_horizon: 预测天数
返回:
X, y: 全部特征和目标
X_train, X_val: 训练和验证特征
y_train, y_val: 训练和验证目标
scaler_X, scaler_y: 特征和目标的归一化器
"""
# 创建特征和目标变量
features = ['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
# 预处理数据
X_raw = product_data[features].values
y_raw = product_data[['sales']].values # 保持为二维数组
# 归一化数据
scaler_X = MinMaxScaler(feature_range=(0, 1))
scaler_y = MinMaxScaler(feature_range=(0, 1))
X_scaled = scaler_X.fit_transform(X_raw)
y_scaled = scaler_y.fit_transform(y_raw)
# 创建时间序列数据
X, y = create_dataset(X_scaled, y_scaled, sequence_length, forecast_horizon)
# 划分训练集和验证集80% 训练20% 验证)
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42, shuffle=False)
return X, y, X_train, X_val, y_train, y_val, scaler_X, scaler_y
def prepare_sequences(X, y, batch_size=32):
"""
将数据转换为DataLoader对象用于批量训练
参数:
X: 输入特征
y: 目标变量
batch_size: 批次大小
返回:
DataLoader对象
"""
# 转换为PyTorch张量
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.float32)
# 创建数据集
dataset = PharmacyDataset(X_tensor, y_tensor)
# 创建数据加载器
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
return data_loader