ShopTRAINING/test/create_test_model_with_metrics.py

126 lines
3.3 KiB
Python
Raw Permalink Normal View History

2025-07-02 11:05:23 +08:00
#!/usr/bin/env python3
"""
创建包含评估指标的测试模型文件
"""
import torch
import os
from datetime import datetime
def create_test_model_with_metrics():
"""创建包含完整评估指标的测试模型"""
# 创建一个简单的模型
model = torch.nn.Linear(10, 1)
# 示例训练指标
metrics = {
'RMSE': 12.3456,
'MAE': 8.9012,
'R2': 0.8765,
'MAPE': 15.23,
'MSE': 152.414,
'training_time': 45.67,
'loss_curve': [0.5, 0.3, 0.2, 0.15, 0.12]
}
# 模型配置
config = {
'model_type': 'transformer',
'product_id': 'P003',
'version': 'v1',
'training_mode': 'product',
'input_size': 10,
'hidden_size': 64,
'num_layers': 2,
'epochs': 5,
'learning_rate': 0.001,
'batch_size': 32
}
# 模型信息
model_info = {
'product_id': 'P003',
'product_name': '阿司匹林片',
'model_type': 'transformer',
'version': 'v1',
'training_mode': 'product',
'created_at': datetime.now().isoformat(),
'store_id': None,
'aggregation_method': None
}
# 保存模型
model_data = {
'model_state_dict': model.state_dict(),
'metrics': metrics,
'config': config,
'model_info': model_info
}
# 确保saved_models目录存在
saved_models_dir = 'saved_models'
os.makedirs(saved_models_dir, exist_ok=True)
# 按照新的命名格式保存
filename = 'transformer_product_P003_v1.pth'
filepath = os.path.join(saved_models_dir, filename)
torch.save(model_data, filepath)
print(f"已创建测试模型: {filepath}")
print(f"包含评估指标: {list(metrics.keys())}")
print(f"R² = {metrics['R2']:.4f}")
print(f"RMSE = {metrics['RMSE']:.4f}")
print(f"MAE = {metrics['MAE']:.4f}")
print(f"MAPE = {metrics['MAPE']:.2f}%")
# 创建第二个模型KAN类型
kan_metrics = {
'RMSE': 9.8765,
'MAE': 6.4321,
'R2': 0.9123,
'MAPE': 12.34,
'MSE': 97.544,
'training_time': 67.89
}
kan_config = {
'model_type': 'kan_optimized',
'product_id': 'P004',
'version': 'v1',
'training_mode': 'product',
'grid_size': 5,
'spline_order': 3,
'epochs': 10
}
kan_info = {
'product_id': 'P004',
'product_name': '布洛芬胶囊',
'model_type': 'kan_optimized',
'version': 'v1',
'training_mode': 'product',
'created_at': datetime.now().isoformat(),
'store_id': None,
'aggregation_method': None
}
kan_data = {
'model_state_dict': model.state_dict(),
'metrics': kan_metrics,
'config': kan_config,
'model_info': kan_info
}
kan_filename = 'kan_optimized_product_P004_v1.pth'
kan_filepath = os.path.join(saved_models_dir, kan_filename)
torch.save(kan_data, kan_filepath)
print(f"\n已创建第二个测试模型: {kan_filepath}")
print(f"包含评估指标: {list(kan_metrics.keys())}")
print(f"R² = {kan_metrics['R2']:.4f}")
print(f"RMSE = {kan_metrics['RMSE']:.4f}")
if __name__ == "__main__":
create_test_model_with_metrics()