735 lines
21 KiB
Vue
735 lines
21 KiB
Vue
<template>
|
||
<div class="product-training-container">
|
||
<el-row :gutter="20">
|
||
<!-- 左侧:训练控制 -->
|
||
<el-col :span="8">
|
||
<el-card>
|
||
<template #header>
|
||
<div class="card-header">
|
||
<span>按药品训练模型</span>
|
||
<el-tag type="info">针对特定药品</el-tag>
|
||
</div>
|
||
</template>
|
||
|
||
<div class="training-description">
|
||
<p>为特定药品训练专门的预测模型,可选择使用单店铺数据或聚合多店铺数据。</p>
|
||
</div>
|
||
|
||
<el-form :model="form" label-width="100px">
|
||
<el-form-item label="选择药品" required>
|
||
<el-select
|
||
v-model="form.product_id"
|
||
placeholder="请选择要训练的药品"
|
||
filterable
|
||
style="width: 100%"
|
||
>
|
||
<el-option
|
||
v-for="item in products"
|
||
:key="item.product_id"
|
||
:label="`${item.product_name} (${item.product_id})`"
|
||
:value="item.product_id"
|
||
/>
|
||
</el-select>
|
||
</el-form-item>
|
||
|
||
<el-form-item label="数据范围">
|
||
<el-radio-group v-model="form.data_scope">
|
||
<el-radio label="global">聚合所有店铺</el-radio>
|
||
<el-radio label="specific">指定店铺</el-radio>
|
||
</el-radio-group>
|
||
</el-form-item>
|
||
|
||
<el-form-item
|
||
label="选择店铺"
|
||
v-if="form.data_scope === 'specific'"
|
||
>
|
||
<el-select
|
||
v-model="form.store_id"
|
||
placeholder="选择店铺"
|
||
filterable
|
||
style="width: 100%"
|
||
>
|
||
<el-option
|
||
v-for="store in stores"
|
||
:key="store.store_id"
|
||
:label="`${store.store_name} (${store.store_id})`"
|
||
:value="store.store_id"
|
||
/>
|
||
</el-select>
|
||
</el-form-item>
|
||
|
||
<el-form-item label="模型类型" required>
|
||
<el-select
|
||
v-model="form.model_type"
|
||
placeholder="请选择模型"
|
||
@change="onModelTypeChange"
|
||
style="width: 100%"
|
||
>
|
||
<el-option
|
||
v-for="item in modelTypes"
|
||
:key="item.id"
|
||
:label="item.name"
|
||
:value="item.id"
|
||
/>
|
||
</el-select>
|
||
</el-form-item>
|
||
|
||
<el-form-item label="训练模式">
|
||
<el-radio-group v-model="form.training_type">
|
||
<el-radio label="new">新训练</el-radio>
|
||
<el-radio label="retrain" :disabled="!hasExistingVersions">
|
||
继续训练
|
||
</el-radio>
|
||
</el-radio-group>
|
||
</el-form-item>
|
||
|
||
<el-form-item
|
||
label="基础版本"
|
||
v-if="form.training_type === 'retrain'"
|
||
>
|
||
<el-select v-model="form.base_version" placeholder="选择基础版本" style="width: 100%">
|
||
<el-option
|
||
v-for="version in existingVersions"
|
||
:key="version"
|
||
:label="version"
|
||
:value="version"
|
||
/>
|
||
</el-select>
|
||
</el-form-item>
|
||
|
||
<el-form-item label="训练轮次">
|
||
<el-input-number v-model="form.epochs" :min="1" :max="1000" style="width: 100%" />
|
||
</el-form-item>
|
||
|
||
<el-form-item>
|
||
<el-button
|
||
type="primary"
|
||
@click="startTraining"
|
||
:loading="trainingLoading"
|
||
:disabled="!form.product_id || !form.model_type"
|
||
style="width: 100%"
|
||
>
|
||
<el-icon><Cpu /></el-icon>
|
||
启动药品训练
|
||
</el-button>
|
||
</el-form-item>
|
||
</el-form>
|
||
</el-card>
|
||
|
||
<!-- 实时训练状态卡片 -->
|
||
<el-card
|
||
v-if="currentTraining"
|
||
style="margin-top: 20px"
|
||
class="training-progress-container"
|
||
>
|
||
<template #header>
|
||
<span>实时训练状态</span>
|
||
</template>
|
||
<div>
|
||
<p><strong>任务ID:</strong> {{ currentTraining.task_id }}</p>
|
||
<p><strong>药品:</strong> {{ getProductName(currentTraining.product_id) }}</p>
|
||
<p><strong>数据范围:</strong> {{ getDataScopeText(currentTraining) }}</p>
|
||
<p><strong>模型:</strong> {{ getModelTypeName(currentTraining.model_type) }}</p>
|
||
<p><strong>版本:</strong> {{ currentTraining.version }}</p>
|
||
<p>
|
||
<strong>状态:</strong>
|
||
<el-tag :type="statusTag(currentTraining.status)">
|
||
{{ statusText(currentTraining.status) }}
|
||
</el-tag>
|
||
</p>
|
||
<el-progress
|
||
v-if="currentTraining.status === 'running'"
|
||
:percentage="currentTraining.progress || 0"
|
||
:format="formatProgress"
|
||
/>
|
||
<div v-if="currentTraining.message" style="margin-top: 10px">
|
||
<el-alert
|
||
:title="currentTraining.message"
|
||
type="info"
|
||
show-icon
|
||
:closable="false"
|
||
class="training-status-text"
|
||
/>
|
||
</div>
|
||
<div
|
||
v-if="currentTraining.metrics"
|
||
style="margin-top: 10px"
|
||
class="training-metrics"
|
||
>
|
||
<h4>训练指标:</h4>
|
||
<pre>{{ JSON.stringify(currentTraining.metrics, null, 2) }}</pre>
|
||
</div>
|
||
</div>
|
||
</el-card>
|
||
</el-col>
|
||
|
||
<!-- 右侧:任务状态 -->
|
||
<el-col :span="16">
|
||
<el-card>
|
||
<template #header>
|
||
<div class="card-header">
|
||
<span>药品训练任务队列</span>
|
||
<el-button size="small" @click="fetchTrainingTasks">
|
||
<el-icon><Refresh /></el-icon>
|
||
刷新
|
||
</el-button>
|
||
</div>
|
||
</template>
|
||
<el-table :data="filteredTrainingTasks" stripe>
|
||
<el-table-column
|
||
prop="task_id"
|
||
label="任务ID"
|
||
width="120"
|
||
show-overflow-tooltip
|
||
/>
|
||
<el-table-column
|
||
prop="product_id"
|
||
label="药品"
|
||
width="100"
|
||
>
|
||
<template #default="{ row }">
|
||
{{ getProductName(row.product_id) }}
|
||
</template>
|
||
</el-table-column>
|
||
<el-table-column
|
||
prop="store_id"
|
||
label="范围"
|
||
width="120"
|
||
>
|
||
<template #default="{ row }">
|
||
{{ getDataScopeText(row) }}
|
||
</template>
|
||
</el-table-column>
|
||
<el-table-column
|
||
prop="model_type"
|
||
label="模型类型"
|
||
width="120"
|
||
>
|
||
<template #default="{ row }">
|
||
{{ getModelTypeName(row.model_type) }}
|
||
</template>
|
||
</el-table-column>
|
||
<el-table-column
|
||
prop="version"
|
||
label="版本"
|
||
width="80"
|
||
/>
|
||
<el-table-column prop="status" label="状态" width="100">
|
||
<template #default="{ row }">
|
||
<el-tag :type="statusTag(row.status)">
|
||
{{ statusText(row.status) }}
|
||
</el-tag>
|
||
</template>
|
||
</el-table-column>
|
||
<el-table-column prop="start_time" label="创建时间">
|
||
<template #default="{ row }">
|
||
{{ formatDateTime(row.start_time) }}
|
||
</template>
|
||
</el-table-column>
|
||
<el-table-column label="详情">
|
||
<template #default="{ row }">
|
||
<el-popover placement="left" trigger="hover" width="400">
|
||
<template #reference>
|
||
<el-button type="text" size="small">查看</el-button>
|
||
</template>
|
||
<div v-if="row.status === 'completed'">
|
||
<h4>评估指标</h4>
|
||
<pre>{{ JSON.stringify(row.metrics, null, 2) }}</pre>
|
||
<div v-if="row.version">
|
||
<h4>版本信息</h4>
|
||
<p><strong>版本:</strong> {{ row.version }}</p>
|
||
<p><strong>模型路径:</strong> {{ row.model_path }}</p>
|
||
</div>
|
||
</div>
|
||
<div v-if="row.status === 'failed'">
|
||
<h4>错误信息</h4>
|
||
<p>{{ row.error }}</p>
|
||
</div>
|
||
<div
|
||
v-if="row.status === 'running' || row.status === 'pending'"
|
||
>
|
||
<p>任务正在进行中...</p>
|
||
<div v-if="row.progress !== undefined">
|
||
<el-progress :percentage="row.progress" />
|
||
</div>
|
||
</div>
|
||
</el-popover>
|
||
</template>
|
||
</el-table-column>
|
||
</el-table>
|
||
</el-card>
|
||
</el-col>
|
||
</el-row>
|
||
</div>
|
||
</template>
|
||
|
||
<script setup>
|
||
import { ref, onMounted, onUnmounted, reactive, watch, computed } from "vue";
|
||
import axios from "axios";
|
||
import {
|
||
ElMessage,
|
||
ElPopover,
|
||
ElButton,
|
||
ElTag,
|
||
ElProgress,
|
||
ElAlert
|
||
} from "element-plus";
|
||
import { io } from "socket.io-client";
|
||
import { Cpu, Refresh } from '@element-plus/icons-vue';
|
||
|
||
const products = ref([]);
|
||
const stores = ref([]);
|
||
const modelTypes = ref([]);
|
||
const trainingLoading = ref(false);
|
||
const existingVersions = ref([]);
|
||
const hasExistingVersions = ref(false);
|
||
const currentTraining = ref(null);
|
||
|
||
const form = reactive({
|
||
product_id: "",
|
||
store_id: "",
|
||
data_scope: "global",
|
||
model_type: "",
|
||
epochs: 50,
|
||
training_type: "new",
|
||
base_version: ""
|
||
});
|
||
|
||
const trainingTasks = ref([]);
|
||
let pollInterval = null;
|
||
let socket = null;
|
||
|
||
// 过滤只显示药品训练任务
|
||
const filteredTrainingTasks = computed(() => {
|
||
return trainingTasks.value.filter(task => {
|
||
// 药品训练任务的特征:有product_id且不是全局任务类型
|
||
return task.product_id && task.training_mode !== 'global';
|
||
});
|
||
});
|
||
|
||
// WebSocket连接
|
||
const initWebSocket = () => {
|
||
socket = io("http://localhost:5000/training", {
|
||
transports: ["websocket", "polling"]
|
||
});
|
||
|
||
socket.on("connect", () => {
|
||
console.log("WebSocket连接成功");
|
||
});
|
||
|
||
socket.on("training_update", (data) => {
|
||
console.log("收到训练更新:", data);
|
||
|
||
// 更新当前训练状态
|
||
if (
|
||
currentTraining.value &&
|
||
currentTraining.value.task_id === data.task_id
|
||
) {
|
||
currentTraining.value = { ...currentTraining.value, ...data };
|
||
}
|
||
|
||
// 更新任务列表中的相应任务
|
||
const taskIndex = trainingTasks.value.findIndex(
|
||
(task) => task.task_id === data.task_id
|
||
);
|
||
if (taskIndex !== -1) {
|
||
trainingTasks.value[taskIndex] = {
|
||
...trainingTasks.value[taskIndex],
|
||
...data
|
||
};
|
||
}
|
||
|
||
// 显示消息通知
|
||
if (data.status === "completed") {
|
||
ElMessage.success(
|
||
`药品模型 ${data.model_type} 版本 ${data.version} 训练完成!`
|
||
);
|
||
currentTraining.value = null;
|
||
} else if (data.status === "failed") {
|
||
ElMessage.error(`药品模型训练失败: ${data.error}`);
|
||
currentTraining.value = null;
|
||
}
|
||
});
|
||
|
||
// 监听训练进度更新(epoch级别的详细进度)
|
||
socket.on("training_progress", (data) => {
|
||
console.log("收到训练进度更新:", data);
|
||
|
||
// 更新当前训练状态的进度信息
|
||
if (
|
||
currentTraining.value &&
|
||
currentTraining.value.task_id === data.task_id
|
||
) {
|
||
currentTraining.value = {
|
||
...currentTraining.value,
|
||
progress: data.progress || 0,
|
||
message: data.message || currentTraining.value.message,
|
||
status: 'running' // 确保状态为运行中
|
||
};
|
||
|
||
// 如果包含训练指标,也更新
|
||
if (data.metrics) {
|
||
currentTraining.value.metrics = data.metrics;
|
||
}
|
||
}
|
||
|
||
// 更新任务列表中的进度
|
||
const taskIndex = trainingTasks.value.findIndex(
|
||
(task) => task.task_id === data.task_id
|
||
);
|
||
if (taskIndex !== -1) {
|
||
trainingTasks.value[taskIndex] = {
|
||
...trainingTasks.value[taskIndex],
|
||
progress: data.progress || 0,
|
||
message: data.message || trainingTasks.value[taskIndex].message
|
||
};
|
||
}
|
||
});
|
||
|
||
// 专门监听训练完成事件
|
||
socket.on("training_completed", (data) => {
|
||
console.log("收到训练完成事件:", data);
|
||
|
||
// 确保当前训练状态被正确更新为完成
|
||
if (
|
||
currentTraining.value &&
|
||
currentTraining.value.task_id === data.task_id
|
||
) {
|
||
currentTraining.value = {
|
||
...currentTraining.value,
|
||
...data,
|
||
status: "completed",
|
||
progress: 100
|
||
};
|
||
|
||
// 显示完成通知
|
||
ElMessage.success(
|
||
`药品模型 ${data.model_type} 训练完成!`
|
||
);
|
||
|
||
// 2秒后清除当前训练状态
|
||
setTimeout(() => {
|
||
if (currentTraining.value && currentTraining.value.task_id === data.task_id) {
|
||
currentTraining.value = null;
|
||
}
|
||
}, 2000);
|
||
}
|
||
|
||
// 更新任务列表
|
||
const taskIndex = trainingTasks.value.findIndex(
|
||
(task) => task.task_id === data.task_id
|
||
);
|
||
if (taskIndex !== -1) {
|
||
trainingTasks.value[taskIndex] = {
|
||
...trainingTasks.value[taskIndex],
|
||
...data,
|
||
status: "completed",
|
||
progress: 100
|
||
};
|
||
}
|
||
|
||
// 刷新任务列表
|
||
fetchTrainingTasks();
|
||
});
|
||
|
||
socket.on("disconnect", () => {
|
||
console.log("WebSocket连接断开");
|
||
});
|
||
};
|
||
|
||
const fetchProducts = async () => {
|
||
try {
|
||
const response = await axios.get("/api/products");
|
||
if (response.data.status === "success") {
|
||
products.value = response.data.data;
|
||
}
|
||
} catch (error) {
|
||
ElMessage.error("获取药品列表失败");
|
||
console.error(error);
|
||
}
|
||
};
|
||
|
||
const fetchStores = async () => {
|
||
try {
|
||
const response = await axios.get("/api/stores");
|
||
if (response.data.status === "success") {
|
||
stores.value = response.data.data;
|
||
}
|
||
} catch (error) {
|
||
console.error("获取店铺列表失败:", error);
|
||
}
|
||
};
|
||
|
||
const fetchModelTypes = async () => {
|
||
try {
|
||
const response = await axios.get("/api/model_types");
|
||
if (response.data.status === "success") {
|
||
modelTypes.value = response.data.data;
|
||
if (modelTypes.value.length > 0 && !form.model_type) {
|
||
form.model_type = modelTypes.value[0].id;
|
||
}
|
||
}
|
||
} catch (error) {
|
||
ElMessage.error("获取模型类型列表失败");
|
||
console.error(error);
|
||
}
|
||
};
|
||
|
||
const fetchExistingVersions = async () => {
|
||
if (!form.product_id || !form.model_type) {
|
||
existingVersions.value = [];
|
||
hasExistingVersions.value = false;
|
||
return;
|
||
}
|
||
|
||
try {
|
||
const response = await axios.get(
|
||
`/api/models/${form.product_id}/${form.model_type}/versions`
|
||
);
|
||
if (response.data.status === "success") {
|
||
existingVersions.value = response.data.data.versions || [];
|
||
hasExistingVersions.value = existingVersions.value.length > 0;
|
||
if (hasExistingVersions.value && !form.base_version) {
|
||
form.base_version = response.data.data.latest_version;
|
||
}
|
||
}
|
||
} catch (error) {
|
||
existingVersions.value = [];
|
||
hasExistingVersions.value = false;
|
||
console.error("获取现有版本失败:", error);
|
||
}
|
||
};
|
||
|
||
const onModelTypeChange = () => {
|
||
form.training_type = "new";
|
||
form.base_version = "";
|
||
fetchExistingVersions();
|
||
};
|
||
|
||
const fetchTrainingTasks = async () => {
|
||
try {
|
||
const response = await axios.get("/api/training");
|
||
if (response.data.status === "success") {
|
||
trainingTasks.value = response.data.data;
|
||
}
|
||
} catch (error) {
|
||
if (!pollInterval) ElMessage.error("获取训练任务列表失败");
|
||
console.error("获取训练任务列表失败", error);
|
||
}
|
||
};
|
||
|
||
const startTraining = async () => {
|
||
if (!form.product_id || !form.model_type) {
|
||
ElMessage.warning("请选择药品和模型类型");
|
||
return;
|
||
}
|
||
|
||
if (form.data_scope === 'specific' && !form.store_id) {
|
||
ElMessage.warning("请选择店铺");
|
||
return;
|
||
}
|
||
|
||
if (form.training_type === "retrain" && !form.base_version) {
|
||
ElMessage.warning("请选择基础版本进行继续训练");
|
||
return;
|
||
}
|
||
|
||
trainingLoading.value = true;
|
||
try {
|
||
const endpoint =
|
||
form.training_type === "retrain"
|
||
? "/api/training/retrain"
|
||
: "/api/training";
|
||
|
||
const payload = {
|
||
product_id: form.product_id,
|
||
store_id: form.data_scope === 'global' ? null : form.store_id,
|
||
model_type: form.model_type,
|
||
epochs: form.epochs,
|
||
training_mode: 'product' // 标识这是药品训练模式
|
||
};
|
||
|
||
if (form.training_type === "retrain") {
|
||
payload.base_version = form.base_version;
|
||
}
|
||
|
||
const response = await axios.post(endpoint, payload);
|
||
if (response.data.task_id) {
|
||
ElMessage.success(`药品训练任务 ${response.data.task_id} 已启动`);
|
||
|
||
// 设置当前训练状态
|
||
currentTraining.value = {
|
||
task_id: response.data.task_id,
|
||
product_id: form.product_id,
|
||
store_id: form.data_scope === 'global' ? null : form.store_id,
|
||
model_type: form.model_type,
|
||
version: response.data.new_version || "v1",
|
||
status: "starting",
|
||
progress: 0,
|
||
message: "正在启动药品训练...",
|
||
training_mode: 'product'
|
||
};
|
||
|
||
// 加入WebSocket房间
|
||
if (socket) {
|
||
socket.emit("join_training", { task_id: response.data.task_id });
|
||
}
|
||
|
||
fetchTrainingTasks();
|
||
} else {
|
||
ElMessage.error(response.data.error || "启动训练失败");
|
||
}
|
||
} catch (error) {
|
||
const errorMsg = error.response?.data?.error || "启动训练请求失败";
|
||
ElMessage.error(errorMsg);
|
||
console.error(error);
|
||
} finally {
|
||
trainingLoading.value = false;
|
||
}
|
||
};
|
||
|
||
// 辅助函数
|
||
const getProductName = (productId) => {
|
||
const product = products.value.find(p => p.product_id === productId);
|
||
return product ? product.product_name : productId;
|
||
};
|
||
|
||
const getModelTypeName = (modelType) => {
|
||
const model = modelTypes.value.find(m => m.id === modelType);
|
||
return model ? model.name : modelType;
|
||
};
|
||
|
||
const getDataScopeText = (task) => {
|
||
if (!task.store_id) {
|
||
return '全部店铺';
|
||
}
|
||
const store = stores.value.find(s => s.store_id === task.store_id);
|
||
return store ? store.store_name : `店铺${task.store_id}`;
|
||
};
|
||
|
||
const statusTag = (status) => {
|
||
if (status === "completed") return "success";
|
||
if (status === "running") return "primary";
|
||
if (status === "starting") return "primary";
|
||
if (status === "pending") return "warning";
|
||
if (status === "failed") return "danger";
|
||
return "info";
|
||
};
|
||
|
||
const statusText = (status) => {
|
||
const map = {
|
||
pending: "等待中",
|
||
starting: "启动中",
|
||
running: "进行中",
|
||
completed: "已完成",
|
||
failed: "失败"
|
||
};
|
||
return map[status] || "未知";
|
||
};
|
||
|
||
const formatProgress = (percentage) => {
|
||
return `${percentage}%`;
|
||
};
|
||
|
||
const formatDateTime = (isoString) => {
|
||
if (!isoString) return "N/A";
|
||
return new Date(isoString).toLocaleString();
|
||
};
|
||
|
||
// 监听产品和模型类型变化
|
||
watch([() => form.product_id, () => form.model_type], () => {
|
||
fetchExistingVersions();
|
||
});
|
||
|
||
// 监听数据范围变化,清空店铺选择
|
||
watch(() => form.data_scope, (newVal) => {
|
||
if (newVal === 'global') {
|
||
form.store_id = '';
|
||
}
|
||
});
|
||
|
||
onMounted(() => {
|
||
fetchProducts();
|
||
fetchStores();
|
||
fetchModelTypes();
|
||
fetchTrainingTasks();
|
||
initWebSocket();
|
||
pollInterval = setInterval(fetchTrainingTasks, 10000);
|
||
});
|
||
|
||
onUnmounted(() => {
|
||
if (pollInterval) {
|
||
clearInterval(pollInterval);
|
||
}
|
||
if (socket) {
|
||
socket.disconnect();
|
||
}
|
||
});
|
||
</script>
|
||
|
||
<style scoped>
|
||
.product-training-container {
|
||
padding: 20px;
|
||
}
|
||
|
||
.card-header {
|
||
display: flex;
|
||
justify-content: space-between;
|
||
align-items: center;
|
||
}
|
||
|
||
.training-description {
|
||
background-color: #f5f7fa;
|
||
padding: 15px;
|
||
border-radius: 6px;
|
||
margin-bottom: 20px;
|
||
border-left: 4px solid #409eff;
|
||
}
|
||
|
||
.training-description p {
|
||
margin: 0;
|
||
color: #606266;
|
||
font-size: 14px;
|
||
line-height: 1.5;
|
||
}
|
||
|
||
.training-progress-container {
|
||
border-left: 4px solid #67c23a;
|
||
}
|
||
|
||
.training-status-text {
|
||
margin-top: 10px;
|
||
}
|
||
|
||
.training-metrics {
|
||
background-color: #f5f7fa;
|
||
padding: 10px;
|
||
border-radius: 4px;
|
||
}
|
||
|
||
.training-metrics pre {
|
||
margin: 5px 0 0 0;
|
||
font-size: 12px;
|
||
line-height: 1.4;
|
||
white-space: pre-wrap;
|
||
word-wrap: break-word;
|
||
}
|
||
|
||
.el-radio-group {
|
||
width: 100%;
|
||
}
|
||
|
||
.el-radio {
|
||
margin-right: 20px;
|
||
}
|
||
|
||
@media (max-width: 768px) {
|
||
.product-training-container {
|
||
padding: 10px;
|
||
}
|
||
|
||
.el-col {
|
||
margin-bottom: 20px;
|
||
}
|
||
}
|
||
</style> |