""" 多店铺销售预测系统 - 数据处理工具函数 支持多店铺数据的加载、过滤和处理 """ import pandas as pd import numpy as np import os import json from datetime import datetime, timedelta from typing import Optional, List, Tuple, Dict, Any from core.config import DEFAULT_DATA_PATH def load_multi_store_data(file_path: str = None, store_id: Optional[str] = None, product_id: Optional[str] = None, start_date: Optional[str] = None, end_date: Optional[str] = None) -> pd.DataFrame: """ 加载多店铺销售数据,支持按店铺、产品、时间范围过滤 """ if file_path is None: file_path = DEFAULT_DATA_PATH if not os.path.exists(file_path): raise FileNotFoundError(f"数据文件不存在: {file_path}") try: df = pd.read_parquet(file_path or DEFAULT_DATA_PATH) except Exception as e: raise e df = standardize_column_names(df) if store_id: df['store_id'] = df['store_id'].astype(str) df = df[df['store_id'] == str(store_id)].copy() if product_id: df['product_id'] = df['product_id'].astype(str) df = df[df['product_id'] == str(product_id)].copy() if start_date: df = df[df['date'] >= pd.to_datetime(start_date)].copy() if end_date: df = df[df['date'] <= pd.to_datetime(end_date)].copy() return df def standardize_column_names(df: pd.DataFrame) -> pd.DataFrame: """ 标准化列名以匹配训练代码和API期望的格式 """ df = df.copy() try: current_dir = os.path.dirname(os.path.abspath(__file__)) mapping_file_path = os.path.join(current_dir, '..', 'core', 'field_mapping.json') with open(mapping_file_path, 'r', encoding='utf-8') as f: rename_map = json.load(f).get("mapping", {}) except Exception as e: # 关键修复:添加明确的错误日志 print(f"!!! CRITICAL ERROR: 无法加载或解析 'field_mapping.json'. 错误: {e}") print(f" - 尝试加载的路径: {mapping_file_path}") print(f" - 将使用空的映射,列名可能不会被标准化。") rename_map = {} if 'kdrq' in df.columns and 'date' in df.columns: df = df.drop(columns=['date']) df.rename(columns=rename_map, inplace=True) if 'date' in df.columns: df['date'] = pd.to_datetime(df['date'], errors='coerce') df.dropna(subset=['date'], inplace=True) else: raise ValueError("数据中缺少 'date' (或 'kdrq') 列。") return df def get_available_stores(file_path: str = None) -> List[Dict[str, Any]]: """ 获取可用的店铺列表 (最终修复版) """ try: df = pd.read_parquet(file_path or DEFAULT_DATA_PATH) source_cols = ['subbh', 'kdrq', 'district_name', 'area_sq_km'] existing_cols = [col for col in source_cols if col in df.columns] if 'subbh' not in df.columns: return [] agg_dict = {'kdrq': 'min'} for col in existing_cols: if col not in ['subbh', 'kdrq']: agg_dict[col] = 'first' stores_df = df[existing_cols].groupby('subbh').agg(agg_dict).reset_index() stores_df.rename(columns={ 'subbh': 'store_id', 'kdrq': 'opening_date', 'district_name': 'location', 'area_sq_km': 'area' }, inplace=True) stores_df['store_name'] = stores_df['store_id'] if 'opening_date' in stores_df.columns: stores_df['opening_date'] = pd.to_datetime(stores_df['opening_date']).dt.strftime('%Y-%m-%d') stores_df['type'] = '综合' stores_df['status'] = '营业中' final_cols = ['store_id', 'store_name', 'location', 'type', 'area', 'opening_date', 'status'] for col in final_cols: if col not in stores_df.columns: stores_df[col] = 'N/A' return stores_df[final_cols].to_dict('records') except Exception as e: print(f"获取店铺列表失败: {e}") return [] def get_all_unique_products(file_path: str = None) -> List[Dict[str, Any]]: """ 高效地获取数据源中所有唯一的产品列表。 """ try: # 只读取需要的列以提高效率 df = pd.read_parquet(file_path or DEFAULT_DATA_PATH, columns=['hh']) if 'hh' not in df.columns: return [] unique_products = df['hh'].unique() # 格式化为前端期望的格式 product_list = [ {'product_id': pid, 'product_name': pid} for pid in unique_products ] return product_list except Exception as e: print(f"获取所有唯一产品列表失败: {e}") return [] def get_available_products(file_path: str = None, store_id: Optional[str] = None) -> List[Dict[str, Any]]: """ 获取店铺相关的产品列表及其销售统计 (v4 - 统一数据加载逻辑) """ try: # 复用统一的数据加载和过滤函数,确保逻辑一致 df = load_multi_store_data(file_path=file_path, store_id=store_id) if df.empty: # 添加日志以便调试 print(f"!!! INFO: 为店铺ID '{store_id}' 加载数据后DataFrame为空。") return [] # 此处的 df 已经是标准化和过滤后的,可以直接使用 agg_ops = { 'date': 'max', 'sales_quantity': 'sum', 'gross_profit': 'sum', 'net_sales_quantity': 'sum', 'category': 'first', } # 注意:这里的列名应该是标准化后的 valid_agg_ops = {k: v for k, v in agg_ops.items() if k in df.columns} if 'product_id' not in df.columns: print("!!! WARNING: 'product_id' 列在标准化后不存在,返回空列表。") return [] products_df = df.groupby('product_id').agg(valid_agg_ops).reset_index() products_df.rename(columns={ 'date': 'last_sale_date', 'sales_quantity': 'total_sales', }, inplace=True) if 'gross_profit' in products_df.columns and 'net_sales_quantity' in products_df.columns: products_df['avg_price'] = (products_df['gross_profit'] / products_df['net_sales_quantity']).where(products_df['net_sales_quantity'] != 0, 0) products_df['product_name'] = products_df['product_id'] if 'last_sale_date' in products_df.columns: products_df['last_sale_date'] = pd.to_datetime(products_df['last_sale_date']).dt.strftime('%Y-%m-%d') final_cols = ['product_id', 'product_name', 'category', 'total_sales', 'avg_price', 'last_sale_date'] for col in final_cols: if col not in products_df.columns: if col in ['total_sales', 'avg_price']: products_df[col] = 0 else: products_df[col] = 'N/A' # 关键修复:在转换为字典之前,将所有NaN和无穷大值替换为None,以确保JSON序列化成功 products_df.replace([np.inf, -np.inf], np.nan, inplace=True) products_df = products_df.astype(object).where(products_df.notna(), None) return products_df[final_cols].to_dict('records') except Exception as e: print(f"获取产品列表失败: {e}") return [] def get_store_product_sales_data(store_id: str, product_id: str, file_path: str = None) -> pd.DataFrame: df = load_multi_store_data(file_path, store_id=store_id, product_id=product_id) if len(df) == 0: raise ValueError(f"没有找到店铺 {store_id} 产品 {product_id} 的销售数据") return df def aggregate_multi_store_data(product_id: Optional[str] = None, store_id: Optional[str] = None, aggregation_method: str = 'sum', file_path: str = None) -> pd.DataFrame: if store_id: df = load_multi_store_data(file_path, store_id=store_id) elif product_id: df = load_multi_store_data(file_path, product_id=product_id) else: df = load_multi_store_data(file_path) if df.empty: raise ValueError("过滤后数据为空,无法聚合") agg_dict = {} if aggregation_method == 'sum': agg_dict = {'sales_quantity': 'sum'} elif aggregation_method == 'mean': agg_dict = {'sales_quantity': 'mean'} valid_agg_dict = {k: v for k, v in agg_dict.items() if k in df.columns} aggregated_df = df.groupby('date').agg(valid_agg_dict).reset_index() return aggregated_df def get_sales_statistics(file_path: str = None, store_id: Optional[str] = None, product_id: Optional[str] = None) -> Dict[str, Any]: try: df = load_multi_store_data(file_path, store_id=store_id, product_id=product_id) if len(df) == 0: return {'error': '没有数据'} stats = { 'total_records': len(df), 'date_range': { 'start': df['date'].min().strftime('%Y-%m-%d'), 'end': df['date'].max().strftime('%Y-%m-%d') }, 'stores': df['store_id'].nunique(), 'products': df['product_id'].nunique(), } return stats except Exception as e: return {'error': str(e)} def load_data(file_path=None, store_id=None): return load_multi_store_data(file_path, store_id=store_id)