ShopTRAINING/test/test_simple_integration.py
2025-07-02 11:05:23 +08:00

169 lines
5.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
"""
简单集成测试 - 验证训练函数调用是否正确
"""
import sys
import os
sys.path.append('server')
def test_import_structure():
"""测试导入结构"""
print("=== 测试导入结构 ===")
try:
# 测试训练函数导入
from trainers.tcn_trainer import train_product_model_with_tcn
from trainers.kan_trainer import train_product_model_with_kan
from trainers.transformer_trainer import train_product_model_with_transformer
from trainers.mlstm_trainer import train_product_model_with_mlstm
print("✅ 所有训练函数导入成功")
# 测试多店铺数据工具导入
from utils.multi_store_data_utils import load_multi_store_data, get_store_product_sales_data, aggregate_multi_store_data
print("✅ 多店铺数据工具导入成功")
return True
except Exception as e:
print(f"❌ 导入失败: {e}")
return False
def test_function_signatures():
"""测试函数签名"""
print("\n=== 测试函数签名 ===")
try:
from trainers.tcn_trainer import train_product_model_with_tcn
from trainers.kan_trainer import train_product_model_with_kan
from trainers.transformer_trainer import train_product_model_with_transformer
from trainers.mlstm_trainer import train_product_model_with_mlstm
import inspect
# 检查TCN函数签名
tcn_sig = inspect.signature(train_product_model_with_tcn)
tcn_params = list(tcn_sig.parameters.keys())
expected_params = ['product_id', 'store_id', 'training_mode', 'aggregation_method']
missing_params = [p for p in expected_params if p not in tcn_params]
if missing_params:
print(f"❌ TCN函数缺少参数: {missing_params}")
else:
print("✅ TCN函数签名正确")
# 检查KAN函数签名
kan_sig = inspect.signature(train_product_model_with_kan)
kan_params = list(kan_sig.parameters.keys())
missing_params = [p for p in expected_params if p not in kan_params]
if missing_params:
print(f"❌ KAN函数缺少参数: {missing_params}")
else:
print("✅ KAN函数签名正确")
# 检查Transformer函数签名
transformer_sig = inspect.signature(train_product_model_with_transformer)
transformer_params = list(transformer_sig.parameters.keys())
missing_params = [p for p in expected_params if p not in transformer_params]
if missing_params:
print(f"❌ Transformer函数缺少参数: {missing_params}")
else:
print("✅ Transformer函数签名正确")
# 检查mLSTM函数签名
mlstm_sig = inspect.signature(train_product_model_with_mlstm)
mlstm_params = list(mlstm_sig.parameters.keys())
missing_params = [p for p in expected_params if p not in mlstm_params]
if missing_params:
print(f"❌ mLSTM函数缺少参数: {missing_params}")
else:
print("✅ mLSTM函数签名正确")
return True
except Exception as e:
print(f"❌ 签名检查失败: {e}")
return False
def test_data_file_existence():
"""测试数据文件是否存在"""
print("\n=== 测试数据文件 ===")
data_file = 'pharmacy_sales_multi_store.csv'
if os.path.exists(data_file):
print(f"✅ 多店铺数据文件存在: {data_file}")
# 检查文件大小
file_size = os.path.getsize(data_file)
print(f"📊 文件大小: {file_size} bytes")
return True
else:
print(f"❌ 多店铺数据文件不存在: {data_file}")
return False
def test_predictor_import():
"""测试PharmacyPredictor导入不初始化"""
print("\n=== 测试PharmacyPredictor导入 ===")
try:
# 只导入类不初始化避免matplotlib依赖
from core.predictor import PharmacyPredictor
print("✅ PharmacyPredictor类导入成功")
# 检查train_model方法是否存在
if hasattr(PharmacyPredictor, 'train_model'):
print("✅ train_model方法存在")
# 检查方法签名
import inspect
sig = inspect.signature(PharmacyPredictor.train_model)
params = list(sig.parameters.keys())
expected_params = ['store_id', 'training_mode', 'aggregation_method']
missing_params = [p for p in expected_params if p not in params]
if missing_params:
print(f"❌ train_model方法缺少参数: {missing_params}")
else:
print("✅ train_model方法签名正确")
else:
print("❌ train_model方法不存在")
return True
except Exception as e:
print(f"❌ PharmacyPredictor导入失败: {e}")
return False
def main():
"""主测试函数"""
print("🚀 开始简单集成测试")
tests_passed = 0
total_tests = 4
# 测试导入结构
if test_import_structure():
tests_passed += 1
# 测试函数签名
if test_function_signatures():
tests_passed += 1
# 测试数据文件
if test_data_file_existence():
tests_passed += 1
# 测试PharmacyPredictor导入
if test_predictor_import():
tests_passed += 1
print(f"\n📊 测试结果: {tests_passed}/{total_tests} 项测试通过")
if tests_passed == total_tests:
print("🎉 所有测试通过!训练集成应该正常工作")
else:
print("⚠️ 部分测试失败,需要进一步检查")
if __name__ == "__main__":
main()