ShopTRAINING/UI/src/views/training/ProductTrainingView.vue
2025-07-02 11:05:23 +08:00

735 lines
21 KiB
Vue
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<template>
<div class="product-training-container">
<el-row :gutter="20">
<!-- 左侧训练控制 -->
<el-col :span="8">
<el-card>
<template #header>
<div class="card-header">
<span>按药品训练模型</span>
<el-tag type="info">针对特定药品</el-tag>
</div>
</template>
<div class="training-description">
<p>为特定药品训练专门的预测模型可选择使用单店铺数据或聚合多店铺数据</p>
</div>
<el-form :model="form" label-width="100px">
<el-form-item label="选择药品" required>
<el-select
v-model="form.product_id"
placeholder="请选择要训练的药品"
filterable
style="width: 100%"
>
<el-option
v-for="item in products"
:key="item.product_id"
:label="`${item.product_name} (${item.product_id})`"
:value="item.product_id"
/>
</el-select>
</el-form-item>
<el-form-item label="数据范围">
<el-radio-group v-model="form.data_scope">
<el-radio label="global">聚合所有店铺</el-radio>
<el-radio label="specific">指定店铺</el-radio>
</el-radio-group>
</el-form-item>
<el-form-item
label="选择店铺"
v-if="form.data_scope === 'specific'"
>
<el-select
v-model="form.store_id"
placeholder="选择店铺"
filterable
style="width: 100%"
>
<el-option
v-for="store in stores"
:key="store.store_id"
:label="`${store.store_name} (${store.store_id})`"
:value="store.store_id"
/>
</el-select>
</el-form-item>
<el-form-item label="模型类型" required>
<el-select
v-model="form.model_type"
placeholder="请选择模型"
@change="onModelTypeChange"
style="width: 100%"
>
<el-option
v-for="item in modelTypes"
:key="item.id"
:label="item.name"
:value="item.id"
/>
</el-select>
</el-form-item>
<el-form-item label="训练模式">
<el-radio-group v-model="form.training_type">
<el-radio label="new">新训练</el-radio>
<el-radio label="retrain" :disabled="!hasExistingVersions">
继续训练
</el-radio>
</el-radio-group>
</el-form-item>
<el-form-item
label="基础版本"
v-if="form.training_type === 'retrain'"
>
<el-select v-model="form.base_version" placeholder="选择基础版本" style="width: 100%">
<el-option
v-for="version in existingVersions"
:key="version"
:label="version"
:value="version"
/>
</el-select>
</el-form-item>
<el-form-item label="训练轮次">
<el-input-number v-model="form.epochs" :min="1" :max="1000" style="width: 100%" />
</el-form-item>
<el-form-item>
<el-button
type="primary"
@click="startTraining"
:loading="trainingLoading"
:disabled="!form.product_id || !form.model_type"
style="width: 100%"
>
<el-icon><Cpu /></el-icon>
启动药品训练
</el-button>
</el-form-item>
</el-form>
</el-card>
<!-- 实时训练状态卡片 -->
<el-card
v-if="currentTraining"
style="margin-top: 20px"
class="training-progress-container"
>
<template #header>
<span>实时训练状态</span>
</template>
<div>
<p><strong>任务ID:</strong> {{ currentTraining.task_id }}</p>
<p><strong>药品:</strong> {{ getProductName(currentTraining.product_id) }}</p>
<p><strong>数据范围:</strong> {{ getDataScopeText(currentTraining) }}</p>
<p><strong>模型:</strong> {{ getModelTypeName(currentTraining.model_type) }}</p>
<p><strong>版本:</strong> {{ currentTraining.version }}</p>
<p>
<strong>状态:</strong>
<el-tag :type="statusTag(currentTraining.status)">
{{ statusText(currentTraining.status) }}
</el-tag>
</p>
<el-progress
v-if="currentTraining.status === 'running'"
:percentage="currentTraining.progress || 0"
:format="formatProgress"
/>
<div v-if="currentTraining.message" style="margin-top: 10px">
<el-alert
:title="currentTraining.message"
type="info"
show-icon
:closable="false"
class="training-status-text"
/>
</div>
<div
v-if="currentTraining.metrics"
style="margin-top: 10px"
class="training-metrics"
>
<h4>训练指标:</h4>
<pre>{{ JSON.stringify(currentTraining.metrics, null, 2) }}</pre>
</div>
</div>
</el-card>
</el-col>
<!-- 右侧:任务状态 -->
<el-col :span="16">
<el-card>
<template #header>
<div class="card-header">
<span>药品训练任务队列</span>
<el-button size="small" @click="fetchTrainingTasks">
<el-icon><Refresh /></el-icon>
刷新
</el-button>
</div>
</template>
<el-table :data="filteredTrainingTasks" stripe>
<el-table-column
prop="task_id"
label="任务ID"
width="120"
show-overflow-tooltip
/>
<el-table-column
prop="product_id"
label="药品"
width="100"
>
<template #default="{ row }">
{{ getProductName(row.product_id) }}
</template>
</el-table-column>
<el-table-column
prop="store_id"
label="范围"
width="120"
>
<template #default="{ row }">
{{ getDataScopeText(row) }}
</template>
</el-table-column>
<el-table-column
prop="model_type"
label="模型类型"
width="120"
>
<template #default="{ row }">
{{ getModelTypeName(row.model_type) }}
</template>
</el-table-column>
<el-table-column
prop="version"
label="版本"
width="80"
/>
<el-table-column prop="status" label="状态" width="100">
<template #default="{ row }">
<el-tag :type="statusTag(row.status)">
{{ statusText(row.status) }}
</el-tag>
</template>
</el-table-column>
<el-table-column prop="start_time" label="创建时间">
<template #default="{ row }">
{{ formatDateTime(row.start_time) }}
</template>
</el-table-column>
<el-table-column label="详情">
<template #default="{ row }">
<el-popover placement="left" trigger="hover" width="400">
<template #reference>
<el-button type="text" size="small">查看</el-button>
</template>
<div v-if="row.status === 'completed'">
<h4>评估指标</h4>
<pre>{{ JSON.stringify(row.metrics, null, 2) }}</pre>
<div v-if="row.version">
<h4>版本信息</h4>
<p><strong>版本:</strong> {{ row.version }}</p>
<p><strong>模型路径:</strong> {{ row.model_path }}</p>
</div>
</div>
<div v-if="row.status === 'failed'">
<h4>错误信息</h4>
<p>{{ row.error }}</p>
</div>
<div
v-if="row.status === 'running' || row.status === 'pending'"
>
<p>任务正在进行中...</p>
<div v-if="row.progress !== undefined">
<el-progress :percentage="row.progress" />
</div>
</div>
</el-popover>
</template>
</el-table-column>
</el-table>
</el-card>
</el-col>
</el-row>
</div>
</template>
<script setup>
import { ref, onMounted, onUnmounted, reactive, watch, computed } from "vue";
import axios from "axios";
import {
ElMessage,
ElPopover,
ElButton,
ElTag,
ElProgress,
ElAlert
} from "element-plus";
import { io } from "socket.io-client";
import { Cpu, Refresh } from '@element-plus/icons-vue';
const products = ref([]);
const stores = ref([]);
const modelTypes = ref([]);
const trainingLoading = ref(false);
const existingVersions = ref([]);
const hasExistingVersions = ref(false);
const currentTraining = ref(null);
const form = reactive({
product_id: "",
store_id: "",
data_scope: "global",
model_type: "",
epochs: 50,
training_type: "new",
base_version: ""
});
const trainingTasks = ref([]);
let pollInterval = null;
let socket = null;
// 过滤只显示药品训练任务
const filteredTrainingTasks = computed(() => {
return trainingTasks.value.filter(task => {
// 药品训练任务的特征有product_id且不是全局任务类型
return task.product_id && task.training_mode !== 'global';
});
});
// WebSocket连接
const initWebSocket = () => {
socket = io("http://localhost:5000/training", {
transports: ["websocket", "polling"]
});
socket.on("connect", () => {
console.log("WebSocket连接成功");
});
socket.on("training_update", (data) => {
console.log("收到训练更新:", data);
// 更新当前训练状态
if (
currentTraining.value &&
currentTraining.value.task_id === data.task_id
) {
currentTraining.value = { ...currentTraining.value, ...data };
}
// 更新任务列表中的相应任务
const taskIndex = trainingTasks.value.findIndex(
(task) => task.task_id === data.task_id
);
if (taskIndex !== -1) {
trainingTasks.value[taskIndex] = {
...trainingTasks.value[taskIndex],
...data
};
}
// 显示消息通知
if (data.status === "completed") {
ElMessage.success(
`药品模型 ${data.model_type} 版本 ${data.version} 训练完成!`
);
currentTraining.value = null;
} else if (data.status === "failed") {
ElMessage.error(`药品模型训练失败: ${data.error}`);
currentTraining.value = null;
}
});
// 监听训练进度更新epoch级别的详细进度
socket.on("training_progress", (data) => {
console.log("收到训练进度更新:", data);
// 更新当前训练状态的进度信息
if (
currentTraining.value &&
currentTraining.value.task_id === data.task_id
) {
currentTraining.value = {
...currentTraining.value,
progress: data.progress || 0,
message: data.message || currentTraining.value.message,
status: 'running' // 确保状态为运行中
};
// 如果包含训练指标,也更新
if (data.metrics) {
currentTraining.value.metrics = data.metrics;
}
}
// 更新任务列表中的进度
const taskIndex = trainingTasks.value.findIndex(
(task) => task.task_id === data.task_id
);
if (taskIndex !== -1) {
trainingTasks.value[taskIndex] = {
...trainingTasks.value[taskIndex],
progress: data.progress || 0,
message: data.message || trainingTasks.value[taskIndex].message
};
}
});
// 专门监听训练完成事件
socket.on("training_completed", (data) => {
console.log("收到训练完成事件:", data);
// 确保当前训练状态被正确更新为完成
if (
currentTraining.value &&
currentTraining.value.task_id === data.task_id
) {
currentTraining.value = {
...currentTraining.value,
...data,
status: "completed",
progress: 100
};
// 显示完成通知
ElMessage.success(
`药品模型 ${data.model_type} 训练完成!`
);
// 2秒后清除当前训练状态
setTimeout(() => {
if (currentTraining.value && currentTraining.value.task_id === data.task_id) {
currentTraining.value = null;
}
}, 2000);
}
// 更新任务列表
const taskIndex = trainingTasks.value.findIndex(
(task) => task.task_id === data.task_id
);
if (taskIndex !== -1) {
trainingTasks.value[taskIndex] = {
...trainingTasks.value[taskIndex],
...data,
status: "completed",
progress: 100
};
}
// 刷新任务列表
fetchTrainingTasks();
});
socket.on("disconnect", () => {
console.log("WebSocket连接断开");
});
};
const fetchProducts = async () => {
try {
const response = await axios.get("/api/products");
if (response.data.status === "success") {
products.value = response.data.data;
}
} catch (error) {
ElMessage.error("获取药品列表失败");
console.error(error);
}
};
const fetchStores = async () => {
try {
const response = await axios.get("/api/stores");
if (response.data.status === "success") {
stores.value = response.data.data;
}
} catch (error) {
console.error("获取店铺列表失败:", error);
}
};
const fetchModelTypes = async () => {
try {
const response = await axios.get("/api/model_types");
if (response.data.status === "success") {
modelTypes.value = response.data.data;
if (modelTypes.value.length > 0 && !form.model_type) {
form.model_type = modelTypes.value[0].id;
}
}
} catch (error) {
ElMessage.error("获取模型类型列表失败");
console.error(error);
}
};
const fetchExistingVersions = async () => {
if (!form.product_id || !form.model_type) {
existingVersions.value = [];
hasExistingVersions.value = false;
return;
}
try {
const response = await axios.get(
`/api/models/${form.product_id}/${form.model_type}/versions`
);
if (response.data.status === "success") {
existingVersions.value = response.data.data.versions || [];
hasExistingVersions.value = existingVersions.value.length > 0;
if (hasExistingVersions.value && !form.base_version) {
form.base_version = response.data.data.latest_version;
}
}
} catch (error) {
existingVersions.value = [];
hasExistingVersions.value = false;
console.error("获取现有版本失败:", error);
}
};
const onModelTypeChange = () => {
form.training_type = "new";
form.base_version = "";
fetchExistingVersions();
};
const fetchTrainingTasks = async () => {
try {
const response = await axios.get("/api/training");
if (response.data.status === "success") {
trainingTasks.value = response.data.data;
}
} catch (error) {
if (!pollInterval) ElMessage.error("获取训练任务列表失败");
console.error("获取训练任务列表失败", error);
}
};
const startTraining = async () => {
if (!form.product_id || !form.model_type) {
ElMessage.warning("请选择药品和模型类型");
return;
}
if (form.data_scope === 'specific' && !form.store_id) {
ElMessage.warning("请选择店铺");
return;
}
if (form.training_type === "retrain" && !form.base_version) {
ElMessage.warning("请选择基础版本进行继续训练");
return;
}
trainingLoading.value = true;
try {
const endpoint =
form.training_type === "retrain"
? "/api/training/retrain"
: "/api/training";
const payload = {
product_id: form.product_id,
store_id: form.data_scope === 'global' ? null : form.store_id,
model_type: form.model_type,
epochs: form.epochs,
training_mode: 'product' // 标识这是药品训练模式
};
if (form.training_type === "retrain") {
payload.base_version = form.base_version;
}
const response = await axios.post(endpoint, payload);
if (response.data.task_id) {
ElMessage.success(`药品训练任务 ${response.data.task_id} 已启动`);
// 设置当前训练状态
currentTraining.value = {
task_id: response.data.task_id,
product_id: form.product_id,
store_id: form.data_scope === 'global' ? null : form.store_id,
model_type: form.model_type,
version: response.data.new_version || "v1",
status: "starting",
progress: 0,
message: "正在启动药品训练...",
training_mode: 'product'
};
// 加入WebSocket房间
if (socket) {
socket.emit("join_training", { task_id: response.data.task_id });
}
fetchTrainingTasks();
} else {
ElMessage.error(response.data.error || "启动训练失败");
}
} catch (error) {
const errorMsg = error.response?.data?.error || "启动训练请求失败";
ElMessage.error(errorMsg);
console.error(error);
} finally {
trainingLoading.value = false;
}
};
// 辅助函数
const getProductName = (productId) => {
const product = products.value.find(p => p.product_id === productId);
return product ? product.product_name : productId;
};
const getModelTypeName = (modelType) => {
const model = modelTypes.value.find(m => m.id === modelType);
return model ? model.name : modelType;
};
const getDataScopeText = (task) => {
if (!task.store_id) {
return '全部店铺';
}
const store = stores.value.find(s => s.store_id === task.store_id);
return store ? store.store_name : `店铺${task.store_id}`;
};
const statusTag = (status) => {
if (status === "completed") return "success";
if (status === "running") return "primary";
if (status === "starting") return "primary";
if (status === "pending") return "warning";
if (status === "failed") return "danger";
return "info";
};
const statusText = (status) => {
const map = {
pending: "等待中",
starting: "启动中",
running: "进行中",
completed: "已完成",
failed: "失败"
};
return map[status] || "未知";
};
const formatProgress = (percentage) => {
return `${percentage}%`;
};
const formatDateTime = (isoString) => {
if (!isoString) return "N/A";
return new Date(isoString).toLocaleString();
};
// 监听产品和模型类型变化
watch([() => form.product_id, () => form.model_type], () => {
fetchExistingVersions();
});
// 监听数据范围变化,清空店铺选择
watch(() => form.data_scope, (newVal) => {
if (newVal === 'global') {
form.store_id = '';
}
});
onMounted(() => {
fetchProducts();
fetchStores();
fetchModelTypes();
fetchTrainingTasks();
initWebSocket();
pollInterval = setInterval(fetchTrainingTasks, 10000);
});
onUnmounted(() => {
if (pollInterval) {
clearInterval(pollInterval);
}
if (socket) {
socket.disconnect();
}
});
</script>
<style scoped>
.product-training-container {
padding: 20px;
}
.card-header {
display: flex;
justify-content: space-between;
align-items: center;
}
.training-description {
background-color: #f5f7fa;
padding: 15px;
border-radius: 6px;
margin-bottom: 20px;
border-left: 4px solid #409eff;
}
.training-description p {
margin: 0;
color: #606266;
font-size: 14px;
line-height: 1.5;
}
.training-progress-container {
border-left: 4px solid #67c23a;
}
.training-status-text {
margin-top: 10px;
}
.training-metrics {
background-color: #f5f7fa;
padding: 10px;
border-radius: 4px;
}
.training-metrics pre {
margin: 5px 0 0 0;
font-size: 12px;
line-height: 1.4;
white-space: pre-wrap;
word-wrap: break-word;
}
.el-radio-group {
width: 100%;
}
.el-radio {
margin-right: 20px;
}
@media (max-width: 768px) {
.product-training-container {
padding: 10px;
}
.el-col {
margin-bottom: 20px;
}
}
</style>