125 lines
3.5 KiB
Python
125 lines
3.5 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
测试修复后的 PharmacyPredictor 类
|
|
"""
|
|
|
|
import sys
|
|
import os
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), 'server'))
|
|
|
|
# 测试数据加载
|
|
def test_multi_store_data_loading():
|
|
"""测试多店铺数据加载"""
|
|
print("测试多店铺数据加载...")
|
|
|
|
try:
|
|
from utils.multi_store_data_utils import load_multi_store_data
|
|
|
|
# 测试加载数据
|
|
data = load_multi_store_data('pharmacy_sales_multi_store.csv')
|
|
print(f"成功加载数据,记录数: {len(data)}")
|
|
print(f"列名: {list(data.columns)}")
|
|
print(f"店铺数量: {data['store_id'].nunique()}")
|
|
print(f"产品数量: {data['product_id'].nunique()}")
|
|
|
|
return True
|
|
except Exception as e:
|
|
print(f"数据加载失败: {e}")
|
|
return False
|
|
|
|
def test_predictor_initialization():
|
|
"""测试 PharmacyPredictor 初始化"""
|
|
print("\n测试 PharmacyPredictor 初始化...")
|
|
|
|
try:
|
|
from core.predictor import PharmacyPredictor
|
|
|
|
predictor = PharmacyPredictor()
|
|
print("PharmacyPredictor 初始化成功")
|
|
|
|
if predictor.data is not None:
|
|
print(f"数据加载成功,记录数: {len(predictor.data)}")
|
|
return True
|
|
else:
|
|
print("数据加载失败")
|
|
return False
|
|
|
|
except Exception as e:
|
|
print(f"PharmacyPredictor 初始化失败: {e}")
|
|
return False
|
|
|
|
def test_train_model_signature():
|
|
"""测试 train_model 方法签名"""
|
|
print("\n测试 train_model 方法签名...")
|
|
|
|
try:
|
|
from core.predictor import PharmacyPredictor
|
|
import inspect
|
|
|
|
predictor = PharmacyPredictor()
|
|
|
|
# 获取 train_model 方法的签名
|
|
sig = inspect.signature(predictor.train_model)
|
|
params = list(sig.parameters.keys())
|
|
|
|
print(f"train_model 方法参数: {params}")
|
|
|
|
# 检查是否有 store_id 参数
|
|
if 'store_id' in params:
|
|
print("✓ store_id 参数存在")
|
|
else:
|
|
print("✗ store_id 参数不存在")
|
|
|
|
# 检查是否有 training_mode 参数
|
|
if 'training_mode' in params:
|
|
print("✓ training_mode 参数存在")
|
|
else:
|
|
print("✗ training_mode 参数不存在")
|
|
|
|
return 'store_id' in params and 'training_mode' in params
|
|
|
|
except Exception as e:
|
|
print(f"方法签名检查失败: {e}")
|
|
return False
|
|
|
|
def main():
|
|
print("=" * 50)
|
|
print("PharmacyPredictor 修复测试")
|
|
print("=" * 50)
|
|
|
|
tests = [
|
|
test_multi_store_data_loading,
|
|
test_predictor_initialization,
|
|
test_train_model_signature
|
|
]
|
|
|
|
results = []
|
|
for test in tests:
|
|
try:
|
|
result = test()
|
|
results.append(result)
|
|
except Exception as e:
|
|
print(f"测试失败: {e}")
|
|
results.append(False)
|
|
|
|
print("\n" + "=" * 50)
|
|
print("测试结果汇总:")
|
|
print("=" * 50)
|
|
|
|
test_names = [
|
|
"多店铺数据加载",
|
|
"PharmacyPredictor 初始化",
|
|
"train_model 方法签名"
|
|
]
|
|
|
|
for i, (name, result) in enumerate(zip(test_names, results)):
|
|
status = "✓ 通过" if result else "✗ 失败"
|
|
print(f"{i+1}. {name}: {status}")
|
|
|
|
all_passed = all(results)
|
|
print(f"\n总体结果: {'✓ 所有测试通过' if all_passed else '✗ 存在失败的测试'}")
|
|
|
|
return all_passed
|
|
|
|
if __name__ == "__main__":
|
|
main() |