126 lines
3.3 KiB
Python
126 lines
3.3 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
创建包含评估指标的测试模型文件
|
||
"""
|
||
import torch
|
||
import os
|
||
from datetime import datetime
|
||
|
||
def create_test_model_with_metrics():
|
||
"""创建包含完整评估指标的测试模型"""
|
||
|
||
# 创建一个简单的模型
|
||
model = torch.nn.Linear(10, 1)
|
||
|
||
# 示例训练指标
|
||
metrics = {
|
||
'RMSE': 12.3456,
|
||
'MAE': 8.9012,
|
||
'R2': 0.8765,
|
||
'MAPE': 15.23,
|
||
'MSE': 152.414,
|
||
'training_time': 45.67,
|
||
'loss_curve': [0.5, 0.3, 0.2, 0.15, 0.12]
|
||
}
|
||
|
||
# 模型配置
|
||
config = {
|
||
'model_type': 'transformer',
|
||
'product_id': 'P003',
|
||
'version': 'v1',
|
||
'training_mode': 'product',
|
||
'input_size': 10,
|
||
'hidden_size': 64,
|
||
'num_layers': 2,
|
||
'epochs': 5,
|
||
'learning_rate': 0.001,
|
||
'batch_size': 32
|
||
}
|
||
|
||
# 模型信息
|
||
model_info = {
|
||
'product_id': 'P003',
|
||
'product_name': '阿司匹林片',
|
||
'model_type': 'transformer',
|
||
'version': 'v1',
|
||
'training_mode': 'product',
|
||
'created_at': datetime.now().isoformat(),
|
||
'store_id': None,
|
||
'aggregation_method': None
|
||
}
|
||
|
||
# 保存模型
|
||
model_data = {
|
||
'model_state_dict': model.state_dict(),
|
||
'metrics': metrics,
|
||
'config': config,
|
||
'model_info': model_info
|
||
}
|
||
|
||
# 确保saved_models目录存在
|
||
saved_models_dir = 'saved_models'
|
||
os.makedirs(saved_models_dir, exist_ok=True)
|
||
|
||
# 按照新的命名格式保存
|
||
filename = 'transformer_product_P003_v1.pth'
|
||
filepath = os.path.join(saved_models_dir, filename)
|
||
|
||
torch.save(model_data, filepath)
|
||
|
||
print(f"已创建测试模型: {filepath}")
|
||
print(f"包含评估指标: {list(metrics.keys())}")
|
||
print(f"R² = {metrics['R2']:.4f}")
|
||
print(f"RMSE = {metrics['RMSE']:.4f}")
|
||
print(f"MAE = {metrics['MAE']:.4f}")
|
||
print(f"MAPE = {metrics['MAPE']:.2f}%")
|
||
|
||
# 创建第二个模型(KAN类型)
|
||
kan_metrics = {
|
||
'RMSE': 9.8765,
|
||
'MAE': 6.4321,
|
||
'R2': 0.9123,
|
||
'MAPE': 12.34,
|
||
'MSE': 97.544,
|
||
'training_time': 67.89
|
||
}
|
||
|
||
kan_config = {
|
||
'model_type': 'kan_optimized',
|
||
'product_id': 'P004',
|
||
'version': 'v1',
|
||
'training_mode': 'product',
|
||
'grid_size': 5,
|
||
'spline_order': 3,
|
||
'epochs': 10
|
||
}
|
||
|
||
kan_info = {
|
||
'product_id': 'P004',
|
||
'product_name': '布洛芬胶囊',
|
||
'model_type': 'kan_optimized',
|
||
'version': 'v1',
|
||
'training_mode': 'product',
|
||
'created_at': datetime.now().isoformat(),
|
||
'store_id': None,
|
||
'aggregation_method': None
|
||
}
|
||
|
||
kan_data = {
|
||
'model_state_dict': model.state_dict(),
|
||
'metrics': kan_metrics,
|
||
'config': kan_config,
|
||
'model_info': kan_info
|
||
}
|
||
|
||
kan_filename = 'kan_optimized_product_P004_v1.pth'
|
||
kan_filepath = os.path.join(saved_models_dir, kan_filename)
|
||
|
||
torch.save(kan_data, kan_filepath)
|
||
|
||
print(f"\n已创建第二个测试模型: {kan_filepath}")
|
||
print(f"包含评估指标: {list(kan_metrics.keys())}")
|
||
print(f"R² = {kan_metrics['R2']:.4f}")
|
||
print(f"RMSE = {kan_metrics['RMSE']:.4f}")
|
||
|
||
if __name__ == "__main__":
|
||
create_test_model_with_metrics() |