ShopTRAINING/UI/src/views/training/GlobalTrainingView.vue
xz2000 a9a0e51769 # 修改记录日志 (日期: 2025-07-16)
## 1. 核心 Bug 修复

### 文件: `server/core/predictor.py`

- **问题**: 在 `train_model` 方法中调用内部辅助函数 `_prepare_training_params` 时,没有正确传递 `product_ids` 和 `store_ids` 参数,导致在 `_prepare_training_params` 内部发生 `NameError`。
- **修复**:
    - 修正了 `train_model` 方法内部对 `_prepare_training_params` 的调用,确保 `product_ids` 和 `store_ids` 被显式传递。
    - 此前已修复 `train_model` 的函数签名,使其能正确接收 `store_ids`。
- **结果**: 彻底解决了训练流程中的参数传递问题,根除了由此引发的 `NameError`。

## 2. 代码清理与重构

### 文件: `server/api.py`

- **内容**: 移除了在 `start_training` API 端点中遗留的旧版、基于线程(`threading.Thread`)的训练逻辑。
- **原因**: 该代码块已被新的、基于多进程(`multiprocessing`)的 `TrainingProcessManager` 完全取代。旧代码中包含了大量用于调试的 `thread_safe_print` 日志,已无用处。
- **结果**: `start_training` 端点的逻辑变得更加清晰,只负责参数校验和向 `TrainingProcessManager` 提交任务。

### 文件: `server/utils/training_process_manager.py`

- **内容**: 在 `TrainingWorker` 的 `run_training_task` 方法中,移除了一个用于模拟训练进度的 `for` 循环。
- **原因**: 该循环包含 `time.sleep(1)`,仅用于在没有实际训练逻辑时模拟进度更新,现在实际的训练器会通过回调函数报告真实进度,因此该模拟代码不再需要。
- **结果**: `TrainingWorker` 现在直接调用实际的训练器,不再有模拟延迟,代码更贴近生产环境。

## 3. 启动依赖

- **Python**: 3.x
- **主要库**:
    - Flask
    - Flask-SocketIO
    - Flasgger
    - pandas
    - numpy
    - torch
    - scikit-learn
    - matplotlib
- **启动命令**: `python server/api.py`
2025-07-16 15:34:57 +08:00

848 lines
25 KiB
Vue
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<template>
<div class="global-training-container">
<el-row :gutter="20">
<!-- 左侧训练控制 -->
<el-col :span="8">
<el-card>
<template #header>
<div class="card-header">
<span>全局模型训练</span>
<el-tag type="success">跨店铺通用</el-tag>
</div>
</template>
<div class="training-description">
<p>使用所有店铺的历史数据训练通用预测模型可用于新店铺或数据不足的场景</p>
</div>
<el-form :model="form" label-width="100px">
<el-form-item label="训练范围">
<el-radio-group v-model="form.training_scope">
<el-radio label="all_stores_all_products">所有店铺所有药品</el-radio>
<el-radio label="selected_stores">选择店铺</el-radio>
<el-radio label="selected_products">选择药品</el-radio>
<el-radio label="custom">自定义范围</el-radio>
</el-radio-group>
</el-form-item>
<el-form-item
label="选择店铺"
v-if="form.training_scope === 'selected_stores' || form.training_scope === 'custom'"
>
<el-select
v-model="form.store_ids"
placeholder="选择参与训练的店铺"
multiple
filterable
style="width: 100%"
>
<el-option
v-for="store in stores"
:key="store.store_id"
:label="`${store.store_name} (${store.location})`"
:value="store.store_id"
/>
</el-select>
</el-form-item>
<el-form-item
label="选择药品"
v-if="form.training_scope === 'selected_products' || form.training_scope === 'custom'"
>
<el-select
v-model="form.product_ids"
placeholder="选择参与训练的药品"
multiple
filterable
style="width: 100%"
>
<el-option
v-for="product in products"
:key="product.product_id"
:label="`${product.product_name} (${product.product_id})`"
:value="product.product_id"
/>
</el-select>
</el-form-item>
<el-form-item label="聚合方式">
<el-select
v-model="form.aggregation_method"
placeholder="选择数据聚合方式"
style="width: 100%"
>
<el-option label="求和 (Sum)" value="sum" />
<el-option label="平均值 (Mean)" value="mean" />
<el-option label="加权平均 (Weighted)" value="weighted" />
</el-select>
</el-form-item>
<el-form-item label="模型类型" required>
<el-select
v-model="form.model_type"
placeholder="请选择模型"
@change="onModelTypeChange"
style="width: 100%"
>
<el-option
v-for="item in modelTypes"
:key="item.id"
:label="item.name"
:value="item.id"
/>
</el-select>
</el-form-item>
<el-form-item label="训练模式">
<el-radio-group v-model="form.training_type">
<el-radio label="new">新训练</el-radio>
<el-radio label="retrain" :disabled="!hasExistingVersions">
继续训练
</el-radio>
</el-radio-group>
</el-form-item>
<el-form-item
label="基础版本"
v-if="form.training_type === 'retrain'"
>
<el-select v-model="form.base_version" placeholder="选择基础版本" style="width: 100%">
<el-option
v-for="version in existingVersions"
:key="version"
:label="version"
:value="version"
/>
</el-select>
</el-form-item>
<el-form-item label="训练轮次">
<el-input-number v-model="form.epochs" :min="1" :max="1000" style="width: 100%" />
</el-form-item>
<el-form-item>
<el-button
type="primary"
@click="startTraining"
:loading="trainingLoading"
:disabled="!form.model_type"
style="width: 100%"
>
<el-icon><Operation /></el-icon>
启动全局训练
</el-button>
</el-form-item>
</el-form>
<!-- 训练统计信息 -->
<el-card v-if="trainingStats" style="margin-top: 20px" shadow="never">
<template #header>
<span>训练数据统计</span>
</template>
<div class="training-stats">
<p><strong>涉及店铺:</strong> {{ trainingStats.stores_count }} 个</p>
<p><strong>涉及药品:</strong> {{ trainingStats.products_count }} 种</p>
<p><strong>数据记录:</strong> {{ trainingStats.records_count }} 条</p>
<p><strong>时间范围:</strong> {{ trainingStats.date_range }}</p>
</div>
</el-card>
</el-card>
<!-- 实时训练状态卡片 -->
<el-card
v-if="currentTraining"
style="margin-top: 20px"
class="training-progress-container"
>
<template #header>
<span>实时训练状态</span>
</template>
<div>
<p><strong>任务ID:</strong> {{ currentTraining.task_id }}</p>
<p><strong>训练范围:</strong> {{ getTrainingScopeText(currentTraining) }}</p>
<p><strong>聚合方式:</strong> {{ getAggregationText(currentTraining.aggregation_method) }}</p>
<p><strong>模型:</strong> {{ getModelTypeName(currentTraining.model_type) }}</p>
<p><strong>版本:</strong> {{ currentTraining.version }}</p>
<p>
<strong>状态:</strong>
<el-tag :type="statusTag(currentTraining.status)">
{{ statusText(currentTraining.status) }}
</el-tag>
</p>
<el-progress
v-if="currentTraining.status === 'running'"
:percentage="currentTraining.progress || 0"
:format="formatProgress"
/>
<div v-if="currentTraining.message" style="margin-top: 10px">
<el-alert
:title="currentTraining.message"
type="info"
show-icon
:closable="false"
class="training-status-text"
/>
</div>
<div
v-if="currentTraining.metrics"
style="margin-top: 10px"
class="training-metrics"
>
<h4>训练指标:</h4>
<pre>{{ JSON.stringify(currentTraining.metrics, null, 2) }}</pre>
</div>
</div>
</el-card>
</el-col>
<!-- 右侧:任务状态 -->
<el-col :span="16">
<el-card>
<template #header>
<div class="card-header">
<span>全局训练任务队列</span>
<el-button size="small" @click="fetchTrainingTasks">
<el-icon><Refresh /></el-icon>
刷新
</el-button>
</div>
</template>
<el-table :data="filteredTrainingTasks" stripe>
<el-table-column
prop="task_id"
label="任务ID"
width="120"
show-overflow-tooltip
/>
<el-table-column
label="训练范围"
width="150"
>
<template #default="{ row }">
{{ getTrainingScopeText(row) }}
</template>
</el-table-column>
<el-table-column
prop="aggregation_method"
label="聚合方式"
width="100"
>
<template #default="{ row }">
{{ getAggregationText(row.aggregation_method) }}
</template>
</el-table-column>
<el-table-column
prop="model_type"
label="模型类型"
width="120"
>
<template #default="{ row }">
{{ getModelTypeName(row.model_type) }}
</template>
</el-table-column>
<el-table-column
prop="version"
label="版本"
width="80"
/>
<el-table-column prop="status" label="状态" width="100">
<template #default="{ row }">
<el-tag :type="statusTag(row.status)">
{{ statusText(row.status) }}
</el-tag>
</template>
</el-table-column>
<el-table-column prop="start_time" label="创建时间">
<template #default="{ row }">
{{ formatDateTime(row.start_time) }}
</template>
</el-table-column>
<el-table-column label="详情">
<template #default="{ row }">
<el-popover placement="left" trigger="hover" width="400">
<template #reference>
<el-button type="text" size="small">查看</el-button>
</template>
<div v-if="row.status === 'completed'">
<h4>评估指标</h4>
<pre>{{ JSON.stringify(row.metrics, null, 2) }}</pre>
<div v-if="row.version">
<h4>版本信息</h4>
<p><strong>版本:</strong> {{ row.version }}</p>
<p><strong>模型路径:</strong> {{ row.model_path }}</p>
</div>
</div>
<div v-if="row.status === 'failed'">
<h4>错误信息</h4>
<p>{{ row.error }}</p>
</div>
<div
v-if="row.status === 'running' || row.status === 'pending'"
>
<p>任务正在进行中...</p>
<div v-if="row.progress !== undefined">
<el-progress :percentage="row.progress" />
</div>
</div>
</el-popover>
</template>
</el-table-column>
</el-table>
</el-card>
</el-col>
</el-row>
</div>
</template>
<script setup>
import { ref, onMounted, onUnmounted, reactive, watch, computed } from "vue";
import axios from "axios";
import {
ElMessage,
ElPopover,
ElButton,
ElTag,
ElProgress,
ElAlert
} from "element-plus";
import { io } from "socket.io-client";
import { Operation, Refresh } from '@element-plus/icons-vue';
const stores = ref([]);
const products = ref([]);
const modelTypes = ref([]);
const trainingLoading = ref(false);
const existingVersions = ref([]);
const hasExistingVersions = ref(false);
const currentTraining = ref(null);
const trainingStats = ref(null);
const form = reactive({
training_scope: "all_stores_all_products",
store_ids: [],
product_ids: [],
aggregation_method: "sum",
model_type: "",
epochs: 50,
training_type: "new",
base_version: ""
});
const trainingTasks = ref([]);
let pollInterval = null;
let socket = null;
// 过滤只显示全局训练任务
const filteredTrainingTasks = computed(() => {
return trainingTasks.value.filter(task => {
// 全局训练任务的特征training_mode为global
return task.training_mode === 'global';
});
});
// WebSocket连接
const initWebSocket = () => {
socket = io("http://localhost:5000/training", {
transports: ["websocket", "polling"]
});
socket.on("connect", () => {
console.log("WebSocket连接成功");
});
socket.on("training_update", (data) => {
console.log("收到训练更新:", data);
// 更新当前训练状态
if (
currentTraining.value &&
currentTraining.value.task_id === data.task_id
) {
currentTraining.value = { ...currentTraining.value, ...data };
}
// 更新任务列表中的相应任务
const taskIndex = trainingTasks.value.findIndex(
(task) => task.task_id === data.task_id
);
if (taskIndex !== -1) {
trainingTasks.value[taskIndex] = {
...trainingTasks.value[taskIndex],
...data
};
}
// 显示消息通知
if (data.status === "completed") {
ElMessage.success(
`全局模型 ${data.model_type} 版本 ${data.version} 训练完成!`
);
currentTraining.value = null;
} else if (data.status === "failed") {
ElMessage.error(`全局模型训练失败: ${data.error}`);
currentTraining.value = null;
}
});
// 监听训练进度更新epoch级别的详细进度
socket.on("training_progress", (data) => {
console.log("收到训练进度更新:", data);
// 更新当前训练状态的进度信息
if (
currentTraining.value &&
currentTraining.value.task_id === data.task_id
) {
currentTraining.value = {
...currentTraining.value,
progress: data.progress || 0,
message: data.message || currentTraining.value.message,
status: 'running' // 确保状态为运行中
};
// 如果包含训练指标,也更新
if (data.metrics) {
currentTraining.value.metrics = data.metrics;
}
}
// 更新任务列表中的进度
const taskIndex = trainingTasks.value.findIndex(
(task) => task.task_id === data.task_id
);
if (taskIndex !== -1) {
trainingTasks.value[taskIndex] = {
...trainingTasks.value[taskIndex],
progress: data.progress || 0,
message: data.message || trainingTasks.value[taskIndex].message
};
}
});
// 专门监听训练完成事件
socket.on("training_completed", (data) => {
console.log("收到训练完成事件:", data);
// 确保当前训练状态被正确更新为完成
if (
currentTraining.value &&
currentTraining.value.task_id === data.task_id
) {
currentTraining.value = {
...currentTraining.value,
...data,
status: "completed",
progress: 100
};
// 显示完成通知
ElMessage.success(
`全局模型 ${data.model_type} 训练完成!`
);
// 2秒后清除当前训练状态
setTimeout(() => {
if (currentTraining.value && currentTraining.value.task_id === data.task_id) {
currentTraining.value = null;
}
}, 2000);
}
// 更新任务列表
const taskIndex = trainingTasks.value.findIndex(
(task) => task.task_id === data.task_id
);
if (taskIndex !== -1) {
trainingTasks.value[taskIndex] = {
...trainingTasks.value[taskIndex],
...data,
status: "completed",
progress: 100
};
}
// 刷新任务列表
fetchTrainingTasks();
});
socket.on("disconnect", () => {
console.log("WebSocket连接断开");
});
};
const fetchStores = async () => {
try {
const response = await axios.get("/api/stores");
if (response.data.status === "success") {
stores.value = response.data.data;
}
} catch (error) {
console.error("获取店铺列表失败:", error);
}
};
const fetchProducts = async () => {
try {
const response = await axios.get("/api/products");
if (response.data.status === "success") {
products.value = response.data.data;
}
} catch (error) {
console.error("获取药品列表失败:", error);
}
};
const fetchModelTypes = async () => {
try {
const response = await axios.get("/api/model_types");
if (response.data.status === "success") {
modelTypes.value = response.data.data;
if (modelTypes.value.length > 0 && !form.model_type) {
form.model_type = modelTypes.value[0].id;
}
}
} catch (error) {
ElMessage.error("获取模型类型列表失败");
console.error(error);
}
};
const fetchTrainingStats = async () => {
try {
const params = {
training_scope: form.training_scope,
aggregation_method: form.aggregation_method
};
if (form.store_ids.length > 0) {
params.store_ids = form.store_ids.join(',');
}
if (form.product_ids.length > 0) {
params.product_ids = form.product_ids.join(',');
}
const response = await axios.get("/api/training/global/stats", { params });
if (response.data.status === "success") {
trainingStats.value = response.data.data;
}
} catch (error) {
console.error("获取训练统计信息失败:", error);
trainingStats.value = null;
}
};
const fetchExistingVersions = async () => {
if (!form.model_type) {
existingVersions.value = [];
hasExistingVersions.value = false;
return;
}
try {
// 全局训练的版本查询
const response = await axios.get(
`/api/models/global/${form.model_type}/versions`
);
if (response.data.status === "success") {
existingVersions.value = response.data.data.versions || [];
hasExistingVersions.value = existingVersions.value.length > 0;
if (hasExistingVersions.value && !form.base_version) {
form.base_version = response.data.data.latest_version;
}
}
} catch (error) {
existingVersions.value = [];
hasExistingVersions.value = false;
console.error("获取现有版本失败:", error);
}
};
const onModelTypeChange = () => {
form.training_type = "new";
form.base_version = "";
fetchExistingVersions();
};
const fetchTrainingTasks = async () => {
try {
const response = await axios.get("/api/training");
if (response.data.status === "success") {
trainingTasks.value = response.data.data;
}
} catch (error) {
if (!pollInterval) ElMessage.error("获取训练任务列表失败");
console.error("获取训练任务列表失败", error);
}
};
const startTraining = async () => {
if (!form.model_type) {
ElMessage.warning("请选择模型类型");
return;
}
if ((form.training_scope === 'selected_stores' || form.training_scope === 'custom') && form.store_ids.length === 0) {
ElMessage.warning("请选择参与训练的店铺");
return;
}
if ((form.training_scope === 'selected_products' || form.training_scope === 'custom') && form.product_ids.length === 0) {
ElMessage.warning("请选择参与训练的药品");
return;
}
if (form.training_type === "retrain" && !form.base_version) {
ElMessage.warning("请选择基础版本进行继续训练");
return;
}
trainingLoading.value = true;
try {
const endpoint =
form.training_type === "retrain"
? "/api/training/retrain"
: "/api/training";
const payload = {
model_type: form.model_type,
epochs: form.epochs,
training_mode: 'global', // 标识这是全局训练模式
training_scope: form.training_scope,
aggregation_method: form.aggregation_method,
store_ids: form.store_ids || [], // 确保始终发送数组
product_ids: form.product_ids || [] // 确保始终发送数组
};
// 关键修复即使是列表也传递第一个作为代表ID
if (payload.store_ids.length > 0) {
payload.store_id = payload.store_ids[0];
}
if (payload.product_ids.length > 0) {
payload.product_id = payload.product_ids[0];
}
if (form.training_type === "retrain") {
payload.base_version = form.base_version;
}
const response = await axios.post(endpoint, payload);
if (response.data.task_id) {
ElMessage.success(`全局训练任务 ${response.data.task_id} 已启动`);
// 设置当前训练状态
currentTraining.value = {
task_id: response.data.task_id,
model_type: form.model_type,
version: response.data.new_version || "v1",
status: "starting",
progress: 0,
message: "正在启动全局训练...",
training_mode: 'global',
training_scope: form.training_scope,
aggregation_method: form.aggregation_method,
store_ids: form.store_ids.length > 0 ? form.store_ids : null,
product_ids: form.product_ids.length > 0 ? form.product_ids : null
};
// 加入WebSocket房间
if (socket) {
socket.emit("join_training", { task_id: response.data.task_id });
}
fetchTrainingTasks();
} else {
ElMessage.error(response.data.error || "启动训练失败");
}
} catch (error) {
const errorMsg = error.response?.data?.error || "启动训练请求失败";
ElMessage.error(errorMsg);
console.error(error);
} finally {
trainingLoading.value = false;
}
};
// 辅助函数
const getModelTypeName = (modelType) => {
const model = modelTypes.value.find(m => m.id === modelType);
return model ? model.name : modelType;
};
const getTrainingScopeText = (task) => {
const scopeMap = {
'all_stores_all_products': '全部数据',
'selected_stores': `${task.store_ids?.length || 0} 个店铺`,
'selected_products': `${task.product_ids?.length || 0} 种药品`,
'custom': '自定义范围'
};
return scopeMap[task.training_scope] || '未知范围';
};
const getAggregationText = (method) => {
const methodMap = {
'sum': '求和',
'mean': '平均值',
'weighted': '加权平均'
};
return methodMap[method] || method;
};
const statusTag = (status) => {
if (status === "completed") return "success";
if (status === "running") return "primary";
if (status === "starting") return "primary";
if (status === "pending") return "warning";
if (status === "failed") return "danger";
return "info";
};
const statusText = (status) => {
const map = {
pending: "等待中",
starting: "启动中",
running: "进行中",
completed: "已完成",
failed: "失败"
};
return map[status] || "未知";
};
const formatProgress = (percentage) => {
return `${percentage}%`;
};
const formatDateTime = (isoString) => {
if (!isoString) return "N/A";
return new Date(isoString).toLocaleString();
};
// 监听训练范围变化,自动获取统计信息
watch([
() => form.training_scope,
() => form.store_ids,
() => form.product_ids,
() => form.aggregation_method
], () => {
fetchTrainingStats();
}, { deep: true });
// 监听模型类型变化
watch(() => form.model_type, () => {
fetchExistingVersions();
});
// 监听训练范围变化,清空相关选择
watch(() => form.training_scope, (newVal) => {
if (newVal === 'all_stores_all_products') {
form.store_ids = [];
form.product_ids = [];
} else if (newVal === 'selected_stores') {
form.product_ids = [];
} else if (newVal === 'selected_products') {
form.store_ids = [];
}
});
onMounted(() => {
fetchStores();
fetchProducts();
fetchModelTypes();
fetchTrainingTasks();
fetchTrainingStats();
initWebSocket();
pollInterval = setInterval(fetchTrainingTasks, 10000);
});
onUnmounted(() => {
if (pollInterval) {
clearInterval(pollInterval);
}
if (socket) {
socket.disconnect();
}
});
</script>
<style scoped>
.global-training-container {
padding: 20px;
}
.card-header {
display: flex;
justify-content: space-between;
align-items: center;
}
.training-description {
background-color: #f0f9ff;
padding: 15px;
border-radius: 6px;
margin-bottom: 20px;
border-left: 4px solid #67c23a;
}
.training-description p {
margin: 0;
color: #606266;
font-size: 14px;
line-height: 1.5;
}
.training-stats {
font-size: 14px;
}
.training-stats p {
margin: 8px 0;
color: #606266;
}
.training-progress-container {
border-left: 4px solid #67c23a;
}
.training-status-text {
margin-top: 10px;
}
.training-metrics {
background-color: #f5f7fa;
padding: 10px;
border-radius: 4px;
}
.training-metrics pre {
margin: 5px 0 0 0;
font-size: 12px;
line-height: 1.4;
white-space: pre-wrap;
word-wrap: break-word;
}
.el-radio-group {
width: 100%;
}
.el-radio {
margin-right: 20px;
margin-bottom: 10px;
}
@media (max-width: 768px) {
.global-training-container {
padding: 10px;
}
.el-col {
margin-bottom: 20px;
}
.el-radio {
display: block;
margin-right: 0;
margin-bottom: 10px;
}
}
</style>