ShopTRAINING/server/utils/multi_store_data_utils.py
xz2000 28bae35783 # 扁平化模型数据处理规范 (最终版)
**版本**: 4.0 (最终版)
**核心思想**: 逻辑路径被转换为文件名的一部分,实现极致扁平化的文件存储。

---

## 一、 文件保存规则

### 1.1. 核心原则

所有元数据都被编码到文件名中。一个逻辑上的层级路径(例如 `product/P001_all/mlstm/v2`)应该被转换为一个用下划线连接的文件名前缀(`product_P001_all_mlstm_v2`)。

### 1.2. 文件存储位置

-   **最终产物**: 所有最终模型、元数据文件、损失图等,统一存放在 `saved_models/` 根目录下。
-   **过程文件**: 所有训练过程中的检查点文件,统一存放在 `saved_models/checkpoints/` 目录下。

### 1.3. 文件名生成规则

1.  **构建逻辑路径**: 根据训练参数(模式、范围、类型、版本)确定逻辑路径。
    -   *示例*: `product/P001_all/mlstm/v2`

2.  **生成文件名前缀**: 将逻辑路径中的所有 `/` 替换为 `_`。
    -   *示例*: `product_P001_all_mlstm_v2`

3.  **拼接文件后缀**: 在前缀后加上描述文件类型的后缀。
    -   `_model.pth`
    -   `_metadata.json`
    -   `_loss_curve.png`
    -   `_checkpoint_best.pth`
    -   `_checkpoint_epoch_{N}.pth`

#### **完整示例:**

-   **最终模型**: `saved_models/product_P001_all_mlstm_v2_model.pth`
-   **元数据**: `saved_models/product_P001_all_mlstm_v2_metadata.json`
-   **最佳检查点**: `saved_models/checkpoints/product_P001_all_mlstm_v2_checkpoint_best.pth`
-   **Epoch 50 检查点**: `saved_models/checkpoints/product_P001_all_mlstm_v2_checkpoint_epoch_50.pth`

---

## 二、 文件读取规则

1.  **确定模型元数据**: 根据需求确定要加载的模型的训练模式、范围、类型和版本。
2.  **构建文件名前缀**: 按照与保存时相同的逻辑,将元数据拼接成文件名前缀(例如 `product_P001_all_mlstm_v2`)。
3.  **定位文件**:
    -   要加载最终模型,查找文件: `saved_models/{prefix}_model.pth`。
    -   要加载最佳检查点,查找文件: `saved_models/checkpoints/{prefix}_checkpoint_best.pth`。

---

## 三、 数据库存储规则

数据库用于索引,应存储足以重构文件名前缀的关键元数据。

#### **`models` 表结构建议:**

| 字段名 | 类型 | 描述 | 示例 |
| :--- | :--- | :--- | :--- |
| `id` | INTEGER | 主键 | 1 |
| `filename_prefix` | TEXT | **完整文件名前缀,可作为唯一标识** | `product_P001_all_mlstm_v2` |
| `model_identifier`| TEXT | 用于版本控制的标识符 (不含版本) | `product_P001_all_mlstm` |
| `version` | INTEGER | 版本号 | `2` |
| `status` | TEXT | 模型状态 | `completed`, `training`, `failed` |
| `created_at` | TEXT | 创建时间 | `2025-07-21 02:29:00` |
| `metrics_summary`| TEXT | 关键性能指标的JSON字符串 | `{"rmse": 10.5, "r2": 0.89}` |

#### **保存逻辑:**
-   训练完成后,向表中插入一条记录。`filename_prefix` 字段是查找与该次训练相关的所有文件的关键。

---

## 四、 版本记录规则

版本管理依赖于根目录下的 `versions.json` 文件,以实现原子化、线程安全的版本号递增。

-   **文件名**: `versions.json`
-   **位置**: `saved_models/versions.json`
-   **结构**: 一个JSON对象,`key` 是不包含版本号的标识符,`value` 是该标识符下最新的版本号(整数)。
    -   **Key**: `{prefix_core}_{model_type}` (例如: `product_P001_all_mlstm`)
    -   **Value**: `Integer`

#### **`versions.json` 示例:**
```json
{
  "product_P001_all_mlstm": 2,
  "store_S001_P002_transformer": 1
}
```

#### **版本管理流程:**

1.  **获取新版本**: 开始训练前,构建 `key`。读取 `versions.json`,找到对应 `key` 的 `value`。新版本号为 `value + 1` (若key不存在,则为 `1`)。
2.  **更新版本**: 训练成功后,将新的版本号写回到 `versions.json`。此过程**必须使用文件锁**以防止并发冲突。

调试完成药品预测和店铺预测
2025-07-21 16:39:52 +08:00

416 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
多店铺销售预测系统 - 数据处理工具函数
支持多店铺数据的加载、过滤和处理
"""
import pandas as pd
import numpy as np
import os
from datetime import datetime, timedelta
from typing import Optional, List, Tuple, Dict, Any
from core.config import DEFAULT_DATA_PATH
def load_multi_store_data(file_path: str = None,
store_id: Optional[str] = None,
product_id: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None) -> pd.DataFrame:
"""
加载多店铺销售数据,支持按店铺、产品、时间范围过滤
参数:
file_path: 数据文件路径 (支持 .csv, .xlsx, .parquet)。如果为None则使用config中定义的默认路径。
store_id: 店铺ID为None时返回所有店铺数据
product_id: 产品ID为None时返回所有产品数据
start_date: 开始日期 (YYYY-MM-DD)
end_date: 结束日期 (YYYY-MM-DD)
返回:
DataFrame: 过滤后的销售数据
"""
# 如果未提供文件路径,则使用配置文件中的默认路径
if file_path is None:
file_path = DEFAULT_DATA_PATH
if not os.path.exists(file_path):
raise FileNotFoundError(f"数据文件不存在: {file_path}")
try:
if file_path.endswith('.csv'):
df = pd.read_csv(file_path)
elif file_path.endswith('.xlsx'):
df = pd.read_excel(file_path)
elif file_path.endswith('.parquet'):
df = pd.read_parquet(file_path)
else:
raise ValueError(f"不支持的文件格式: {file_path}")
print(f"成功加载数据文件: {file_path}")
except Exception as e:
print(f"加载文件 {file_path} 失败: {e}")
raise
# 按店铺过滤
if store_id:
df = df[df['store_id'] == store_id].copy()
print(f"按店铺过滤: {store_id}, 剩余记录数: {len(df)}")
# 按产品过滤
if product_id:
df = df[df['product_id'] == product_id].copy()
print(f"按产品过滤: {product_id}, 剩余记录数: {len(df)}")
# 标准化列名和数据类型
df = standardize_column_names(df)
# 在标准化之后进行时间范围过滤
if start_date:
try:
start_date_dt = pd.to_datetime(start_date)
# 确保比较是在datetime对象之间
if 'date' in df.columns:
df = df[df['date'] >= start_date_dt].copy()
print(f"开始日期过滤: {start_date_dt}, 剩余记录数: {len(df)}")
except (ValueError, TypeError):
print(f"警告: 无效的开始日期格式 '{start_date}',已忽略。")
if end_date:
try:
end_date_dt = pd.to_datetime(end_date)
# 确保比较是在datetime对象之间
if 'date' in df.columns:
df = df[df['date'] <= end_date_dt].copy()
print(f"结束日期过滤: {end_date_dt}, 剩余记录数: {len(df)}")
except (ValueError, TypeError):
print(f"警告: 无效的结束日期格式 '{end_date}',已忽略。")
if len(df) == 0:
print("警告: 过滤后没有数据")
return df
def standardize_column_names(df: pd.DataFrame) -> pd.DataFrame:
"""
标准化列名以匹配训练代码和API期望的格式
参数:
df: 原始DataFrame
返回:
DataFrame: 标准化列名后的DataFrame
"""
df = df.copy()
# 定义列名映射并强制重命名
rename_map = {
'sales_quantity': 'sales', # 修复:匹配原始列名
'temperature_2m_mean': 'temperature', # 新增:处理温度列
'dayofweek': 'weekday' # 修复:匹配原始列名
}
df.rename(columns={k: v for k, v in rename_map.items() if k in df.columns}, inplace=True)
# 确保date列是datetime类型
if 'date' in df.columns:
df['date'] = pd.to_datetime(df['date'], errors='coerce')
df.dropna(subset=['date'], inplace=True) # 移除无法解析的日期行
else:
# 如果没有date列无法继续返回空DataFrame
return pd.DataFrame()
# 计算 sales_amount
# 由于没有price列sales_amount的计算逻辑需要调整或移除
# 这里我们注释掉它因为原始数据中已有sales_amount
# if 'sales_amount' not in df.columns and 'sales' in df.columns and 'price' in df.columns:
# # 先确保sales和price是数字
# df['sales'] = pd.to_numeric(df['sales'], errors='coerce')
# df['price'] = pd.to_numeric(df['price'], errors='coerce')
# df['sales_amount'] = df['sales'] * df['price']
# 创建缺失的特征列
if 'weekday' not in df.columns:
df['weekday'] = df['date'].dt.dayofweek
if 'month' not in df.columns:
df['month'] = df['date'].dt.month
# 添加缺失的元数据列
meta_columns = {
'store_name': 'Unknown Store',
'store_location': 'Unknown Location',
'store_type': 'Unknown',
'product_name': 'Unknown Product',
'product_category': 'Unknown Category'
}
for col, default in meta_columns.items():
if col not in df.columns:
df[col] = default
# 添加缺失的布尔特征列
default_features = {
'is_holiday': False,
'is_weekend': None,
'is_promotion': False,
'temperature': 20.0
}
for feature, default_value in default_features.items():
if feature not in df.columns:
if feature == 'is_weekend':
df['is_weekend'] = df['weekday'].isin([5, 6])
else:
df[feature] = default_value
# 确保数值类型正确
numeric_columns = ['sales', 'sales_amount', 'weekday', 'month', 'temperature']
for col in numeric_columns:
if col in df.columns:
df[col] = pd.to_numeric(df[col], errors='coerce')
# 确保布尔类型正确
boolean_columns = ['is_holiday', 'is_weekend', 'is_promotion']
for col in boolean_columns:
if col in df.columns:
df[col] = df[col].astype(bool)
print(f"数据标准化完成,可用特征列: {[col for col in ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'] if col in df.columns]}")
return df
def get_available_stores(file_path: str = None) -> List[Dict[str, Any]]:
"""
获取可用的店铺列表
参数:
file_path: 数据文件路径
返回:
List[Dict]: 店铺信息列表
"""
try:
df = load_multi_store_data(file_path)
if 'store_id' not in df.columns:
print("数据文件中缺少 'store_id'")
return []
# 智能地获取店铺信息,即使某些列缺失
store_info = []
# 使用drop_duplicates获取唯一的店铺组合
stores_df = df.drop_duplicates(subset=['store_id'])
for _, row in stores_df.iterrows():
store_info.append({
'store_id': row['store_id'],
'store_name': row.get('store_name', f"店铺 {row['store_id']}"),
'location': row.get('store_location', '未知位置'),
'type': row.get('store_type', '标准'),
'opening_date': row.get('opening_date', '未知'),
})
return store_info
except Exception as e:
print(f"获取店铺列表失败: {e}")
return []
def get_available_products(file_path: str = None,
store_id: Optional[str] = None) -> List[Dict[str, Any]]:
"""
获取可用的产品列表
参数:
file_path: 数据文件路径
store_id: 店铺ID为None时返回所有产品
返回:
List[Dict]: 产品信息列表
"""
try:
df = load_multi_store_data(file_path, store_id=store_id)
# 获取唯一产品信息
product_columns = ['product_id', 'product_name']
if 'product_category' in df.columns:
product_columns.append('product_category')
if 'unit_price' in df.columns:
product_columns.append('unit_price')
products = df[product_columns].drop_duplicates()
return products.to_dict('records')
except Exception as e:
print(f"获取产品列表失败: {e}")
return []
def get_store_product_sales_data(store_id: str,
product_id: str,
file_path: str = None) -> pd.DataFrame:
"""
获取特定店铺和产品的销售数据,用于模型训练
参数:
file_path: 数据文件路径
store_id: 店铺ID
product_id: 产品ID
返回:
DataFrame: 处理后的销售数据,包含模型需要的特征
"""
# 加载数据
df = load_multi_store_data(file_path, store_id=store_id, product_id=product_id)
if len(df) == 0:
raise ValueError(f"没有找到店铺 {store_id} 产品 {product_id} 的销售数据")
# 确保数据按日期排序
df = df.sort_values('date').copy()
# 数据标准化已在load_multi_store_data中完成
# 验证必要的列是否存在
required_columns = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
missing_columns = [col for col in required_columns if col not in df.columns]
if missing_columns:
print(f"警告: 数据标准化后仍缺少列 {missing_columns}")
raise ValueError(f"无法获取完整的特征数据,缺少列: {missing_columns}")
# 定义模型训练所需的所有列(特征 + 目标)
final_columns = [
'date', 'sales', 'product_id', 'product_name', 'store_id', 'store_name',
'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'
]
# 筛选出DataFrame中实际存在的列
existing_columns = [col for col in final_columns if col in df.columns]
# 返回只包含这些必需列的DataFrame
return df[existing_columns]
def aggregate_multi_store_data(product_id: Optional[str] = None,
store_id: Optional[str] = None,
aggregation_method: str = 'sum',
file_path: str = None) -> pd.DataFrame:
"""
聚合销售数据,可按产品(全局)或按店铺(所有产品)
参数:
file_path: 数据文件路径
product_id: 产品ID (用于全局模型)
store_id: 店铺ID (用于店铺聚合模型)
aggregation_method: 聚合方法 ('sum', 'mean', 'median')
返回:
DataFrame: 聚合后的销售数据
"""
# 根据是全局聚合、店铺聚合还是真正全局聚合来加载数据
if store_id:
# 店铺聚合:加载该店铺的所有数据
df = load_multi_store_data(file_path, store_id=store_id)
if len(df) == 0:
raise ValueError(f"没有找到店铺 {store_id} 的销售数据")
grouping_entity = f"店铺 {store_id}"
elif product_id:
# 按产品聚合:加载该产品在所有店铺的数据
df = load_multi_store_data(file_path, product_id=product_id)
if len(df) == 0:
raise ValueError(f"没有找到产品 {product_id} 的销售数据")
grouping_entity = f"产品 {product_id}"
else:
# 真正全局聚合:加载所有数据
df = load_multi_store_data(file_path)
if len(df) == 0:
raise ValueError("数据文件为空,无法进行全局聚合")
grouping_entity = "所有产品"
# 按日期聚合(使用标准化后的列名)
# 定义一个更健壮的聚合规范,以保留所有特征
agg_spec = {
'sales': aggregation_method,
'sales_amount': aggregation_method,
'price': 'mean',
'weekday': 'first',
'month': 'first',
'is_holiday': 'first',
'is_weekend': 'first',
'is_promotion': 'first',
'temperature': 'mean'
}
# 只聚合DataFrame中存在的列
agg_dict = {k: v for k, v in agg_spec.items() if k in df.columns}
# 聚合数据
aggregated_df = df.groupby('date').agg(agg_dict).reset_index()
# 获取产品信息(取第一个店铺的信息)
product_info = df[['product_id', 'product_name', 'product_category']].iloc[0]
for col, val in product_info.items():
aggregated_df[col] = val
# 添加店铺信息标识为全局
aggregated_df['store_id'] = 'GLOBAL'
aggregated_df['store_name'] = f'全部店铺-{aggregation_method.upper()}'
aggregated_df['store_location'] = '全局聚合'
aggregated_df['store_type'] = 'global'
# 对聚合后的数据进行标准化(添加缺失的特征列)
aggregated_df = aggregated_df.sort_values('date').copy()
aggregated_df = standardize_column_names(aggregated_df)
# 定义模型训练所需的所有列(特征 + 目标)
final_columns = [
'date', 'sales', 'product_id', 'product_name', 'store_id', 'store_name',
'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'
]
# 筛选出DataFrame中实际存在的列
existing_columns = [col for col in final_columns if col in aggregated_df.columns]
# 返回只包含这些必需列的DataFrame
return aggregated_df[existing_columns]
def get_sales_statistics(file_path: str = None,
store_id: Optional[str] = None,
product_id: Optional[str] = None) -> Dict[str, Any]:
"""
获取销售数据统计信息
参数:
file_path: 数据文件路径
store_id: 店铺ID
product_id: 产品ID
返回:
Dict: 统计信息
"""
try:
df = load_multi_store_data(file_path, store_id=store_id, product_id=product_id)
if len(df) == 0:
return {'error': '没有数据'}
stats = {
'total_records': len(df),
'date_range': {
'start': df['date'].min().strftime('%Y-%m-%d'),
'end': df['date'].max().strftime('%Y-%m-%d')
},
'stores': df['store_id'].nunique(),
'products': df['product_id'].nunique(),
'total_sales_amount': float(df['sales_amount'].sum()) if 'sales_amount' in df.columns else 0,
'total_quantity': int(df['quantity_sold'].sum()) if 'quantity_sold' in df.columns else 0,
'avg_daily_sales': float(df.groupby('date')['quantity_sold'].sum().mean()) if 'quantity_sold' in df.columns else 0
}
return stats
except Exception as e:
return {'error': str(e)}
# 向后兼容的函数
def load_data(file_path=None, store_id=None):
"""
向后兼容的数据加载函数
"""
return load_multi_store_data(file_path, store_id=store_id)