#!/usr/bin/env python3 """ 测试修复后的 PharmacyPredictor 类 """ import sys import os sys.path.append(os.path.join(os.path.dirname(__file__), 'server')) # 测试数据加载 def test_multi_store_data_loading(): """测试多店铺数据加载""" print("测试多店铺数据加载...") try: from utils.multi_store_data_utils import load_multi_store_data # 测试加载数据 data = load_multi_store_data('pharmacy_sales_multi_store.csv') print(f"成功加载数据,记录数: {len(data)}") print(f"列名: {list(data.columns)}") print(f"店铺数量: {data['store_id'].nunique()}") print(f"产品数量: {data['product_id'].nunique()}") return True except Exception as e: print(f"数据加载失败: {e}") return False def test_predictor_initialization(): """测试 PharmacyPredictor 初始化""" print("\n测试 PharmacyPredictor 初始化...") try: from core.predictor import PharmacyPredictor predictor = PharmacyPredictor() print("PharmacyPredictor 初始化成功") if predictor.data is not None: print(f"数据加载成功,记录数: {len(predictor.data)}") return True else: print("数据加载失败") return False except Exception as e: print(f"PharmacyPredictor 初始化失败: {e}") return False def test_train_model_signature(): """测试 train_model 方法签名""" print("\n测试 train_model 方法签名...") try: from core.predictor import PharmacyPredictor import inspect predictor = PharmacyPredictor() # 获取 train_model 方法的签名 sig = inspect.signature(predictor.train_model) params = list(sig.parameters.keys()) print(f"train_model 方法参数: {params}") # 检查是否有 store_id 参数 if 'store_id' in params: print("✓ store_id 参数存在") else: print("✗ store_id 参数不存在") # 检查是否有 training_mode 参数 if 'training_mode' in params: print("✓ training_mode 参数存在") else: print("✗ training_mode 参数不存在") return 'store_id' in params and 'training_mode' in params except Exception as e: print(f"方法签名检查失败: {e}") return False def main(): print("=" * 50) print("PharmacyPredictor 修复测试") print("=" * 50) tests = [ test_multi_store_data_loading, test_predictor_initialization, test_train_model_signature ] results = [] for test in tests: try: result = test() results.append(result) except Exception as e: print(f"测试失败: {e}") results.append(False) print("\n" + "=" * 50) print("测试结果汇总:") print("=" * 50) test_names = [ "多店铺数据加载", "PharmacyPredictor 初始化", "train_model 方法签名" ] for i, (name, result) in enumerate(zip(test_names, results)): status = "✓ 通过" if result else "✗ 失败" print(f"{i+1}. {name}: {status}") all_passed = all(results) print(f"\n总体结果: {'✓ 所有测试通过' if all_passed else '✗ 存在失败的测试'}") return all_passed if __name__ == "__main__": main()