From ab8110e59b7890c92a70a87b47429639746e7476 Mon Sep 17 00:00:00 2001 From: xz2000 Date: Wed, 23 Jul 2025 18:15:10 +0800 Subject: [PATCH] =?UTF-8?q?=E8=A7=A3=E5=86=B3mlstm=E4=BF=9D=E5=AD=98?= =?UTF-8?q?=E9=94=99=E8=AF=AF=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/trainers/cnn_bilstm_attention_trainer.py | 2 +- server/trainers/mlstm_trainer.py | 16 +++------------- 2 files changed, 4 insertions(+), 14 deletions(-) diff --git a/server/trainers/cnn_bilstm_attention_trainer.py b/server/trainers/cnn_bilstm_attention_trainer.py index c8d4147..47b100f 100644 --- a/server/trainers/cnn_bilstm_attention_trainer.py +++ b/server/trainers/cnn_bilstm_attention_trainer.py @@ -132,7 +132,7 @@ def train_with_cnn_bilstm_attention(product_id, model_identifier, product_df, st loss_curve_path = plot_loss_curve( loss_history['train'], loss_history['val'], - product_name, + model_identifier, 'cnn_bilstm_attention', model_dir=model_dir ) diff --git a/server/trainers/mlstm_trainer.py b/server/trainers/mlstm_trainer.py index d3d0bf1..1cb0233 100644 --- a/server/trainers/mlstm_trainer.py +++ b/server/trainers/mlstm_trainer.py @@ -393,19 +393,9 @@ def train_product_model_with_mlstm( emit_progress("生成损失曲线...", progress=95) - # 确定模型保存目录(支持多店铺) - if store_id: - # 为特定店铺创建子目录 - store_model_dir = os.path.join(model_dir, 'mlstm', store_id) - os.makedirs(store_model_dir, exist_ok=True) - loss_curve_filename = f"{product_id}_mlstm_{version}_loss_curve.png" - loss_curve_path = os.path.join(store_model_dir, loss_curve_filename) - else: - # 全局模型保存在global目录 - global_model_dir = os.path.join(model_dir, 'mlstm', 'global') - os.makedirs(global_model_dir, exist_ok=True) - loss_curve_filename = f"{product_id}_mlstm_{version}_global_loss_curve.png" - loss_curve_path = os.path.join(global_model_dir, loss_curve_filename) + # 确定模型保存目录 + loss_curve_filename = f"{model_identifier}_mlstm_{version}_loss_curve.png" + loss_curve_path = os.path.join(model_dir, loss_curve_filename) # 绘制损失曲线并保存到模型目录 plt.figure(figsize=(10, 6))