ShopTRAINING/UI/src/views/training/GlobalTrainingView.vue

846 lines
24 KiB
Vue
Raw Normal View History

2025-07-02 11:05:23 +08:00
<template>
<div class="global-training-container">
<el-row :gutter="20">
<!-- 左侧训练控制 -->
<el-col :span="8">
<el-card>
<template #header>
<div class="card-header">
<span>全局模型训练</span>
<el-tag type="success">跨店铺通用</el-tag>
</div>
</template>
<div class="training-description">
<p>使用所有店铺的历史数据训练通用预测模型可用于新店铺或数据不足的场景</p>
</div>
<el-form :model="form" label-width="100px">
<el-form-item label="训练范围">
<el-radio-group v-model="form.training_scope">
<el-radio label="all_stores_all_products">所有店铺所有药品</el-radio>
<el-radio label="selected_stores">选择店铺</el-radio>
<el-radio label="selected_products">选择药品</el-radio>
<el-radio label="custom">自定义范围</el-radio>
</el-radio-group>
</el-form-item>
<el-form-item
label="选择店铺"
v-if="form.training_scope === 'selected_stores' || form.training_scope === 'custom'"
>
<el-select
v-model="form.store_ids"
placeholder="选择参与训练的店铺"
multiple
filterable
style="width: 100%"
>
<el-option
v-for="store in stores"
:key="store.store_id"
:label="`${store.store_name} (${store.location})`"
:value="store.store_id"
/>
</el-select>
</el-form-item>
<el-form-item
label="选择药品"
v-if="form.training_scope === 'selected_products' || form.training_scope === 'custom'"
>
<el-select
v-model="form.product_ids"
placeholder="选择参与训练的药品"
multiple
filterable
style="width: 100%"
>
<el-option
v-for="product in products"
:key="product.product_id"
:label="`${product.product_name} (${product.product_id})`"
:value="product.product_id"
/>
</el-select>
</el-form-item>
<el-form-item label="聚合方式">
<el-select
v-model="form.aggregation_method"
placeholder="选择数据聚合方式"
style="width: 100%"
>
<el-option label="求和 (Sum)" value="sum" />
<el-option label="平均值 (Mean)" value="mean" />
<el-option label="加权平均 (Weighted)" value="weighted" />
</el-select>
</el-form-item>
<el-form-item label="模型类型" required>
<el-select
v-model="form.model_type"
placeholder="请选择模型"
@change="onModelTypeChange"
style="width: 100%"
>
<el-option
v-for="item in modelTypes"
:key="item.id"
:label="item.name"
:value="item.id"
/>
</el-select>
</el-form-item>
<el-form-item label="训练模式">
<el-radio-group v-model="form.training_type">
<el-radio label="new">新训练</el-radio>
<el-radio label="retrain" :disabled="!hasExistingVersions">
继续训练
</el-radio>
</el-radio-group>
</el-form-item>
<el-form-item
label="基础版本"
v-if="form.training_type === 'retrain'"
>
<el-select v-model="form.base_version" placeholder="选择基础版本" style="width: 100%">
<el-option
v-for="version in existingVersions"
:key="version"
:label="version"
:value="version"
/>
</el-select>
</el-form-item>
<el-form-item label="训练轮次">
<el-input-number v-model="form.epochs" :min="1" :max="1000" style="width: 100%" />
</el-form-item>
<el-form-item>
<el-button
type="primary"
@click="startTraining"
:loading="trainingLoading"
:disabled="!form.model_type"
style="width: 100%"
>
<el-icon><Operation /></el-icon>
启动全局训练
</el-button>
</el-form-item>
</el-form>
<!-- 训练统计信息 -->
<el-card v-if="trainingStats" style="margin-top: 20px" shadow="never">
<template #header>
<span>训练数据统计</span>
</template>
<div class="training-stats">
<p><strong>涉及店铺:</strong> {{ trainingStats.stores_count }} </p>
<p><strong>涉及药品:</strong> {{ trainingStats.products_count }} </p>
<p><strong>数据记录:</strong> {{ trainingStats.records_count }} </p>
<p><strong>时间范围:</strong> {{ trainingStats.date_range }}</p>
</div>
</el-card>
</el-card>
<!-- 实时训练状态卡片 -->
<el-card
v-if="currentTraining"
style="margin-top: 20px"
class="training-progress-container"
>
<template #header>
<span>实时训练状态</span>
</template>
<div>
<p><strong>任务ID:</strong> {{ currentTraining.task_id }}</p>
<p><strong>训练范围:</strong> {{ getTrainingScopeText(currentTraining) }}</p>
<p><strong>聚合方式:</strong> {{ getAggregationText(currentTraining.aggregation_method) }}</p>
<p><strong>模型:</strong> {{ getModelTypeName(currentTraining.model_type) }}</p>
<p><strong>版本:</strong> {{ currentTraining.version }}</p>
<p>
<strong>状态:</strong>
<el-tag :type="statusTag(currentTraining.status)">
{{ statusText(currentTraining.status) }}
</el-tag>
</p>
<el-progress
v-if="currentTraining.status === 'running'"
:percentage="currentTraining.progress || 0"
:format="formatProgress"
/>
<div v-if="currentTraining.message" style="margin-top: 10px">
<el-alert
:title="currentTraining.message"
type="info"
show-icon
:closable="false"
class="training-status-text"
/>
</div>
<div
v-if="currentTraining.metrics"
style="margin-top: 10px"
class="training-metrics"
>
<h4>训练指标:</h4>
<pre>{{ JSON.stringify(currentTraining.metrics, null, 2) }}</pre>
</div>
</div>
</el-card>
</el-col>
<!-- 右侧任务状态 -->
<el-col :span="16">
<el-card>
<template #header>
<div class="card-header">
<span>全局训练任务队列</span>
<el-button size="small" @click="fetchTrainingTasks">
<el-icon><Refresh /></el-icon>
刷新
</el-button>
</div>
</template>
<el-table :data="filteredTrainingTasks" stripe>
<el-table-column
prop="task_id"
label="任务ID"
width="120"
show-overflow-tooltip
/>
<el-table-column
label="训练范围"
width="150"
>
<template #default="{ row }">
{{ getTrainingScopeText(row) }}
</template>
</el-table-column>
<el-table-column
prop="aggregation_method"
label="聚合方式"
width="100"
>
<template #default="{ row }">
{{ getAggregationText(row.aggregation_method) }}
</template>
</el-table-column>
<el-table-column
prop="model_type"
label="模型类型"
width="120"
>
<template #default="{ row }">
{{ getModelTypeName(row.model_type) }}
</template>
</el-table-column>
<el-table-column
prop="version"
label="版本"
width="80"
/>
<el-table-column prop="status" label="状态" width="100">
<template #default="{ row }">
<el-tag :type="statusTag(row.status)">
{{ statusText(row.status) }}
</el-tag>
</template>
</el-table-column>
<el-table-column prop="start_time" label="创建时间">
<template #default="{ row }">
{{ formatDateTime(row.start_time) }}
</template>
</el-table-column>
<el-table-column label="详情">
<template #default="{ row }">
<el-popover placement="left" trigger="hover" width="400">
<template #reference>
<el-button type="text" size="small">查看</el-button>
</template>
<div v-if="row.status === 'completed'">
<h4>评估指标</h4>
<pre>{{ JSON.stringify(row.metrics, null, 2) }}</pre>
<div v-if="row.version">
<h4>版本信息</h4>
<p><strong>版本:</strong> {{ row.version }}</p>
<p><strong>模型路径:</strong> {{ row.model_path }}</p>
</div>
</div>
<div v-if="row.status === 'failed'">
<h4>错误信息</h4>
<p>{{ row.error }}</p>
</div>
<div
v-if="row.status === 'running' || row.status === 'pending'"
>
<p>任务正在进行中...</p>
<div v-if="row.progress !== undefined">
<el-progress :percentage="row.progress" />
</div>
</div>
</el-popover>
</template>
</el-table-column>
</el-table>
</el-card>
</el-col>
</el-row>
</div>
</template>
<script setup>
import { ref, onMounted, onUnmounted, reactive, watch, computed } from "vue";
import axios from "axios";
import {
ElMessage,
ElPopover,
ElButton,
ElTag,
ElProgress,
ElAlert
} from "element-plus";
import { io } from "socket.io-client";
import { Operation, Refresh } from '@element-plus/icons-vue';
const stores = ref([]);
const products = ref([]);
const modelTypes = ref([]);
const trainingLoading = ref(false);
const existingVersions = ref([]);
const hasExistingVersions = ref(false);
const currentTraining = ref(null);
const trainingStats = ref(null);
const form = reactive({
training_scope: "all_stores_all_products",
store_ids: [],
product_ids: [],
aggregation_method: "sum",
model_type: "",
epochs: 50,
training_type: "new",
base_version: ""
});
const trainingTasks = ref([]);
let pollInterval = null;
let socket = null;
// 过滤只显示全局训练任务
const filteredTrainingTasks = computed(() => {
return trainingTasks.value.filter(task => {
// 全局训练任务的特征training_mode为global
return task.training_mode === 'global';
});
});
// WebSocket连接
const initWebSocket = () => {
socket = io("http://localhost:5000/training", {
transports: ["websocket", "polling"]
});
socket.on("connect", () => {
console.log("WebSocket连接成功");
});
socket.on("training_update", (data) => {
console.log("收到训练更新:", data);
// 更新当前训练状态
if (
currentTraining.value &&
currentTraining.value.task_id === data.task_id
) {
currentTraining.value = { ...currentTraining.value, ...data };
}
// 更新任务列表中的相应任务
const taskIndex = trainingTasks.value.findIndex(
(task) => task.task_id === data.task_id
);
if (taskIndex !== -1) {
trainingTasks.value[taskIndex] = {
...trainingTasks.value[taskIndex],
...data
};
}
// 显示消息通知
if (data.status === "completed") {
ElMessage.success(
`全局模型 ${data.model_type} 版本 ${data.version} 训练完成!`
);
currentTraining.value = null;
} else if (data.status === "failed") {
ElMessage.error(`全局模型训练失败: ${data.error}`);
currentTraining.value = null;
}
});
// 监听训练进度更新epoch级别的详细进度
socket.on("training_progress", (data) => {
console.log("收到训练进度更新:", data);
// 更新当前训练状态的进度信息
if (
currentTraining.value &&
currentTraining.value.task_id === data.task_id
) {
currentTraining.value = {
...currentTraining.value,
progress: data.progress || 0,
message: data.message || currentTraining.value.message,
status: 'running' // 确保状态为运行中
};
// 如果包含训练指标,也更新
if (data.metrics) {
currentTraining.value.metrics = data.metrics;
}
}
// 更新任务列表中的进度
const taskIndex = trainingTasks.value.findIndex(
(task) => task.task_id === data.task_id
);
if (taskIndex !== -1) {
trainingTasks.value[taskIndex] = {
...trainingTasks.value[taskIndex],
progress: data.progress || 0,
message: data.message || trainingTasks.value[taskIndex].message
};
}
});
// 专门监听训练完成事件
socket.on("training_completed", (data) => {
console.log("收到训练完成事件:", data);
// 确保当前训练状态被正确更新为完成
if (
currentTraining.value &&
currentTraining.value.task_id === data.task_id
) {
currentTraining.value = {
...currentTraining.value,
...data,
status: "completed",
progress: 100
};
// 显示完成通知
ElMessage.success(
`全局模型 ${data.model_type} 训练完成!`
);
// 2秒后清除当前训练状态
setTimeout(() => {
if (currentTraining.value && currentTraining.value.task_id === data.task_id) {
currentTraining.value = null;
}
}, 2000);
}
// 更新任务列表
const taskIndex = trainingTasks.value.findIndex(
(task) => task.task_id === data.task_id
);
if (taskIndex !== -1) {
trainingTasks.value[taskIndex] = {
...trainingTasks.value[taskIndex],
...data,
status: "completed",
progress: 100
};
}
// 刷新任务列表
fetchTrainingTasks();
});
socket.on("disconnect", () => {
console.log("WebSocket连接断开");
});
};
const fetchStores = async () => {
try {
const response = await axios.get("/api/stores");
if (response.data.status === "success") {
stores.value = response.data.data;
}
} catch (error) {
console.error("获取店铺列表失败:", error);
}
};
const fetchProducts = async () => {
try {
const response = await axios.get("/api/products");
if (response.data.status === "success") {
products.value = response.data.data;
}
} catch (error) {
console.error("获取药品列表失败:", error);
}
};
const fetchModelTypes = async () => {
try {
const response = await axios.get("/api/model_types");
if (response.data.status === "success") {
modelTypes.value = response.data.data;
if (modelTypes.value.length > 0 && !form.model_type) {
form.model_type = modelTypes.value[0].id;
}
}
} catch (error) {
ElMessage.error("获取模型类型列表失败");
console.error(error);
}
};
const fetchTrainingStats = async () => {
try {
const params = {
training_scope: form.training_scope,
aggregation_method: form.aggregation_method
};
if (form.store_ids.length > 0) {
params.store_ids = form.store_ids.join(',');
}
if (form.product_ids.length > 0) {
params.product_ids = form.product_ids.join(',');
}
const response = await axios.get("/api/training/global/stats", { params });
if (response.data.status === "success") {
trainingStats.value = response.data.data;
}
} catch (error) {
console.error("获取训练统计信息失败:", error);
trainingStats.value = null;
}
};
const fetchExistingVersions = async () => {
if (!form.model_type) {
existingVersions.value = [];
hasExistingVersions.value = false;
return;
}
try {
// 全局训练的版本查询
const response = await axios.get(
`/api/models/global/${form.model_type}/versions`
);
if (response.data.status === "success") {
existingVersions.value = response.data.data.versions || [];
hasExistingVersions.value = existingVersions.value.length > 0;
if (hasExistingVersions.value && !form.base_version) {
form.base_version = response.data.data.latest_version;
}
}
} catch (error) {
existingVersions.value = [];
hasExistingVersions.value = false;
console.error("获取现有版本失败:", error);
}
};
const onModelTypeChange = () => {
form.training_type = "new";
form.base_version = "";
fetchExistingVersions();
};
const fetchTrainingTasks = async () => {
try {
const response = await axios.get("/api/training");
if (response.data.status === "success") {
trainingTasks.value = response.data.data;
}
} catch (error) {
if (!pollInterval) ElMessage.error("获取训练任务列表失败");
console.error("获取训练任务列表失败", error);
}
};
const startTraining = async () => {
if (!form.model_type) {
ElMessage.warning("请选择模型类型");
return;
}
if ((form.training_scope === 'selected_stores' || form.training_scope === 'custom') && form.store_ids.length === 0) {
ElMessage.warning("请选择参与训练的店铺");
return;
}
if ((form.training_scope === 'selected_products' || form.training_scope === 'custom') && form.product_ids.length === 0) {
ElMessage.warning("请选择参与训练的药品");
return;
}
if (form.training_type === "retrain" && !form.base_version) {
ElMessage.warning("请选择基础版本进行继续训练");
return;
}
trainingLoading.value = true;
try {
const endpoint =
form.training_type === "retrain"
? "/api/training/retrain"
: "/api/training";
const payload = {
model_type: form.model_type,
epochs: form.epochs,
training_mode: 'global', // 标识这是全局训练模式
training_scope: form.training_scope,
aggregation_method: form.aggregation_method
};
if (form.store_ids.length > 0) {
payload.store_ids = form.store_ids;
}
if (form.product_ids.length > 0) {
payload.product_ids = form.product_ids;
}
if (form.training_type === "retrain") {
payload.base_version = form.base_version;
}
const response = await axios.post(endpoint, payload);
if (response.data.task_id) {
ElMessage.success(`全局训练任务 ${response.data.task_id} 已启动`);
// 设置当前训练状态
currentTraining.value = {
task_id: response.data.task_id,
model_type: form.model_type,
version: response.data.new_version || "v1",
status: "starting",
progress: 0,
message: "正在启动全局训练...",
training_mode: 'global',
training_scope: form.training_scope,
aggregation_method: form.aggregation_method,
store_ids: form.store_ids.length > 0 ? form.store_ids : null,
product_ids: form.product_ids.length > 0 ? form.product_ids : null
};
// 加入WebSocket房间
if (socket) {
socket.emit("join_training", { task_id: response.data.task_id });
}
fetchTrainingTasks();
} else {
ElMessage.error(response.data.error || "启动训练失败");
}
} catch (error) {
const errorMsg = error.response?.data?.error || "启动训练请求失败";
ElMessage.error(errorMsg);
console.error(error);
} finally {
trainingLoading.value = false;
}
};
// 辅助函数
const getModelTypeName = (modelType) => {
const model = modelTypes.value.find(m => m.id === modelType);
return model ? model.name : modelType;
};
const getTrainingScopeText = (task) => {
const scopeMap = {
'all_stores_all_products': '全部数据',
'selected_stores': `${task.store_ids?.length || 0} 个店铺`,
'selected_products': `${task.product_ids?.length || 0} 种药品`,
'custom': '自定义范围'
};
return scopeMap[task.training_scope] || '未知范围';
};
const getAggregationText = (method) => {
const methodMap = {
'sum': '求和',
'mean': '平均值',
'weighted': '加权平均'
};
return methodMap[method] || method;
};
const statusTag = (status) => {
if (status === "completed") return "success";
if (status === "running") return "primary";
if (status === "starting") return "primary";
if (status === "pending") return "warning";
if (status === "failed") return "danger";
return "info";
};
const statusText = (status) => {
const map = {
pending: "等待中",
starting: "启动中",
running: "进行中",
completed: "已完成",
failed: "失败"
};
return map[status] || "未知";
};
const formatProgress = (percentage) => {
return `${percentage}%`;
};
const formatDateTime = (isoString) => {
if (!isoString) return "N/A";
return new Date(isoString).toLocaleString();
};
// 监听训练范围变化,自动获取统计信息
watch([
() => form.training_scope,
() => form.store_ids,
() => form.product_ids,
() => form.aggregation_method
], () => {
fetchTrainingStats();
}, { deep: true });
// 监听模型类型变化
watch(() => form.model_type, () => {
fetchExistingVersions();
});
// 监听训练范围变化,清空相关选择
watch(() => form.training_scope, (newVal) => {
if (newVal === 'all_stores_all_products') {
form.store_ids = [];
form.product_ids = [];
} else if (newVal === 'selected_stores') {
form.product_ids = [];
} else if (newVal === 'selected_products') {
form.store_ids = [];
}
});
onMounted(() => {
fetchStores();
fetchProducts();
fetchModelTypes();
fetchTrainingTasks();
fetchTrainingStats();
initWebSocket();
pollInterval = setInterval(fetchTrainingTasks, 10000);
});
onUnmounted(() => {
if (pollInterval) {
clearInterval(pollInterval);
}
if (socket) {
socket.disconnect();
}
});
</script>
<style scoped>
.global-training-container {
padding: 20px;
}
.card-header {
display: flex;
justify-content: space-between;
align-items: center;
}
.training-description {
background-color: #f0f9ff;
padding: 15px;
border-radius: 6px;
margin-bottom: 20px;
border-left: 4px solid #67c23a;
}
.training-description p {
margin: 0;
color: #606266;
font-size: 14px;
line-height: 1.5;
}
.training-stats {
font-size: 14px;
}
.training-stats p {
margin: 8px 0;
color: #606266;
}
.training-progress-container {
border-left: 4px solid #67c23a;
}
.training-status-text {
margin-top: 10px;
}
.training-metrics {
background-color: #f5f7fa;
padding: 10px;
border-radius: 4px;
}
.training-metrics pre {
margin: 5px 0 0 0;
font-size: 12px;
line-height: 1.4;
white-space: pre-wrap;
word-wrap: break-word;
}
.el-radio-group {
width: 100%;
}
.el-radio {
margin-right: 20px;
margin-bottom: 10px;
}
@media (max-width: 768px) {
.global-training-container {
padding: 10px;
}
.el-col {
margin-bottom: 20px;
}
.el-radio {
display: block;
margin-right: 0;
margin-bottom: 10px;
}
}
</style>