ShopTRAINING/test/test_predictor_fix.py
2025-07-02 11:05:23 +08:00

125 lines
3.5 KiB
Python

#!/usr/bin/env python3
"""
测试修复后的 PharmacyPredictor 类
"""
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), 'server'))
# 测试数据加载
def test_multi_store_data_loading():
"""测试多店铺数据加载"""
print("测试多店铺数据加载...")
try:
from utils.multi_store_data_utils import load_multi_store_data
# 测试加载数据
data = load_multi_store_data('pharmacy_sales_multi_store.csv')
print(f"成功加载数据,记录数: {len(data)}")
print(f"列名: {list(data.columns)}")
print(f"店铺数量: {data['store_id'].nunique()}")
print(f"产品数量: {data['product_id'].nunique()}")
return True
except Exception as e:
print(f"数据加载失败: {e}")
return False
def test_predictor_initialization():
"""测试 PharmacyPredictor 初始化"""
print("\n测试 PharmacyPredictor 初始化...")
try:
from core.predictor import PharmacyPredictor
predictor = PharmacyPredictor()
print("PharmacyPredictor 初始化成功")
if predictor.data is not None:
print(f"数据加载成功,记录数: {len(predictor.data)}")
return True
else:
print("数据加载失败")
return False
except Exception as e:
print(f"PharmacyPredictor 初始化失败: {e}")
return False
def test_train_model_signature():
"""测试 train_model 方法签名"""
print("\n测试 train_model 方法签名...")
try:
from core.predictor import PharmacyPredictor
import inspect
predictor = PharmacyPredictor()
# 获取 train_model 方法的签名
sig = inspect.signature(predictor.train_model)
params = list(sig.parameters.keys())
print(f"train_model 方法参数: {params}")
# 检查是否有 store_id 参数
if 'store_id' in params:
print("✓ store_id 参数存在")
else:
print("✗ store_id 参数不存在")
# 检查是否有 training_mode 参数
if 'training_mode' in params:
print("✓ training_mode 参数存在")
else:
print("✗ training_mode 参数不存在")
return 'store_id' in params and 'training_mode' in params
except Exception as e:
print(f"方法签名检查失败: {e}")
return False
def main():
print("=" * 50)
print("PharmacyPredictor 修复测试")
print("=" * 50)
tests = [
test_multi_store_data_loading,
test_predictor_initialization,
test_train_model_signature
]
results = []
for test in tests:
try:
result = test()
results.append(result)
except Exception as e:
print(f"测试失败: {e}")
results.append(False)
print("\n" + "=" * 50)
print("测试结果汇总:")
print("=" * 50)
test_names = [
"多店铺数据加载",
"PharmacyPredictor 初始化",
"train_model 方法签名"
]
for i, (name, result) in enumerate(zip(test_names, results)):
status = "✓ 通过" if result else "✗ 失败"
print(f"{i+1}. {name}: {status}")
all_passed = all(results)
print(f"\n总体结果: {'✓ 所有测试通过' if all_passed else '✗ 存在失败的测试'}")
return all_passed
if __name__ == "__main__":
main()