46 lines
1.5 KiB
Python
46 lines
1.5 KiB
Python
from .base_source import IDataSource
|
|
import pandas as pd
|
|
from typing import List, Optional
|
|
|
|
class ParquetDataSource(IDataSource):
|
|
"""
|
|
一个从Parquet文件加载数据的数据源实现。
|
|
"""
|
|
def __init__(self, file_path: str):
|
|
"""
|
|
初始化Parquet数据源。
|
|
|
|
Args:
|
|
file_path: Parquet文件的路径。
|
|
"""
|
|
self.file_path = file_path
|
|
try:
|
|
self._df = pd.read_parquet(self.file_path)
|
|
except FileNotFoundError:
|
|
print(f"警告: Parquet文件未找到于 {self.file_path}。将使用空DataFrame。")
|
|
self._df = pd.DataFrame()
|
|
|
|
def get_data(self, store_ids: Optional[List[str]] = None, product_ids: Optional[List[str]] = None, **kwargs) -> pd.DataFrame:
|
|
"""
|
|
从Parquet文件中筛选并返回数据。
|
|
|
|
Args:
|
|
store_ids: 要筛选的店铺ID列表。
|
|
product_ids: 要筛选的药品ID列表。
|
|
**kwargs: 其他预留的筛选参数。
|
|
|
|
Returns:
|
|
一个经过筛选的pandas DataFrame。
|
|
"""
|
|
if self._df.empty:
|
|
return self._df
|
|
|
|
filtered_df = self._df.copy()
|
|
|
|
if store_ids and 'store_id' in filtered_df.columns:
|
|
filtered_df = filtered_df[filtered_df['store_id'].isin(store_ids)]
|
|
|
|
if product_ids and 'product_id' in filtered_df.columns:
|
|
filtered_df = filtered_df[filtered_df['product_id'].isin(product_ids)]
|
|
|
|
return filtered_df |