139 lines
4.6 KiB
Python
139 lines
4.6 KiB
Python
#!/usr/bin/env python
|
||
"""
|
||
测试训练集成 - 验证PharmacyPredictor与多店铺训练功能的集成
|
||
"""
|
||
|
||
import sys
|
||
import os
|
||
sys.path.append('server')
|
||
|
||
from core.predictor import PharmacyPredictor
|
||
from utils.multi_store_data_utils import load_multi_store_data
|
||
|
||
def test_data_availability():
|
||
"""测试多店铺数据可用性"""
|
||
print("=== 测试数据可用性 ===")
|
||
|
||
try:
|
||
# 测试数据文件是否存在
|
||
data_file = 'pharmacy_sales_multi_store.csv'
|
||
if not os.path.exists(data_file):
|
||
print(f"❌ 数据文件不存在: {data_file}")
|
||
return False
|
||
|
||
# 测试数据加载
|
||
data = load_multi_store_data(data_file)
|
||
print(f"✅ 成功加载多店铺数据,总记录数: {len(data)}")
|
||
|
||
# 显示可用产品和店铺
|
||
products = data['product_id'].unique()
|
||
stores = data['store_id'].unique()
|
||
print(f"📊 可用产品数量: {len(products)}")
|
||
print(f"🏪 可用店铺数量: {len(stores)}")
|
||
print(f"📦 产品列表: {products[:5]}..." if len(products) > 5 else f"📦 产品列表: {products}")
|
||
print(f"🏪 店铺列表: {stores[:3]}..." if len(stores) > 3 else f"🏪 店铺列表: {stores}")
|
||
|
||
return True, products[0], stores[0] # 返回第一个产品和店铺用于测试
|
||
|
||
except Exception as e:
|
||
print(f"❌ 数据加载失败: {e}")
|
||
return False
|
||
|
||
def test_predictor_initialization():
|
||
"""测试PharmacyPredictor初始化"""
|
||
print("\n=== 测试PharmacyPredictor初始化 ===")
|
||
|
||
try:
|
||
predictor = PharmacyPredictor()
|
||
print("✅ PharmacyPredictor初始化成功")
|
||
|
||
if predictor.data is not None:
|
||
print(f"✅ 数据加载成功,记录数: {len(predictor.data)}")
|
||
return predictor
|
||
else:
|
||
print("❌ 数据未正确加载到predictor中")
|
||
return None
|
||
|
||
except Exception as e:
|
||
print(f"❌ PharmacyPredictor初始化失败: {e}")
|
||
return None
|
||
|
||
def test_training_modes(predictor, product_id, store_id):
|
||
"""测试不同训练模式"""
|
||
print(f"\n=== 测试训练模式 (产品: {product_id}, 店铺: {store_id}) ===")
|
||
|
||
# 测试产品训练模式 (使用较少轮次以快速测试)
|
||
print("\n📦 测试产品训练模式 (TCN, 5 epochs)...")
|
||
try:
|
||
metrics = predictor.train_model(
|
||
product_id=product_id,
|
||
model_type='tcn',
|
||
epochs=5,
|
||
training_mode='product'
|
||
)
|
||
if metrics:
|
||
print(f"✅ 产品训练模式成功,RMSE: {metrics.get('rmse', 'N/A'):.4f}")
|
||
else:
|
||
print("❌ 产品训练模式失败")
|
||
except Exception as e:
|
||
print(f"❌ 产品训练模式异常: {e}")
|
||
|
||
# 测试店铺训练模式
|
||
print(f"\n🏪 测试店铺训练模式 (TCN, 5 epochs, 店铺: {store_id})...")
|
||
try:
|
||
metrics = predictor.train_model(
|
||
product_id=product_id,
|
||
model_type='tcn',
|
||
epochs=5,
|
||
training_mode='store',
|
||
store_id=store_id
|
||
)
|
||
if metrics:
|
||
print(f"✅ 店铺训练模式成功,RMSE: {metrics.get('rmse', 'N/A'):.4f}")
|
||
else:
|
||
print("❌ 店铺训练模式失败")
|
||
except Exception as e:
|
||
print(f"❌ 店铺训练模式异常: {e}")
|
||
|
||
# 测试全局训练模式
|
||
print(f"\n🌍 测试全局训练模式 (TCN, 5 epochs, 聚合方法: sum)...")
|
||
try:
|
||
metrics = predictor.train_model(
|
||
product_id=product_id,
|
||
model_type='tcn',
|
||
epochs=5,
|
||
training_mode='global',
|
||
aggregation_method='sum'
|
||
)
|
||
if metrics:
|
||
print(f"✅ 全局训练模式成功,RMSE: {metrics.get('rmse', 'N/A'):.4f}")
|
||
else:
|
||
print("❌ 全局训练模式失败")
|
||
except Exception as e:
|
||
print(f"❌ 全局训练模式异常: {e}")
|
||
|
||
def main():
|
||
"""主测试函数"""
|
||
print("🚀 开始训练集成测试")
|
||
|
||
# 测试数据可用性
|
||
data_result = test_data_availability()
|
||
if data_result is False:
|
||
print("❌ 数据不可用,终止测试")
|
||
return
|
||
|
||
_, product_id, store_id = data_result
|
||
|
||
# 测试PharmacyPredictor初始化
|
||
predictor = test_predictor_initialization()
|
||
if predictor is None:
|
||
print("❌ PharmacyPredictor初始化失败,终止测试")
|
||
return
|
||
|
||
# 测试训练模式
|
||
test_training_modes(predictor, product_id, store_id)
|
||
|
||
print("\n🎉 训练集成测试完成")
|
||
|
||
if __name__ == "__main__":
|
||
main() |