47 lines
1.3 KiB
Python
47 lines
1.3 KiB
Python
#!/usr/bin/env python
|
|
"""
|
|
快速训练测试 - 验证模型管理系统
|
|
"""
|
|
|
|
import sys
|
|
import os
|
|
sys.path.append('server')
|
|
|
|
def quick_test():
|
|
"""快速训练测试"""
|
|
from core.predictor import PharmacyPredictor
|
|
|
|
predictor = PharmacyPredictor()
|
|
print('开始快速训练测试 - TCN模型, 5轮次')
|
|
|
|
try:
|
|
metrics = predictor.train_model(
|
|
product_id='P001',
|
|
model_type='tcn',
|
|
epochs=5,
|
|
training_mode='product'
|
|
)
|
|
print('✅ 训练成功完成')
|
|
|
|
# 检查模型是否正确保存
|
|
from utils.model_manager import model_manager
|
|
models = model_manager.list_models(product_id='P001', model_type='tcn')
|
|
print(f'✅ 找到 {len(models)} 个TCN模型')
|
|
|
|
if models:
|
|
model = models[0]
|
|
print(f'最新模型: {model.get("filename")}')
|
|
print(f'产品: {model.get("product_name")}')
|
|
print(f'训练模式: {model.get("training_mode")}')
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f'训练异常: {e}')
|
|
if "数据不足" in str(e):
|
|
print('这是预期的数据不足错误')
|
|
return True
|
|
return False
|
|
|
|
if __name__ == "__main__":
|
|
quick_test() |