ShopTRAINING/test/create_test_model_with_metrics.py
2025-07-02 11:05:23 +08:00

126 lines
3.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
创建包含评估指标的测试模型文件
"""
import torch
import os
from datetime import datetime
def create_test_model_with_metrics():
"""创建包含完整评估指标的测试模型"""
# 创建一个简单的模型
model = torch.nn.Linear(10, 1)
# 示例训练指标
metrics = {
'RMSE': 12.3456,
'MAE': 8.9012,
'R2': 0.8765,
'MAPE': 15.23,
'MSE': 152.414,
'training_time': 45.67,
'loss_curve': [0.5, 0.3, 0.2, 0.15, 0.12]
}
# 模型配置
config = {
'model_type': 'transformer',
'product_id': 'P003',
'version': 'v1',
'training_mode': 'product',
'input_size': 10,
'hidden_size': 64,
'num_layers': 2,
'epochs': 5,
'learning_rate': 0.001,
'batch_size': 32
}
# 模型信息
model_info = {
'product_id': 'P003',
'product_name': '阿司匹林片',
'model_type': 'transformer',
'version': 'v1',
'training_mode': 'product',
'created_at': datetime.now().isoformat(),
'store_id': None,
'aggregation_method': None
}
# 保存模型
model_data = {
'model_state_dict': model.state_dict(),
'metrics': metrics,
'config': config,
'model_info': model_info
}
# 确保saved_models目录存在
saved_models_dir = 'saved_models'
os.makedirs(saved_models_dir, exist_ok=True)
# 按照新的命名格式保存
filename = 'transformer_product_P003_v1.pth'
filepath = os.path.join(saved_models_dir, filename)
torch.save(model_data, filepath)
print(f"已创建测试模型: {filepath}")
print(f"包含评估指标: {list(metrics.keys())}")
print(f"R² = {metrics['R2']:.4f}")
print(f"RMSE = {metrics['RMSE']:.4f}")
print(f"MAE = {metrics['MAE']:.4f}")
print(f"MAPE = {metrics['MAPE']:.2f}%")
# 创建第二个模型KAN类型
kan_metrics = {
'RMSE': 9.8765,
'MAE': 6.4321,
'R2': 0.9123,
'MAPE': 12.34,
'MSE': 97.544,
'training_time': 67.89
}
kan_config = {
'model_type': 'kan_optimized',
'product_id': 'P004',
'version': 'v1',
'training_mode': 'product',
'grid_size': 5,
'spline_order': 3,
'epochs': 10
}
kan_info = {
'product_id': 'P004',
'product_name': '布洛芬胶囊',
'model_type': 'kan_optimized',
'version': 'v1',
'training_mode': 'product',
'created_at': datetime.now().isoformat(),
'store_id': None,
'aggregation_method': None
}
kan_data = {
'model_state_dict': model.state_dict(),
'metrics': kan_metrics,
'config': kan_config,
'model_info': kan_info
}
kan_filename = 'kan_optimized_product_P004_v1.pth'
kan_filepath = os.path.join(saved_models_dir, kan_filename)
torch.save(kan_data, kan_filepath)
print(f"\n已创建第二个测试模型: {kan_filepath}")
print(f"包含评估指标: {list(kan_metrics.keys())}")
print(f"R² = {kan_metrics['R2']:.4f}")
print(f"RMSE = {kan_metrics['RMSE']:.4f}")
if __name__ == "__main__":
create_test_model_with_metrics()