ShopTRAINING/server/utils/multi_store_data_utils.py

376 lines
14 KiB
Python
Raw Normal View History

2025-07-02 11:05:23 +08:00
"""
多店铺销售预测系统 - 数据处理工具函数
支持多店铺数据的加载过滤和处理
"""
import pandas as pd
import numpy as np
import os
from datetime import datetime, timedelta
from typing import Optional, List, Tuple, Dict, Any
2025-07-14 19:48:00 +08:00
import logging
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def load_multi_store_data(file_path: str = 'data/timeseries_training_data_sample_10s50p.parquet',
2025-07-02 11:05:23 +08:00
store_id: Optional[str] = None,
product_id: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None) -> pd.DataFrame:
"""
加载多店铺销售数据支持按店铺产品时间范围过滤
参数:
file_path: 数据文件路径
store_id: 店铺ID为None时返回所有店铺数据
product_id: 产品ID为None时返回所有产品数据
start_date: 开始日期 (YYYY-MM-DD)
end_date: 结束日期 (YYYY-MM-DD)
返回:
DataFrame: 过滤后的销售数据
"""
2025-07-14 19:48:00 +08:00
logger.info("\n[DEBUG-UTIL] ---> Entering load_multi_store_data function.")
logger.info(f"[DEBUG-UTIL] Initial params: file_path={file_path}, store_id={store_id}, product_id={product_id}, start_date={start_date}, end_date={end_date}")
try:
# 构建数据文件的绝对路径,使其独立于当前工作目录
script_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(os.path.dirname(script_dir))
absolute_file_path = os.path.join(project_root, file_path)
logger.info(f"[DEBUG-UTIL] Constructed absolute path to data file: {absolute_file_path}")
if not os.path.exists(absolute_file_path):
logger.error(f"[DEBUG-UTIL] Data file not found at absolute path: {absolute_file_path}")
error_message = f"V3-FIX: 无法找到数据文件: {absolute_file_path}"
raise FileNotFoundError(error_message)
logger.info(f"[DEBUG-UTIL] Attempting to load file from: {absolute_file_path}")
if absolute_file_path.endswith('.parquet'):
df = pd.read_parquet(absolute_file_path)
elif absolute_file_path.endswith('.csv'):
df = pd.read_csv(absolute_file_path)
elif absolute_file_path.endswith('.xlsx'):
df = pd.read_excel(absolute_file_path)
else:
raise ValueError(f"不支持的文件格式: {absolute_file_path}")
logger.info(f"[DEBUG-UTIL] Successfully loaded data file. Shape: {df.shape}")
except Exception as e:
logger.error(f"[DEBUG-UTIL] !!! An error occurred during file loading: {e}")
import traceback
logger.error(traceback.format_exc())
raise e
2025-07-02 11:05:23 +08:00
2025-07-14 19:48:00 +08:00
logger.info(f"[DEBUG-UTIL] Initial DataFrame columns: {df.columns.tolist()}")
2025-07-02 11:05:23 +08:00
# 确保date列是datetime类型
if 'date' in df.columns:
df['date'] = pd.to_datetime(df['date'])
2025-07-14 19:48:00 +08:00
logger.info("[DEBUG-UTIL] Converted 'date' column to datetime objects.")
2025-07-02 11:05:23 +08:00
# 按店铺过滤
if store_id:
df = df[df['store_id'] == store_id].copy()
2025-07-14 19:48:00 +08:00
logger.info(f"[DEBUG-UTIL] Filtered by store_id='{store_id}'. Records remaining: {len(df)}")
2025-07-02 11:05:23 +08:00
# 按产品过滤
if product_id:
df = df[df['product_id'] == product_id].copy()
2025-07-14 19:48:00 +08:00
logger.info(f"[DEBUG-UTIL] Filtered by product_id='{product_id}'. Records remaining: {len(df)}")
2025-07-02 11:05:23 +08:00
# 按时间范围过滤
if start_date:
2025-07-14 19:48:00 +08:00
start_date_dt = pd.to_datetime(start_date)
df = df[df['date'] >= start_date_dt].copy()
logger.info(f"[DEBUG-UTIL] Filtered by start_date>='{start_date}'. Records remaining: {len(df)}")
2025-07-02 11:05:23 +08:00
if end_date:
2025-07-14 19:48:00 +08:00
end_date_dt = pd.to_datetime(end_date)
df = df[df['date'] <= end_date_dt].copy()
logger.info(f"[DEBUG-UTIL] Filtered by end_date<='{end_date}'. Records remaining: {len(df)}")
2025-07-02 11:05:23 +08:00
if len(df) == 0:
2025-07-14 19:48:00 +08:00
logger.warning("[DEBUG-UTIL] Warning: DataFrame is empty after filtering.")
2025-07-02 11:05:23 +08:00
# 标准化列名以匹配训练代码期望的格式
2025-07-14 19:48:00 +08:00
logger.info("[DEBUG-UTIL] Calling standardize_column_names...")
2025-07-02 11:05:23 +08:00
df = standardize_column_names(df)
2025-07-14 19:48:00 +08:00
logger.info(f"[DEBUG-UTIL] DataFrame columns after standardization: {df.columns.tolist()}")
2025-07-02 11:05:23 +08:00
2025-07-14 19:48:00 +08:00
logger.info("[DEBUG-UTIL] <--- Exiting load_multi_store_data function.")
2025-07-02 11:05:23 +08:00
return df
def standardize_column_names(df: pd.DataFrame) -> pd.DataFrame:
"""
标准化列名以匹配训练代码期望的格式
参数:
df: 原始DataFrame
返回:
DataFrame: 标准化列名后的DataFrame
"""
df = df.copy()
# 列名映射:新列名 -> 原列名
column_mapping = {
2025-07-14 19:48:00 +08:00
'quantity_sold': 'sales_quantity', # 销售数量
2025-07-02 11:05:23 +08:00
'price': 'unit_price', # 单价
'weekday': 'day_of_week' # 星期几
}
# 应用列名映射
for new_name, old_name in column_mapping.items():
if old_name in df.columns and new_name not in df.columns:
df[new_name] = df[old_name]
# 创建缺失的特征列
if 'date' in df.columns:
df['date'] = pd.to_datetime(df['date'])
# 创建数值型的weekday (0=Monday, 6=Sunday)
if 'weekday' not in df.columns:
df['weekday'] = df['date'].dt.dayofweek
elif df['weekday'].dtype == 'object':
# 如果weekday是字符串转换为数值
weekday_map = {
'Monday': 0, 'Tuesday': 1, 'Wednesday': 2, 'Thursday': 3,
'Friday': 4, 'Saturday': 5, 'Sunday': 6
}
df['weekday'] = df['weekday'].map(weekday_map).fillna(df['date'].dt.dayofweek)
# 添加月份信息
if 'month' not in df.columns:
df['month'] = df['date'].dt.month
# 添加缺失的布尔特征列(如果不存在则设为默认值)
default_features = {
'is_holiday': False, # 是否节假日
'is_weekend': None, # 是否周末从weekday计算
'is_promotion': False, # 是否促销
'temperature': 20.0 # 默认温度
}
for feature, default_value in default_features.items():
if feature not in df.columns:
if feature == 'is_weekend' and 'weekday' in df.columns:
# 周末:周六(5)和周日(6)
df['is_weekend'] = df['weekday'].isin([5, 6])
else:
df[feature] = default_value
# 确保数值类型正确
numeric_columns = ['sales', 'price', 'weekday', 'month', 'temperature']
for col in numeric_columns:
if col in df.columns:
df[col] = pd.to_numeric(df[col], errors='coerce')
# 确保布尔类型正确
boolean_columns = ['is_holiday', 'is_weekend', 'is_promotion']
for col in boolean_columns:
if col in df.columns:
df[col] = df[col].astype(bool)
print(f"数据标准化完成,可用特征列: {[col for col in ['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'] if col in df.columns]}")
return df
2025-07-14 19:48:00 +08:00
def get_available_stores(file_path: str = 'data/timeseries_training_data_sample_10s50p.parquet') -> List[Dict[str, Any]]:
2025-07-02 11:05:23 +08:00
"""
获取可用的店铺列表
参数:
file_path: 数据文件路径
返回:
List[Dict]: 店铺信息列表
"""
try:
df = load_multi_store_data(file_path)
# 获取唯一店铺信息
stores = df[['store_id', 'store_name', 'store_location', 'store_type']].drop_duplicates()
return stores.to_dict('records')
except Exception as e:
print(f"获取店铺列表失败: {e}")
return []
2025-07-14 19:48:00 +08:00
def get_available_products(file_path: str = 'data/timeseries_training_data_sample_10s50p.parquet',
2025-07-02 11:05:23 +08:00
store_id: Optional[str] = None) -> List[Dict[str, Any]]:
"""
获取可用的产品列表
参数:
file_path: 数据文件路径
store_id: 店铺ID为None时返回所有产品
返回:
List[Dict]: 产品信息列表
"""
try:
df = load_multi_store_data(file_path, store_id=store_id)
# 获取唯一产品信息
product_columns = ['product_id', 'product_name']
if 'product_category' in df.columns:
product_columns.append('product_category')
if 'unit_price' in df.columns:
product_columns.append('unit_price')
products = df[product_columns].drop_duplicates()
return products.to_dict('records')
except Exception as e:
print(f"获取产品列表失败: {e}")
return []
def get_store_product_sales_data(store_id: str,
product_id: str,
2025-07-14 19:48:00 +08:00
file_path: str = 'data/timeseries_training_data_sample_10s50p.parquet') -> pd.DataFrame:
2025-07-02 11:05:23 +08:00
"""
获取特定店铺和产品的销售数据用于模型训练
参数:
file_path: 数据文件路径
store_id: 店铺ID
product_id: 产品ID
返回:
DataFrame: 处理后的销售数据包含模型需要的特征
"""
# 加载数据
df = load_multi_store_data(file_path, store_id=store_id, product_id=product_id)
if len(df) == 0:
raise ValueError(f"没有找到店铺 {store_id} 产品 {product_id} 的销售数据")
# 确保数据按日期排序
df = df.sort_values('date').copy()
# 数据标准化已在load_multi_store_data中完成
# 验证必要的列是否存在
required_columns = ['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
missing_columns = [col for col in required_columns if col not in df.columns]
if missing_columns:
print(f"警告: 数据标准化后仍缺少列 {missing_columns}")
raise ValueError(f"无法获取完整的特征数据,缺少列: {missing_columns}")
return df
def aggregate_multi_store_data(product_id: str,
aggregation_method: str = 'sum',
2025-07-14 19:48:00 +08:00
file_path: str = 'data/timeseries_training_data_sample_10s50p.parquet') -> pd.DataFrame:
2025-07-02 11:05:23 +08:00
"""
聚合多个店铺的销售数据用于全局模型训练
参数:
file_path: 数据文件路径
product_id: 产品ID
aggregation_method: 聚合方法 ('sum', 'mean', 'median')
返回:
DataFrame: 聚合后的销售数据
"""
# 加载所有店铺的产品数据
df = load_multi_store_data(file_path, product_id=product_id)
if len(df) == 0:
raise ValueError(f"没有找到产品 {product_id} 的销售数据")
# 按日期聚合(使用标准化后的列名)
agg_dict = {}
if aggregation_method == 'sum':
agg_dict = {
'sales': 'sum', # 标准化后的销量列
'sales_amount': 'sum',
'price': 'mean' # 标准化后的价格列,取平均值
}
elif aggregation_method == 'mean':
agg_dict = {
'sales': 'mean',
'sales_amount': 'mean',
'price': 'mean'
}
elif aggregation_method == 'median':
agg_dict = {
'sales': 'median',
'sales_amount': 'median',
'price': 'median'
}
# 确保列名存在
available_cols = df.columns.tolist()
agg_dict = {k: v for k, v in agg_dict.items() if k in available_cols}
# 聚合数据
aggregated_df = df.groupby('date').agg(agg_dict).reset_index()
# 获取产品信息(取第一个店铺的信息)
product_info = df[['product_id', 'product_name', 'product_category']].iloc[0]
for col, val in product_info.items():
aggregated_df[col] = val
# 添加店铺信息标识为全局
aggregated_df['store_id'] = 'GLOBAL'
aggregated_df['store_name'] = f'全部店铺-{aggregation_method.upper()}'
aggregated_df['store_location'] = '全局聚合'
aggregated_df['store_type'] = 'global'
# 对聚合后的数据进行标准化(添加缺失的特征列)
aggregated_df = aggregated_df.sort_values('date').copy()
aggregated_df = standardize_column_names(aggregated_df)
return aggregated_df
2025-07-14 19:48:00 +08:00
def get_sales_statistics(file_path: str = 'data/timeseries_training_data_sample_10s50p.parquet',
2025-07-02 11:05:23 +08:00
store_id: Optional[str] = None,
product_id: Optional[str] = None) -> Dict[str, Any]:
"""
获取销售数据统计信息
参数:
file_path: 数据文件路径
store_id: 店铺ID
product_id: 产品ID
返回:
Dict: 统计信息
"""
try:
df = load_multi_store_data(file_path, store_id=store_id, product_id=product_id)
if len(df) == 0:
return {'error': '没有数据'}
stats = {
'total_records': len(df),
'date_range': {
'start': df['date'].min().strftime('%Y-%m-%d'),
'end': df['date'].max().strftime('%Y-%m-%d')
},
'stores': df['store_id'].nunique(),
'products': df['product_id'].nunique(),
'total_sales_amount': float(df['sales_amount'].sum()) if 'sales_amount' in df.columns else 0,
'total_quantity': int(df['quantity_sold'].sum()) if 'quantity_sold' in df.columns else 0,
'avg_daily_sales': float(df.groupby('date')['quantity_sold'].sum().mean()) if 'quantity_sold' in df.columns else 0
}
return stats
except Exception as e:
return {'error': str(e)}
# 向后兼容的函数
2025-07-14 19:48:00 +08:00
def load_data(file_path='data/timeseries_training_data_sample_10s50p.parquet', store_id=None):
2025-07-02 11:05:23 +08:00
"""
向后兼容的数据加载函数
"""
return load_multi_store_data(file_path, store_id=store_id)