ShopTRAINING/server/utils/data_utils.py
xz2000 28bae35783 # 扁平化模型数据处理规范 (最终版)
**版本**: 4.0 (最终版)
**核心思想**: 逻辑路径被转换为文件名的一部分,实现极致扁平化的文件存储。

---

## 一、 文件保存规则

### 1.1. 核心原则

所有元数据都被编码到文件名中。一个逻辑上的层级路径(例如 `product/P001_all/mlstm/v2`)应该被转换为一个用下划线连接的文件名前缀(`product_P001_all_mlstm_v2`)。

### 1.2. 文件存储位置

-   **最终产物**: 所有最终模型、元数据文件、损失图等,统一存放在 `saved_models/` 根目录下。
-   **过程文件**: 所有训练过程中的检查点文件,统一存放在 `saved_models/checkpoints/` 目录下。

### 1.3. 文件名生成规则

1.  **构建逻辑路径**: 根据训练参数(模式、范围、类型、版本)确定逻辑路径。
    -   *示例*: `product/P001_all/mlstm/v2`

2.  **生成文件名前缀**: 将逻辑路径中的所有 `/` 替换为 `_`。
    -   *示例*: `product_P001_all_mlstm_v2`

3.  **拼接文件后缀**: 在前缀后加上描述文件类型的后缀。
    -   `_model.pth`
    -   `_metadata.json`
    -   `_loss_curve.png`
    -   `_checkpoint_best.pth`
    -   `_checkpoint_epoch_{N}.pth`

#### **完整示例:**

-   **最终模型**: `saved_models/product_P001_all_mlstm_v2_model.pth`
-   **元数据**: `saved_models/product_P001_all_mlstm_v2_metadata.json`
-   **最佳检查点**: `saved_models/checkpoints/product_P001_all_mlstm_v2_checkpoint_best.pth`
-   **Epoch 50 检查点**: `saved_models/checkpoints/product_P001_all_mlstm_v2_checkpoint_epoch_50.pth`

---

## 二、 文件读取规则

1.  **确定模型元数据**: 根据需求确定要加载的模型的训练模式、范围、类型和版本。
2.  **构建文件名前缀**: 按照与保存时相同的逻辑,将元数据拼接成文件名前缀(例如 `product_P001_all_mlstm_v2`)。
3.  **定位文件**:
    -   要加载最终模型,查找文件: `saved_models/{prefix}_model.pth`。
    -   要加载最佳检查点,查找文件: `saved_models/checkpoints/{prefix}_checkpoint_best.pth`。

---

## 三、 数据库存储规则

数据库用于索引,应存储足以重构文件名前缀的关键元数据。

#### **`models` 表结构建议:**

| 字段名 | 类型 | 描述 | 示例 |
| :--- | :--- | :--- | :--- |
| `id` | INTEGER | 主键 | 1 |
| `filename_prefix` | TEXT | **完整文件名前缀,可作为唯一标识** | `product_P001_all_mlstm_v2` |
| `model_identifier`| TEXT | 用于版本控制的标识符 (不含版本) | `product_P001_all_mlstm` |
| `version` | INTEGER | 版本号 | `2` |
| `status` | TEXT | 模型状态 | `completed`, `training`, `failed` |
| `created_at` | TEXT | 创建时间 | `2025-07-21 02:29:00` |
| `metrics_summary`| TEXT | 关键性能指标的JSON字符串 | `{"rmse": 10.5, "r2": 0.89}` |

#### **保存逻辑:**
-   训练完成后,向表中插入一条记录。`filename_prefix` 字段是查找与该次训练相关的所有文件的关键。

---

## 四、 版本记录规则

版本管理依赖于根目录下的 `versions.json` 文件,以实现原子化、线程安全的版本号递增。

-   **文件名**: `versions.json`
-   **位置**: `saved_models/versions.json`
-   **结构**: 一个JSON对象,`key` 是不包含版本号的标识符,`value` 是该标识符下最新的版本号(整数)。
    -   **Key**: `{prefix_core}_{model_type}` (例如: `product_P001_all_mlstm`)
    -   **Value**: `Integer`

#### **`versions.json` 示例:**
```json
{
  "product_P001_all_mlstm": 2,
  "store_S001_P002_transformer": 1
}
```

#### **版本管理流程:**

1.  **获取新版本**: 开始训练前,构建 `key`。读取 `versions.json`,找到对应 `key` 的 `value`。新版本号为 `value + 1` (若key不存在,则为 `1`)。
2.  **更新版本**: 训练成功后,将新的版本号写回到 `versions.json`。此过程**必须使用文件锁**以防止并发冲突。

调试完成药品预测和店铺预测
2025-07-21 16:39:52 +08:00

106 lines
3.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
药店销售预测系统 - 数据处理工具函数
"""
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
class PharmacyDataset(Dataset):
"""
药店销售数据集类用于PyTorch数据加载
"""
def __init__(self, data_X, data_Y):
self.data_X = data_X
self.data_Y = data_Y
def __getitem__(self, index):
return self.data_X[index], self.data_Y[index]
def __len__(self):
return len(self.data_X)
def create_dataset(datasetX, datasetY, look_back=1, predict_steps=1):
"""
将时间序列数据转换为监督学习问题的格式
参数:
datasetX: 输入特征数据
datasetY: 目标变量数据
look_back: 使用过去多少天的数据作为输入
predict_steps: 预测未来多少天的数据
返回:
dataX: 输入特征,形状为 (样本数, 时间步, 特征数)
dataY: 目标变量,形状为 (样本数, 预测步数)
"""
dataX, dataY = [], []
for i in range(len(datasetX) - look_back - predict_steps + 1):
x = datasetX[i:(i + look_back)]
dataX.append(x)
y = datasetY[(i + look_back):(i + look_back + predict_steps)]
dataY.append(y.flatten())
return np.array(dataX), np.array(dataY)
def prepare_data(product_data, sequence_length=30, forecast_horizon=7):
"""
准备训练和验证数据
参数:
product_data: 产品销售数据DataFrame
sequence_length: 输入序列长度
forecast_horizon: 预测天数
返回:
X, y: 全部特征和目标
X_train, X_val: 训练和验证特征
y_train, y_val: 训练和验证目标
scaler_X, scaler_y: 特征和目标的归一化器
"""
# 创建特征和目标变量
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
# 预处理数据
X_raw = product_data[features].values
y_raw = product_data[['sales']].values # 保持为二维数组
# 归一化数据
scaler_X = MinMaxScaler(feature_range=(0, 1))
scaler_y = MinMaxScaler(feature_range=(0, 1))
X_scaled = scaler_X.fit_transform(X_raw)
y_scaled = scaler_y.fit_transform(y_raw)
# 创建时间序列数据
X, y = create_dataset(X_scaled, y_scaled, sequence_length, forecast_horizon)
# 划分训练集和验证集80% 训练20% 验证)
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42, shuffle=False)
return X, y, X_train, X_val, y_train, y_val, scaler_X, scaler_y
def prepare_sequences(X, y, batch_size=32):
"""
将数据转换为DataLoader对象用于批量训练
参数:
X: 输入特征
y: 目标变量
batch_size: 批次大小
返回:
DataLoader对象
"""
# 转换为PyTorch张量
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.float32)
# 创建数据集
dataset = PharmacyDataset(X_tensor, y_tensor)
# 创建数据加载器
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
return data_loader