2025-06-18 06:39:41 +08:00
|
|
|
|
"""
|
|
|
|
|
药店销售预测系统 - mLSTM模型训练函数
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
import time
|
|
|
|
|
import pandas as pd
|
|
|
|
|
import numpy as np
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
import torch.optim as optim
|
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
from sklearn.preprocessing import MinMaxScaler
|
|
|
|
|
import matplotlib.pyplot as plt
|
2025-07-15 20:09:05 +08:00
|
|
|
|
from datetime import datetime
|
2025-06-18 06:39:41 +08:00
|
|
|
|
|
|
|
|
|
from models.mlstm_model import MLSTMTransformer as MatrixLSTM
|
|
|
|
|
from utils.data_utils import create_dataset, PharmacyDataset
|
|
|
|
|
from analysis.metrics import evaluate_model
|
2025-07-02 11:05:23 +08:00
|
|
|
|
from core.config import (
|
2025-07-15 20:09:05 +08:00
|
|
|
|
DEVICE, LOOK_BACK, FORECAST_HORIZON
|
2025-07-02 11:05:23 +08:00
|
|
|
|
)
|
|
|
|
|
from utils.training_progress import progress_manager
|
2025-07-15 20:09:05 +08:00
|
|
|
|
from utils.model_manager import model_manager
|
2025-07-16 15:34:48 +08:00
|
|
|
|
from typing import Any
|
|
|
|
|
|
|
|
|
|
def convert_numpy_types(obj: Any) -> Any:
|
|
|
|
|
"""
|
|
|
|
|
递归地将字典或列表中的Numpy数值类型转换为Python原生类型。
|
|
|
|
|
"""
|
|
|
|
|
if isinstance(obj, dict):
|
|
|
|
|
return {k: convert_numpy_types(v) for k, v in obj.items()}
|
|
|
|
|
elif isinstance(obj, list):
|
|
|
|
|
return [convert_numpy_types(elem) for elem in obj]
|
|
|
|
|
elif isinstance(obj, np.floating):
|
|
|
|
|
return float(obj)
|
|
|
|
|
elif isinstance(obj, np.integer):
|
|
|
|
|
return int(obj)
|
|
|
|
|
elif isinstance(obj, np.ndarray):
|
|
|
|
|
return obj.tolist()
|
|
|
|
|
return obj
|
2025-07-02 11:05:23 +08:00
|
|
|
|
|
|
|
|
|
def train_product_model_with_mlstm(
|
2025-07-14 19:26:57 +08:00
|
|
|
|
product_id,
|
|
|
|
|
product_df,
|
2025-07-02 11:05:23 +08:00
|
|
|
|
store_id=None,
|
|
|
|
|
training_mode='product',
|
|
|
|
|
aggregation_method='sum',
|
2025-07-16 15:34:48 +08:00
|
|
|
|
scope=None,
|
2025-07-14 19:26:57 +08:00
|
|
|
|
epochs=50,
|
2025-07-02 11:05:23 +08:00
|
|
|
|
socketio=None,
|
|
|
|
|
task_id=None,
|
|
|
|
|
continue_training=False,
|
2025-07-14 19:26:57 +08:00
|
|
|
|
progress_callback=None,
|
|
|
|
|
patience=10,
|
|
|
|
|
learning_rate=0.001,
|
|
|
|
|
clip_norm=1.0
|
2025-07-02 11:05:23 +08:00
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
使用mLSTM训练产品销售预测模型
|
2025-07-15 20:09:05 +08:00
|
|
|
|
参数:
|
2025-07-02 11:05:23 +08:00
|
|
|
|
product_id: 产品ID
|
|
|
|
|
store_id: 店铺ID,为None时使用全局数据
|
|
|
|
|
training_mode: 训练模式 ('product', 'store', 'global')
|
|
|
|
|
aggregation_method: 聚合方法 ('sum', 'mean', 'weighted')
|
|
|
|
|
epochs: 训练轮次
|
|
|
|
|
model_dir: 模型保存目录
|
|
|
|
|
version: 模型版本,如果为None则自动生成
|
|
|
|
|
socketio: Socket.IO实例,用于实时进度推送
|
|
|
|
|
task_id: 任务ID
|
|
|
|
|
continue_training: 是否继续训练
|
|
|
|
|
progress_callback: 进度回调函数,用于多进程训练
|
2025-06-18 06:39:41 +08:00
|
|
|
|
"""
|
|
|
|
|
|
2025-07-16 15:34:48 +08:00
|
|
|
|
# 创建WebSocket进度反馈函数,支持多进程
|
2025-07-15 20:09:05 +08:00
|
|
|
|
|
2025-07-02 11:05:23 +08:00
|
|
|
|
def emit_progress(message, progress=None, metrics=None):
|
|
|
|
|
"""发送训练进度到前端"""
|
|
|
|
|
progress_data = {
|
|
|
|
|
'task_id': task_id,
|
|
|
|
|
'message': message,
|
|
|
|
|
'timestamp': time.time()
|
|
|
|
|
}
|
|
|
|
|
if progress is not None:
|
|
|
|
|
progress_data['progress'] = progress
|
|
|
|
|
if metrics is not None:
|
|
|
|
|
progress_data['metrics'] = metrics
|
|
|
|
|
|
|
|
|
|
if progress_callback:
|
|
|
|
|
try:
|
|
|
|
|
progress_callback(progress_data)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(f"[mLSTM] 进度回调失败: {e}")
|
|
|
|
|
|
|
|
|
|
if socketio and task_id:
|
|
|
|
|
try:
|
|
|
|
|
socketio.emit('training_progress', progress_data, namespace='/training')
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(f"[mLSTM] WebSocket发送失败: {e}")
|
|
|
|
|
|
|
|
|
|
print(f"[mLSTM] {message}", flush=True)
|
|
|
|
|
import sys
|
|
|
|
|
sys.stdout.flush()
|
|
|
|
|
sys.stderr.flush()
|
|
|
|
|
|
|
|
|
|
emit_progress("开始mLSTM模型训练...")
|
|
|
|
|
|
2025-07-15 20:09:05 +08:00
|
|
|
|
# 1. 确定模型标识符和版本
|
|
|
|
|
model_type = 'mlstm'
|
2025-07-16 15:34:48 +08:00
|
|
|
|
|
|
|
|
|
# 直接使用从 predictor 传递过来的、已经构建好的 scope
|
|
|
|
|
if scope is None:
|
|
|
|
|
# 作为后备,如果scope未提供,则根据旧逻辑构建(不推荐)
|
|
|
|
|
if training_mode == 'store':
|
|
|
|
|
current_product_id = product_id if product_id and product_id not in ['unknown', 'all'] else 'all'
|
|
|
|
|
scope = f"{store_id}_{current_product_id}"
|
|
|
|
|
elif training_mode == 'product':
|
|
|
|
|
scope = f"{product_id}_{store_id or 'all'}"
|
|
|
|
|
elif training_mode == 'global':
|
|
|
|
|
scope = product_id if product_id else "all"
|
|
|
|
|
emit_progress(f"警告: Scope未由调用方提供,已自动构建为 '{scope}'", 'warning')
|
2025-07-15 20:09:05 +08:00
|
|
|
|
|
|
|
|
|
model_identifier = model_manager.get_model_identifier(model_type, training_mode, scope, aggregation_method)
|
|
|
|
|
version = model_manager.get_next_version_number(model_identifier)
|
2025-07-02 11:05:23 +08:00
|
|
|
|
|
2025-07-15 20:09:05 +08:00
|
|
|
|
emit_progress(f"开始训练 mLSTM 模型 v{version}")
|
2025-07-02 11:05:23 +08:00
|
|
|
|
|
2025-07-15 20:09:05 +08:00
|
|
|
|
# 2. 获取模型版本路径
|
2025-07-16 15:34:48 +08:00
|
|
|
|
model_version_path = model_manager.get_model_version_path(
|
|
|
|
|
model_type=model_type,
|
|
|
|
|
training_mode=training_mode,
|
|
|
|
|
version=version,
|
|
|
|
|
aggregation_method=aggregation_method,
|
|
|
|
|
product_id=product_id,
|
|
|
|
|
store_id=store_id
|
|
|
|
|
)
|
2025-07-15 20:09:05 +08:00
|
|
|
|
emit_progress(f"模型将保存到: {model_version_path}")
|
2025-07-14 19:26:57 +08:00
|
|
|
|
|
|
|
|
|
if training_mode == 'store' and store_id:
|
|
|
|
|
training_scope = f"店铺 {store_id}"
|
|
|
|
|
elif training_mode == 'global':
|
|
|
|
|
training_scope = f"全局聚合({aggregation_method})"
|
2025-07-16 15:34:48 +08:00
|
|
|
|
else: # 主要对应 product 模式
|
|
|
|
|
if store_id:
|
|
|
|
|
training_scope = f"店铺 {store_id}"
|
|
|
|
|
else:
|
|
|
|
|
training_scope = "所有店铺"
|
2025-07-02 11:05:23 +08:00
|
|
|
|
|
|
|
|
|
min_required_samples = LOOK_BACK + FORECAST_HORIZON
|
|
|
|
|
if len(product_df) < min_required_samples:
|
2025-07-15 20:09:05 +08:00
|
|
|
|
error_msg = f"数据不足: 需要 {min_required_samples} 天, 实际 {len(product_df)} 天。"
|
2025-07-02 11:05:23 +08:00
|
|
|
|
print(error_msg)
|
2025-07-15 20:09:05 +08:00
|
|
|
|
emit_progress(f"训练失败:{error_msg}")
|
2025-07-02 11:05:23 +08:00
|
|
|
|
raise ValueError(error_msg)
|
|
|
|
|
|
2025-06-18 06:39:41 +08:00
|
|
|
|
product_name = product_df['product_name'].iloc[0]
|
|
|
|
|
|
2025-07-16 15:34:48 +08:00
|
|
|
|
print(f"[mLSTM] 使用mLSTM模型训练 '{product_name}' (ID: {product_id}) 的销售预测模型", flush=True)
|
2025-07-02 11:05:23 +08:00
|
|
|
|
print(f"[mLSTM] 训练范围: {training_scope}", flush=True)
|
2025-07-15 20:09:05 +08:00
|
|
|
|
print(f"[mLSTM] 版本: v{version}", flush=True)
|
2025-07-02 11:05:23 +08:00
|
|
|
|
print(f"[mLSTM] 使用设备: {DEVICE}", flush=True)
|
|
|
|
|
print(f"[mLSTM] 数据量: {len(product_df)} 条记录", flush=True)
|
2025-07-16 15:34:48 +08:00
|
|
|
|
emit_progress(f"训练产品: {product_name} (ID: {product_id}) - {training_scope}")
|
2025-06-18 06:39:41 +08:00
|
|
|
|
|
|
|
|
|
# 创建特征和目标变量
|
2025-07-14 19:26:57 +08:00
|
|
|
|
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
2025-07-16 15:34:48 +08:00
|
|
|
|
|
2025-07-02 11:05:23 +08:00
|
|
|
|
print(f"[mLSTM] 开始数据预处理,特征: {features}", flush=True)
|
|
|
|
|
|
2025-06-18 06:39:41 +08:00
|
|
|
|
# 预处理数据
|
|
|
|
|
X = product_df[features].values
|
2025-07-15 20:09:05 +08:00
|
|
|
|
y = product_df[['sales']].values
|
|
|
|
|
|
2025-07-02 11:05:23 +08:00
|
|
|
|
print(f"[mLSTM] 特征矩阵形状: {X.shape}, 目标矩阵形状: {y.shape}", flush=True)
|
|
|
|
|
emit_progress("数据预处理中...")
|
|
|
|
|
|
2025-06-18 06:39:41 +08:00
|
|
|
|
scaler_X = MinMaxScaler(feature_range=(0, 1))
|
|
|
|
|
scaler_y = MinMaxScaler(feature_range=(0, 1))
|
|
|
|
|
X_scaled = scaler_X.fit_transform(X)
|
|
|
|
|
y_scaled = scaler_y.fit_transform(y)
|
2025-07-16 15:34:48 +08:00
|
|
|
|
print(f"[mLSTM] 数据归一化完成", flush=True)
|
2025-06-18 06:39:41 +08:00
|
|
|
|
train_size = int(len(X_scaled) * 0.8)
|
|
|
|
|
X_train, X_test = X_scaled[:train_size], X_scaled[train_size:]
|
|
|
|
|
y_train, y_test = y_scaled[:train_size], y_scaled[train_size:]
|
|
|
|
|
|
|
|
|
|
trainX, trainY = create_dataset(X_train, y_train, LOOK_BACK, FORECAST_HORIZON)
|
|
|
|
|
testX, testY = create_dataset(X_test, y_test, LOOK_BACK, FORECAST_HORIZON)
|
|
|
|
|
|
2025-07-15 20:09:05 +08:00
|
|
|
|
train_loader = DataLoader(PharmacyDataset(torch.Tensor(trainX), torch.Tensor(trainY)), batch_size=32, shuffle=True)
|
|
|
|
|
test_loader = DataLoader(PharmacyDataset(torch.Tensor(testX), torch.Tensor(testY)), batch_size=32, shuffle=False)
|
2025-07-02 11:05:23 +08:00
|
|
|
|
|
2025-07-16 15:34:48 +08:00
|
|
|
|
total_batches = len(train_loader)
|
|
|
|
|
total_samples = len(train_loader.dataset)
|
|
|
|
|
print(f"[mLSTM] 数据加载器创建完成 - 批次数: {total_batches}, 样本数: {total_samples}", flush=True)
|
2025-07-02 11:05:23 +08:00
|
|
|
|
emit_progress(f"数据加载器准备完成 - 批次数: {total_batches}, 样本数: {total_samples}")
|
2025-06-18 06:39:41 +08:00
|
|
|
|
input_dim = X_train.shape[1]
|
|
|
|
|
output_dim = FORECAST_HORIZON
|
2025-07-15 20:09:05 +08:00
|
|
|
|
hidden_size, num_heads, dropout_rate, num_blocks, embed_dim, dense_dim = 128, 4, 0.1, 3, 32, 32
|
2025-07-02 11:05:23 +08:00
|
|
|
|
|
2025-06-18 06:39:41 +08:00
|
|
|
|
model = MatrixLSTM(
|
2025-07-15 20:09:05 +08:00
|
|
|
|
num_features=input_dim, hidden_size=hidden_size, mlstm_layers=2, embed_dim=embed_dim,
|
|
|
|
|
dense_dim=dense_dim, num_heads=num_heads, dropout_rate=dropout_rate,
|
|
|
|
|
num_blocks=num_blocks, output_sequence_length=output_dim
|
|
|
|
|
).to(DEVICE)
|
2025-07-16 15:34:48 +08:00
|
|
|
|
print(f"[mLSTM] 模型创建完成", flush=True)
|
2025-07-02 11:05:23 +08:00
|
|
|
|
emit_progress("mLSTM模型初始化完成")
|
2025-07-15 20:09:05 +08:00
|
|
|
|
if continue_training:
|
|
|
|
|
emit_progress("继续训练模式启动,但当前重构版本将从头开始。")
|
|
|
|
|
|
2025-06-18 06:39:41 +08:00
|
|
|
|
criterion = nn.MSELoss()
|
2025-07-14 19:26:57 +08:00
|
|
|
|
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
2025-07-15 11:56:19 +08:00
|
|
|
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=patience // 2, factor=0.5)
|
2025-07-14 19:26:57 +08:00
|
|
|
|
|
2025-07-02 11:05:23 +08:00
|
|
|
|
emit_progress("数据预处理完成,开始模型训练...", progress=10)
|
|
|
|
|
|
2025-07-15 20:09:05 +08:00
|
|
|
|
train_losses, test_losses = [], []
|
2025-06-18 06:39:41 +08:00
|
|
|
|
start_time = time.time()
|
2025-07-15 20:09:05 +08:00
|
|
|
|
checkpoint_interval = max(1, epochs // 10)
|
2025-07-02 11:05:23 +08:00
|
|
|
|
best_loss = float('inf')
|
2025-07-14 19:26:57 +08:00
|
|
|
|
epochs_no_improve = 0
|
2025-07-16 15:34:48 +08:00
|
|
|
|
emit_progress(f"开始训练 - 总epoch: {epochs}, 检查点间隔: {checkpoint_interval}, 耐心值: {patience}")
|
2025-07-15 20:09:05 +08:00
|
|
|
|
|
2025-07-02 11:05:23 +08:00
|
|
|
|
|
2025-06-18 06:39:41 +08:00
|
|
|
|
for epoch in range(epochs):
|
2025-07-16 15:34:48 +08:00
|
|
|
|
emit_progress(f"开始训练 Epoch {epoch+1}/{epochs}")
|
2025-07-02 11:05:23 +08:00
|
|
|
|
|
2025-06-18 06:39:41 +08:00
|
|
|
|
model.train()
|
|
|
|
|
epoch_loss = 0
|
2025-07-15 20:09:05 +08:00
|
|
|
|
for X_batch, y_batch in train_loader:
|
2025-06-18 06:39:41 +08:00
|
|
|
|
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
|
2025-07-15 20:09:05 +08:00
|
|
|
|
optimizer.zero_grad()
|
2025-06-18 06:39:41 +08:00
|
|
|
|
outputs = model(X_batch)
|
|
|
|
|
loss = criterion(outputs, y_batch)
|
|
|
|
|
loss.backward()
|
2025-07-14 19:26:57 +08:00
|
|
|
|
if clip_norm:
|
|
|
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_norm)
|
2025-06-18 06:39:41 +08:00
|
|
|
|
optimizer.step()
|
|
|
|
|
epoch_loss += loss.item()
|
2025-07-16 15:34:48 +08:00
|
|
|
|
# 计算训练损失
|
2025-06-18 06:39:41 +08:00
|
|
|
|
train_loss = epoch_loss / len(train_loader)
|
|
|
|
|
train_losses.append(train_loss)
|
2025-07-16 15:34:48 +08:00
|
|
|
|
# 在测试集上评估
|
2025-06-18 06:39:41 +08:00
|
|
|
|
model.eval()
|
|
|
|
|
test_loss = 0
|
|
|
|
|
with torch.no_grad():
|
2025-07-15 20:09:05 +08:00
|
|
|
|
for X_batch, y_batch in test_loader:
|
2025-06-18 06:39:41 +08:00
|
|
|
|
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
|
|
|
|
|
outputs = model(X_batch)
|
|
|
|
|
loss = criterion(outputs, y_batch)
|
|
|
|
|
test_loss += loss.item()
|
|
|
|
|
|
2025-07-15 20:09:05 +08:00
|
|
|
|
test_loss /= len(test_loader)
|
2025-06-18 06:39:41 +08:00
|
|
|
|
test_losses.append(test_loss)
|
2025-07-16 15:34:48 +08:00
|
|
|
|
# 更新学习率
|
2025-07-14 19:26:57 +08:00
|
|
|
|
scheduler.step(test_loss)
|
|
|
|
|
|
2025-07-02 11:05:23 +08:00
|
|
|
|
emit_progress(f"Epoch {epoch+1}/{epochs} 完成 - Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}",
|
2025-07-15 20:09:05 +08:00
|
|
|
|
progress=10 + ((epoch + 1) / epochs) * 85)
|
2025-07-02 11:05:23 +08:00
|
|
|
|
# 定期保存检查点
|
2025-07-16 15:34:48 +08:00
|
|
|
|
# 3. 保存检查点
|
|
|
|
|
checkpoint_data = {
|
2025-07-15 20:09:05 +08:00
|
|
|
|
'epoch': epoch + 1,
|
|
|
|
|
'model_state_dict': model.state_dict(),
|
|
|
|
|
'optimizer_state_dict': optimizer.state_dict(),
|
|
|
|
|
'scaler_X': scaler_X,
|
|
|
|
|
'scaler_y': scaler_y,
|
|
|
|
|
}
|
|
|
|
|
if (epoch + 1) % checkpoint_interval == 0:
|
|
|
|
|
model_manager.save_model_artifact(checkpoint_data, f"checkpoint_epoch_{epoch+1}.pth", model_version_path)
|
2025-07-02 11:05:23 +08:00
|
|
|
|
emit_progress(f"💾 保存训练检查点 epoch_{epoch+1}")
|
2025-07-14 19:26:57 +08:00
|
|
|
|
|
2025-07-15 20:09:05 +08:00
|
|
|
|
if test_loss < best_loss:
|
|
|
|
|
best_loss = test_loss
|
|
|
|
|
model_manager.save_model_artifact(checkpoint_data, "checkpoint_best.pth", model_version_path)
|
|
|
|
|
emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})")
|
|
|
|
|
epochs_no_improve = 0
|
|
|
|
|
else:
|
|
|
|
|
epochs_no_improve += 1
|
|
|
|
|
|
2025-07-14 19:26:57 +08:00
|
|
|
|
if epochs_no_improve >= patience:
|
|
|
|
|
emit_progress(f"连续 {patience} 个epoch测试损失未改善,提前停止训练。")
|
|
|
|
|
break
|
2025-07-15 20:09:05 +08:00
|
|
|
|
|
2025-06-18 06:39:41 +08:00
|
|
|
|
training_time = time.time() - start_time
|
|
|
|
|
|
2025-07-15 20:09:05 +08:00
|
|
|
|
loss_fig = plt.figure(figsize=(10, 6))
|
2025-07-02 11:05:23 +08:00
|
|
|
|
plt.plot(train_losses, label='Training Loss')
|
|
|
|
|
plt.plot(test_losses, label='Test Loss')
|
2025-07-15 20:09:05 +08:00
|
|
|
|
plt.title(f'mLSTM 损失曲线 - {product_name} (v{version}) - {training_scope}')
|
|
|
|
|
plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.legend(); plt.grid(True)
|
|
|
|
|
model_manager.save_model_artifact(loss_fig, "loss_curve.png", model_version_path)
|
|
|
|
|
plt.close(loss_fig)
|
|
|
|
|
print(f"损失曲线已保存到: {os.path.join(model_version_path, 'loss_curve.png')}")
|
2025-06-18 06:39:41 +08:00
|
|
|
|
|
|
|
|
|
model.eval()
|
|
|
|
|
with torch.no_grad():
|
2025-07-15 20:09:05 +08:00
|
|
|
|
test_pred = model(torch.Tensor(testX).to(DEVICE)).cpu().numpy()
|
2025-06-18 06:39:41 +08:00
|
|
|
|
|
2025-07-15 20:09:05 +08:00
|
|
|
|
metrics = evaluate_model(scaler_y.inverse_transform(testY), scaler_y.inverse_transform(test_pred))
|
2025-06-18 06:39:41 +08:00
|
|
|
|
metrics['training_time'] = training_time
|
2025-07-16 15:34:48 +08:00
|
|
|
|
|
|
|
|
|
# 解决 'Object of type float32 is not JSON serializable' 错误
|
|
|
|
|
metrics = convert_numpy_types(metrics)
|
|
|
|
|
# 打印评估指标
|
2025-06-18 06:39:41 +08:00
|
|
|
|
print("\n模型评估指标:")
|
|
|
|
|
print(f"MSE: {metrics['mse']:.4f}")
|
|
|
|
|
print(f"RMSE: {metrics['rmse']:.4f}")
|
|
|
|
|
print(f"MAE: {metrics['mae']:.4f}")
|
|
|
|
|
print(f"R²: {metrics['r2']:.4f}")
|
|
|
|
|
print(f"MAPE: {metrics['mape']:.2f}%")
|
|
|
|
|
print(f"训练时间: {training_time:.2f}秒")
|
2025-07-15 20:09:05 +08:00
|
|
|
|
|
2025-07-02 11:05:23 +08:00
|
|
|
|
final_model_data = {
|
2025-07-15 20:09:05 +08:00
|
|
|
|
'epoch': epoch + 1,
|
2025-06-18 06:39:41 +08:00
|
|
|
|
'model_state_dict': model.state_dict(),
|
|
|
|
|
'scaler_X': scaler_X,
|
|
|
|
|
'scaler_y': scaler_y,
|
2025-07-15 20:09:05 +08:00
|
|
|
|
}
|
|
|
|
|
model_manager.save_model_artifact(final_model_data, "model.pth", model_version_path)
|
|
|
|
|
|
|
|
|
|
metadata = {
|
|
|
|
|
'product_id': product_id, 'product_name': product_name, 'model_type': model_type,
|
|
|
|
|
'version': f'v{version}', 'training_mode': training_mode, 'scope': scope,
|
|
|
|
|
'aggregation_method': aggregation_method, 'training_scope_description': training_scope,
|
2025-07-16 15:34:48 +08:00
|
|
|
|
'product_scope': '所有药品' if product_id == 'all' else product_name,
|
2025-07-15 20:09:05 +08:00
|
|
|
|
'timestamp': datetime.now().isoformat(), 'metrics': metrics,
|
2025-06-18 06:39:41 +08:00
|
|
|
|
'config': {
|
2025-07-15 20:09:05 +08:00
|
|
|
|
'input_dim': input_dim, 'output_dim': output_dim, 'hidden_size': hidden_size,
|
|
|
|
|
'num_heads': num_heads, 'dropout': dropout_rate, 'num_blocks': num_blocks,
|
|
|
|
|
'embed_dim': embed_dim, 'dense_dim': dense_dim,
|
|
|
|
|
'sequence_length': LOOK_BACK, 'forecast_horizon': FORECAST_HORIZON,
|
2025-07-02 11:05:23 +08:00
|
|
|
|
}
|
|
|
|
|
}
|
2025-07-15 20:09:05 +08:00
|
|
|
|
model_manager.save_model_artifact(metadata, "metadata.json", model_version_path)
|
|
|
|
|
|
|
|
|
|
# 6. 更新版本文件
|
|
|
|
|
model_manager.update_version(model_identifier, version)
|
2025-07-02 11:05:23 +08:00
|
|
|
|
|
2025-07-15 20:09:05 +08:00
|
|
|
|
emit_progress(f"✅ mLSTM模型 v{version} 训练完成!", progress=100, metrics=metrics)
|
2025-06-18 06:39:41 +08:00
|
|
|
|
|
2025-07-15 20:09:05 +08:00
|
|
|
|
return model, metrics, version, model_version_path
|