#!/usr/bin/env python """ 集成测试 - 验证训练函数调用是否正确 """ import sys import os sys.path.append('server') def test_import_structure(): """测试导入结构""" print("=== 测试导入结构 ===") try: # 测试训练函数导入 from trainers.tcn_trainer import train_product_model_with_tcn from trainers.kan_trainer import train_product_model_with_kan from trainers.transformer_trainer import train_product_model_with_transformer from trainers.mlstm_trainer import train_product_model_with_mlstm print("OK: 所有训练函数导入成功") # 测试多店铺数据工具导入 from utils.multi_store_data_utils import load_multi_store_data, get_store_product_sales_data, aggregate_multi_store_data print("OK: 多店铺数据工具导入成功") return True except Exception as e: print(f"FAIL: 导入失败: {e}") return False def test_function_signatures(): """测试函数签名""" print("\n=== 测试函数签名 ===") try: from trainers.tcn_trainer import train_product_model_with_tcn from trainers.kan_trainer import train_product_model_with_kan from trainers.transformer_trainer import train_product_model_with_transformer from trainers.mlstm_trainer import train_product_model_with_mlstm import inspect # 检查TCN函数签名 tcn_sig = inspect.signature(train_product_model_with_tcn) tcn_params = list(tcn_sig.parameters.keys()) expected_params = ['product_id', 'store_id', 'training_mode', 'aggregation_method'] missing_params = [p for p in expected_params if p not in tcn_params] if missing_params: print(f"FAIL: TCN函数缺少参数: {missing_params}") print(f"实际参数: {tcn_params}") else: print("OK: TCN函数签名正确") # 检查KAN函数签名 kan_sig = inspect.signature(train_product_model_with_kan) kan_params = list(kan_sig.parameters.keys()) missing_params = [p for p in expected_params if p not in kan_params] if missing_params: print(f"FAIL: KAN函数缺少参数: {missing_params}") print(f"实际参数: {kan_params}") else: print("OK: KAN函数签名正确") # 检查Transformer函数签名 transformer_sig = inspect.signature(train_product_model_with_transformer) transformer_params = list(transformer_sig.parameters.keys()) missing_params = [p for p in expected_params if p not in transformer_params] if missing_params: print(f"FAIL: Transformer函数缺少参数: {missing_params}") print(f"实际参数: {transformer_params}") else: print("OK: Transformer函数签名正确") # 检查mLSTM函数签名 mlstm_sig = inspect.signature(train_product_model_with_mlstm) mlstm_params = list(mlstm_sig.parameters.keys()) missing_params = [p for p in expected_params if p not in mlstm_params] if missing_params: print(f"FAIL: mLSTM函数缺少参数: {missing_params}") print(f"实际参数: {mlstm_params}") else: print("OK: mLSTM函数签名正确") return len([p for p in expected_params]) == 0 except Exception as e: print(f"FAIL: 签名检查失败: {e}") return False def test_data_file_existence(): """测试数据文件是否存在""" print("\n=== 测试数据文件 ===") data_file = 'pharmacy_sales_multi_store.csv' if os.path.exists(data_file): print(f"OK: 多店铺数据文件存在: {data_file}") # 检查文件大小 file_size = os.path.getsize(data_file) print(f"文件大小: {file_size} bytes") return True else: print(f"FAIL: 多店铺数据文件不存在: {data_file}") return False def test_predictor_import(): """测试PharmacyPredictor导入(不初始化)""" print("\n=== 测试PharmacyPredictor导入 ===") try: # 只导入类,不初始化(避免matplotlib依赖) from core.predictor import PharmacyPredictor print("OK: PharmacyPredictor类导入成功") # 检查train_model方法是否存在 if hasattr(PharmacyPredictor, 'train_model'): print("OK: train_model方法存在") # 检查方法签名 import inspect sig = inspect.signature(PharmacyPredictor.train_model) params = list(sig.parameters.keys()) expected_params = ['store_id', 'training_mode', 'aggregation_method'] missing_params = [p for p in expected_params if p not in params] if missing_params: print(f"FAIL: train_model方法缺少参数: {missing_params}") print(f"实际参数: {params}") else: print("OK: train_model方法签名正确") else: print("FAIL: train_model方法不存在") return True except Exception as e: print(f"FAIL: PharmacyPredictor导入失败: {e}") return False def main(): """主测试函数""" print("开始集成测试") tests_passed = 0 total_tests = 4 # 测试导入结构 if test_import_structure(): tests_passed += 1 # 测试函数签名 if test_function_signatures(): tests_passed += 1 # 测试数据文件 if test_data_file_existence(): tests_passed += 1 # 测试PharmacyPredictor导入 if test_predictor_import(): tests_passed += 1 print(f"\n测试结果: {tests_passed}/{total_tests} 项测试通过") if tests_passed == total_tests: print("SUCCESS: 所有测试通过!训练集成应该正常工作") else: print("WARNING: 部分测试失败,需要进一步检查") if __name__ == "__main__": main()