ShopTRAINING/test/test_integration.py

174 lines
6.1 KiB
Python
Raw Normal View History

2025-07-02 11:05:23 +08:00
#!/usr/bin/env python
"""
集成测试 - 验证训练函数调用是否正确
"""
import sys
import os
sys.path.append('server')
def test_import_structure():
"""测试导入结构"""
print("=== 测试导入结构 ===")
try:
# 测试训练函数导入
from trainers.tcn_trainer import train_product_model_with_tcn
from trainers.kan_trainer import train_product_model_with_kan
from trainers.transformer_trainer import train_product_model_with_transformer
from trainers.mlstm_trainer import train_product_model_with_mlstm
print("OK: 所有训练函数导入成功")
# 测试多店铺数据工具导入
from utils.multi_store_data_utils import load_multi_store_data, get_store_product_sales_data, aggregate_multi_store_data
print("OK: 多店铺数据工具导入成功")
return True
except Exception as e:
print(f"FAIL: 导入失败: {e}")
return False
def test_function_signatures():
"""测试函数签名"""
print("\n=== 测试函数签名 ===")
try:
from trainers.tcn_trainer import train_product_model_with_tcn
from trainers.kan_trainer import train_product_model_with_kan
from trainers.transformer_trainer import train_product_model_with_transformer
from trainers.mlstm_trainer import train_product_model_with_mlstm
import inspect
# 检查TCN函数签名
tcn_sig = inspect.signature(train_product_model_with_tcn)
tcn_params = list(tcn_sig.parameters.keys())
expected_params = ['product_id', 'store_id', 'training_mode', 'aggregation_method']
missing_params = [p for p in expected_params if p not in tcn_params]
if missing_params:
print(f"FAIL: TCN函数缺少参数: {missing_params}")
print(f"实际参数: {tcn_params}")
else:
print("OK: TCN函数签名正确")
# 检查KAN函数签名
kan_sig = inspect.signature(train_product_model_with_kan)
kan_params = list(kan_sig.parameters.keys())
missing_params = [p for p in expected_params if p not in kan_params]
if missing_params:
print(f"FAIL: KAN函数缺少参数: {missing_params}")
print(f"实际参数: {kan_params}")
else:
print("OK: KAN函数签名正确")
# 检查Transformer函数签名
transformer_sig = inspect.signature(train_product_model_with_transformer)
transformer_params = list(transformer_sig.parameters.keys())
missing_params = [p for p in expected_params if p not in transformer_params]
if missing_params:
print(f"FAIL: Transformer函数缺少参数: {missing_params}")
print(f"实际参数: {transformer_params}")
else:
print("OK: Transformer函数签名正确")
# 检查mLSTM函数签名
mlstm_sig = inspect.signature(train_product_model_with_mlstm)
mlstm_params = list(mlstm_sig.parameters.keys())
missing_params = [p for p in expected_params if p not in mlstm_params]
if missing_params:
print(f"FAIL: mLSTM函数缺少参数: {missing_params}")
print(f"实际参数: {mlstm_params}")
else:
print("OK: mLSTM函数签名正确")
return len([p for p in expected_params]) == 0
except Exception as e:
print(f"FAIL: 签名检查失败: {e}")
return False
def test_data_file_existence():
"""测试数据文件是否存在"""
print("\n=== 测试数据文件 ===")
data_file = 'pharmacy_sales_multi_store.csv'
if os.path.exists(data_file):
print(f"OK: 多店铺数据文件存在: {data_file}")
# 检查文件大小
file_size = os.path.getsize(data_file)
print(f"文件大小: {file_size} bytes")
return True
else:
print(f"FAIL: 多店铺数据文件不存在: {data_file}")
return False
def test_predictor_import():
"""测试PharmacyPredictor导入不初始化"""
print("\n=== 测试PharmacyPredictor导入 ===")
try:
# 只导入类不初始化避免matplotlib依赖
from core.predictor import PharmacyPredictor
print("OK: PharmacyPredictor类导入成功")
# 检查train_model方法是否存在
if hasattr(PharmacyPredictor, 'train_model'):
print("OK: train_model方法存在")
# 检查方法签名
import inspect
sig = inspect.signature(PharmacyPredictor.train_model)
params = list(sig.parameters.keys())
expected_params = ['store_id', 'training_mode', 'aggregation_method']
missing_params = [p for p in expected_params if p not in params]
if missing_params:
print(f"FAIL: train_model方法缺少参数: {missing_params}")
print(f"实际参数: {params}")
else:
print("OK: train_model方法签名正确")
else:
print("FAIL: train_model方法不存在")
return True
except Exception as e:
print(f"FAIL: PharmacyPredictor导入失败: {e}")
return False
def main():
"""主测试函数"""
print("开始集成测试")
tests_passed = 0
total_tests = 4
# 测试导入结构
if test_import_structure():
tests_passed += 1
# 测试函数签名
if test_function_signatures():
tests_passed += 1
# 测试数据文件
if test_data_file_existence():
tests_passed += 1
# 测试PharmacyPredictor导入
if test_predictor_import():
tests_passed += 1
print(f"\n测试结果: {tests_passed}/{total_tests} 项测试通过")
if tests_passed == total_tests:
print("SUCCESS: 所有测试通过!训练集成应该正常工作")
else:
print("WARNING: 部分测试失败,需要进一步检查")
if __name__ == "__main__":
main()