126 lines
3.3 KiB
Python
126 lines
3.3 KiB
Python
![]() |
#!/usr/bin/env python3
|
|||
|
"""
|
|||
|
创建包含评估指标的测试模型文件
|
|||
|
"""
|
|||
|
import torch
|
|||
|
import os
|
|||
|
from datetime import datetime
|
|||
|
|
|||
|
def create_test_model_with_metrics():
|
|||
|
"""创建包含完整评估指标的测试模型"""
|
|||
|
|
|||
|
# 创建一个简单的模型
|
|||
|
model = torch.nn.Linear(10, 1)
|
|||
|
|
|||
|
# 示例训练指标
|
|||
|
metrics = {
|
|||
|
'RMSE': 12.3456,
|
|||
|
'MAE': 8.9012,
|
|||
|
'R2': 0.8765,
|
|||
|
'MAPE': 15.23,
|
|||
|
'MSE': 152.414,
|
|||
|
'training_time': 45.67,
|
|||
|
'loss_curve': [0.5, 0.3, 0.2, 0.15, 0.12]
|
|||
|
}
|
|||
|
|
|||
|
# 模型配置
|
|||
|
config = {
|
|||
|
'model_type': 'transformer',
|
|||
|
'product_id': 'P003',
|
|||
|
'version': 'v1',
|
|||
|
'training_mode': 'product',
|
|||
|
'input_size': 10,
|
|||
|
'hidden_size': 64,
|
|||
|
'num_layers': 2,
|
|||
|
'epochs': 5,
|
|||
|
'learning_rate': 0.001,
|
|||
|
'batch_size': 32
|
|||
|
}
|
|||
|
|
|||
|
# 模型信息
|
|||
|
model_info = {
|
|||
|
'product_id': 'P003',
|
|||
|
'product_name': '阿司匹林片',
|
|||
|
'model_type': 'transformer',
|
|||
|
'version': 'v1',
|
|||
|
'training_mode': 'product',
|
|||
|
'created_at': datetime.now().isoformat(),
|
|||
|
'store_id': None,
|
|||
|
'aggregation_method': None
|
|||
|
}
|
|||
|
|
|||
|
# 保存模型
|
|||
|
model_data = {
|
|||
|
'model_state_dict': model.state_dict(),
|
|||
|
'metrics': metrics,
|
|||
|
'config': config,
|
|||
|
'model_info': model_info
|
|||
|
}
|
|||
|
|
|||
|
# 确保saved_models目录存在
|
|||
|
saved_models_dir = 'saved_models'
|
|||
|
os.makedirs(saved_models_dir, exist_ok=True)
|
|||
|
|
|||
|
# 按照新的命名格式保存
|
|||
|
filename = 'transformer_product_P003_v1.pth'
|
|||
|
filepath = os.path.join(saved_models_dir, filename)
|
|||
|
|
|||
|
torch.save(model_data, filepath)
|
|||
|
|
|||
|
print(f"已创建测试模型: {filepath}")
|
|||
|
print(f"包含评估指标: {list(metrics.keys())}")
|
|||
|
print(f"R² = {metrics['R2']:.4f}")
|
|||
|
print(f"RMSE = {metrics['RMSE']:.4f}")
|
|||
|
print(f"MAE = {metrics['MAE']:.4f}")
|
|||
|
print(f"MAPE = {metrics['MAPE']:.2f}%")
|
|||
|
|
|||
|
# 创建第二个模型(KAN类型)
|
|||
|
kan_metrics = {
|
|||
|
'RMSE': 9.8765,
|
|||
|
'MAE': 6.4321,
|
|||
|
'R2': 0.9123,
|
|||
|
'MAPE': 12.34,
|
|||
|
'MSE': 97.544,
|
|||
|
'training_time': 67.89
|
|||
|
}
|
|||
|
|
|||
|
kan_config = {
|
|||
|
'model_type': 'kan_optimized',
|
|||
|
'product_id': 'P004',
|
|||
|
'version': 'v1',
|
|||
|
'training_mode': 'product',
|
|||
|
'grid_size': 5,
|
|||
|
'spline_order': 3,
|
|||
|
'epochs': 10
|
|||
|
}
|
|||
|
|
|||
|
kan_info = {
|
|||
|
'product_id': 'P004',
|
|||
|
'product_name': '布洛芬胶囊',
|
|||
|
'model_type': 'kan_optimized',
|
|||
|
'version': 'v1',
|
|||
|
'training_mode': 'product',
|
|||
|
'created_at': datetime.now().isoformat(),
|
|||
|
'store_id': None,
|
|||
|
'aggregation_method': None
|
|||
|
}
|
|||
|
|
|||
|
kan_data = {
|
|||
|
'model_state_dict': model.state_dict(),
|
|||
|
'metrics': kan_metrics,
|
|||
|
'config': kan_config,
|
|||
|
'model_info': kan_info
|
|||
|
}
|
|||
|
|
|||
|
kan_filename = 'kan_optimized_product_P004_v1.pth'
|
|||
|
kan_filepath = os.path.join(saved_models_dir, kan_filename)
|
|||
|
|
|||
|
torch.save(kan_data, kan_filepath)
|
|||
|
|
|||
|
print(f"\n已创建第二个测试模型: {kan_filepath}")
|
|||
|
print(f"包含评估指标: {list(kan_metrics.keys())}")
|
|||
|
print(f"R² = {kan_metrics['R2']:.4f}")
|
|||
|
print(f"RMSE = {kan_metrics['RMSE']:.4f}")
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
create_test_model_with_metrics()
|