ShopTRAINING/UI/src/views/TrainingView.vue
2025-07-02 11:05:23 +08:00

589 lines
17 KiB
Vue
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<template>
<el-row :gutter="20">
<!-- 左侧训练控制 -->
<el-col :span="8">
<el-card>
<template #header>
<span>启动模型训练</span>
</template>
<el-form :model="form" label-width="100px">
<el-form-item label="产品">
<el-select
v-model="form.product_id"
placeholder="请选择产品"
filterable
>
<el-option
v-for="item in products"
:key="item.product_id"
:label="item.product_name"
:value="item.product_id"
/>
</el-select>
</el-form-item>
<el-form-item label="店铺">
<el-select
v-model="form.store_id"
placeholder="选择店铺(留空为全局模型)"
clearable
filterable
>
<el-option label="全局模型(聚合所有店铺)" value=""></el-option>
<el-option
v-for="store in stores"
:key="store.store_id"
:label="store.store_name"
:value="store.store_id"
/>
</el-select>
</el-form-item>
<el-form-item label="模型类型">
<el-select
v-model="form.model_type"
placeholder="请选择模型"
@change="onModelTypeChange"
>
<el-option
v-for="item in modelTypes"
:key="item.id"
:label="item.name"
:value="item.id"
/>
</el-select>
</el-form-item>
<el-form-item label="训练类型">
<el-radio-group v-model="form.training_type">
<el-radio label="new">新训练</el-radio>
<el-radio label="retrain" :disabled="!hasExistingVersions"
>继续训练</el-radio
>
</el-radio-group>
</el-form-item>
<el-form-item
label="基础版本"
v-if="form.training_type === 'retrain'"
>
<el-select v-model="form.base_version" placeholder="选择基础版本">
<el-option
v-for="version in existingVersions"
:key="version"
:label="version"
:value="version"
/>
</el-select>
</el-form-item>
<el-form-item label="训练轮次">
<el-input-number v-model="form.epochs" :min="1" :max="1000" />
</el-form-item>
<el-form-item>
<el-button
type="primary"
@click="startTraining"
:loading="trainingLoading"
>启动训练</el-button
>
</el-form-item>
</el-form>
</el-card>
<!-- 增强的实时训练状态 -->
<EnhancedTrainingProgress
v-if="currentTraining"
:training-data="currentTraining.detailed_progress || currentTraining"
style="margin-top: 20px"
/>
<!-- 后备的简单训练状态卡片 -->
<el-card
v-if="currentTraining && !currentTraining.detailed_progress"
style="margin-top: 20px"
class="training-progress-container"
>
<template #header>
<span>训练状态</span>
</template>
<div>
<p><strong>任务ID:</strong> {{ currentTraining.task_id }}</p>
<p><strong>产品:</strong> {{ currentTraining.product_id }}</p>
<p><strong>店铺:</strong> {{ currentTraining.store_id || '全局模型' }}</p>
<p><strong>模型:</strong> {{ currentTraining.model_type }}</p>
<p><strong>版本:</strong> {{ currentTraining.version }}</p>
<p>
<strong>状态:</strong>
<el-tag :type="statusTag(currentTraining.status)">
{{ statusText(currentTraining.status) }}
</el-tag>
</p>
<el-progress
v-if="currentTraining.status === 'running'"
:percentage="currentTraining.progress || 0"
:format="formatProgress"
/>
<div v-if="currentTraining.message" style="margin-top: 10px">
<el-alert
:title="currentTraining.message"
type="info"
show-icon
:closable="false"
class="training-status-text"
/>
</div>
</div>
</el-card>
</el-col>
<!-- 右侧:任务状态 -->
<el-col :span="16">
<el-card>
<template #header>
<span>训练任务队列</span>
</template>
<el-table :data="trainingTasks" stripe>
<el-table-column
prop="task_id"
label="任务ID"
width="120"
show-overflow-tooltip
></el-table-column>
<el-table-column
prop="product_id"
label="产品ID"
width="100"
></el-table-column>
<el-table-column
prop="model_type"
label="模型类型"
width="120"
></el-table-column>
<el-table-column
prop="version"
label="版本"
width="80"
></el-table-column>
<el-table-column prop="status" label="状态" width="100">
<template #default="{ row }">
<el-tag :type="statusTag(row.status)">{{
statusText(row.status)
}}</el-tag>
</template>
</el-table-column>
<el-table-column prop="start_time" label="创建时间">
<template #default="{ row }">
{{ formatDateTime(row.start_time) }}
</template>
</el-table-column>
<el-table-column label="详情">
<template #default="{ row }">
<el-popover placement="left" trigger="hover" width="400">
<template #reference>
<el-button type="text" size="small">查看</el-button>
</template>
<div v-if="row.status === 'completed'">
<h4>评估指标</h4>
<pre>{{ JSON.stringify(row.metrics, null, 2) }}</pre>
<div v-if="row.version">
<h4>版本信息</h4>
<p><strong>版本:</strong> {{ row.version }}</p>
<p><strong>模型路径:</strong> {{ row.model_path }}</p>
</div>
</div>
<div v-if="row.status === 'failed'">
<h4>错误信息</h4>
<p>{{ row.error }}</p>
</div>
<div
v-if="row.status === 'running' || row.status === 'pending'"
>
<p>任务正在进行中...</p>
<div v-if="row.progress !== undefined">
<el-progress :percentage="row.progress" />
</div>
</div>
</el-popover>
</template>
</el-table-column>
</el-table>
</el-card>
</el-col>
</el-row>
</template>
<script setup>
import { ref, onMounted, onUnmounted, reactive, watch } from "vue";
import axios from "axios";
import {
ElMessage,
ElPopover,
ElButton,
ElTag,
ElProgress,
ElAlert
} from "element-plus";
import { io } from "socket.io-client";
import EnhancedTrainingProgress from "@/components/EnhancedTrainingProgress.vue";
const products = ref([]);
const stores = ref([]);
const modelTypes = ref([]);
const trainingLoading = ref(false);
const existingVersions = ref([]);
const hasExistingVersions = ref(false);
const currentTraining = ref(null);
const form = reactive({
product_id: "",
store_id: "",
model_type: "",
epochs: 50,
training_type: "new",
base_version: ""
});
const trainingTasks = ref([]);
let pollInterval = null;
let socket = null;
// WebSocket连接
const initWebSocket = () => {
socket = io("http://localhost:5000/training", {
transports: ["websocket", "polling"]
});
socket.on("connect", () => {
console.log("WebSocket连接成功");
});
socket.on("training_update", (data) => {
console.log("收到训练更新:", data);
// 更新当前训练状态
if (
currentTraining.value &&
currentTraining.value.task_id === data.task_id
) {
currentTraining.value = { ...currentTraining.value, ...data };
}
// 更新任务列表中的相应任务
const taskIndex = trainingTasks.value.findIndex(
(task) => task.task_id === data.task_id
);
if (taskIndex !== -1) {
trainingTasks.value[taskIndex] = {
...trainingTasks.value[taskIndex],
...data
};
}
// 显示消息通知
if (data.status === "completed") {
ElMessage.success(
`模型 ${data.model_type} 版本 ${data.version} 训练完成!`
);
currentTraining.value = null;
} else if (data.status === "failed") {
ElMessage.error(`模型训练失败: ${data.error}`);
currentTraining.value = null;
}
});
// 监听训练进度更新epoch级别的详细进度
socket.on("training_progress", (data) => {
console.log("收到训练进度更新:", data);
// 更新当前训练状态的进度信息
if (
currentTraining.value &&
currentTraining.value.task_id === data.task_id
) {
currentTraining.value = {
...currentTraining.value,
progress: data.progress || 0,
message: data.message || currentTraining.value.message,
status: 'running' // 确保状态为运行中
};
// 如果包含训练指标,也更新
if (data.metrics) {
currentTraining.value.metrics = data.metrics;
}
}
// 更新任务列表中的进度
const taskIndex = trainingTasks.value.findIndex(
(task) => task.task_id === data.task_id
);
if (taskIndex !== -1) {
trainingTasks.value[taskIndex] = {
...trainingTasks.value[taskIndex],
progress: data.progress || 0,
message: data.message || trainingTasks.value[taskIndex].message
};
}
});
// 专门监听训练完成事件
socket.on("training_completed", (data) => {
console.log("收到训练完成事件:", data);
// 确保当前训练状态被正确更新为完成
if (
currentTraining.value &&
currentTraining.value.task_id === data.task_id
) {
currentTraining.value = {
...currentTraining.value,
...data,
status: "completed",
progress: 100
};
// 显示完成通知
ElMessage.success(
`模型 ${data.model_type} 训练完成!`
);
// 2秒后清除当前训练状态
setTimeout(() => {
if (currentTraining.value && currentTraining.value.task_id === data.task_id) {
currentTraining.value = null;
}
}, 2000);
}
// 更新任务列表
const taskIndex = trainingTasks.value.findIndex(
(task) => task.task_id === data.task_id
);
if (taskIndex !== -1) {
trainingTasks.value[taskIndex] = {
...trainingTasks.value[taskIndex],
...data,
status: "completed",
progress: 100
};
}
// 刷新任务列表
fetchTrainingTasks();
});
// 监听详细训练进度更新
socket.on("training_progress_detailed", (data) => {
console.log("收到详细训练进度:", data);
// 更新当前训练的详细进度数据
if (
currentTraining.value &&
currentTraining.value.task_id === data.training_id
) {
currentTraining.value = {
...currentTraining.value,
detailed_progress: data,
status: 'training' // 确保状态为训练中
};
}
});
socket.on("disconnect", () => {
console.log("WebSocket连接断开");
});
};
const fetchProducts = async () => {
try {
const response = await axios.get("/api/products");
if (response.data.status === "success") {
products.value = response.data.data;
}
} catch (error) {
ElMessage.error("获取产品列表失败");
console.error(error);
}
};
const fetchModelTypes = async () => {
try {
const response = await axios.get("/api/model_types");
if (response.data.status === "success") {
modelTypes.value = response.data.data;
if (modelTypes.value.length > 0 && !form.model_type) {
form.model_type = modelTypes.value[0].id;
}
}
} catch (error) {
ElMessage.error("获取模型类型列表失败");
console.error(error);
}
};
const fetchStores = async () => {
try {
const response = await axios.get("/api/stores");
if (response.data.status === "success") {
stores.value = response.data.data;
}
} catch (error) {
console.error("获取店铺列表失败:", error);
// 不显示错误消息,因为店铺是可选的
}
};
const fetchExistingVersions = async () => {
if (!form.product_id || !form.model_type) {
existingVersions.value = [];
hasExistingVersions.value = false;
return;
}
try {
const response = await axios.get(
`/api/models/${form.product_id}/${form.model_type}/versions`
);
if (response.data.status === "success") {
existingVersions.value = response.data.data.versions || [];
hasExistingVersions.value = existingVersions.value.length > 0;
if (hasExistingVersions.value && !form.base_version) {
form.base_version = response.data.data.latest_version;
}
}
} catch (error) {
existingVersions.value = [];
hasExistingVersions.value = false;
console.error("获取现有版本失败:", error);
}
};
const onModelTypeChange = () => {
form.training_type = "new";
form.base_version = "";
fetchExistingVersions();
};
const fetchTrainingTasks = async () => {
try {
const response = await axios.get("/api/training");
if (response.data.status === "success") {
trainingTasks.value = response.data.data;
}
} catch (error) {
if (!pollInterval) ElMessage.error("获取训练任务列表失败");
console.error("获取训练任务列表失败", error);
}
};
const startTraining = async () => {
if (!form.product_id || !form.model_type) {
ElMessage.warning("请选择产品和模型类型");
return;
}
if (form.training_type === "retrain" && !form.base_version) {
ElMessage.warning("请选择基础版本进行继续训练");
return;
}
trainingLoading.value = true;
try {
const endpoint =
form.training_type === "retrain"
? "/api/training/retrain"
: "/api/training";
const payload =
form.training_type === "retrain"
? {
product_id: form.product_id,
store_id: form.store_id,
model_type: form.model_type,
epochs: form.epochs,
base_version: form.base_version
}
: form;
const response = await axios.post(endpoint, payload);
if (response.data.task_id) {
ElMessage.success(`训练任务 ${response.data.task_id} 已启动`);
// 设置当前训练状态
currentTraining.value = {
task_id: response.data.task_id,
product_id: form.product_id,
store_id: form.store_id,
model_type: form.model_type,
version: response.data.new_version || "v1",
status: "starting",
progress: 0,
message: "正在启动训练..."
};
// 加入WebSocket房间
if (socket) {
socket.emit("join_training", { task_id: response.data.task_id });
}
fetchTrainingTasks();
} else {
ElMessage.error(response.data.error || "启动训练失败");
}
} catch (error) {
const errorMsg = error.response?.data?.error || "启动训练请求失败";
ElMessage.error(errorMsg);
console.error(error);
} finally {
trainingLoading.value = false;
}
};
const statusTag = (status) => {
if (status === "completed") return "success";
if (status === "running") return "primary";
if (status === "starting") return "primary";
if (status === "pending") return "warning";
if (status === "failed") return "danger";
return "info";
};
const statusText = (status) => {
const map = {
pending: "等待中",
starting: "启动中",
running: "进行中",
completed: "已完成",
failed: "失败"
};
return map[status] || "未知";
};
const formatProgress = (percentage) => {
return `${percentage}%`;
};
const formatDateTime = (isoString) => {
if (!isoString) return "N/A";
return new Date(isoString).toLocaleString();
};
// 监听产品和模型类型变化
watch([() => form.product_id, () => form.model_type], () => {
fetchExistingVersions();
});
onMounted(() => {
fetchProducts();
fetchStores();
fetchModelTypes();
fetchTrainingTasks();
initWebSocket();
pollInterval = setInterval(fetchTrainingTasks, 10000); // 降低轮询频率因为有WebSocket
});
onUnmounted(() => {
if (pollInterval) {
clearInterval(pollInterval);
}
if (socket) {
socket.disconnect();
}
});
</script>