解决mlstm保存错误问题
This commit is contained in:
parent
4ed92a1bc6
commit
ab8110e59b
@ -132,7 +132,7 @@ def train_with_cnn_bilstm_attention(product_id, model_identifier, product_df, st
|
||||
loss_curve_path = plot_loss_curve(
|
||||
loss_history['train'],
|
||||
loss_history['val'],
|
||||
product_name,
|
||||
model_identifier,
|
||||
'cnn_bilstm_attention',
|
||||
model_dir=model_dir
|
||||
)
|
||||
|
@ -393,19 +393,9 @@ def train_product_model_with_mlstm(
|
||||
|
||||
emit_progress("生成损失曲线...", progress=95)
|
||||
|
||||
# 确定模型保存目录(支持多店铺)
|
||||
if store_id:
|
||||
# 为特定店铺创建子目录
|
||||
store_model_dir = os.path.join(model_dir, 'mlstm', store_id)
|
||||
os.makedirs(store_model_dir, exist_ok=True)
|
||||
loss_curve_filename = f"{product_id}_mlstm_{version}_loss_curve.png"
|
||||
loss_curve_path = os.path.join(store_model_dir, loss_curve_filename)
|
||||
else:
|
||||
# 全局模型保存在global目录
|
||||
global_model_dir = os.path.join(model_dir, 'mlstm', 'global')
|
||||
os.makedirs(global_model_dir, exist_ok=True)
|
||||
loss_curve_filename = f"{product_id}_mlstm_{version}_global_loss_curve.png"
|
||||
loss_curve_path = os.path.join(global_model_dir, loss_curve_filename)
|
||||
# 确定模型保存目录
|
||||
loss_curve_filename = f"{model_identifier}_mlstm_{version}_loss_curve.png"
|
||||
loss_curve_path = os.path.join(model_dir, loss_curve_filename)
|
||||
|
||||
# 绘制损失曲线并保存到模型目录
|
||||
plt.figure(figsize=(10, 6))
|
||||
|
Loading…
x
Reference in New Issue
Block a user