ShopTRAINING/server/utils/multi_store_data_utils.py
2025-07-02 11:05:23 +08:00

365 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
多店铺销售预测系统 - 数据处理工具函数
支持多店铺数据的加载、过滤和处理
"""
import pandas as pd
import numpy as np
import os
from datetime import datetime, timedelta
from typing import Optional, List, Tuple, Dict, Any
def load_multi_store_data(file_path: str = 'pharmacy_sales_multi_store.csv',
store_id: Optional[str] = None,
product_id: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None) -> pd.DataFrame:
"""
加载多店铺销售数据,支持按店铺、产品、时间范围过滤
参数:
file_path: 数据文件路径
store_id: 店铺ID为None时返回所有店铺数据
product_id: 产品ID为None时返回所有产品数据
start_date: 开始日期 (YYYY-MM-DD)
end_date: 结束日期 (YYYY-MM-DD)
返回:
DataFrame: 过滤后的销售数据
"""
# 尝试多个可能的文件路径
possible_paths = [
file_path,
f'../{file_path}',
f'server/{file_path}',
'pharmacy_sales_multi_store.csv',
'../pharmacy_sales_multi_store.csv',
'pharmacy_sales.xlsx', # 后向兼容原始文件
'../pharmacy_sales.xlsx'
]
df = None
for path in possible_paths:
try:
if path.endswith('.csv'):
df = pd.read_csv(path)
elif path.endswith('.xlsx'):
df = pd.read_excel(path)
# 为原始Excel文件添加默认店铺信息
if 'store_id' not in df.columns:
df['store_id'] = 'S001'
df['store_name'] = '默认店铺'
df['store_location'] = '未知位置'
df['store_type'] = 'standard'
if df is not None:
print(f"成功加载数据文件: {path}")
break
except Exception as e:
continue
if df is None:
raise FileNotFoundError(f"无法找到数据文件,尝试的路径: {possible_paths}")
# 确保date列是datetime类型
if 'date' in df.columns:
df['date'] = pd.to_datetime(df['date'])
# 按店铺过滤
if store_id:
df = df[df['store_id'] == store_id].copy()
print(f"按店铺过滤: {store_id}, 剩余记录数: {len(df)}")
# 按产品过滤
if product_id:
df = df[df['product_id'] == product_id].copy()
print(f"按产品过滤: {product_id}, 剩余记录数: {len(df)}")
# 按时间范围过滤
if start_date:
start_date = pd.to_datetime(start_date)
df = df[df['date'] >= start_date].copy()
print(f"开始日期过滤: {start_date}, 剩余记录数: {len(df)}")
if end_date:
end_date = pd.to_datetime(end_date)
df = df[df['date'] <= end_date].copy()
print(f"结束日期过滤: {end_date}, 剩余记录数: {len(df)}")
if len(df) == 0:
print("警告: 过滤后没有数据")
# 标准化列名以匹配训练代码期望的格式
df = standardize_column_names(df)
return df
def standardize_column_names(df: pd.DataFrame) -> pd.DataFrame:
"""
标准化列名以匹配训练代码期望的格式
参数:
df: 原始DataFrame
返回:
DataFrame: 标准化列名后的DataFrame
"""
df = df.copy()
# 列名映射:新列名 -> 原列名
column_mapping = {
'sales': 'quantity_sold', # 销售数量
'price': 'unit_price', # 单价
'weekday': 'day_of_week' # 星期几
}
# 应用列名映射
for new_name, old_name in column_mapping.items():
if old_name in df.columns and new_name not in df.columns:
df[new_name] = df[old_name]
# 创建缺失的特征列
if 'date' in df.columns:
df['date'] = pd.to_datetime(df['date'])
# 创建数值型的weekday (0=Monday, 6=Sunday)
if 'weekday' not in df.columns:
df['weekday'] = df['date'].dt.dayofweek
elif df['weekday'].dtype == 'object':
# 如果weekday是字符串转换为数值
weekday_map = {
'Monday': 0, 'Tuesday': 1, 'Wednesday': 2, 'Thursday': 3,
'Friday': 4, 'Saturday': 5, 'Sunday': 6
}
df['weekday'] = df['weekday'].map(weekday_map).fillna(df['date'].dt.dayofweek)
# 添加月份信息
if 'month' not in df.columns:
df['month'] = df['date'].dt.month
# 添加缺失的布尔特征列(如果不存在则设为默认值)
default_features = {
'is_holiday': False, # 是否节假日
'is_weekend': None, # 是否周末从weekday计算
'is_promotion': False, # 是否促销
'temperature': 20.0 # 默认温度
}
for feature, default_value in default_features.items():
if feature not in df.columns:
if feature == 'is_weekend' and 'weekday' in df.columns:
# 周末:周六(5)和周日(6)
df['is_weekend'] = df['weekday'].isin([5, 6])
else:
df[feature] = default_value
# 确保数值类型正确
numeric_columns = ['sales', 'price', 'weekday', 'month', 'temperature']
for col in numeric_columns:
if col in df.columns:
df[col] = pd.to_numeric(df[col], errors='coerce')
# 确保布尔类型正确
boolean_columns = ['is_holiday', 'is_weekend', 'is_promotion']
for col in boolean_columns:
if col in df.columns:
df[col] = df[col].astype(bool)
print(f"数据标准化完成,可用特征列: {[col for col in ['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'] if col in df.columns]}")
return df
def get_available_stores(file_path: str = 'pharmacy_sales_multi_store.csv') -> List[Dict[str, Any]]:
"""
获取可用的店铺列表
参数:
file_path: 数据文件路径
返回:
List[Dict]: 店铺信息列表
"""
try:
df = load_multi_store_data(file_path)
# 获取唯一店铺信息
stores = df[['store_id', 'store_name', 'store_location', 'store_type']].drop_duplicates()
return stores.to_dict('records')
except Exception as e:
print(f"获取店铺列表失败: {e}")
return []
def get_available_products(file_path: str = 'pharmacy_sales_multi_store.csv',
store_id: Optional[str] = None) -> List[Dict[str, Any]]:
"""
获取可用的产品列表
参数:
file_path: 数据文件路径
store_id: 店铺ID为None时返回所有产品
返回:
List[Dict]: 产品信息列表
"""
try:
df = load_multi_store_data(file_path, store_id=store_id)
# 获取唯一产品信息
product_columns = ['product_id', 'product_name']
if 'product_category' in df.columns:
product_columns.append('product_category')
if 'unit_price' in df.columns:
product_columns.append('unit_price')
products = df[product_columns].drop_duplicates()
return products.to_dict('records')
except Exception as e:
print(f"获取产品列表失败: {e}")
return []
def get_store_product_sales_data(store_id: str,
product_id: str,
file_path: str = 'pharmacy_sales_multi_store.csv') -> pd.DataFrame:
"""
获取特定店铺和产品的销售数据,用于模型训练
参数:
file_path: 数据文件路径
store_id: 店铺ID
product_id: 产品ID
返回:
DataFrame: 处理后的销售数据,包含模型需要的特征
"""
# 加载数据
df = load_multi_store_data(file_path, store_id=store_id, product_id=product_id)
if len(df) == 0:
raise ValueError(f"没有找到店铺 {store_id} 产品 {product_id} 的销售数据")
# 确保数据按日期排序
df = df.sort_values('date').copy()
# 数据标准化已在load_multi_store_data中完成
# 验证必要的列是否存在
required_columns = ['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
missing_columns = [col for col in required_columns if col not in df.columns]
if missing_columns:
print(f"警告: 数据标准化后仍缺少列 {missing_columns}")
raise ValueError(f"无法获取完整的特征数据,缺少列: {missing_columns}")
return df
def aggregate_multi_store_data(product_id: str,
aggregation_method: str = 'sum',
file_path: str = 'pharmacy_sales_multi_store.csv') -> pd.DataFrame:
"""
聚合多个店铺的销售数据,用于全局模型训练
参数:
file_path: 数据文件路径
product_id: 产品ID
aggregation_method: 聚合方法 ('sum', 'mean', 'median')
返回:
DataFrame: 聚合后的销售数据
"""
# 加载所有店铺的产品数据
df = load_multi_store_data(file_path, product_id=product_id)
if len(df) == 0:
raise ValueError(f"没有找到产品 {product_id} 的销售数据")
# 按日期聚合(使用标准化后的列名)
agg_dict = {}
if aggregation_method == 'sum':
agg_dict = {
'sales': 'sum', # 标准化后的销量列
'sales_amount': 'sum',
'price': 'mean' # 标准化后的价格列,取平均值
}
elif aggregation_method == 'mean':
agg_dict = {
'sales': 'mean',
'sales_amount': 'mean',
'price': 'mean'
}
elif aggregation_method == 'median':
agg_dict = {
'sales': 'median',
'sales_amount': 'median',
'price': 'median'
}
# 确保列名存在
available_cols = df.columns.tolist()
agg_dict = {k: v for k, v in agg_dict.items() if k in available_cols}
# 聚合数据
aggregated_df = df.groupby('date').agg(agg_dict).reset_index()
# 获取产品信息(取第一个店铺的信息)
product_info = df[['product_id', 'product_name', 'product_category']].iloc[0]
for col, val in product_info.items():
aggregated_df[col] = val
# 添加店铺信息标识为全局
aggregated_df['store_id'] = 'GLOBAL'
aggregated_df['store_name'] = f'全部店铺-{aggregation_method.upper()}'
aggregated_df['store_location'] = '全局聚合'
aggregated_df['store_type'] = 'global'
# 对聚合后的数据进行标准化(添加缺失的特征列)
aggregated_df = aggregated_df.sort_values('date').copy()
aggregated_df = standardize_column_names(aggregated_df)
return aggregated_df
def get_sales_statistics(file_path: str = 'pharmacy_sales_multi_store.csv',
store_id: Optional[str] = None,
product_id: Optional[str] = None) -> Dict[str, Any]:
"""
获取销售数据统计信息
参数:
file_path: 数据文件路径
store_id: 店铺ID
product_id: 产品ID
返回:
Dict: 统计信息
"""
try:
df = load_multi_store_data(file_path, store_id=store_id, product_id=product_id)
if len(df) == 0:
return {'error': '没有数据'}
stats = {
'total_records': len(df),
'date_range': {
'start': df['date'].min().strftime('%Y-%m-%d'),
'end': df['date'].max().strftime('%Y-%m-%d')
},
'stores': df['store_id'].nunique(),
'products': df['product_id'].nunique(),
'total_sales_amount': float(df['sales_amount'].sum()) if 'sales_amount' in df.columns else 0,
'total_quantity': int(df['quantity_sold'].sum()) if 'quantity_sold' in df.columns else 0,
'avg_daily_sales': float(df.groupby('date')['quantity_sold'].sum().mean()) if 'quantity_sold' in df.columns else 0
}
return stats
except Exception as e:
return {'error': str(e)}
# 向后兼容的函数
def load_data(file_path='pharmacy_sales.xlsx', store_id=None):
"""
向后兼容的数据加载函数
"""
return load_multi_store_data(file_path, store_id=store_id)