ShopTRAINING/server/data/parquet_source.py

46 lines
1.5 KiB
Python

from .base_source import IDataSource
import pandas as pd
from typing import List, Optional
class ParquetDataSource(IDataSource):
"""
一个从Parquet文件加载数据的数据源实现。
"""
def __init__(self, file_path: str):
"""
初始化Parquet数据源。
Args:
file_path: Parquet文件的路径。
"""
self.file_path = file_path
try:
self._df = pd.read_parquet(self.file_path)
except FileNotFoundError:
print(f"警告: Parquet文件未找到于 {self.file_path}。将使用空DataFrame。")
self._df = pd.DataFrame()
def get_data(self, store_ids: Optional[List[str]] = None, product_ids: Optional[List[str]] = None, **kwargs) -> pd.DataFrame:
"""
从Parquet文件中筛选并返回数据。
Args:
store_ids: 要筛选的店铺ID列表。
product_ids: 要筛选的药品ID列表。
**kwargs: 其他预留的筛选参数。
Returns:
一个经过筛选的pandas DataFrame。
"""
if self._df.empty:
return self._df
filtered_df = self._df.copy()
if store_ids and 'store_id' in filtered_df.columns:
filtered_df = filtered_df[filtered_df['store_id'].isin(store_ids)]
if product_ids and 'product_id' in filtered_df.columns:
filtered_df = filtered_df[filtered_df['product_id'].isin(product_ids)]
return filtered_df