**日期**: 2025-07-18
**主题**: 统一训练页面UI显示并修复后端数据传递

### 问题描述
1.  在“按店铺训练”和“全局模型训练”页面的任务列表中,模型版本号前缺少 'v' 前缀,与“按品训练”页面不一致。
2.  在“全局模型训练”页面的任务列表中,“聚合方式”一列始终为空,无法显示数据。

### 根本原因
1.  **UI层面**: `UI/src/views/StoreTrainingView.vue` 和 `UI/src/views/training/GlobalTrainingView.vue` 在渲染版本号时,没有像 `ProductTrainingView.vue` 一样添加 'v' 前缀的模板。
2.  **后端层面**: `server/utils/training_process_manager.py` 中的 `TrainingTask` 数据类缺少 `aggregation_method` 字段,导致从任务提交到数据返回的整个流程中,该信息都丢失了。

### 解决方案
1.  **修复前端UI**:
    *   **文件**: `UI/src/views/StoreTrainingView.vue`, `UI/src/views/training/GlobalTrainingView.vue`
    *   **操作**: 修改了 `el-table-column` for `version`,为其添加了 `<template>`,使用 `<el-tag>v{{ row.version }}</el-tag>` 来渲染版本号,确保了显示格式的统一。

2.  **修复后端数据流**:
    *   **文件**: `server/utils/training_process_manager.py`
    *   **操作**:
        1.  在 `TrainingTask` 数据类中增加了 `aggregation_method: Optional[str] = None` 字段。
        2.  修改 `submit_task` 方法,使其在创建 `TrainingTask` 对象时能接收并设置 `aggregation_method`。
        3.  修改 `run_training_task` 方法,在调用 `predictor.train_model` 时,将 `task.aggregation_method` 传递下去。

### 最终结果
通过前后端的协同修复,现在所有训练页面的UI表现完全一致,并且全局训练的“聚合方式”能够被正确记录和显示。
This commit is contained in:
xz2000 2025-07-18 18:18:50 +08:00
parent 5b2cdfa74a
commit 341d8d179c
5 changed files with 86 additions and 12 deletions

View File

@ -244,7 +244,12 @@
prop="version"
label="版本"
width="80"
/>
>
<template #default="{ row }">
<el-tag v-if="row.version" type="primary" size="small">v{{ row.version }}</el-tag>
<span v-else>-</span>
</template>
</el-table-column>
<el-table-column prop="status" label="状态" width="100">
<template #default="{ row }">
<el-tag :type="statusTag(row.status)">

View File

@ -213,7 +213,12 @@
prop="version"
label="版本"
width="80"
/>
>
<template #default="{ row }">
<el-tag v-if="row.version" type="primary" size="small">v{{ row.version }}</el-tag>
<span v-else>-</span>
</template>
</el-table-column>
<el-table-column prop="status" label="状态" width="100">
<template #default="{ row }">
<el-tag :type="statusTag(row.status)">
@ -428,8 +433,8 @@ const initWebSocket = () => {
};
}
//
fetchTrainingTasks();
// (WebSocket)
// fetchTrainingTasks();
});
socket.on("disconnect", () => {
@ -563,7 +568,7 @@ const startTraining = async () => {
product_id: form.product_id,
store_id: form.data_scope === 'global' ? null : form.store_id,
model_type: form.model_type,
version: response.data.new_version || "v1",
version: response.data.path_info?.version || response.data.new_version || "v1",
status: "starting",
progress: 0,
message: "正在启动药品训练...",

View File

@ -228,7 +228,12 @@
prop="version"
label="版本"
width="80"
/>
>
<template #default="{ row }">
<el-tag v-if="row.version" type="primary" size="small">v{{ row.version }}</el-tag>
<span v-else>-</span>
</template>
</el-table-column>
<el-table-column prop="status" label="状态" width="100">
<template #default="{ row }">
<el-tag :type="statusTag(row.status)">

View File

@ -45,6 +45,7 @@ class TrainingTask:
model_type: str
training_mode: str
store_id: Optional[str] = None
aggregation_method: Optional[str] = None # 新增:聚合方式
epochs: int = 100
status: str = "pending" # pending, running, completed, failed
start_time: Optional[str] = None
@ -55,6 +56,7 @@ class TrainingTask:
metrics: Optional[Dict[str, Any]] = None
process_id: Optional[int] = None
path_info: Optional[Dict[str, Any]] = None # 新增字段
version: Optional[int] = None # 新增版本字段
class TrainingWorker:
"""训练工作进程"""
@ -146,6 +148,7 @@ class TrainingWorker:
epochs=task.epochs,
store_id=task.store_id,
training_mode=task.training_mode,
aggregation_method=task.aggregation_method, # 传递聚合方式
socketio=None, # 子进程中不能直接使用socketio
task_id=task.task_id,
progress_callback=progress_callback, # 传递进度回调函数
@ -301,6 +304,7 @@ class TrainingProcessManager:
training_mode=training_params.get('training_mode', 'product'),
store_id=training_params.get('store_id'),
epochs=training_params.get('epochs', 100),
aggregation_method=training_params.get('aggregation_method'), # 新增
path_info=path_info # 存储路径信息
)
@ -350,13 +354,14 @@ class TrainingProcessManager:
with self.lock:
if task_id in self.tasks:
task = self.tasks[task_id]
# 使用转换后的数据更新任务状态
for key, value in serializable_task_data.items():
setattr(self.tasks[task_id], key, value)
if hasattr(task, key):
setattr(task, key, value)
# 如果任务成功完成,则更新版本文件
# 如果任务成功完成,则更新版本文件和任务对象中的版本号
if action == 'complete':
task = self.tasks[task_id]
if task.path_info:
identifier = task.path_info.get('identifier')
version = task.path_info.get('version')
@ -364,6 +369,7 @@ class TrainingProcessManager:
try:
self.path_manager.save_version_info(identifier, version)
self.logger.info(f"✅ 版本信息已更新: identifier={identifier}, version={version}")
task.version = version # 关键修复:将版本号保存到任务对象中
except Exception as e:
self.logger.error(f"❌ 更新版本文件失败: {e}")
@ -371,6 +377,13 @@ class TrainingProcessManager:
if self.websocket_callback:
try:
if action == 'complete':
# 从任务信息中获取版本号
version = None
with self.lock:
task = self.tasks.get(task_id)
if task and task.path_info:
version = task.path_info.get('version')
# 训练完成 - 发送完成状态
self.websocket_callback('training_update', {
'task_id': task_id,
@ -381,7 +394,8 @@ class TrainingProcessManager:
'metrics': serializable_task_data.get('metrics'),
'end_time': serializable_task_data.get('end_time'),
'product_id': serializable_task_data.get('product_id'),
'model_type': serializable_task_data.get('model_type')
'model_type': serializable_task_data.get('model_type'),
'version': version # 添加版本号
})
# 额外发送一个完成事件,确保前端能收到
self.websocket_callback('training_completed', {
@ -391,7 +405,8 @@ class TrainingProcessManager:
'message': serializable_task_data.get('message', '训练完成'),
'metrics': serializable_task_data.get('metrics'),
'product_id': serializable_task_data.get('product_id'),
'model_type': serializable_task_data.get('model_type')
'model_type': serializable_task_data.get('model_type'),
'version': version # 添加版本号
})
elif action == 'error':
# 训练失败

View File

@ -1021,4 +1021,48 @@
5. **原因**: 使 `custom` 模式下的路径生成逻辑与 `selected_stores``selected_products` 模式保持一致在只选择一个ID时优先使用ID本身提高了路径的可读性和一致性。
6. **更新测试**:
* **文件**: `test/test_file_save_logic.py`
* **操作**: 增加了新的测试用例,专门验证“全局训练-自定义范围-单ID”场景下的路径生成是否正确。
* **操作**: 增加了新的测试用例,专门验证“全局训练-自定义范围-单ID”场景下的路径生成是否正确。
####
---
**日期**: 2025-07-18
**主题**: 统一训练页面UI显示并修复后端数据传递
### 问题描述
1. 在“按店铺训练”和“全局模型训练”页面的任务列表中,模型版本号前缺少 'v' 前缀,与“按品训练”页面不一致。
2. 在“全局模型训练”页面的任务列表中,“聚合方式”一列始终为空,无法显示数据。
### 根本原因
1. **UI层面**: `UI/src/views/StoreTrainingView.vue``UI/src/views/training/GlobalTrainingView.vue` 在渲染版本号时,没有像 `ProductTrainingView.vue` 一样添加 'v' 前缀的模板。
2. **后端层面**: `server/utils/training_process_manager.py` 中的 `TrainingTask` 数据类缺少 `aggregation_method` 字段,导致从任务提交到数据返回的整个流程中,该信息都丢失了。
### 解决方案
1. **修复前端UI**:
* **文件**: `UI/src/views/StoreTrainingView.vue`, `UI/src/views/training/GlobalTrainingView.vue`
* **操作**: 修改了 `el-table-column` for `version`,为其添加了 `<template>`,使用 `<el-tag>v{{ row.version }}</el-tag>` 来渲染版本号,确保了显示格式的统一。
2. **修复后端数据流**:
* **文件**: `server/utils/training_process_manager.py`
* **操作**:
1. 在 `TrainingTask` 数据类中增加了 `aggregation_method: Optional[str] = None` 字段。
2. 修改 `submit_task` 方法,使其在创建 `TrainingTask` 对象时能接收并设置 `aggregation_method`
3. 修改 `run_training_task` 方法,在调用 `predictor.train_model` 时,将 `task.aggregation_method` 传递下去。
### 最终结果
通过前后端的协同修复现在所有训练页面的UI表现完全一致并且全局训练的“聚合方式”能够被正确记录和显示。