ShopTRAINING/test/test_training_integration.py

139 lines
4.6 KiB
Python
Raw Normal View History

2025-07-02 11:05:23 +08:00
#!/usr/bin/env python
"""
测试训练集成 - 验证PharmacyPredictor与多店铺训练功能的集成
"""
import sys
import os
sys.path.append('server')
from core.predictor import PharmacyPredictor
from utils.multi_store_data_utils import load_multi_store_data
def test_data_availability():
"""测试多店铺数据可用性"""
print("=== 测试数据可用性 ===")
try:
# 测试数据文件是否存在
data_file = 'pharmacy_sales_multi_store.csv'
if not os.path.exists(data_file):
print(f"❌ 数据文件不存在: {data_file}")
return False
# 测试数据加载
data = load_multi_store_data(data_file)
print(f"✅ 成功加载多店铺数据,总记录数: {len(data)}")
# 显示可用产品和店铺
products = data['product_id'].unique()
stores = data['store_id'].unique()
print(f"📊 可用产品数量: {len(products)}")
print(f"🏪 可用店铺数量: {len(stores)}")
print(f"📦 产品列表: {products[:5]}..." if len(products) > 5 else f"📦 产品列表: {products}")
print(f"🏪 店铺列表: {stores[:3]}..." if len(stores) > 3 else f"🏪 店铺列表: {stores}")
return True, products[0], stores[0] # 返回第一个产品和店铺用于测试
except Exception as e:
print(f"❌ 数据加载失败: {e}")
return False
def test_predictor_initialization():
"""测试PharmacyPredictor初始化"""
print("\n=== 测试PharmacyPredictor初始化 ===")
try:
predictor = PharmacyPredictor()
print("✅ PharmacyPredictor初始化成功")
if predictor.data is not None:
print(f"✅ 数据加载成功,记录数: {len(predictor.data)}")
return predictor
else:
print("❌ 数据未正确加载到predictor中")
return None
except Exception as e:
print(f"❌ PharmacyPredictor初始化失败: {e}")
return None
def test_training_modes(predictor, product_id, store_id):
"""测试不同训练模式"""
print(f"\n=== 测试训练模式 (产品: {product_id}, 店铺: {store_id}) ===")
# 测试产品训练模式 (使用较少轮次以快速测试)
print("\n📦 测试产品训练模式 (TCN, 5 epochs)...")
try:
metrics = predictor.train_model(
product_id=product_id,
model_type='tcn',
epochs=5,
training_mode='product'
)
if metrics:
print(f"✅ 产品训练模式成功RMSE: {metrics.get('rmse', 'N/A'):.4f}")
else:
print("❌ 产品训练模式失败")
except Exception as e:
print(f"❌ 产品训练模式异常: {e}")
# 测试店铺训练模式
print(f"\n🏪 测试店铺训练模式 (TCN, 5 epochs, 店铺: {store_id})...")
try:
metrics = predictor.train_model(
product_id=product_id,
model_type='tcn',
epochs=5,
training_mode='store',
store_id=store_id
)
if metrics:
print(f"✅ 店铺训练模式成功RMSE: {metrics.get('rmse', 'N/A'):.4f}")
else:
print("❌ 店铺训练模式失败")
except Exception as e:
print(f"❌ 店铺训练模式异常: {e}")
# 测试全局训练模式
print(f"\n🌍 测试全局训练模式 (TCN, 5 epochs, 聚合方法: sum)...")
try:
metrics = predictor.train_model(
product_id=product_id,
model_type='tcn',
epochs=5,
training_mode='global',
aggregation_method='sum'
)
if metrics:
print(f"✅ 全局训练模式成功RMSE: {metrics.get('rmse', 'N/A'):.4f}")
else:
print("❌ 全局训练模式失败")
except Exception as e:
print(f"❌ 全局训练模式异常: {e}")
def main():
"""主测试函数"""
print("🚀 开始训练集成测试")
# 测试数据可用性
data_result = test_data_availability()
if data_result is False:
print("❌ 数据不可用,终止测试")
return
_, product_id, store_id = data_result
# 测试PharmacyPredictor初始化
predictor = test_predictor_initialization()
if predictor is None:
print("❌ PharmacyPredictor初始化失败终止测试")
return
# 测试训练模式
test_training_modes(predictor, product_id, store_id)
print("\n🎉 训练集成测试完成")
if __name__ == "__main__":
main()