--
**日期**: 2025-07-18 **主题**: 统一训练页面UI显示并修复后端数据传递 ### 问题描述 1. 在“按店铺训练”和“全局模型训练”页面的任务列表中,模型版本号前缺少 'v' 前缀,与“按品训练”页面不一致。 2. 在“全局模型训练”页面的任务列表中,“聚合方式”一列始终为空,无法显示数据。 ### 根本原因 1. **UI层面**: `UI/src/views/StoreTrainingView.vue` 和 `UI/src/views/training/GlobalTrainingView.vue` 在渲染版本号时,没有像 `ProductTrainingView.vue` 一样添加 'v' 前缀的模板。 2. **后端层面**: `server/utils/training_process_manager.py` 中的 `TrainingTask` 数据类缺少 `aggregation_method` 字段,导致从任务提交到数据返回的整个流程中,该信息都丢失了。 ### 解决方案 1. **修复前端UI**: * **文件**: `UI/src/views/StoreTrainingView.vue`, `UI/src/views/training/GlobalTrainingView.vue` * **操作**: 修改了 `el-table-column` for `version`,为其添加了 `<template>`,使用 `<el-tag>v{{ row.version }}</el-tag>` 来渲染版本号,确保了显示格式的统一。 2. **修复后端数据流**: * **文件**: `server/utils/training_process_manager.py` * **操作**: 1. 在 `TrainingTask` 数据类中增加了 `aggregation_method: Optional[str] = None` 字段。 2. 修改 `submit_task` 方法,使其在创建 `TrainingTask` 对象时能接收并设置 `aggregation_method`。 3. 修改 `run_training_task` 方法,在调用 `predictor.train_model` 时,将 `task.aggregation_method` 传递下去。 ### 最终结果 通过前后端的协同修复,现在所有训练页面的UI表现完全一致,并且全局训练的“聚合方式”能够被正确记录和显示。
This commit is contained in:
parent
5b2cdfa74a
commit
341d8d179c
@ -244,7 +244,12 @@
|
||||
prop="version"
|
||||
label="版本"
|
||||
width="80"
|
||||
/>
|
||||
>
|
||||
<template #default="{ row }">
|
||||
<el-tag v-if="row.version" type="primary" size="small">v{{ row.version }}</el-tag>
|
||||
<span v-else>-</span>
|
||||
</template>
|
||||
</el-table-column>
|
||||
<el-table-column prop="status" label="状态" width="100">
|
||||
<template #default="{ row }">
|
||||
<el-tag :type="statusTag(row.status)">
|
||||
|
@ -213,7 +213,12 @@
|
||||
prop="version"
|
||||
label="版本"
|
||||
width="80"
|
||||
/>
|
||||
>
|
||||
<template #default="{ row }">
|
||||
<el-tag v-if="row.version" type="primary" size="small">v{{ row.version }}</el-tag>
|
||||
<span v-else>-</span>
|
||||
</template>
|
||||
</el-table-column>
|
||||
<el-table-column prop="status" label="状态" width="100">
|
||||
<template #default="{ row }">
|
||||
<el-tag :type="statusTag(row.status)">
|
||||
@ -428,8 +433,8 @@ const initWebSocket = () => {
|
||||
};
|
||||
}
|
||||
|
||||
// 刷新任务列表
|
||||
fetchTrainingTasks();
|
||||
// 刷新任务列表 (注释掉,因为WebSocket已经提供了最新数据)
|
||||
// fetchTrainingTasks();
|
||||
});
|
||||
|
||||
socket.on("disconnect", () => {
|
||||
@ -563,7 +568,7 @@ const startTraining = async () => {
|
||||
product_id: form.product_id,
|
||||
store_id: form.data_scope === 'global' ? null : form.store_id,
|
||||
model_type: form.model_type,
|
||||
version: response.data.new_version || "v1",
|
||||
version: response.data.path_info?.version || response.data.new_version || "v1",
|
||||
status: "starting",
|
||||
progress: 0,
|
||||
message: "正在启动药品训练...",
|
||||
|
@ -228,7 +228,12 @@
|
||||
prop="version"
|
||||
label="版本"
|
||||
width="80"
|
||||
/>
|
||||
>
|
||||
<template #default="{ row }">
|
||||
<el-tag v-if="row.version" type="primary" size="small">v{{ row.version }}</el-tag>
|
||||
<span v-else>-</span>
|
||||
</template>
|
||||
</el-table-column>
|
||||
<el-table-column prop="status" label="状态" width="100">
|
||||
<template #default="{ row }">
|
||||
<el-tag :type="statusTag(row.status)">
|
||||
|
@ -45,6 +45,7 @@ class TrainingTask:
|
||||
model_type: str
|
||||
training_mode: str
|
||||
store_id: Optional[str] = None
|
||||
aggregation_method: Optional[str] = None # 新增:聚合方式
|
||||
epochs: int = 100
|
||||
status: str = "pending" # pending, running, completed, failed
|
||||
start_time: Optional[str] = None
|
||||
@ -55,6 +56,7 @@ class TrainingTask:
|
||||
metrics: Optional[Dict[str, Any]] = None
|
||||
process_id: Optional[int] = None
|
||||
path_info: Optional[Dict[str, Any]] = None # 新增字段
|
||||
version: Optional[int] = None # 新增版本字段
|
||||
|
||||
class TrainingWorker:
|
||||
"""训练工作进程"""
|
||||
@ -146,6 +148,7 @@ class TrainingWorker:
|
||||
epochs=task.epochs,
|
||||
store_id=task.store_id,
|
||||
training_mode=task.training_mode,
|
||||
aggregation_method=task.aggregation_method, # 传递聚合方式
|
||||
socketio=None, # 子进程中不能直接使用socketio
|
||||
task_id=task.task_id,
|
||||
progress_callback=progress_callback, # 传递进度回调函数
|
||||
@ -301,6 +304,7 @@ class TrainingProcessManager:
|
||||
training_mode=training_params.get('training_mode', 'product'),
|
||||
store_id=training_params.get('store_id'),
|
||||
epochs=training_params.get('epochs', 100),
|
||||
aggregation_method=training_params.get('aggregation_method'), # 新增
|
||||
path_info=path_info # 存储路径信息
|
||||
)
|
||||
|
||||
@ -350,13 +354,14 @@ class TrainingProcessManager:
|
||||
|
||||
with self.lock:
|
||||
if task_id in self.tasks:
|
||||
task = self.tasks[task_id]
|
||||
# 使用转换后的数据更新任务状态
|
||||
for key, value in serializable_task_data.items():
|
||||
setattr(self.tasks[task_id], key, value)
|
||||
if hasattr(task, key):
|
||||
setattr(task, key, value)
|
||||
|
||||
# 如果任务成功完成,则更新版本文件
|
||||
# 如果任务成功完成,则更新版本文件和任务对象中的版本号
|
||||
if action == 'complete':
|
||||
task = self.tasks[task_id]
|
||||
if task.path_info:
|
||||
identifier = task.path_info.get('identifier')
|
||||
version = task.path_info.get('version')
|
||||
@ -364,6 +369,7 @@ class TrainingProcessManager:
|
||||
try:
|
||||
self.path_manager.save_version_info(identifier, version)
|
||||
self.logger.info(f"✅ 版本信息已更新: identifier={identifier}, version={version}")
|
||||
task.version = version # 关键修复:将版本号保存到任务对象中
|
||||
except Exception as e:
|
||||
self.logger.error(f"❌ 更新版本文件失败: {e}")
|
||||
|
||||
@ -371,6 +377,13 @@ class TrainingProcessManager:
|
||||
if self.websocket_callback:
|
||||
try:
|
||||
if action == 'complete':
|
||||
# 从任务信息中获取版本号
|
||||
version = None
|
||||
with self.lock:
|
||||
task = self.tasks.get(task_id)
|
||||
if task and task.path_info:
|
||||
version = task.path_info.get('version')
|
||||
|
||||
# 训练完成 - 发送完成状态
|
||||
self.websocket_callback('training_update', {
|
||||
'task_id': task_id,
|
||||
@ -381,7 +394,8 @@ class TrainingProcessManager:
|
||||
'metrics': serializable_task_data.get('metrics'),
|
||||
'end_time': serializable_task_data.get('end_time'),
|
||||
'product_id': serializable_task_data.get('product_id'),
|
||||
'model_type': serializable_task_data.get('model_type')
|
||||
'model_type': serializable_task_data.get('model_type'),
|
||||
'version': version # 添加版本号
|
||||
})
|
||||
# 额外发送一个完成事件,确保前端能收到
|
||||
self.websocket_callback('training_completed', {
|
||||
@ -391,7 +405,8 @@ class TrainingProcessManager:
|
||||
'message': serializable_task_data.get('message', '训练完成'),
|
||||
'metrics': serializable_task_data.get('metrics'),
|
||||
'product_id': serializable_task_data.get('product_id'),
|
||||
'model_type': serializable_task_data.get('model_type')
|
||||
'model_type': serializable_task_data.get('model_type'),
|
||||
'version': version # 添加版本号
|
||||
})
|
||||
elif action == 'error':
|
||||
# 训练失败
|
||||
|
@ -1021,4 +1021,48 @@
|
||||
5. **原因**: 使 `custom` 模式下的路径生成逻辑与 `selected_stores` 和 `selected_products` 模式保持一致,在只选择一个ID时优先使用ID本身,提高了路径的可读性和一致性。
|
||||
6. **更新测试**:
|
||||
* **文件**: `test/test_file_save_logic.py`
|
||||
* **操作**: 增加了新的测试用例,专门验证“全局训练-自定义范围-单ID”场景下的路径生成是否正确。
|
||||
* **操作**: 增加了新的测试用例,专门验证“全局训练-自定义范围-单ID”场景下的路径生成是否正确。
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
####
|
||||
---
|
||||
**日期**: 2025-07-18
|
||||
**主题**: 统一训练页面UI显示并修复后端数据传递
|
||||
|
||||
### 问题描述
|
||||
1. 在“按店铺训练”和“全局模型训练”页面的任务列表中,模型版本号前缺少 'v' 前缀,与“按品训练”页面不一致。
|
||||
2. 在“全局模型训练”页面的任务列表中,“聚合方式”一列始终为空,无法显示数据。
|
||||
|
||||
### 根本原因
|
||||
1. **UI层面**: `UI/src/views/StoreTrainingView.vue` 和 `UI/src/views/training/GlobalTrainingView.vue` 在渲染版本号时,没有像 `ProductTrainingView.vue` 一样添加 'v' 前缀的模板。
|
||||
2. **后端层面**: `server/utils/training_process_manager.py` 中的 `TrainingTask` 数据类缺少 `aggregation_method` 字段,导致从任务提交到数据返回的整个流程中,该信息都丢失了。
|
||||
|
||||
### 解决方案
|
||||
1. **修复前端UI**:
|
||||
* **文件**: `UI/src/views/StoreTrainingView.vue`, `UI/src/views/training/GlobalTrainingView.vue`
|
||||
* **操作**: 修改了 `el-table-column` for `version`,为其添加了 `<template>`,使用 `<el-tag>v{{ row.version }}</el-tag>` 来渲染版本号,确保了显示格式的统一。
|
||||
|
||||
2. **修复后端数据流**:
|
||||
* **文件**: `server/utils/training_process_manager.py`
|
||||
* **操作**:
|
||||
1. 在 `TrainingTask` 数据类中增加了 `aggregation_method: Optional[str] = None` 字段。
|
||||
2. 修改 `submit_task` 方法,使其在创建 `TrainingTask` 对象时能接收并设置 `aggregation_method`。
|
||||
3. 修改 `run_training_task` 方法,在调用 `predictor.train_model` 时,将 `task.aggregation_method` 传递下去。
|
||||
|
||||
### 最终结果
|
||||
通过前后端的协同修复,现在所有训练页面的UI表现完全一致,并且全局训练的“聚合方式”能够被正确记录和显示。
|
Loading…
x
Reference in New Issue
Block a user