Compare commits
2 Commits
87cc7b4d03
...
ab8110e59b
Author | SHA1 | Date | |
---|---|---|---|
ab8110e59b | |||
4ed92a1bc6 |
@ -116,12 +116,12 @@ const filters = reactive({
|
|||||||
|
|
||||||
const pagination = reactive({
|
const pagination = reactive({
|
||||||
currentPage: 1,
|
currentPage: 1,
|
||||||
pageSize: 8
|
pageSize: 12
|
||||||
})
|
})
|
||||||
|
|
||||||
const filteredModelList = computed(() => {
|
const filteredModelList = computed(() => {
|
||||||
return modelList.value.filter(model => {
|
return modelList.value.filter(model => {
|
||||||
const modelTypeMatch = !filters.model_type || model.model_type === filters.model_type
|
const modelTypeMatch = !filters.model_type || model.model_type === filters.model_type.id
|
||||||
return modelTypeMatch
|
return modelTypeMatch
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
@ -125,13 +125,13 @@ const filters = reactive({
|
|||||||
|
|
||||||
const pagination = reactive({
|
const pagination = reactive({
|
||||||
currentPage: 1,
|
currentPage: 1,
|
||||||
pageSize: 8
|
pageSize: 12
|
||||||
})
|
})
|
||||||
|
|
||||||
const filteredModelList = computed(() => {
|
const filteredModelList = computed(() => {
|
||||||
return modelList.value.filter(model => {
|
return modelList.value.filter(model => {
|
||||||
const productMatch = !filters.product_id || model.product_id === filters.product_id
|
const productMatch = !filters.product_id || model.product_id === filters.product_id
|
||||||
const modelTypeMatch = !filters.model_type || model.model_type === filters.model_type
|
const modelTypeMatch = !filters.model_type || model.model_type === filters.model_type.id
|
||||||
return productMatch && modelTypeMatch
|
return productMatch && modelTypeMatch
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
@ -126,7 +126,7 @@ const filters = reactive({
|
|||||||
|
|
||||||
const pagination = reactive({
|
const pagination = reactive({
|
||||||
currentPage: 1,
|
currentPage: 1,
|
||||||
pageSize: 8
|
pageSize: 12
|
||||||
})
|
})
|
||||||
|
|
||||||
const storeNameMap = computed(() => {
|
const storeNameMap = computed(() => {
|
||||||
@ -146,7 +146,7 @@ const modelsWithNames = computed(() => {
|
|||||||
const filteredModelList = computed(() => {
|
const filteredModelList = computed(() => {
|
||||||
return modelsWithNames.value.filter(model => {
|
return modelsWithNames.value.filter(model => {
|
||||||
const storeMatch = !filters.store_id || model.store_id === filters.store_id
|
const storeMatch = !filters.store_id || model.store_id === filters.store_id
|
||||||
const modelTypeMatch = !filters.model_type || model.model_type === filters.model_type
|
const modelTypeMatch = !filters.model_type || model.model_type === filters.model_type.id
|
||||||
return storeMatch && modelTypeMatch
|
return storeMatch && modelTypeMatch
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
@ -132,7 +132,7 @@ def train_with_cnn_bilstm_attention(product_id, model_identifier, product_df, st
|
|||||||
loss_curve_path = plot_loss_curve(
|
loss_curve_path = plot_loss_curve(
|
||||||
loss_history['train'],
|
loss_history['train'],
|
||||||
loss_history['val'],
|
loss_history['val'],
|
||||||
product_name,
|
model_identifier,
|
||||||
'cnn_bilstm_attention',
|
'cnn_bilstm_attention',
|
||||||
model_dir=model_dir
|
model_dir=model_dir
|
||||||
)
|
)
|
||||||
|
@ -393,19 +393,9 @@ def train_product_model_with_mlstm(
|
|||||||
|
|
||||||
emit_progress("生成损失曲线...", progress=95)
|
emit_progress("生成损失曲线...", progress=95)
|
||||||
|
|
||||||
# 确定模型保存目录(支持多店铺)
|
# 确定模型保存目录
|
||||||
if store_id:
|
loss_curve_filename = f"{model_identifier}_mlstm_{version}_loss_curve.png"
|
||||||
# 为特定店铺创建子目录
|
loss_curve_path = os.path.join(model_dir, loss_curve_filename)
|
||||||
store_model_dir = os.path.join(model_dir, 'mlstm', store_id)
|
|
||||||
os.makedirs(store_model_dir, exist_ok=True)
|
|
||||||
loss_curve_filename = f"{product_id}_mlstm_{version}_loss_curve.png"
|
|
||||||
loss_curve_path = os.path.join(store_model_dir, loss_curve_filename)
|
|
||||||
else:
|
|
||||||
# 全局模型保存在global目录
|
|
||||||
global_model_dir = os.path.join(model_dir, 'mlstm', 'global')
|
|
||||||
os.makedirs(global_model_dir, exist_ok=True)
|
|
||||||
loss_curve_filename = f"{product_id}_mlstm_{version}_global_loss_curve.png"
|
|
||||||
loss_curve_path = os.path.join(global_model_dir, loss_curve_filename)
|
|
||||||
|
|
||||||
# 绘制损失曲线并保存到模型目录
|
# 绘制损失曲线并保存到模型目录
|
||||||
plt.figure(figsize=(10, 6))
|
plt.figure(figsize=(10, 6))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user