#!/usr/bin/env python """ 测试训练集成 - 验证PharmacyPredictor与多店铺训练功能的集成 """ import sys import os sys.path.append('server') from core.predictor import PharmacyPredictor from utils.multi_store_data_utils import load_multi_store_data def test_data_availability(): """测试多店铺数据可用性""" print("=== 测试数据可用性 ===") try: # 测试数据文件是否存在 data_file = 'pharmacy_sales_multi_store.csv' if not os.path.exists(data_file): print(f"❌ 数据文件不存在: {data_file}") return False # 测试数据加载 data = load_multi_store_data(data_file) print(f"✅ 成功加载多店铺数据,总记录数: {len(data)}") # 显示可用产品和店铺 products = data['product_id'].unique() stores = data['store_id'].unique() print(f"📊 可用产品数量: {len(products)}") print(f"🏪 可用店铺数量: {len(stores)}") print(f"📦 产品列表: {products[:5]}..." if len(products) > 5 else f"📦 产品列表: {products}") print(f"🏪 店铺列表: {stores[:3]}..." if len(stores) > 3 else f"🏪 店铺列表: {stores}") return True, products[0], stores[0] # 返回第一个产品和店铺用于测试 except Exception as e: print(f"❌ 数据加载失败: {e}") return False def test_predictor_initialization(): """测试PharmacyPredictor初始化""" print("\n=== 测试PharmacyPredictor初始化 ===") try: predictor = PharmacyPredictor() print("✅ PharmacyPredictor初始化成功") if predictor.data is not None: print(f"✅ 数据加载成功,记录数: {len(predictor.data)}") return predictor else: print("❌ 数据未正确加载到predictor中") return None except Exception as e: print(f"❌ PharmacyPredictor初始化失败: {e}") return None def test_training_modes(predictor, product_id, store_id): """测试不同训练模式""" print(f"\n=== 测试训练模式 (产品: {product_id}, 店铺: {store_id}) ===") # 测试产品训练模式 (使用较少轮次以快速测试) print("\n📦 测试产品训练模式 (TCN, 5 epochs)...") try: metrics = predictor.train_model( product_id=product_id, model_type='tcn', epochs=5, training_mode='product' ) if metrics: print(f"✅ 产品训练模式成功,RMSE: {metrics.get('rmse', 'N/A'):.4f}") else: print("❌ 产品训练模式失败") except Exception as e: print(f"❌ 产品训练模式异常: {e}") # 测试店铺训练模式 print(f"\n🏪 测试店铺训练模式 (TCN, 5 epochs, 店铺: {store_id})...") try: metrics = predictor.train_model( product_id=product_id, model_type='tcn', epochs=5, training_mode='store', store_id=store_id ) if metrics: print(f"✅ 店铺训练模式成功,RMSE: {metrics.get('rmse', 'N/A'):.4f}") else: print("❌ 店铺训练模式失败") except Exception as e: print(f"❌ 店铺训练模式异常: {e}") # 测试全局训练模式 print(f"\n🌍 测试全局训练模式 (TCN, 5 epochs, 聚合方法: sum)...") try: metrics = predictor.train_model( product_id=product_id, model_type='tcn', epochs=5, training_mode='global', aggregation_method='sum' ) if metrics: print(f"✅ 全局训练模式成功,RMSE: {metrics.get('rmse', 'N/A'):.4f}") else: print("❌ 全局训练模式失败") except Exception as e: print(f"❌ 全局训练模式异常: {e}") def main(): """主测试函数""" print("🚀 开始训练集成测试") # 测试数据可用性 data_result = test_data_availability() if data_result is False: print("❌ 数据不可用,终止测试") return _, product_id, store_id = data_result # 测试PharmacyPredictor初始化 predictor = test_predictor_initialization() if predictor is None: print("❌ PharmacyPredictor初始化失败,终止测试") return # 测试训练模式 test_training_modes(predictor, product_id, store_id) print("\n🎉 训练集成测试完成") if __name__ == "__main__": main()