174 lines
6.1 KiB
Python
174 lines
6.1 KiB
Python
#!/usr/bin/env python
|
||
"""
|
||
集成测试 - 验证训练函数调用是否正确
|
||
"""
|
||
|
||
import sys
|
||
import os
|
||
sys.path.append('server')
|
||
|
||
def test_import_structure():
|
||
"""测试导入结构"""
|
||
print("=== 测试导入结构 ===")
|
||
|
||
try:
|
||
# 测试训练函数导入
|
||
from trainers.tcn_trainer import train_product_model_with_tcn
|
||
from trainers.kan_trainer import train_product_model_with_kan
|
||
from trainers.transformer_trainer import train_product_model_with_transformer
|
||
from trainers.mlstm_trainer import train_product_model_with_mlstm
|
||
print("OK: 所有训练函数导入成功")
|
||
|
||
# 测试多店铺数据工具导入
|
||
from utils.multi_store_data_utils import load_multi_store_data, get_store_product_sales_data, aggregate_multi_store_data
|
||
print("OK: 多店铺数据工具导入成功")
|
||
|
||
return True
|
||
except Exception as e:
|
||
print(f"FAIL: 导入失败: {e}")
|
||
return False
|
||
|
||
def test_function_signatures():
|
||
"""测试函数签名"""
|
||
print("\n=== 测试函数签名 ===")
|
||
|
||
try:
|
||
from trainers.tcn_trainer import train_product_model_with_tcn
|
||
from trainers.kan_trainer import train_product_model_with_kan
|
||
from trainers.transformer_trainer import train_product_model_with_transformer
|
||
from trainers.mlstm_trainer import train_product_model_with_mlstm
|
||
import inspect
|
||
|
||
# 检查TCN函数签名
|
||
tcn_sig = inspect.signature(train_product_model_with_tcn)
|
||
tcn_params = list(tcn_sig.parameters.keys())
|
||
expected_params = ['product_id', 'store_id', 'training_mode', 'aggregation_method']
|
||
|
||
missing_params = [p for p in expected_params if p not in tcn_params]
|
||
if missing_params:
|
||
print(f"FAIL: TCN函数缺少参数: {missing_params}")
|
||
print(f"实际参数: {tcn_params}")
|
||
else:
|
||
print("OK: TCN函数签名正确")
|
||
|
||
# 检查KAN函数签名
|
||
kan_sig = inspect.signature(train_product_model_with_kan)
|
||
kan_params = list(kan_sig.parameters.keys())
|
||
|
||
missing_params = [p for p in expected_params if p not in kan_params]
|
||
if missing_params:
|
||
print(f"FAIL: KAN函数缺少参数: {missing_params}")
|
||
print(f"实际参数: {kan_params}")
|
||
else:
|
||
print("OK: KAN函数签名正确")
|
||
|
||
# 检查Transformer函数签名
|
||
transformer_sig = inspect.signature(train_product_model_with_transformer)
|
||
transformer_params = list(transformer_sig.parameters.keys())
|
||
|
||
missing_params = [p for p in expected_params if p not in transformer_params]
|
||
if missing_params:
|
||
print(f"FAIL: Transformer函数缺少参数: {missing_params}")
|
||
print(f"实际参数: {transformer_params}")
|
||
else:
|
||
print("OK: Transformer函数签名正确")
|
||
|
||
# 检查mLSTM函数签名
|
||
mlstm_sig = inspect.signature(train_product_model_with_mlstm)
|
||
mlstm_params = list(mlstm_sig.parameters.keys())
|
||
|
||
missing_params = [p for p in expected_params if p not in mlstm_params]
|
||
if missing_params:
|
||
print(f"FAIL: mLSTM函数缺少参数: {missing_params}")
|
||
print(f"实际参数: {mlstm_params}")
|
||
else:
|
||
print("OK: mLSTM函数签名正确")
|
||
|
||
return len([p for p in expected_params]) == 0
|
||
except Exception as e:
|
||
print(f"FAIL: 签名检查失败: {e}")
|
||
return False
|
||
|
||
def test_data_file_existence():
|
||
"""测试数据文件是否存在"""
|
||
print("\n=== 测试数据文件 ===")
|
||
|
||
data_file = 'pharmacy_sales_multi_store.csv'
|
||
if os.path.exists(data_file):
|
||
print(f"OK: 多店铺数据文件存在: {data_file}")
|
||
|
||
# 检查文件大小
|
||
file_size = os.path.getsize(data_file)
|
||
print(f"文件大小: {file_size} bytes")
|
||
|
||
return True
|
||
else:
|
||
print(f"FAIL: 多店铺数据文件不存在: {data_file}")
|
||
return False
|
||
|
||
def test_predictor_import():
|
||
"""测试PharmacyPredictor导入(不初始化)"""
|
||
print("\n=== 测试PharmacyPredictor导入 ===")
|
||
|
||
try:
|
||
# 只导入类,不初始化(避免matplotlib依赖)
|
||
from core.predictor import PharmacyPredictor
|
||
print("OK: PharmacyPredictor类导入成功")
|
||
|
||
# 检查train_model方法是否存在
|
||
if hasattr(PharmacyPredictor, 'train_model'):
|
||
print("OK: train_model方法存在")
|
||
|
||
# 检查方法签名
|
||
import inspect
|
||
sig = inspect.signature(PharmacyPredictor.train_model)
|
||
params = list(sig.parameters.keys())
|
||
|
||
expected_params = ['store_id', 'training_mode', 'aggregation_method']
|
||
missing_params = [p for p in expected_params if p not in params]
|
||
|
||
if missing_params:
|
||
print(f"FAIL: train_model方法缺少参数: {missing_params}")
|
||
print(f"实际参数: {params}")
|
||
else:
|
||
print("OK: train_model方法签名正确")
|
||
else:
|
||
print("FAIL: train_model方法不存在")
|
||
|
||
return True
|
||
except Exception as e:
|
||
print(f"FAIL: PharmacyPredictor导入失败: {e}")
|
||
return False
|
||
|
||
def main():
|
||
"""主测试函数"""
|
||
print("开始集成测试")
|
||
|
||
tests_passed = 0
|
||
total_tests = 4
|
||
|
||
# 测试导入结构
|
||
if test_import_structure():
|
||
tests_passed += 1
|
||
|
||
# 测试函数签名
|
||
if test_function_signatures():
|
||
tests_passed += 1
|
||
|
||
# 测试数据文件
|
||
if test_data_file_existence():
|
||
tests_passed += 1
|
||
|
||
# 测试PharmacyPredictor导入
|
||
if test_predictor_import():
|
||
tests_passed += 1
|
||
|
||
print(f"\n测试结果: {tests_passed}/{total_tests} 项测试通过")
|
||
|
||
if tests_passed == total_tests:
|
||
print("SUCCESS: 所有测试通过!训练集成应该正常工作")
|
||
else:
|
||
print("WARNING: 部分测试失败,需要进一步检查")
|
||
|
||
if __name__ == "__main__":
|
||
main() |