139 lines
4.6 KiB
Python
139 lines
4.6 KiB
Python
![]() |
#!/usr/bin/env python
|
|||
|
"""
|
|||
|
测试训练集成 - 验证PharmacyPredictor与多店铺训练功能的集成
|
|||
|
"""
|
|||
|
|
|||
|
import sys
|
|||
|
import os
|
|||
|
sys.path.append('server')
|
|||
|
|
|||
|
from core.predictor import PharmacyPredictor
|
|||
|
from utils.multi_store_data_utils import load_multi_store_data
|
|||
|
|
|||
|
def test_data_availability():
|
|||
|
"""测试多店铺数据可用性"""
|
|||
|
print("=== 测试数据可用性 ===")
|
|||
|
|
|||
|
try:
|
|||
|
# 测试数据文件是否存在
|
|||
|
data_file = 'pharmacy_sales_multi_store.csv'
|
|||
|
if not os.path.exists(data_file):
|
|||
|
print(f"❌ 数据文件不存在: {data_file}")
|
|||
|
return False
|
|||
|
|
|||
|
# 测试数据加载
|
|||
|
data = load_multi_store_data(data_file)
|
|||
|
print(f"✅ 成功加载多店铺数据,总记录数: {len(data)}")
|
|||
|
|
|||
|
# 显示可用产品和店铺
|
|||
|
products = data['product_id'].unique()
|
|||
|
stores = data['store_id'].unique()
|
|||
|
print(f"📊 可用产品数量: {len(products)}")
|
|||
|
print(f"🏪 可用店铺数量: {len(stores)}")
|
|||
|
print(f"📦 产品列表: {products[:5]}..." if len(products) > 5 else f"📦 产品列表: {products}")
|
|||
|
print(f"🏪 店铺列表: {stores[:3]}..." if len(stores) > 3 else f"🏪 店铺列表: {stores}")
|
|||
|
|
|||
|
return True, products[0], stores[0] # 返回第一个产品和店铺用于测试
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
print(f"❌ 数据加载失败: {e}")
|
|||
|
return False
|
|||
|
|
|||
|
def test_predictor_initialization():
|
|||
|
"""测试PharmacyPredictor初始化"""
|
|||
|
print("\n=== 测试PharmacyPredictor初始化 ===")
|
|||
|
|
|||
|
try:
|
|||
|
predictor = PharmacyPredictor()
|
|||
|
print("✅ PharmacyPredictor初始化成功")
|
|||
|
|
|||
|
if predictor.data is not None:
|
|||
|
print(f"✅ 数据加载成功,记录数: {len(predictor.data)}")
|
|||
|
return predictor
|
|||
|
else:
|
|||
|
print("❌ 数据未正确加载到predictor中")
|
|||
|
return None
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
print(f"❌ PharmacyPredictor初始化失败: {e}")
|
|||
|
return None
|
|||
|
|
|||
|
def test_training_modes(predictor, product_id, store_id):
|
|||
|
"""测试不同训练模式"""
|
|||
|
print(f"\n=== 测试训练模式 (产品: {product_id}, 店铺: {store_id}) ===")
|
|||
|
|
|||
|
# 测试产品训练模式 (使用较少轮次以快速测试)
|
|||
|
print("\n📦 测试产品训练模式 (TCN, 5 epochs)...")
|
|||
|
try:
|
|||
|
metrics = predictor.train_model(
|
|||
|
product_id=product_id,
|
|||
|
model_type='tcn',
|
|||
|
epochs=5,
|
|||
|
training_mode='product'
|
|||
|
)
|
|||
|
if metrics:
|
|||
|
print(f"✅ 产品训练模式成功,RMSE: {metrics.get('rmse', 'N/A'):.4f}")
|
|||
|
else:
|
|||
|
print("❌ 产品训练模式失败")
|
|||
|
except Exception as e:
|
|||
|
print(f"❌ 产品训练模式异常: {e}")
|
|||
|
|
|||
|
# 测试店铺训练模式
|
|||
|
print(f"\n🏪 测试店铺训练模式 (TCN, 5 epochs, 店铺: {store_id})...")
|
|||
|
try:
|
|||
|
metrics = predictor.train_model(
|
|||
|
product_id=product_id,
|
|||
|
model_type='tcn',
|
|||
|
epochs=5,
|
|||
|
training_mode='store',
|
|||
|
store_id=store_id
|
|||
|
)
|
|||
|
if metrics:
|
|||
|
print(f"✅ 店铺训练模式成功,RMSE: {metrics.get('rmse', 'N/A'):.4f}")
|
|||
|
else:
|
|||
|
print("❌ 店铺训练模式失败")
|
|||
|
except Exception as e:
|
|||
|
print(f"❌ 店铺训练模式异常: {e}")
|
|||
|
|
|||
|
# 测试全局训练模式
|
|||
|
print(f"\n🌍 测试全局训练模式 (TCN, 5 epochs, 聚合方法: sum)...")
|
|||
|
try:
|
|||
|
metrics = predictor.train_model(
|
|||
|
product_id=product_id,
|
|||
|
model_type='tcn',
|
|||
|
epochs=5,
|
|||
|
training_mode='global',
|
|||
|
aggregation_method='sum'
|
|||
|
)
|
|||
|
if metrics:
|
|||
|
print(f"✅ 全局训练模式成功,RMSE: {metrics.get('rmse', 'N/A'):.4f}")
|
|||
|
else:
|
|||
|
print("❌ 全局训练模式失败")
|
|||
|
except Exception as e:
|
|||
|
print(f"❌ 全局训练模式异常: {e}")
|
|||
|
|
|||
|
def main():
|
|||
|
"""主测试函数"""
|
|||
|
print("🚀 开始训练集成测试")
|
|||
|
|
|||
|
# 测试数据可用性
|
|||
|
data_result = test_data_availability()
|
|||
|
if data_result is False:
|
|||
|
print("❌ 数据不可用,终止测试")
|
|||
|
return
|
|||
|
|
|||
|
_, product_id, store_id = data_result
|
|||
|
|
|||
|
# 测试PharmacyPredictor初始化
|
|||
|
predictor = test_predictor_initialization()
|
|||
|
if predictor is None:
|
|||
|
print("❌ PharmacyPredictor初始化失败,终止测试")
|
|||
|
return
|
|||
|
|
|||
|
# 测试训练模式
|
|||
|
test_training_modes(predictor, product_id, store_id)
|
|||
|
|
|||
|
print("\n🎉 训练集成测试完成")
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
main()
|