ShopTRAINING/test/test_training_integration.py
2025-07-02 11:05:23 +08:00

139 lines
4.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
"""
测试训练集成 - 验证PharmacyPredictor与多店铺训练功能的集成
"""
import sys
import os
sys.path.append('server')
from core.predictor import PharmacyPredictor
from utils.multi_store_data_utils import load_multi_store_data
def test_data_availability():
"""测试多店铺数据可用性"""
print("=== 测试数据可用性 ===")
try:
# 测试数据文件是否存在
data_file = 'pharmacy_sales_multi_store.csv'
if not os.path.exists(data_file):
print(f"❌ 数据文件不存在: {data_file}")
return False
# 测试数据加载
data = load_multi_store_data(data_file)
print(f"✅ 成功加载多店铺数据,总记录数: {len(data)}")
# 显示可用产品和店铺
products = data['product_id'].unique()
stores = data['store_id'].unique()
print(f"📊 可用产品数量: {len(products)}")
print(f"🏪 可用店铺数量: {len(stores)}")
print(f"📦 产品列表: {products[:5]}..." if len(products) > 5 else f"📦 产品列表: {products}")
print(f"🏪 店铺列表: {stores[:3]}..." if len(stores) > 3 else f"🏪 店铺列表: {stores}")
return True, products[0], stores[0] # 返回第一个产品和店铺用于测试
except Exception as e:
print(f"❌ 数据加载失败: {e}")
return False
def test_predictor_initialization():
"""测试PharmacyPredictor初始化"""
print("\n=== 测试PharmacyPredictor初始化 ===")
try:
predictor = PharmacyPredictor()
print("✅ PharmacyPredictor初始化成功")
if predictor.data is not None:
print(f"✅ 数据加载成功,记录数: {len(predictor.data)}")
return predictor
else:
print("❌ 数据未正确加载到predictor中")
return None
except Exception as e:
print(f"❌ PharmacyPredictor初始化失败: {e}")
return None
def test_training_modes(predictor, product_id, store_id):
"""测试不同训练模式"""
print(f"\n=== 测试训练模式 (产品: {product_id}, 店铺: {store_id}) ===")
# 测试产品训练模式 (使用较少轮次以快速测试)
print("\n📦 测试产品训练模式 (TCN, 5 epochs)...")
try:
metrics = predictor.train_model(
product_id=product_id,
model_type='tcn',
epochs=5,
training_mode='product'
)
if metrics:
print(f"✅ 产品训练模式成功RMSE: {metrics.get('rmse', 'N/A'):.4f}")
else:
print("❌ 产品训练模式失败")
except Exception as e:
print(f"❌ 产品训练模式异常: {e}")
# 测试店铺训练模式
print(f"\n🏪 测试店铺训练模式 (TCN, 5 epochs, 店铺: {store_id})...")
try:
metrics = predictor.train_model(
product_id=product_id,
model_type='tcn',
epochs=5,
training_mode='store',
store_id=store_id
)
if metrics:
print(f"✅ 店铺训练模式成功RMSE: {metrics.get('rmse', 'N/A'):.4f}")
else:
print("❌ 店铺训练模式失败")
except Exception as e:
print(f"❌ 店铺训练模式异常: {e}")
# 测试全局训练模式
print(f"\n🌍 测试全局训练模式 (TCN, 5 epochs, 聚合方法: sum)...")
try:
metrics = predictor.train_model(
product_id=product_id,
model_type='tcn',
epochs=5,
training_mode='global',
aggregation_method='sum'
)
if metrics:
print(f"✅ 全局训练模式成功RMSE: {metrics.get('rmse', 'N/A'):.4f}")
else:
print("❌ 全局训练模式失败")
except Exception as e:
print(f"❌ 全局训练模式异常: {e}")
def main():
"""主测试函数"""
print("🚀 开始训练集成测试")
# 测试数据可用性
data_result = test_data_availability()
if data_result is False:
print("❌ 数据不可用,终止测试")
return
_, product_id, store_id = data_result
# 测试PharmacyPredictor初始化
predictor = test_predictor_initialization()
if predictor is None:
print("❌ PharmacyPredictor初始化失败终止测试")
return
# 测试训练模式
test_training_modes(predictor, product_id, store_id)
print("\n🎉 训练集成测试完成")
if __name__ == "__main__":
main()