## 1. 核心 Bug 修复 ### 文件: `server/core/predictor.py` - **问题**: 在 `train_model` 方法中调用内部辅助函数 `_prepare_training_params` 时,没有正确传递 `product_ids` 和 `store_ids` 参数,导致在 `_prepare_training_params` 内部发生 `NameError`。 - **修复**: - 修正了 `train_model` 方法内部对 `_prepare_training_params` 的调用,确保 `product_ids` 和 `store_ids` 被显式传递。 - 此前已修复 `train_model` 的函数签名,使其能正确接收 `store_ids`。 - **结果**: 彻底解决了训练流程中的参数传递问题,根除了由此引发的 `NameError`。 ## 2. 代码清理与重构 ### 文件: `server/api.py` - **内容**: 移除了在 `start_training` API 端点中遗留的旧版、基于线程(`threading.Thread`)的训练逻辑。 - **原因**: 该代码块已被新的、基于多进程(`multiprocessing`)的 `TrainingProcessManager` 完全取代。旧代码中包含了大量用于调试的 `thread_safe_print` 日志,已无用处。 - **结果**: `start_training` 端点的逻辑变得更加清晰,只负责参数校验和向 `TrainingProcessManager` 提交任务。 ### 文件: `server/utils/training_process_manager.py` - **内容**: 在 `TrainingWorker` 的 `run_training_task` 方法中,移除了一个用于模拟训练进度的 `for` 循环。 - **原因**: 该循环包含 `time.sleep(1)`,仅用于在没有实际训练逻辑时模拟进度更新,现在实际的训练器会通过回调函数报告真实进度,因此该模拟代码不再需要。 - **结果**: `TrainingWorker` 现在直接调用实际的训练器,不再有模拟延迟,代码更贴近生产环境。 ## 3. 启动依赖 - **Python**: 3.x - **主要库**: - Flask - Flask-SocketIO - Flasgger - pandas - numpy - torch - scikit-learn - matplotlib - **启动命令**: `python server/api.py`
848 lines
25 KiB
Vue
848 lines
25 KiB
Vue
<template>
|
||
<div class="global-training-container">
|
||
<el-row :gutter="20">
|
||
<!-- 左侧:训练控制 -->
|
||
<el-col :span="8">
|
||
<el-card>
|
||
<template #header>
|
||
<div class="card-header">
|
||
<span>全局模型训练</span>
|
||
<el-tag type="success">跨店铺通用</el-tag>
|
||
</div>
|
||
</template>
|
||
|
||
<div class="training-description">
|
||
<p>使用所有店铺的历史数据训练通用预测模型,可用于新店铺或数据不足的场景。</p>
|
||
</div>
|
||
|
||
<el-form :model="form" label-width="100px">
|
||
<el-form-item label="训练范围">
|
||
<el-radio-group v-model="form.training_scope">
|
||
<el-radio label="all_stores_all_products">所有店铺所有药品</el-radio>
|
||
<el-radio label="selected_stores">选择店铺</el-radio>
|
||
<el-radio label="selected_products">选择药品</el-radio>
|
||
<el-radio label="custom">自定义范围</el-radio>
|
||
</el-radio-group>
|
||
</el-form-item>
|
||
|
||
<el-form-item
|
||
label="选择店铺"
|
||
v-if="form.training_scope === 'selected_stores' || form.training_scope === 'custom'"
|
||
>
|
||
<el-select
|
||
v-model="form.store_ids"
|
||
placeholder="选择参与训练的店铺"
|
||
multiple
|
||
filterable
|
||
style="width: 100%"
|
||
>
|
||
<el-option
|
||
v-for="store in stores"
|
||
:key="store.store_id"
|
||
:label="`${store.store_name} (${store.location})`"
|
||
:value="store.store_id"
|
||
/>
|
||
</el-select>
|
||
</el-form-item>
|
||
|
||
<el-form-item
|
||
label="选择药品"
|
||
v-if="form.training_scope === 'selected_products' || form.training_scope === 'custom'"
|
||
>
|
||
<el-select
|
||
v-model="form.product_ids"
|
||
placeholder="选择参与训练的药品"
|
||
multiple
|
||
filterable
|
||
style="width: 100%"
|
||
>
|
||
<el-option
|
||
v-for="product in products"
|
||
:key="product.product_id"
|
||
:label="`${product.product_name} (${product.product_id})`"
|
||
:value="product.product_id"
|
||
/>
|
||
</el-select>
|
||
</el-form-item>
|
||
|
||
<el-form-item label="聚合方式">
|
||
<el-select
|
||
v-model="form.aggregation_method"
|
||
placeholder="选择数据聚合方式"
|
||
style="width: 100%"
|
||
>
|
||
<el-option label="求和 (Sum)" value="sum" />
|
||
<el-option label="平均值 (Mean)" value="mean" />
|
||
<el-option label="加权平均 (Weighted)" value="weighted" />
|
||
</el-select>
|
||
</el-form-item>
|
||
|
||
<el-form-item label="模型类型" required>
|
||
<el-select
|
||
v-model="form.model_type"
|
||
placeholder="请选择模型"
|
||
@change="onModelTypeChange"
|
||
style="width: 100%"
|
||
>
|
||
<el-option
|
||
v-for="item in modelTypes"
|
||
:key="item.id"
|
||
:label="item.name"
|
||
:value="item.id"
|
||
/>
|
||
</el-select>
|
||
</el-form-item>
|
||
|
||
<el-form-item label="训练模式">
|
||
<el-radio-group v-model="form.training_type">
|
||
<el-radio label="new">新训练</el-radio>
|
||
<el-radio label="retrain" :disabled="!hasExistingVersions">
|
||
继续训练
|
||
</el-radio>
|
||
</el-radio-group>
|
||
</el-form-item>
|
||
|
||
<el-form-item
|
||
label="基础版本"
|
||
v-if="form.training_type === 'retrain'"
|
||
>
|
||
<el-select v-model="form.base_version" placeholder="选择基础版本" style="width: 100%">
|
||
<el-option
|
||
v-for="version in existingVersions"
|
||
:key="version"
|
||
:label="version"
|
||
:value="version"
|
||
/>
|
||
</el-select>
|
||
</el-form-item>
|
||
|
||
<el-form-item label="训练轮次">
|
||
<el-input-number v-model="form.epochs" :min="1" :max="1000" style="width: 100%" />
|
||
</el-form-item>
|
||
|
||
<el-form-item>
|
||
<el-button
|
||
type="primary"
|
||
@click="startTraining"
|
||
:loading="trainingLoading"
|
||
:disabled="!form.model_type"
|
||
style="width: 100%"
|
||
>
|
||
<el-icon><Operation /></el-icon>
|
||
启动全局训练
|
||
</el-button>
|
||
</el-form-item>
|
||
</el-form>
|
||
|
||
<!-- 训练统计信息 -->
|
||
<el-card v-if="trainingStats" style="margin-top: 20px" shadow="never">
|
||
<template #header>
|
||
<span>训练数据统计</span>
|
||
</template>
|
||
<div class="training-stats">
|
||
<p><strong>涉及店铺:</strong> {{ trainingStats.stores_count }} 个</p>
|
||
<p><strong>涉及药品:</strong> {{ trainingStats.products_count }} 种</p>
|
||
<p><strong>数据记录:</strong> {{ trainingStats.records_count }} 条</p>
|
||
<p><strong>时间范围:</strong> {{ trainingStats.date_range }}</p>
|
||
</div>
|
||
</el-card>
|
||
</el-card>
|
||
|
||
<!-- 实时训练状态卡片 -->
|
||
<el-card
|
||
v-if="currentTraining"
|
||
style="margin-top: 20px"
|
||
class="training-progress-container"
|
||
>
|
||
<template #header>
|
||
<span>实时训练状态</span>
|
||
</template>
|
||
<div>
|
||
<p><strong>任务ID:</strong> {{ currentTraining.task_id }}</p>
|
||
<p><strong>训练范围:</strong> {{ getTrainingScopeText(currentTraining) }}</p>
|
||
<p><strong>聚合方式:</strong> {{ getAggregationText(currentTraining.aggregation_method) }}</p>
|
||
<p><strong>模型:</strong> {{ getModelTypeName(currentTraining.model_type) }}</p>
|
||
<p><strong>版本:</strong> {{ currentTraining.version }}</p>
|
||
<p>
|
||
<strong>状态:</strong>
|
||
<el-tag :type="statusTag(currentTraining.status)">
|
||
{{ statusText(currentTraining.status) }}
|
||
</el-tag>
|
||
</p>
|
||
<el-progress
|
||
v-if="currentTraining.status === 'running'"
|
||
:percentage="currentTraining.progress || 0"
|
||
:format="formatProgress"
|
||
/>
|
||
<div v-if="currentTraining.message" style="margin-top: 10px">
|
||
<el-alert
|
||
:title="currentTraining.message"
|
||
type="info"
|
||
show-icon
|
||
:closable="false"
|
||
class="training-status-text"
|
||
/>
|
||
</div>
|
||
<div
|
||
v-if="currentTraining.metrics"
|
||
style="margin-top: 10px"
|
||
class="training-metrics"
|
||
>
|
||
<h4>训练指标:</h4>
|
||
<pre>{{ JSON.stringify(currentTraining.metrics, null, 2) }}</pre>
|
||
</div>
|
||
</div>
|
||
</el-card>
|
||
</el-col>
|
||
|
||
<!-- 右侧:任务状态 -->
|
||
<el-col :span="16">
|
||
<el-card>
|
||
<template #header>
|
||
<div class="card-header">
|
||
<span>全局训练任务队列</span>
|
||
<el-button size="small" @click="fetchTrainingTasks">
|
||
<el-icon><Refresh /></el-icon>
|
||
刷新
|
||
</el-button>
|
||
</div>
|
||
</template>
|
||
<el-table :data="filteredTrainingTasks" stripe>
|
||
<el-table-column
|
||
prop="task_id"
|
||
label="任务ID"
|
||
width="120"
|
||
show-overflow-tooltip
|
||
/>
|
||
<el-table-column
|
||
label="训练范围"
|
||
width="150"
|
||
>
|
||
<template #default="{ row }">
|
||
{{ getTrainingScopeText(row) }}
|
||
</template>
|
||
</el-table-column>
|
||
<el-table-column
|
||
prop="aggregation_method"
|
||
label="聚合方式"
|
||
width="100"
|
||
>
|
||
<template #default="{ row }">
|
||
{{ getAggregationText(row.aggregation_method) }}
|
||
</template>
|
||
</el-table-column>
|
||
<el-table-column
|
||
prop="model_type"
|
||
label="模型类型"
|
||
width="120"
|
||
>
|
||
<template #default="{ row }">
|
||
{{ getModelTypeName(row.model_type) }}
|
||
</template>
|
||
</el-table-column>
|
||
<el-table-column
|
||
prop="version"
|
||
label="版本"
|
||
width="80"
|
||
/>
|
||
<el-table-column prop="status" label="状态" width="100">
|
||
<template #default="{ row }">
|
||
<el-tag :type="statusTag(row.status)">
|
||
{{ statusText(row.status) }}
|
||
</el-tag>
|
||
</template>
|
||
</el-table-column>
|
||
<el-table-column prop="start_time" label="创建时间">
|
||
<template #default="{ row }">
|
||
{{ formatDateTime(row.start_time) }}
|
||
</template>
|
||
</el-table-column>
|
||
<el-table-column label="详情">
|
||
<template #default="{ row }">
|
||
<el-popover placement="left" trigger="hover" width="400">
|
||
<template #reference>
|
||
<el-button type="text" size="small">查看</el-button>
|
||
</template>
|
||
<div v-if="row.status === 'completed'">
|
||
<h4>评估指标</h4>
|
||
<pre>{{ JSON.stringify(row.metrics, null, 2) }}</pre>
|
||
<div v-if="row.version">
|
||
<h4>版本信息</h4>
|
||
<p><strong>版本:</strong> {{ row.version }}</p>
|
||
<p><strong>模型路径:</strong> {{ row.model_path }}</p>
|
||
</div>
|
||
</div>
|
||
<div v-if="row.status === 'failed'">
|
||
<h4>错误信息</h4>
|
||
<p>{{ row.error }}</p>
|
||
</div>
|
||
<div
|
||
v-if="row.status === 'running' || row.status === 'pending'"
|
||
>
|
||
<p>任务正在进行中...</p>
|
||
<div v-if="row.progress !== undefined">
|
||
<el-progress :percentage="row.progress" />
|
||
</div>
|
||
</div>
|
||
</el-popover>
|
||
</template>
|
||
</el-table-column>
|
||
</el-table>
|
||
</el-card>
|
||
</el-col>
|
||
</el-row>
|
||
</div>
|
||
</template>
|
||
|
||
<script setup>
|
||
import { ref, onMounted, onUnmounted, reactive, watch, computed } from "vue";
|
||
import axios from "axios";
|
||
import {
|
||
ElMessage,
|
||
ElPopover,
|
||
ElButton,
|
||
ElTag,
|
||
ElProgress,
|
||
ElAlert
|
||
} from "element-plus";
|
||
import { io } from "socket.io-client";
|
||
import { Operation, Refresh } from '@element-plus/icons-vue';
|
||
|
||
const stores = ref([]);
|
||
const products = ref([]);
|
||
const modelTypes = ref([]);
|
||
const trainingLoading = ref(false);
|
||
const existingVersions = ref([]);
|
||
const hasExistingVersions = ref(false);
|
||
const currentTraining = ref(null);
|
||
const trainingStats = ref(null);
|
||
|
||
const form = reactive({
|
||
training_scope: "all_stores_all_products",
|
||
store_ids: [],
|
||
product_ids: [],
|
||
aggregation_method: "sum",
|
||
model_type: "",
|
||
epochs: 50,
|
||
training_type: "new",
|
||
base_version: ""
|
||
});
|
||
|
||
const trainingTasks = ref([]);
|
||
let pollInterval = null;
|
||
let socket = null;
|
||
|
||
// 过滤只显示全局训练任务
|
||
const filteredTrainingTasks = computed(() => {
|
||
return trainingTasks.value.filter(task => {
|
||
// 全局训练任务的特征:training_mode为global
|
||
return task.training_mode === 'global';
|
||
});
|
||
});
|
||
|
||
// WebSocket连接
|
||
const initWebSocket = () => {
|
||
socket = io("http://localhost:5000/training", {
|
||
transports: ["websocket", "polling"]
|
||
});
|
||
|
||
socket.on("connect", () => {
|
||
console.log("WebSocket连接成功");
|
||
});
|
||
|
||
socket.on("training_update", (data) => {
|
||
console.log("收到训练更新:", data);
|
||
|
||
// 更新当前训练状态
|
||
if (
|
||
currentTraining.value &&
|
||
currentTraining.value.task_id === data.task_id
|
||
) {
|
||
currentTraining.value = { ...currentTraining.value, ...data };
|
||
}
|
||
|
||
// 更新任务列表中的相应任务
|
||
const taskIndex = trainingTasks.value.findIndex(
|
||
(task) => task.task_id === data.task_id
|
||
);
|
||
if (taskIndex !== -1) {
|
||
trainingTasks.value[taskIndex] = {
|
||
...trainingTasks.value[taskIndex],
|
||
...data
|
||
};
|
||
}
|
||
|
||
// 显示消息通知
|
||
if (data.status === "completed") {
|
||
ElMessage.success(
|
||
`全局模型 ${data.model_type} 版本 ${data.version} 训练完成!`
|
||
);
|
||
currentTraining.value = null;
|
||
} else if (data.status === "failed") {
|
||
ElMessage.error(`全局模型训练失败: ${data.error}`);
|
||
currentTraining.value = null;
|
||
}
|
||
});
|
||
|
||
// 监听训练进度更新(epoch级别的详细进度)
|
||
socket.on("training_progress", (data) => {
|
||
console.log("收到训练进度更新:", data);
|
||
|
||
// 更新当前训练状态的进度信息
|
||
if (
|
||
currentTraining.value &&
|
||
currentTraining.value.task_id === data.task_id
|
||
) {
|
||
currentTraining.value = {
|
||
...currentTraining.value,
|
||
progress: data.progress || 0,
|
||
message: data.message || currentTraining.value.message,
|
||
status: 'running' // 确保状态为运行中
|
||
};
|
||
|
||
// 如果包含训练指标,也更新
|
||
if (data.metrics) {
|
||
currentTraining.value.metrics = data.metrics;
|
||
}
|
||
}
|
||
|
||
// 更新任务列表中的进度
|
||
const taskIndex = trainingTasks.value.findIndex(
|
||
(task) => task.task_id === data.task_id
|
||
);
|
||
if (taskIndex !== -1) {
|
||
trainingTasks.value[taskIndex] = {
|
||
...trainingTasks.value[taskIndex],
|
||
progress: data.progress || 0,
|
||
message: data.message || trainingTasks.value[taskIndex].message
|
||
};
|
||
}
|
||
});
|
||
|
||
// 专门监听训练完成事件
|
||
socket.on("training_completed", (data) => {
|
||
console.log("收到训练完成事件:", data);
|
||
|
||
// 确保当前训练状态被正确更新为完成
|
||
if (
|
||
currentTraining.value &&
|
||
currentTraining.value.task_id === data.task_id
|
||
) {
|
||
currentTraining.value = {
|
||
...currentTraining.value,
|
||
...data,
|
||
status: "completed",
|
||
progress: 100
|
||
};
|
||
|
||
// 显示完成通知
|
||
ElMessage.success(
|
||
`全局模型 ${data.model_type} 训练完成!`
|
||
);
|
||
|
||
// 2秒后清除当前训练状态
|
||
setTimeout(() => {
|
||
if (currentTraining.value && currentTraining.value.task_id === data.task_id) {
|
||
currentTraining.value = null;
|
||
}
|
||
}, 2000);
|
||
}
|
||
|
||
// 更新任务列表
|
||
const taskIndex = trainingTasks.value.findIndex(
|
||
(task) => task.task_id === data.task_id
|
||
);
|
||
if (taskIndex !== -1) {
|
||
trainingTasks.value[taskIndex] = {
|
||
...trainingTasks.value[taskIndex],
|
||
...data,
|
||
status: "completed",
|
||
progress: 100
|
||
};
|
||
}
|
||
|
||
// 刷新任务列表
|
||
fetchTrainingTasks();
|
||
});
|
||
|
||
socket.on("disconnect", () => {
|
||
console.log("WebSocket连接断开");
|
||
});
|
||
};
|
||
|
||
const fetchStores = async () => {
|
||
try {
|
||
const response = await axios.get("/api/stores");
|
||
if (response.data.status === "success") {
|
||
stores.value = response.data.data;
|
||
}
|
||
} catch (error) {
|
||
console.error("获取店铺列表失败:", error);
|
||
}
|
||
};
|
||
|
||
const fetchProducts = async () => {
|
||
try {
|
||
const response = await axios.get("/api/products");
|
||
if (response.data.status === "success") {
|
||
products.value = response.data.data;
|
||
}
|
||
} catch (error) {
|
||
console.error("获取药品列表失败:", error);
|
||
}
|
||
};
|
||
|
||
const fetchModelTypes = async () => {
|
||
try {
|
||
const response = await axios.get("/api/model_types");
|
||
if (response.data.status === "success") {
|
||
modelTypes.value = response.data.data;
|
||
if (modelTypes.value.length > 0 && !form.model_type) {
|
||
form.model_type = modelTypes.value[0].id;
|
||
}
|
||
}
|
||
} catch (error) {
|
||
ElMessage.error("获取模型类型列表失败");
|
||
console.error(error);
|
||
}
|
||
};
|
||
|
||
const fetchTrainingStats = async () => {
|
||
try {
|
||
const params = {
|
||
training_scope: form.training_scope,
|
||
aggregation_method: form.aggregation_method
|
||
};
|
||
|
||
if (form.store_ids.length > 0) {
|
||
params.store_ids = form.store_ids.join(',');
|
||
}
|
||
|
||
if (form.product_ids.length > 0) {
|
||
params.product_ids = form.product_ids.join(',');
|
||
}
|
||
|
||
const response = await axios.get("/api/training/global/stats", { params });
|
||
if (response.data.status === "success") {
|
||
trainingStats.value = response.data.data;
|
||
}
|
||
} catch (error) {
|
||
console.error("获取训练统计信息失败:", error);
|
||
trainingStats.value = null;
|
||
}
|
||
};
|
||
|
||
const fetchExistingVersions = async () => {
|
||
if (!form.model_type) {
|
||
existingVersions.value = [];
|
||
hasExistingVersions.value = false;
|
||
return;
|
||
}
|
||
|
||
try {
|
||
// 全局训练的版本查询
|
||
const response = await axios.get(
|
||
`/api/models/global/${form.model_type}/versions`
|
||
);
|
||
if (response.data.status === "success") {
|
||
existingVersions.value = response.data.data.versions || [];
|
||
hasExistingVersions.value = existingVersions.value.length > 0;
|
||
if (hasExistingVersions.value && !form.base_version) {
|
||
form.base_version = response.data.data.latest_version;
|
||
}
|
||
}
|
||
} catch (error) {
|
||
existingVersions.value = [];
|
||
hasExistingVersions.value = false;
|
||
console.error("获取现有版本失败:", error);
|
||
}
|
||
};
|
||
|
||
const onModelTypeChange = () => {
|
||
form.training_type = "new";
|
||
form.base_version = "";
|
||
fetchExistingVersions();
|
||
};
|
||
|
||
const fetchTrainingTasks = async () => {
|
||
try {
|
||
const response = await axios.get("/api/training");
|
||
if (response.data.status === "success") {
|
||
trainingTasks.value = response.data.data;
|
||
}
|
||
} catch (error) {
|
||
if (!pollInterval) ElMessage.error("获取训练任务列表失败");
|
||
console.error("获取训练任务列表失败", error);
|
||
}
|
||
};
|
||
|
||
const startTraining = async () => {
|
||
if (!form.model_type) {
|
||
ElMessage.warning("请选择模型类型");
|
||
return;
|
||
}
|
||
|
||
if ((form.training_scope === 'selected_stores' || form.training_scope === 'custom') && form.store_ids.length === 0) {
|
||
ElMessage.warning("请选择参与训练的店铺");
|
||
return;
|
||
}
|
||
|
||
if ((form.training_scope === 'selected_products' || form.training_scope === 'custom') && form.product_ids.length === 0) {
|
||
ElMessage.warning("请选择参与训练的药品");
|
||
return;
|
||
}
|
||
|
||
if (form.training_type === "retrain" && !form.base_version) {
|
||
ElMessage.warning("请选择基础版本进行继续训练");
|
||
return;
|
||
}
|
||
|
||
trainingLoading.value = true;
|
||
try {
|
||
const endpoint =
|
||
form.training_type === "retrain"
|
||
? "/api/training/retrain"
|
||
: "/api/training";
|
||
|
||
const payload = {
|
||
model_type: form.model_type,
|
||
epochs: form.epochs,
|
||
training_mode: 'global', // 标识这是全局训练模式
|
||
training_scope: form.training_scope,
|
||
aggregation_method: form.aggregation_method,
|
||
store_ids: form.store_ids || [], // 确保始终发送数组
|
||
product_ids: form.product_ids || [] // 确保始终发送数组
|
||
};
|
||
|
||
// 关键修复:即使是列表,也传递第一个作为代表ID
|
||
if (payload.store_ids.length > 0) {
|
||
payload.store_id = payload.store_ids[0];
|
||
}
|
||
if (payload.product_ids.length > 0) {
|
||
payload.product_id = payload.product_ids[0];
|
||
}
|
||
|
||
if (form.training_type === "retrain") {
|
||
payload.base_version = form.base_version;
|
||
}
|
||
|
||
const response = await axios.post(endpoint, payload);
|
||
if (response.data.task_id) {
|
||
ElMessage.success(`全局训练任务 ${response.data.task_id} 已启动`);
|
||
|
||
// 设置当前训练状态
|
||
currentTraining.value = {
|
||
task_id: response.data.task_id,
|
||
model_type: form.model_type,
|
||
version: response.data.new_version || "v1",
|
||
status: "starting",
|
||
progress: 0,
|
||
message: "正在启动全局训练...",
|
||
training_mode: 'global',
|
||
training_scope: form.training_scope,
|
||
aggregation_method: form.aggregation_method,
|
||
store_ids: form.store_ids.length > 0 ? form.store_ids : null,
|
||
product_ids: form.product_ids.length > 0 ? form.product_ids : null
|
||
};
|
||
|
||
// 加入WebSocket房间
|
||
if (socket) {
|
||
socket.emit("join_training", { task_id: response.data.task_id });
|
||
}
|
||
|
||
fetchTrainingTasks();
|
||
} else {
|
||
ElMessage.error(response.data.error || "启动训练失败");
|
||
}
|
||
} catch (error) {
|
||
const errorMsg = error.response?.data?.error || "启动训练请求失败";
|
||
ElMessage.error(errorMsg);
|
||
console.error(error);
|
||
} finally {
|
||
trainingLoading.value = false;
|
||
}
|
||
};
|
||
|
||
// 辅助函数
|
||
const getModelTypeName = (modelType) => {
|
||
const model = modelTypes.value.find(m => m.id === modelType);
|
||
return model ? model.name : modelType;
|
||
};
|
||
|
||
const getTrainingScopeText = (task) => {
|
||
const scopeMap = {
|
||
'all_stores_all_products': '全部数据',
|
||
'selected_stores': `${task.store_ids?.length || 0} 个店铺`,
|
||
'selected_products': `${task.product_ids?.length || 0} 种药品`,
|
||
'custom': '自定义范围'
|
||
};
|
||
return scopeMap[task.training_scope] || '未知范围';
|
||
};
|
||
|
||
const getAggregationText = (method) => {
|
||
const methodMap = {
|
||
'sum': '求和',
|
||
'mean': '平均值',
|
||
'weighted': '加权平均'
|
||
};
|
||
return methodMap[method] || method;
|
||
};
|
||
|
||
const statusTag = (status) => {
|
||
if (status === "completed") return "success";
|
||
if (status === "running") return "primary";
|
||
if (status === "starting") return "primary";
|
||
if (status === "pending") return "warning";
|
||
if (status === "failed") return "danger";
|
||
return "info";
|
||
};
|
||
|
||
const statusText = (status) => {
|
||
const map = {
|
||
pending: "等待中",
|
||
starting: "启动中",
|
||
running: "进行中",
|
||
completed: "已完成",
|
||
failed: "失败"
|
||
};
|
||
return map[status] || "未知";
|
||
};
|
||
|
||
const formatProgress = (percentage) => {
|
||
return `${percentage}%`;
|
||
};
|
||
|
||
const formatDateTime = (isoString) => {
|
||
if (!isoString) return "N/A";
|
||
return new Date(isoString).toLocaleString();
|
||
};
|
||
|
||
// 监听训练范围变化,自动获取统计信息
|
||
watch([
|
||
() => form.training_scope,
|
||
() => form.store_ids,
|
||
() => form.product_ids,
|
||
() => form.aggregation_method
|
||
], () => {
|
||
fetchTrainingStats();
|
||
}, { deep: true });
|
||
|
||
// 监听模型类型变化
|
||
watch(() => form.model_type, () => {
|
||
fetchExistingVersions();
|
||
});
|
||
|
||
// 监听训练范围变化,清空相关选择
|
||
watch(() => form.training_scope, (newVal) => {
|
||
if (newVal === 'all_stores_all_products') {
|
||
form.store_ids = [];
|
||
form.product_ids = [];
|
||
} else if (newVal === 'selected_stores') {
|
||
form.product_ids = [];
|
||
} else if (newVal === 'selected_products') {
|
||
form.store_ids = [];
|
||
}
|
||
});
|
||
|
||
onMounted(() => {
|
||
fetchStores();
|
||
fetchProducts();
|
||
fetchModelTypes();
|
||
fetchTrainingTasks();
|
||
fetchTrainingStats();
|
||
initWebSocket();
|
||
pollInterval = setInterval(fetchTrainingTasks, 10000);
|
||
});
|
||
|
||
onUnmounted(() => {
|
||
if (pollInterval) {
|
||
clearInterval(pollInterval);
|
||
}
|
||
if (socket) {
|
||
socket.disconnect();
|
||
}
|
||
});
|
||
</script>
|
||
|
||
<style scoped>
|
||
.global-training-container {
|
||
padding: 20px;
|
||
}
|
||
|
||
.card-header {
|
||
display: flex;
|
||
justify-content: space-between;
|
||
align-items: center;
|
||
}
|
||
|
||
.training-description {
|
||
background-color: #f0f9ff;
|
||
padding: 15px;
|
||
border-radius: 6px;
|
||
margin-bottom: 20px;
|
||
border-left: 4px solid #67c23a;
|
||
}
|
||
|
||
.training-description p {
|
||
margin: 0;
|
||
color: #606266;
|
||
font-size: 14px;
|
||
line-height: 1.5;
|
||
}
|
||
|
||
.training-stats {
|
||
font-size: 14px;
|
||
}
|
||
|
||
.training-stats p {
|
||
margin: 8px 0;
|
||
color: #606266;
|
||
}
|
||
|
||
.training-progress-container {
|
||
border-left: 4px solid #67c23a;
|
||
}
|
||
|
||
.training-status-text {
|
||
margin-top: 10px;
|
||
}
|
||
|
||
.training-metrics {
|
||
background-color: #f5f7fa;
|
||
padding: 10px;
|
||
border-radius: 4px;
|
||
}
|
||
|
||
.training-metrics pre {
|
||
margin: 5px 0 0 0;
|
||
font-size: 12px;
|
||
line-height: 1.4;
|
||
white-space: pre-wrap;
|
||
word-wrap: break-word;
|
||
}
|
||
|
||
.el-radio-group {
|
||
width: 100%;
|
||
}
|
||
|
||
.el-radio {
|
||
margin-right: 20px;
|
||
margin-bottom: 10px;
|
||
}
|
||
|
||
@media (max-width: 768px) {
|
||
.global-training-container {
|
||
padding: 10px;
|
||
}
|
||
|
||
.el-col {
|
||
margin-bottom: 20px;
|
||
}
|
||
|
||
.el-radio {
|
||
display: block;
|
||
margin-right: 0;
|
||
margin-bottom: 10px;
|
||
}
|
||
}
|
||
</style> |