Compare commits
23 Commits
87df49f764
...
e1980b3755
Author | SHA1 | Date | |
---|---|---|---|
e1980b3755 | |||
751de9b548 | |||
038289ae32 | |||
0d3b89abf6 | |||
ec636896da | |||
244393670d | |||
e4d170d667 | |||
311d71b653 | |||
ca7dc432c6 | |||
ada4e8e108 | |||
120caba3cd | |||
c64343fe95 | |||
9d439c36ba | |||
54428c80ca | |||
6f3240c723 | |||
e437658b9d | |||
ee9ba299fa | |||
a1d9c60e61 | |||
a18c8dddf9 | |||
398e949935 | |||
cc30295f1d | |||
066a0429e5 | |||
6c11aff234 |
@ -237,13 +237,11 @@ body {
|
||||
}
|
||||
|
||||
.logo-container {
|
||||
padding: 10px 20px;
|
||||
padding: 20px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 12px;
|
||||
border-bottom: 1px solid var(--card-border);
|
||||
height: 60px;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
.logo-icon {
|
||||
@ -353,4 +351,4 @@ body {
|
||||
::-webkit-scrollbar-thumb:hover {
|
||||
background: rgba(93, 156, 255, 0.5);
|
||||
}
|
||||
</style>
|
||||
</style>
|
@ -242,7 +242,7 @@ watch(() => props.storeId, () => {
|
||||
|
||||
.product-name {
|
||||
font-weight: 500;
|
||||
color: var(--el-text-color-primary);
|
||||
color: #303133;
|
||||
margin-bottom: 2px;
|
||||
}
|
||||
|
||||
@ -254,7 +254,7 @@ watch(() => props.storeId, () => {
|
||||
}
|
||||
|
||||
.product-id {
|
||||
color: var(--el-text-color-secondary);
|
||||
color: #909399;
|
||||
margin-right: 8px;
|
||||
}
|
||||
|
||||
|
@ -242,7 +242,7 @@ watch(() => props.filterByStatus, () => {
|
||||
|
||||
.store-name {
|
||||
font-weight: 500;
|
||||
color: var(--el-text-color-primary);
|
||||
color: #303133;
|
||||
margin-bottom: 2px;
|
||||
}
|
||||
|
||||
@ -254,7 +254,7 @@ watch(() => props.filterByStatus, () => {
|
||||
}
|
||||
|
||||
.store-location {
|
||||
color: var(--el-text-color-secondary);
|
||||
color: #909399;
|
||||
margin-right: 8px;
|
||||
}
|
||||
|
||||
|
@ -5,7 +5,6 @@ import ElementPlus from 'element-plus'
|
||||
import 'element-plus/dist/index.css'
|
||||
import * as ElementPlusIconsVue from '@element-plus/icons-vue'
|
||||
import axios from 'axios'
|
||||
import zhCn from 'element-plus/dist/locale/zh-cn.mjs'
|
||||
|
||||
// 导入Google Roboto字体
|
||||
import '@/assets/fonts.css'
|
||||
@ -23,6 +22,6 @@ for (const [key, component] of Object.entries(ElementPlusIconsVue)) {
|
||||
}
|
||||
|
||||
app.use(router)
|
||||
app.use(ElementPlus, { locale: zhCn })
|
||||
app.use(ElementPlus)
|
||||
|
||||
app.mount('#app')
|
@ -90,7 +90,7 @@
|
||||
|
||||
<script setup>
|
||||
import { ref, computed, onMounted } from 'vue'
|
||||
import { ArrowRight, DataAnalysis, TrendCharts, CircleCheckFilled, WarningFilled, CircleCloseFilled, Shop } from '@element-plus/icons-vue'
|
||||
import { ArrowRight, DataAnalysis, TrendCharts, CircleCheckFilled, WarningFilled, CircleCloseFilled } from '@element-plus/icons-vue'
|
||||
|
||||
// 模拟数据
|
||||
const data = ref({
|
||||
@ -102,10 +102,10 @@ const data = ref({
|
||||
// 功能卡片数据
|
||||
const featureCards = [
|
||||
{
|
||||
title: '店铺管理',
|
||||
description: '管理店铺信息和库存',
|
||||
icon: 'Shop',
|
||||
path: '/store-management',
|
||||
title: '数据管理',
|
||||
description: '管理产品和销售数据',
|
||||
icon: 'FolderOpened',
|
||||
path: '/data',
|
||||
type: 'data'
|
||||
},
|
||||
{
|
||||
|
@ -248,41 +248,9 @@ import { ref, onMounted, reactive, watch, nextTick } from 'vue';
|
||||
import axios from 'axios';
|
||||
import { ElMessage, ElMessageBox } from 'element-plus';
|
||||
import { QuestionFilled, Search, View, Delete, ArrowUp, ArrowDown, Minus, Download } from '@element-plus/icons-vue';
|
||||
import * as echarts from 'echarts/core';
|
||||
import { LineChart, BarChart } from 'echarts/charts';
|
||||
import {
|
||||
TitleComponent,
|
||||
TooltipComponent,
|
||||
GridComponent,
|
||||
DatasetComponent,
|
||||
TransformComponent,
|
||||
LegendComponent,
|
||||
ToolboxComponent,
|
||||
MarkLineComponent,
|
||||
MarkPointComponent
|
||||
} from 'echarts/components';
|
||||
import { LabelLayout, UniversalTransition } from 'echarts/features';
|
||||
import { CanvasRenderer } from 'echarts/renderers';
|
||||
import Chart from 'chart.js/auto'; // << 关键改动:导入Chart.js
|
||||
import { computed, onUnmounted } from 'vue';
|
||||
|
||||
// 注册必须的组件
|
||||
echarts.use([
|
||||
TitleComponent,
|
||||
TooltipComponent,
|
||||
GridComponent,
|
||||
DatasetComponent,
|
||||
TransformComponent,
|
||||
LegendComponent,
|
||||
ToolboxComponent,
|
||||
MarkLineComponent,
|
||||
MarkPointComponent,
|
||||
LineChart,
|
||||
BarChart,
|
||||
LabelLayout,
|
||||
UniversalTransition,
|
||||
CanvasRenderer
|
||||
]);
|
||||
|
||||
const loading = ref(false);
|
||||
const history = ref([]);
|
||||
const products = ref([]);
|
||||
@ -292,8 +260,8 @@ const currentPrediction = ref(null);
|
||||
const rawResponseData = ref(null);
|
||||
const showRawDataFlag = ref(false);
|
||||
|
||||
const fullscreenPredictionChart = ref(null);
|
||||
const fullscreenHistoryChart = ref(null);
|
||||
let predictionChart = null; // << 关键改动:使用单个chart实例
|
||||
let historyChart = null;
|
||||
|
||||
const filters = reactive({
|
||||
product_id: '',
|
||||
@ -982,104 +950,133 @@ const getFactorsArray = computed(() => {
|
||||
watch(detailsVisible, (newVal) => {
|
||||
if (newVal && currentPrediction.value) {
|
||||
nextTick(() => {
|
||||
// Init Prediction Chart
|
||||
if (fullscreenPredictionChart.value) fullscreenPredictionChart.value.dispose();
|
||||
const predChartDom = document.getElementById('fullscreen-prediction-chart-history');
|
||||
if (predChartDom) {
|
||||
fullscreenPredictionChart.value = echarts.init(predChartDom);
|
||||
if (currentPrediction.value.chart_data) {
|
||||
updatePredictionChart(currentPrediction.value.chart_data, fullscreenPredictionChart.value, true);
|
||||
}
|
||||
}
|
||||
|
||||
// Init History Chart
|
||||
if (currentPrediction.value.analysis) {
|
||||
if (fullscreenHistoryChart.value) fullscreenHistoryChart.value.dispose();
|
||||
const histChartDom = document.getElementById('fullscreen-history-chart-history');
|
||||
if (histChartDom) {
|
||||
fullscreenHistoryChart.value = echarts.init(histChartDom);
|
||||
updateHistoryChart(currentPrediction.value.analysis, fullscreenHistoryChart.value, true);
|
||||
}
|
||||
}
|
||||
renderChart();
|
||||
// 可以在这里添加渲染第二个图表的逻辑
|
||||
// renderHistoryAnalysisChart();
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
const updatePredictionChart = (chartData, chart, isFullscreen = false) => {
|
||||
if (!chart || !chartData) return;
|
||||
chart.showLoading();
|
||||
const dates = chartData.dates || [];
|
||||
const sales = chartData.sales || [];
|
||||
const types = chartData.types || [];
|
||||
// << 关键改动:从ProductPredictionView.vue复制并适应的renderChart函数
|
||||
const renderChart = () => {
|
||||
const chartCanvas = document.getElementById('fullscreen-prediction-chart-history');
|
||||
if (!chartCanvas || !currentPrediction.value || !currentPrediction.value.data) return;
|
||||
|
||||
const combinedData = [];
|
||||
for (let i = 0; i < dates.length; i++) {
|
||||
combinedData.push({ date: dates[i], sales: sales[i], type: types[i] });
|
||||
if (predictionChart) {
|
||||
predictionChart.destroy();
|
||||
}
|
||||
combinedData.sort((a, b) => new Date(a.date) - new Date(b.date));
|
||||
|
||||
const allDates = combinedData.map(item => item.date);
|
||||
const historyDates = combinedData.filter(d => d.type === '历史销量').map(d => d.date);
|
||||
const historySales = combinedData.filter(d => d.type === '历史销量').map(d => d.sales);
|
||||
const predictionDates = combinedData.filter(d => d.type === '预测销量').map(d => d.date);
|
||||
const predictionSales = combinedData.filter(d => d.type === '预测销量').map(d => d.sales);
|
||||
|
||||
const allSales = [...historySales, ...predictionSales].filter(val => !isNaN(val));
|
||||
const minSale = Math.max(0, Math.floor(Math.min(...allSales) * 0.9));
|
||||
const maxSale = Math.ceil(Math.max(...allSales) * 1.1);
|
||||
|
||||
const option = {
|
||||
title: { text: '销量预测趋势图', left: 'center', textStyle: { fontSize: isFullscreen ? 18 : 16, fontWeight: 'bold', color: '#e0e6ff' } },
|
||||
tooltip: { trigger: 'axis', axisPointer: { type: 'cross' },
|
||||
formatter: function(params) {
|
||||
if (!params || params.length === 0) return '';
|
||||
const date = params[0].axisValue;
|
||||
let html = `<div style="font-weight:bold">${date}</div>`;
|
||||
params.forEach(item => {
|
||||
if (item.value !== '-') {
|
||||
html += `<div style="display:flex;justify-content:space-between;align-items:center;margin:5px 0;">
|
||||
<span style="display:inline-block;margin-right:5px;width:10px;height:10px;border-radius:50%;background-color:${item.color};"></span>
|
||||
<span>${item.seriesName}:</span>
|
||||
<span style="font-weight:bold;margin-left:5px;">${item.value.toFixed(2)}</span>
|
||||
</div>`;
|
||||
}
|
||||
});
|
||||
return html;
|
||||
}
|
||||
const formatDate = (date) => new Date(date).toISOString().split('T')[0];
|
||||
|
||||
const historyData = (currentPrediction.value.data.history_data || []).map(p => ({ ...p, date: formatDate(p.date) }));
|
||||
const predictionData = (currentPrediction.value.data.prediction_data || []).map(p => ({ ...p, date: formatDate(p.date) }));
|
||||
|
||||
if (historyData.length === 0 && predictionData.length === 0) {
|
||||
ElMessage.warning('没有可用于图表的数据。');
|
||||
return;
|
||||
}
|
||||
|
||||
const allLabels = [...new Set([...historyData.map(p => p.date), ...predictionData.map(p => p.date)])].sort();
|
||||
const simplifiedLabels = allLabels.map(date => date.split('-')[2]);
|
||||
|
||||
const historyMap = new Map(historyData.map(p => [p.date, p.sales]));
|
||||
// 注意:这里使用 'sales' 字段,因为后端已经统一了
|
||||
const predictionMap = new Map(predictionData.map(p => [p.date, p.sales]));
|
||||
|
||||
const alignedHistorySales = allLabels.map(label => historyMap.get(label) ?? null);
|
||||
const alignedPredictionSales = allLabels.map(label => predictionMap.get(label) ?? null);
|
||||
|
||||
if (historyData.length > 0 && predictionData.length > 0) {
|
||||
const lastHistoryDate = historyData[historyData.length - 1].date;
|
||||
const lastHistoryValue = historyData[historyData.length - 1].sales;
|
||||
if (!predictionMap.has(lastHistoryDate)) {
|
||||
alignedPredictionSales[allLabels.indexOf(lastHistoryDate)] = lastHistoryValue;
|
||||
}
|
||||
}
|
||||
|
||||
let subtitleText = '';
|
||||
if (historyData.length > 0) {
|
||||
subtitleText += `历史数据: ${historyData[0].date} ~ ${historyData[historyData.length - 1].date}`;
|
||||
}
|
||||
if (predictionData.length > 0) {
|
||||
if (subtitleText) subtitleText += ' | ';
|
||||
subtitleText += `预测数据: ${predictionData[0].date} ~ ${predictionData[predictionData.length - 1].date}`;
|
||||
}
|
||||
|
||||
predictionChart = new Chart(chartCanvas, {
|
||||
type: 'line',
|
||||
data: {
|
||||
labels: simplifiedLabels,
|
||||
datasets: [
|
||||
{
|
||||
label: '历史销量',
|
||||
data: alignedHistorySales,
|
||||
borderColor: '#67C23A',
|
||||
backgroundColor: 'rgba(103, 194, 58, 0.2)',
|
||||
tension: 0.4,
|
||||
fill: true,
|
||||
spanGaps: false,
|
||||
},
|
||||
{
|
||||
label: '预测销量',
|
||||
data: alignedPredictionSales,
|
||||
borderColor: '#409EFF',
|
||||
backgroundColor: 'rgba(64, 158, 255, 0.2)',
|
||||
tension: 0.4,
|
||||
fill: true,
|
||||
borderDash: [5, 5],
|
||||
}
|
||||
]
|
||||
},
|
||||
legend: { data: ['历史销量', '预测销量'], top: isFullscreen ? 40 : 30, textStyle: { color: '#e0e6ff' } },
|
||||
grid: { left: '3%', right: '4%', bottom: '3%', containLabel: true },
|
||||
toolbox: { feature: { saveAsImage: { title: '保存图片' } }, iconStyle: { borderColor: '#e0e6ff' } },
|
||||
xAxis: { type: 'category', boundaryGap: false, data: allDates, axisLabel: { color: '#e0e6ff' }, axisLine: { lineStyle: { color: 'rgba(224, 230, 255, 0.5)' } } },
|
||||
yAxis: { type: 'value', name: '销量', min: minSale, max: maxSale, axisLabel: { color: '#e0e6ff' }, nameTextStyle: { color: '#e0e6ff' }, axisLine: { lineStyle: { color: 'rgba(224, 230, 255, 0.5)' } }, splitLine: { lineStyle: { color: 'rgba(224, 230, 255, 0.1)' } } },
|
||||
series: [
|
||||
{ name: '历史销量', type: 'line', smooth: true, connectNulls: true, data: allDates.map(date => historyDates.includes(date) ? historySales[historyDates.indexOf(date)] : null), areaStyle: { color: new echarts.graphic.LinearGradient(0, 0, 0, 1, [{ offset: 0, color: 'rgba(64, 158, 255, 0.3)' }, { offset: 1, color: 'rgba(64, 158, 255, 0.1)' }]) }, lineStyle: { color: '#409EFF' } },
|
||||
{ name: '预测销量', type: 'line', smooth: true, connectNulls: true, data: allDates.map(date => predictionDates.includes(date) ? predictionSales[predictionDates.indexOf(date)] : null), lineStyle: { color: '#F56C6C' } }
|
||||
]
|
||||
};
|
||||
chart.hideLoading();
|
||||
chart.setOption(option, true);
|
||||
};
|
||||
|
||||
const updateHistoryChart = (analysisData, chart, isFullscreen = false) => {
|
||||
if (!chart || !analysisData || !analysisData.history_chart_data) return;
|
||||
chart.showLoading();
|
||||
const { dates, changes } = analysisData.history_chart_data;
|
||||
|
||||
const option = {
|
||||
title: { text: '销量日环比变化', left: 'center', textStyle: { fontSize: isFullscreen ? 18 : 16, fontWeight: 'bold', color: '#e0e6ff' } },
|
||||
tooltip: { trigger: 'axis', axisPointer: { type: 'shadow' }, formatter: p => `${p[0].axisValue}<br/>环比: ${p[0].value.toFixed(2)}%` },
|
||||
grid: { left: '3%', right: '4%', bottom: '3%', containLabel: true },
|
||||
toolbox: { feature: { saveAsImage: { title: '保存图片' } }, iconStyle: { borderColor: '#e0e6ff' } },
|
||||
xAxis: { type: 'category', data: dates.map(d => formatDate(d)), axisLabel: { color: '#e0e6ff' }, axisLine: { lineStyle: { color: 'rgba(224, 230, 255, 0.5)' } } },
|
||||
yAxis: { type: 'value', name: '环比变化(%)', axisLabel: { formatter: '{value}%', color: '#e0e6ff' }, nameTextStyle: { color: '#e0e6ff' }, axisLine: { lineStyle: { color: 'rgba(224, 230, 255, 0.5)' } }, splitLine: { lineStyle: { color: 'rgba(224, 230, 255, 0.1)' } } },
|
||||
series: [{
|
||||
name: '日环比变化', type: 'bar',
|
||||
data: changes.map(val => ({ value: val, itemStyle: { color: val >= 0 ? '#67C23A' : '#F56C6C' } }))
|
||||
}]
|
||||
};
|
||||
chart.hideLoading();
|
||||
chart.setOption(option, true);
|
||||
options: {
|
||||
responsive: true,
|
||||
maintainAspectRatio: false,
|
||||
plugins: {
|
||||
title: {
|
||||
display: true,
|
||||
text: `${currentPrediction.value.data.product_name} - 销量预测趋势图`,
|
||||
color: '#ffffff',
|
||||
font: {
|
||||
size: 20,
|
||||
weight: 'bold',
|
||||
}
|
||||
},
|
||||
subtitle: {
|
||||
display: true,
|
||||
text: subtitleText,
|
||||
color: '#6c757d',
|
||||
font: {
|
||||
size: 14,
|
||||
},
|
||||
padding: {
|
||||
bottom: 20
|
||||
}
|
||||
}
|
||||
},
|
||||
scales: {
|
||||
x: {
|
||||
title: {
|
||||
display: true,
|
||||
text: '日期 (日)'
|
||||
},
|
||||
grid: {
|
||||
display: false
|
||||
}
|
||||
},
|
||||
y: {
|
||||
title: {
|
||||
display: true,
|
||||
text: '销量'
|
||||
},
|
||||
grid: {
|
||||
color: '#e9e9e9',
|
||||
drawBorder: false,
|
||||
},
|
||||
beginAtZero: true
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
const exportHistoryData = () => {
|
||||
@ -1102,16 +1099,24 @@ const exportHistoryData = () => {
|
||||
};
|
||||
|
||||
const resizeCharts = () => {
|
||||
if (fullscreenPredictionChart.value) fullscreenPredictionChart.value.resize();
|
||||
if (fullscreenHistoryChart.value) fullscreenHistoryChart.value.resize();
|
||||
if (predictionChart) {
|
||||
predictionChart.resize();
|
||||
}
|
||||
if (historyChart) {
|
||||
historyChart.resize();
|
||||
}
|
||||
};
|
||||
|
||||
window.addEventListener('resize', resizeCharts);
|
||||
|
||||
onUnmounted(() => {
|
||||
window.removeEventListener('resize', resizeCharts);
|
||||
if (fullscreenPredictionChart.value) fullscreenPredictionChart.value.dispose();
|
||||
if (fullscreenHistoryChart.value) fullscreenHistoryChart.value.dispose();
|
||||
if (predictionChart) {
|
||||
predictionChart.destroy();
|
||||
}
|
||||
if (historyChart) {
|
||||
historyChart.destroy();
|
||||
}
|
||||
});
|
||||
|
||||
onMounted(() => {
|
||||
|
@ -1,6 +1,6 @@
|
||||
<template>
|
||||
<div class="store-management-container">
|
||||
<el-card class="full-height-card">
|
||||
<el-card>
|
||||
<template #header>
|
||||
<div class="card-header">
|
||||
<span>店铺管理</span>
|
||||
@ -18,70 +18,67 @@
|
||||
</template>
|
||||
|
||||
<!-- 搜索和过滤 -->
|
||||
<div class="table-container" ref="tableContainerRef">
|
||||
<div class="filter-section" ref="filterSectionRef">
|
||||
<el-row :gutter="20" align="middle">
|
||||
<el-col :span="6">
|
||||
<el-input
|
||||
v-model="searchQuery"
|
||||
placeholder="搜索店铺名称或ID"
|
||||
clearable
|
||||
@input="handleSearch"
|
||||
>
|
||||
<template #prefix>
|
||||
<el-icon><Search /></el-icon>
|
||||
</template>
|
||||
</el-input>
|
||||
</el-col>
|
||||
<el-col :span="4">
|
||||
<el-select v-model="statusFilter" placeholder="状态筛选" clearable @change="handleFilter" style="width: 100%;">
|
||||
<el-option label="全部状态" value="" />
|
||||
<el-option label="营业中" value="active" />
|
||||
<el-option label="暂停营业" value="inactive" />
|
||||
</el-select>
|
||||
</el-col>
|
||||
<el-col :span="4">
|
||||
<el-select v-model="typeFilter" placeholder="类型筛选" clearable @change="handleFilter" style="width: 100%;">
|
||||
<el-option label="全部类型" value="" />
|
||||
<el-option label="旗舰店" value="旗舰店" />
|
||||
<el-option label="标准店" value="标准店" />
|
||||
<el-option label="便民店" value="便民店" />
|
||||
<el-option label="社区店" value="社区店" />
|
||||
</el-select>
|
||||
</el-col>
|
||||
</el-row>
|
||||
</div>
|
||||
<div class="filter-section">
|
||||
<el-row :gutter="20">
|
||||
<el-col :span="6">
|
||||
<el-input
|
||||
v-model="searchQuery"
|
||||
placeholder="搜索店铺名称或ID"
|
||||
clearable
|
||||
@input="handleSearch"
|
||||
>
|
||||
<template #prefix>
|
||||
<el-icon><Search /></el-icon>
|
||||
</template>
|
||||
</el-input>
|
||||
</el-col>
|
||||
<el-col :span="4">
|
||||
<el-select v-model="statusFilter" placeholder="状态筛选" clearable @change="handleFilter">
|
||||
<el-option label="全部状态" value="" />
|
||||
<el-option label="营业中" value="active" />
|
||||
<el-option label="暂停营业" value="inactive" />
|
||||
</el-select>
|
||||
</el-col>
|
||||
<el-col :span="4">
|
||||
<el-select v-model="typeFilter" placeholder="类型筛选" clearable @change="handleFilter">
|
||||
<el-option label="全部类型" value="" />
|
||||
<el-option label="旗舰店" value="旗舰店" />
|
||||
<el-option label="标准店" value="标准店" />
|
||||
<el-option label="便民店" value="便民店" />
|
||||
<el-option label="社区店" value="社区店" />
|
||||
</el-select>
|
||||
</el-col>
|
||||
</el-row>
|
||||
</div>
|
||||
|
||||
<!-- 店铺列表 -->
|
||||
<el-table
|
||||
:data="pagedStores"
|
||||
v-loading="loading"
|
||||
stripe
|
||||
@selection-change="handleSelectionChange"
|
||||
class="store-table"
|
||||
:height="tableHeight"
|
||||
>
|
||||
<!-- 店铺列表 -->
|
||||
<el-table
|
||||
:data="filteredStores"
|
||||
v-loading="loading"
|
||||
stripe
|
||||
@selection-change="handleSelectionChange"
|
||||
>
|
||||
<el-table-column type="selection" width="55" />
|
||||
<el-table-column prop="store_id" label="店铺ID" width="100" align="center" />
|
||||
<el-table-column prop="store_name" label="店铺名称" width="250" align="center" show-overflow-tooltip />
|
||||
<el-table-column prop="location" label="位置" width="250" align="center" show-overflow-tooltip/>
|
||||
<el-table-column prop="type" label="类型" width="120" align="center">
|
||||
<el-table-column prop="store_id" label="店铺ID" width="100" />
|
||||
<el-table-column prop="store_name" label="店铺名称" width="150" />
|
||||
<el-table-column prop="location" label="位置" width="200" />
|
||||
<el-table-column prop="type" label="类型" width="100">
|
||||
<template #default="{ row }">
|
||||
<el-tag :type="getStoreTypeTag(row.type)">
|
||||
{{ row.type }}
|
||||
</el-tag>
|
||||
</template>
|
||||
</el-table-column>
|
||||
<el-table-column prop="size" label="面积(㎡)" width="150" align="center"/>
|
||||
<el-table-column prop="opening_date" label="开业日期" width="150" align="center"/>
|
||||
<el-table-column prop="status" label="状态" width="150" align="center">
|
||||
<el-table-column prop="size" label="面积(㎡)" width="100" align="right" />
|
||||
<el-table-column prop="opening_date" label="开业日期" width="120" />
|
||||
<el-table-column prop="status" label="状态" width="100">
|
||||
<template #default="{ row }">
|
||||
<el-tag :type="row.status === 'active' ? 'success' : 'danger'">
|
||||
{{ row.status === 'active' ? '营业中' : '暂停营业' }}
|
||||
</el-tag>
|
||||
</template>
|
||||
</el-table-column>
|
||||
<el-table-column label="操作" width="200" fixed="right" align="center">
|
||||
<el-table-column label="操作" width="200" fixed="right">
|
||||
<template #default="{ row }">
|
||||
<el-button link type="primary" @click="viewStoreDetails(row)">
|
||||
详情
|
||||
@ -102,16 +99,14 @@
|
||||
<!-- 分页 -->
|
||||
<el-pagination
|
||||
v-if="total > pageSize"
|
||||
layout="total, prev, pager, next, jumper"
|
||||
layout="total, sizes, prev, pager, next, jumper"
|
||||
:total="total"
|
||||
:page-size="pageSize"
|
||||
:current-page="currentPage"
|
||||
:page-sizes="[10, 20, 50, 100]"
|
||||
@current-change="handlePageChange"
|
||||
@size-change="handleSizeChange"
|
||||
class="pagination"
|
||||
ref="paginationRef"
|
||||
/>
|
||||
</div>
|
||||
</el-card>
|
||||
|
||||
<!-- 新增/编辑店铺对话框 -->
|
||||
@ -120,7 +115,6 @@
|
||||
:title="isEditing ? '编辑店铺' : '新增店铺'"
|
||||
width="600px"
|
||||
@close="resetForm"
|
||||
class="form-dialog"
|
||||
>
|
||||
<el-form
|
||||
ref="formRef"
|
||||
@ -235,7 +229,6 @@
|
||||
</div>
|
||||
</el-dialog>
|
||||
|
||||
|
||||
<!-- 店铺产品对话框 -->
|
||||
<el-dialog
|
||||
v-model="productsDialogVisible"
|
||||
@ -262,7 +255,7 @@
|
||||
</template>
|
||||
|
||||
<script setup>
|
||||
import { ref, onMounted, onUnmounted, computed, nextTick } from 'vue'
|
||||
import { ref, onMounted, computed } from 'vue'
|
||||
import axios from 'axios'
|
||||
import { ElMessage, ElMessageBox } from 'element-plus'
|
||||
import { Plus, Refresh, Search } from '@element-plus/icons-vue'
|
||||
@ -279,15 +272,9 @@ const typeFilter = ref('')
|
||||
|
||||
// 分页
|
||||
const currentPage = ref(1)
|
||||
const pageSize = ref(12)
|
||||
const pageSize = ref(20)
|
||||
const total = ref(0)
|
||||
|
||||
// 布局和高度
|
||||
const tableContainerRef = ref(null);
|
||||
const filterSectionRef = ref(null);
|
||||
const paginationRef = ref(null);
|
||||
const tableHeight = ref(400); // 默认高度
|
||||
|
||||
// 对话框
|
||||
const dialogVisible = ref(false)
|
||||
const detailDialogVisible = ref(false)
|
||||
@ -332,34 +319,34 @@ const rules = {
|
||||
|
||||
// 计算属性
|
||||
const filteredStores = computed(() => {
|
||||
let result = stores.value;
|
||||
let result = stores.value
|
||||
|
||||
// 搜索过滤
|
||||
if (searchQuery.value) {
|
||||
const query = searchQuery.value.toLowerCase();
|
||||
result = result.filter(
|
||||
(store) =>
|
||||
store.store_name.toLowerCase().includes(query) ||
|
||||
store.store_id.toLowerCase().includes(query)
|
||||
);
|
||||
const query = searchQuery.value.toLowerCase()
|
||||
result = result.filter(store =>
|
||||
store.store_name.toLowerCase().includes(query) ||
|
||||
store.store_id.toLowerCase().includes(query)
|
||||
)
|
||||
}
|
||||
|
||||
// 状态过滤
|
||||
if (statusFilter.value) {
|
||||
result = result.filter((store) => store.status === statusFilter.value);
|
||||
result = result.filter(store => store.status === statusFilter.value)
|
||||
}
|
||||
|
||||
// 类型过滤
|
||||
if (typeFilter.value) {
|
||||
result = result.filter((store) => store.type === typeFilter.value);
|
||||
result = result.filter(store => store.type === typeFilter.value)
|
||||
}
|
||||
|
||||
return result;
|
||||
});
|
||||
|
||||
const pagedStores = computed(() => {
|
||||
const start = (currentPage.value - 1) * pageSize.value;
|
||||
const end = start + pageSize.value;
|
||||
total.value = filteredStores.value.length;
|
||||
return filteredStores.value.slice(start, end);
|
||||
});
|
||||
total.value = result.length
|
||||
|
||||
// 分页
|
||||
const start = (currentPage.value - 1) * pageSize.value
|
||||
const end = start + pageSize.value
|
||||
return result.slice(start, end)
|
||||
})
|
||||
|
||||
// 方法
|
||||
const fetchStores = async () => {
|
||||
@ -529,49 +516,14 @@ const viewStoreProducts = async (store) => {
|
||||
}
|
||||
|
||||
// 生命周期
|
||||
const updateTableHeight = () => {
|
||||
nextTick(() => {
|
||||
if (tableContainerRef.value) {
|
||||
const containerHeight = tableContainerRef.value.clientHeight;
|
||||
const filterHeight = filterSectionRef.value?.offsetHeight || 0;
|
||||
const paginationHeight = paginationRef.value?.$el.offsetHeight || 0;
|
||||
|
||||
// 减去筛选区、分页区以及一些间距
|
||||
const calculatedHeight = containerHeight - filterHeight - paginationHeight - 20;
|
||||
tableHeight.value = calculatedHeight > 200 ? calculatedHeight : 200; // 最小高度
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
onMounted(() => {
|
||||
fetchStores();
|
||||
updateTableHeight();
|
||||
window.addEventListener('resize', updateTableHeight);
|
||||
});
|
||||
|
||||
onUnmounted(() => {
|
||||
window.removeEventListener('resize', updateTableHeight);
|
||||
});
|
||||
fetchStores()
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.store-management-container {
|
||||
height: 97%;
|
||||
padding: 6px 10px 15px 15px;
|
||||
}
|
||||
|
||||
.full-height-card {
|
||||
height: 100%;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
:deep(.el-card__body) {
|
||||
flex-grow: 1;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
padding: 20px;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.card-header {
|
||||
@ -585,35 +537,21 @@ onUnmounted(() => {
|
||||
gap: 10px;
|
||||
}
|
||||
|
||||
.table-container {
|
||||
flex-grow: 1;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
overflow: hidden; /* 确保容器本身不滚动 */
|
||||
}
|
||||
|
||||
.filter-section {
|
||||
padding-bottom: 20px;
|
||||
}
|
||||
|
||||
|
||||
.store-table {
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
:deep(.store-table .el-table__cell) {
|
||||
padding: 12px 2px;
|
||||
margin-bottom: 20px;
|
||||
padding: 20px;
|
||||
background-color: #f8f9fa;
|
||||
border-radius: 8px;
|
||||
}
|
||||
|
||||
.pagination {
|
||||
margin-top: 20px;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
padding: 14px 0;
|
||||
}
|
||||
|
||||
.store-detail {
|
||||
padding: 5px 0;
|
||||
padding: 10px 0;
|
||||
}
|
||||
|
||||
.store-stats {
|
||||
@ -641,9 +579,4 @@ onUnmounted(() => {
|
||||
gap: 5px;
|
||||
}
|
||||
}
|
||||
|
||||
.form-dialog :deep(.el-dialog) {
|
||||
background: transparent;
|
||||
box-shadow: none;
|
||||
}
|
||||
</style>
|
||||
</style>
|
@ -34,7 +34,7 @@
|
||||
</el-row>
|
||||
|
||||
<el-row :gutter="20" v-if="form.model_type">
|
||||
<el-col :span="6">
|
||||
<el-col :span="5">
|
||||
<el-form-item label="模型版本">
|
||||
<el-select
|
||||
v-model="form.version"
|
||||
@ -52,7 +52,7 @@
|
||||
</el-select>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
<el-col :span="6">
|
||||
<el-col :span="5">
|
||||
<el-form-item label="预测天数">
|
||||
<el-input-number
|
||||
v-model="form.future_days"
|
||||
@ -62,7 +62,17 @@
|
||||
/>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
<el-col :span="6">
|
||||
<el-col :span="5">
|
||||
<el-form-item label="历史天数">
|
||||
<el-input-number
|
||||
v-model="form.history_lookback_days"
|
||||
:min="7"
|
||||
:max="365"
|
||||
style="width: 100%"
|
||||
/>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
<el-col :span="5">
|
||||
<el-form-item label="起始日期">
|
||||
<el-date-picker
|
||||
v-model="form.start_date"
|
||||
@ -75,7 +85,7 @@
|
||||
/>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
<el-col :span="6">
|
||||
<el-col :span="4">
|
||||
<el-form-item label="预测分析">
|
||||
<el-switch
|
||||
v-model="form.analyze_result"
|
||||
@ -135,6 +145,7 @@ const form = reactive({
|
||||
model_type: '',
|
||||
version: '',
|
||||
future_days: 7,
|
||||
history_lookback_days: 30,
|
||||
start_date: '',
|
||||
analyze_result: true
|
||||
})
|
||||
@ -185,9 +196,11 @@ const startPrediction = async () => {
|
||||
try {
|
||||
predicting.value = true
|
||||
const payload = {
|
||||
training_mode: 'global', // 明确指定训练模式
|
||||
model_type: form.model_type,
|
||||
version: form.version,
|
||||
future_days: form.future_days,
|
||||
history_lookback_days: form.history_lookback_days,
|
||||
start_date: form.start_date,
|
||||
analyze_result: form.analyze_result
|
||||
}
|
||||
@ -212,28 +225,113 @@ const renderChart = () => {
|
||||
if (chart) {
|
||||
chart.destroy()
|
||||
}
|
||||
const predictions = predictionResult.value.predictions
|
||||
const labels = predictions.map(p => new Date(p.date).toLocaleDateString('zh-CN', { weekday: 'short', year: 'numeric', month: 'long', day: 'numeric' }))
|
||||
const data = predictions.map(p => p.sales)
|
||||
|
||||
const formatDate = (date) => new Date(date).toISOString().split('T')[0];
|
||||
|
||||
const historyData = (predictionResult.value.history_data || []).map(p => ({ ...p, date: formatDate(p.date) }));
|
||||
const predictionData = (predictionResult.value.prediction_data || []).map(p => ({ ...p, date: formatDate(p.date) }));
|
||||
|
||||
if (historyData.length === 0 && predictionData.length === 0) {
|
||||
ElMessage.warning('没有可用于图表的数据。')
|
||||
return
|
||||
}
|
||||
|
||||
const allLabels = [...new Set([...historyData.map(p => p.date), ...predictionData.map(p => p.date)])].sort()
|
||||
const simplifiedLabels = allLabels.map(date => date.split('-')[2]);
|
||||
|
||||
const historyMap = new Map(historyData.map(p => [p.date, p.sales]))
|
||||
const predictionMap = new Map(predictionData.map(p => [p.date, p.predicted_sales]))
|
||||
|
||||
const alignedHistorySales = allLabels.map(label => historyMap.get(label) ?? null)
|
||||
const alignedPredictionSales = allLabels.map(label => predictionMap.get(label) ?? null)
|
||||
|
||||
if (historyData.length > 0 && predictionData.length > 0) {
|
||||
const lastHistoryDate = historyData[historyData.length - 1].date
|
||||
const lastHistoryValue = historyData[historyData.length - 1].sales
|
||||
if (!predictionMap.has(lastHistoryDate)) {
|
||||
alignedPredictionSales[allLabels.indexOf(lastHistoryDate)] = lastHistoryValue
|
||||
}
|
||||
}
|
||||
|
||||
let subtitleText = '';
|
||||
if (historyData.length > 0) {
|
||||
subtitleText += `历史数据: ${historyData[0].date} ~ ${historyData[historyData.length - 1].date}`;
|
||||
}
|
||||
if (predictionData.length > 0) {
|
||||
if (subtitleText) subtitleText += ' | ';
|
||||
subtitleText += `预测数据: ${predictionData[0].date} ~ ${predictionData[predictionData.length - 1].date}`;
|
||||
}
|
||||
|
||||
chart = new Chart(chartCanvas.value, {
|
||||
type: 'line',
|
||||
data: {
|
||||
labels,
|
||||
datasets: [{
|
||||
label: '预测销量',
|
||||
data,
|
||||
borderColor: '#409EFF',
|
||||
backgroundColor: 'rgba(64, 158, 255, 0.1)',
|
||||
tension: 0.4,
|
||||
fill: true
|
||||
}]
|
||||
labels: simplifiedLabels,
|
||||
datasets: [
|
||||
{
|
||||
label: '历史销量',
|
||||
data: alignedHistorySales,
|
||||
borderColor: '#67C23A',
|
||||
backgroundColor: 'rgba(103, 194, 58, 0.2)',
|
||||
tension: 0.4,
|
||||
fill: true,
|
||||
spanGaps: false,
|
||||
},
|
||||
{
|
||||
label: '预测销量',
|
||||
data: alignedPredictionSales,
|
||||
borderColor: '#409EFF',
|
||||
backgroundColor: 'rgba(64, 158, 255, 0.2)',
|
||||
tension: 0.4,
|
||||
fill: true,
|
||||
borderDash: [5, 5],
|
||||
}
|
||||
]
|
||||
},
|
||||
options: {
|
||||
responsive: true,
|
||||
maintainAspectRatio: false,
|
||||
plugins: {
|
||||
title: {
|
||||
display: true,
|
||||
text: '销量预测趋势图'
|
||||
text: '全局销量预测趋势图',
|
||||
color: '#ffffff',
|
||||
font: {
|
||||
size: 20,
|
||||
weight: 'bold',
|
||||
}
|
||||
},
|
||||
subtitle: {
|
||||
display: true,
|
||||
text: subtitleText,
|
||||
color: '#6c757d',
|
||||
font: {
|
||||
size: 14,
|
||||
},
|
||||
padding: {
|
||||
bottom: 20
|
||||
}
|
||||
}
|
||||
},
|
||||
scales: {
|
||||
x: {
|
||||
title: {
|
||||
display: true,
|
||||
text: '日期 (日)'
|
||||
},
|
||||
grid: {
|
||||
display: false
|
||||
}
|
||||
},
|
||||
y: {
|
||||
title: {
|
||||
display: true,
|
||||
text: '销量'
|
||||
},
|
||||
grid: {
|
||||
color: '#e9e9e9',
|
||||
drawBorder: false,
|
||||
},
|
||||
beginAtZero: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -4,157 +4,147 @@
|
||||
<template #header>
|
||||
<div class="card-header">
|
||||
<span>按药品预测</span>
|
||||
<el-tooltip content="使用针对特定药品训练的模型进行销售预测">
|
||||
<el-tooltip content="对系统中的所有药品模型进行批量或单个预测">
|
||||
<el-icon><QuestionFilled /></el-icon>
|
||||
</el-tooltip>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<div class="model-selection-section">
|
||||
<h4>🎯 选择预测模型</h4>
|
||||
<el-form :model="form" label-width="120px">
|
||||
<el-row :gutter="20">
|
||||
<el-col :span="8">
|
||||
<el-form-item label="目标药品">
|
||||
<ProductSelector
|
||||
v-model="form.product_id"
|
||||
@change="handleProductChange"
|
||||
:show-all-option="false"
|
||||
/>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
<el-col :span="8">
|
||||
<el-form-item label="算法类型">
|
||||
<el-select
|
||||
v-model="form.model_type"
|
||||
placeholder="选择算法"
|
||||
@change="handleModelTypeChange"
|
||||
style="width: 100%"
|
||||
:disabled="!form.product_id"
|
||||
>
|
||||
<el-option
|
||||
v-for="item in modelTypes"
|
||||
:key="item.id"
|
||||
:label="item.name"
|
||||
:value="item.id"
|
||||
/>
|
||||
</el-select>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
</el-row>
|
||||
|
||||
<el-row :gutter="20" v-if="form.model_type">
|
||||
<el-col :span="6">
|
||||
<el-form-item label="模型版本">
|
||||
<el-select
|
||||
v-model="form.version"
|
||||
placeholder="选择版本"
|
||||
style="width: 100%"
|
||||
:disabled="!availableVersions.length"
|
||||
:loading="versionsLoading"
|
||||
>
|
||||
<el-option
|
||||
v-for="version in availableVersions"
|
||||
:key="version"
|
||||
:label="version"
|
||||
:value="version"
|
||||
/>
|
||||
</el-select>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
<el-col :span="6">
|
||||
<el-form-item label="预测天数">
|
||||
<el-input-number
|
||||
v-model="form.future_days"
|
||||
:min="1"
|
||||
:max="365"
|
||||
style="width: 100%"
|
||||
/>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
<el-col :span="6">
|
||||
<el-form-item label="起始日期">
|
||||
<el-date-picker
|
||||
v-model="form.start_date"
|
||||
type="date"
|
||||
placeholder="选择日期"
|
||||
format="YYYY-MM-DD"
|
||||
value-format="YYYY-MM-DD"
|
||||
style="width: 100%"
|
||||
:clearable="false"
|
||||
/>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
<el-col :span="6">
|
||||
<el-form-item label="预测分析">
|
||||
<el-switch
|
||||
v-model="form.analyze_result"
|
||||
active-text="开启"
|
||||
inactive-text="关闭"
|
||||
/>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
</el-row>
|
||||
<div class="controls-section">
|
||||
<el-form :model="filters" label-width="80px" inline>
|
||||
<el-form-item label="目标药品">
|
||||
<ProductSelector
|
||||
v-model="filters.product_id"
|
||||
:show-all-option="true"
|
||||
all-option-label="所有药品"
|
||||
/>
|
||||
</el-form-item>
|
||||
<el-form-item label="算法类型">
|
||||
<el-select v-model="filters.model_type" placeholder="所有类型" clearable>
|
||||
<el-option
|
||||
v-for="item in modelTypes"
|
||||
:key="item.id"
|
||||
:label="item.name"
|
||||
:value="item.id"
|
||||
/>
|
||||
</el-select>
|
||||
</el-form-item>
|
||||
<el-form-item label="预测天数">
|
||||
<el-input-number v-model="form.future_days" :min="1" :max="365" />
|
||||
</el-form-item>
|
||||
<el-form-item label="历史天数">
|
||||
<el-input-number v-model="form.history_lookback_days" :min="7" :max="365" />
|
||||
</el-form-item>
|
||||
<el-form-item label="起始日期">
|
||||
<el-date-picker
|
||||
v-model="form.start_date"
|
||||
type="date"
|
||||
placeholder="选择日期"
|
||||
format="YYYY-MM-DD"
|
||||
value-format="YYYY-MM-DD"
|
||||
:clearable="false"
|
||||
/>
|
||||
</el-form-item>
|
||||
</el-form>
|
||||
</div>
|
||||
|
||||
<div class="prediction-actions">
|
||||
<el-button
|
||||
type="primary"
|
||||
size="large"
|
||||
@click="startPrediction"
|
||||
:loading="predicting"
|
||||
:disabled="!canPredict"
|
||||
>
|
||||
<el-icon><TrendCharts /></el-icon>
|
||||
开始预测
|
||||
</el-button>
|
||||
<!-- 模型列表 -->
|
||||
<div class="model-list-section">
|
||||
<h4>📦 可用药品模型列表</h4>
|
||||
<el-table :data="paginatedModelList" style="width: 100%" v-loading="modelsLoading">
|
||||
<el-table-column prop="product_name" label="药品名称" sortable />
|
||||
<el-table-column prop="model_type" label="模型类型" sortable />
|
||||
<el-table-column prop="version" label="版本" />
|
||||
<el-table-column prop="created_at" label="创建时间" />
|
||||
<el-table-column label="操作">
|
||||
<template #default="{ row }">
|
||||
<el-button
|
||||
type="primary"
|
||||
size="small"
|
||||
@click="startPrediction(row)"
|
||||
:loading="predicting[row.model_id]"
|
||||
>
|
||||
<el-icon><TrendCharts /></el-icon>
|
||||
开始预测
|
||||
</el-button>
|
||||
</template>
|
||||
</el-table-column>
|
||||
</el-table>
|
||||
<el-pagination
|
||||
background
|
||||
layout="prev, pager, next"
|
||||
:total="filteredModelList.length"
|
||||
:page-size="pagination.pageSize"
|
||||
@current-change="handlePageChange"
|
||||
style="margin-top: 20px; justify-content: center;"
|
||||
/>
|
||||
</div>
|
||||
</el-card>
|
||||
|
||||
<el-card v-if="predictionResult" style="margin-top: 20px">
|
||||
<template #header>
|
||||
<div class="card-header">
|
||||
<span>📈 预测结果</span>
|
||||
</div>
|
||||
</template>
|
||||
<!-- 预测结果弹窗 -->
|
||||
<el-dialog v-model="dialogVisible" title="📈 预测结果" width="70%">
|
||||
<div class="prediction-chart">
|
||||
<canvas ref="chartCanvas" width="800" height="400"></canvas>
|
||||
</div>
|
||||
</el-card>
|
||||
<template #footer>
|
||||
<el-button @click="dialogVisible = false">关闭</el-button>
|
||||
</template>
|
||||
</el-dialog>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup>
|
||||
import { ref, reactive, onMounted, computed, watch, nextTick } from 'vue'
|
||||
import { ref, reactive, onMounted, nextTick, computed } from 'vue'
|
||||
import axios from 'axios'
|
||||
import { ElMessage } from 'element-plus'
|
||||
import { ElMessage, ElDialog, ElTable, ElTableColumn, ElButton, ElIcon, ElCard, ElTooltip, ElForm, ElFormItem, ElInputNumber, ElDatePicker, ElSelect, ElOption, ElRow, ElCol, ElPagination } from 'element-plus'
|
||||
import { QuestionFilled, TrendCharts } from '@element-plus/icons-vue'
|
||||
import Chart from 'chart.js/auto'
|
||||
import ProductSelector from '../../components/ProductSelector.vue'
|
||||
|
||||
const modelList = ref([])
|
||||
const modelTypes = ref([])
|
||||
const availableVersions = ref([])
|
||||
const versionsLoading = ref(false)
|
||||
const predicting = ref(false)
|
||||
const modelsLoading = ref(false)
|
||||
const predicting = reactive({})
|
||||
const dialogVisible = ref(false)
|
||||
const predictionResult = ref(null)
|
||||
const chartCanvas = ref(null)
|
||||
let chart = null
|
||||
|
||||
const form = reactive({
|
||||
training_mode: 'product',
|
||||
product_id: '',
|
||||
model_type: '',
|
||||
version: '',
|
||||
future_days: 7,
|
||||
history_lookback_days: 30,
|
||||
start_date: '',
|
||||
analyze_result: true
|
||||
analyze_result: true // 保持分析功能开启,但UI上移除开关
|
||||
})
|
||||
|
||||
const canPredict = computed(() => {
|
||||
return form.product_id && form.model_type && form.version
|
||||
const filters = reactive({
|
||||
product_id: '',
|
||||
model_type: ''
|
||||
})
|
||||
|
||||
const pagination = reactive({
|
||||
currentPage: 1,
|
||||
pageSize: 8
|
||||
})
|
||||
|
||||
const filteredModelList = computed(() => {
|
||||
return modelList.value.filter(model => {
|
||||
const productMatch = !filters.product_id || model.product_id === filters.product_id
|
||||
const modelTypeMatch = !filters.model_type || model.model_type === filters.model_type
|
||||
return productMatch && modelTypeMatch
|
||||
})
|
||||
})
|
||||
|
||||
const paginatedModelList = computed(() => {
|
||||
const start = (pagination.currentPage - 1) * pagination.pageSize
|
||||
const end = start + pagination.pageSize
|
||||
return filteredModelList.value.slice(start, end)
|
||||
})
|
||||
|
||||
const handlePageChange = (page) => {
|
||||
pagination.currentPage = page
|
||||
}
|
||||
|
||||
const fetchModelTypes = async () => {
|
||||
try {
|
||||
const response = await axios.get('/api/model_types')
|
||||
@ -166,72 +156,48 @@ const fetchModelTypes = async () => {
|
||||
}
|
||||
}
|
||||
|
||||
const fetchAvailableVersions = async () => {
|
||||
if (!form.product_id || !form.model_type) {
|
||||
availableVersions.value = []
|
||||
return
|
||||
}
|
||||
const fetchModels = async () => {
|
||||
modelsLoading.value = true
|
||||
try {
|
||||
versionsLoading.value = true
|
||||
const url = `/api/models/${form.product_id}/${form.model_type}/versions`
|
||||
const response = await axios.get(url)
|
||||
const response = await axios.get('/api/models', { params: { training_mode: 'product' } })
|
||||
if (response.data.status === 'success') {
|
||||
availableVersions.value = response.data.data.versions || []
|
||||
if (response.data.data.latest_version) {
|
||||
form.version = response.data.data.latest_version
|
||||
}
|
||||
modelList.value = response.data.data
|
||||
} else {
|
||||
ElMessage.error('获取模型列表失败')
|
||||
}
|
||||
} catch (error) {
|
||||
availableVersions.value = []
|
||||
ElMessage.error('获取模型列表失败')
|
||||
} finally {
|
||||
versionsLoading.value = false
|
||||
modelsLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const handleProductChange = () => {
|
||||
form.model_type = ''
|
||||
form.version = ''
|
||||
availableVersions.value = []
|
||||
}
|
||||
|
||||
const handleModelTypeChange = () => {
|
||||
form.version = ''
|
||||
fetchAvailableVersions()
|
||||
}
|
||||
|
||||
const startPrediction = async () => {
|
||||
if (!form.product_id) {
|
||||
ElMessage.error('请选择目标药品')
|
||||
return
|
||||
}
|
||||
if (!form.model_type) {
|
||||
ElMessage.error('请选择算法类型')
|
||||
return
|
||||
}
|
||||
|
||||
const startPrediction = async (model) => {
|
||||
predicting[model.model_id] = true
|
||||
try {
|
||||
predicting.value = true
|
||||
const payload = {
|
||||
model_type: form.model_type,
|
||||
version: form.version,
|
||||
product_id: model.product_id,
|
||||
model_type: model.model_type,
|
||||
version: model.version,
|
||||
future_days: form.future_days,
|
||||
history_lookback_days: form.history_lookback_days,
|
||||
start_date: form.start_date,
|
||||
analyze_result: form.analyze_result,
|
||||
product_id: form.product_id
|
||||
include_visualization: true, // 分析功能硬编码为开启
|
||||
}
|
||||
const response = await axios.post('/api/prediction', payload)
|
||||
if (response.data.status === 'success') {
|
||||
predictionResult.value = response.data.data
|
||||
ElMessage.success('预测完成!')
|
||||
dialogVisible.value = true
|
||||
await nextTick()
|
||||
renderChart()
|
||||
} else {
|
||||
ElMessage.error(response.data.message || '预测失败')
|
||||
ElMessage.error(response.data.error || '预测失败')
|
||||
}
|
||||
} catch (error) {
|
||||
ElMessage.error('预测请求失败')
|
||||
ElMessage.error(error.response?.data?.error || '预测请求失败')
|
||||
} finally {
|
||||
predicting.value = false
|
||||
predicting[model.model_id] = false
|
||||
}
|
||||
}
|
||||
|
||||
@ -240,28 +206,113 @@ const renderChart = () => {
|
||||
if (chart) {
|
||||
chart.destroy()
|
||||
}
|
||||
const predictions = predictionResult.value.predictions
|
||||
const labels = predictions.map(p => new Date(p.date).toLocaleDateString('zh-CN', { weekday: 'short', year: 'numeric', month: 'long', day: 'numeric' }))
|
||||
const data = predictions.map(p => p.sales)
|
||||
|
||||
const formatDate = (date) => new Date(date).toISOString().split('T')[0];
|
||||
|
||||
const historyData = (predictionResult.value.history_data || []).map(p => ({ ...p, date: formatDate(p.date) }));
|
||||
const predictionData = (predictionResult.value.prediction_data || []).map(p => ({ ...p, date: formatDate(p.date) }));
|
||||
|
||||
if (historyData.length === 0 && predictionData.length === 0) {
|
||||
ElMessage.warning('没有可用于图表的数据。')
|
||||
return
|
||||
}
|
||||
|
||||
const allLabels = [...new Set([...historyData.map(p => p.date), ...predictionData.map(p => p.date)])].sort()
|
||||
const simplifiedLabels = allLabels.map(date => date.split('-').slice(1).join('/'));
|
||||
|
||||
const historyMap = new Map(historyData.map(p => [p.date, p.sales]))
|
||||
const predictionMap = new Map(predictionData.map(p => [p.date, p.predicted_sales]))
|
||||
|
||||
const alignedHistorySales = allLabels.map(label => historyMap.get(label) ?? null)
|
||||
const alignedPredictionSales = allLabels.map(label => predictionMap.get(label) ?? null)
|
||||
|
||||
if (historyData.length > 0 && predictionData.length > 0) {
|
||||
const lastHistoryDate = historyData[historyData.length - 1].date
|
||||
const lastHistoryValue = historyData[historyData.length - 1].sales
|
||||
if (!predictionMap.has(lastHistoryDate)) {
|
||||
alignedPredictionSales[allLabels.indexOf(lastHistoryDate)] = lastHistoryValue
|
||||
}
|
||||
}
|
||||
|
||||
let subtitleText = '';
|
||||
if (historyData.length > 0) {
|
||||
subtitleText += `历史数据: ${historyData[0].date} ~ ${historyData[historyData.length - 1].date}`;
|
||||
}
|
||||
if (predictionData.length > 0) {
|
||||
if (subtitleText) subtitleText += ' | ';
|
||||
subtitleText += `预测数据: ${predictionData[0].date} ~ ${predictionData[predictionData.length - 1].date}`;
|
||||
}
|
||||
|
||||
chart = new Chart(chartCanvas.value, {
|
||||
type: 'line',
|
||||
data: {
|
||||
labels,
|
||||
datasets: [{
|
||||
label: '预测销量',
|
||||
data,
|
||||
borderColor: '#409EFF',
|
||||
backgroundColor: 'rgba(64, 158, 255, 0.1)',
|
||||
tension: 0.4,
|
||||
fill: true
|
||||
}]
|
||||
labels: simplifiedLabels,
|
||||
datasets: [
|
||||
{
|
||||
label: '历史销量',
|
||||
data: alignedHistorySales,
|
||||
borderColor: '#67C23A',
|
||||
backgroundColor: 'rgba(103, 194, 58, 0.2)',
|
||||
tension: 0.4,
|
||||
fill: true,
|
||||
spanGaps: false,
|
||||
},
|
||||
{
|
||||
label: '预测销量',
|
||||
data: alignedPredictionSales,
|
||||
borderColor: '#409EFF',
|
||||
backgroundColor: 'rgba(64, 158, 255, 0.2)',
|
||||
tension: 0.4,
|
||||
fill: true,
|
||||
borderDash: [5, 5],
|
||||
}
|
||||
]
|
||||
},
|
||||
options: {
|
||||
responsive: true,
|
||||
maintainAspectRatio: false,
|
||||
plugins: {
|
||||
title: {
|
||||
display: true,
|
||||
text: '销量预测趋势图'
|
||||
text: `${predictionResult.value.product_name} - 销量预测趋势图`,
|
||||
color: '#303133',
|
||||
font: {
|
||||
size: 20,
|
||||
weight: 'bold',
|
||||
}
|
||||
},
|
||||
subtitle: {
|
||||
display: true,
|
||||
text: subtitleText,
|
||||
color: '#606266',
|
||||
font: {
|
||||
size: 14,
|
||||
},
|
||||
padding: {
|
||||
bottom: 20
|
||||
}
|
||||
}
|
||||
},
|
||||
scales: {
|
||||
x: {
|
||||
title: {
|
||||
display: true,
|
||||
text: '日期'
|
||||
},
|
||||
grid: {
|
||||
display: false
|
||||
}
|
||||
},
|
||||
y: {
|
||||
title: {
|
||||
display: true,
|
||||
text: '销量'
|
||||
},
|
||||
grid: {
|
||||
color: '#e9e9e9',
|
||||
drawBorder: false,
|
||||
},
|
||||
beginAtZero: true
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -269,14 +320,11 @@ const renderChart = () => {
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
fetchModels()
|
||||
fetchModelTypes()
|
||||
const today = new Date()
|
||||
form.start_date = today.toISOString().split('T')[0]
|
||||
})
|
||||
|
||||
watch([() => form.product_id, () => form.model_type], () => {
|
||||
fetchAvailableVersions()
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
@ -288,15 +336,11 @@ watch([() => form.product_id, () => form.model_type], () => {
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
}
|
||||
.model-selection-section h4 {
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
.prediction-actions {
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
.filters-section, .global-settings-section, .model-list-section {
|
||||
margin-top: 20px;
|
||||
padding-top: 20px;
|
||||
border-top: 1px solid #ebeef5;
|
||||
}
|
||||
.filters-section h4, .global-settings-section h4, .model-list-section h4 {
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
.prediction-chart {
|
||||
margin-top: 20px;
|
||||
|
@ -44,7 +44,7 @@
|
||||
</el-row>
|
||||
|
||||
<el-row :gutter="20" v-if="form.model_type">
|
||||
<el-col :span="6">
|
||||
<el-col :span="5">
|
||||
<el-form-item label="模型版本">
|
||||
<el-select
|
||||
v-model="form.version"
|
||||
@ -62,7 +62,7 @@
|
||||
</el-select>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
<el-col :span="6">
|
||||
<el-col :span="5">
|
||||
<el-form-item label="预测天数">
|
||||
<el-input-number
|
||||
v-model="form.future_days"
|
||||
@ -72,7 +72,17 @@
|
||||
/>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
<el-col :span="6">
|
||||
<el-col :span="5">
|
||||
<el-form-item label="历史天数">
|
||||
<el-input-number
|
||||
v-model="form.history_lookback_days"
|
||||
:min="7"
|
||||
:max="365"
|
||||
style="width: 100%"
|
||||
/>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
<el-col :span="5">
|
||||
<el-form-item label="起始日期">
|
||||
<el-date-picker
|
||||
v-model="form.start_date"
|
||||
@ -85,7 +95,7 @@
|
||||
/>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
<el-col :span="6">
|
||||
<el-col :span="4">
|
||||
<el-form-item label="预测分析">
|
||||
<el-switch
|
||||
v-model="form.analyze_result"
|
||||
@ -147,6 +157,7 @@ const form = reactive({
|
||||
model_type: '',
|
||||
version: '',
|
||||
future_days: 7,
|
||||
history_lookback_days: 30,
|
||||
start_date: '',
|
||||
analyze_result: true
|
||||
})
|
||||
@ -200,25 +211,17 @@ const handleModelTypeChange = () => {
|
||||
}
|
||||
|
||||
const startPrediction = async () => {
|
||||
if (!form.store_id) {
|
||||
ElMessage.error('请选择目标店铺')
|
||||
return
|
||||
}
|
||||
if (!form.model_type) {
|
||||
ElMessage.error('请选择算法类型')
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
predicting.value = true
|
||||
const payload = {
|
||||
training_mode: form.training_mode,
|
||||
store_id: form.store_id,
|
||||
model_type: form.model_type,
|
||||
version: form.version,
|
||||
future_days: form.future_days,
|
||||
history_lookback_days: form.history_lookback_days,
|
||||
start_date: form.start_date,
|
||||
analyze_result: form.analyze_result,
|
||||
store_id: form.store_id,
|
||||
training_mode: form.training_mode
|
||||
}
|
||||
const response = await axios.post('/api/prediction', payload)
|
||||
if (response.data.status === 'success') {
|
||||
@ -241,28 +244,113 @@ const renderChart = () => {
|
||||
if (chart) {
|
||||
chart.destroy()
|
||||
}
|
||||
const predictions = predictionResult.value.predictions
|
||||
const labels = predictions.map(p => new Date(p.date).toLocaleDateString('zh-CN', { weekday: 'short', year: 'numeric', month: 'long', day: 'numeric' }))
|
||||
const data = predictions.map(p => p.sales)
|
||||
|
||||
const formatDate = (date) => new Date(date).toISOString().split('T')[0];
|
||||
|
||||
const historyData = (predictionResult.value.history_data || []).map(p => ({ ...p, date: formatDate(p.date) }));
|
||||
const predictionData = (predictionResult.value.prediction_data || []).map(p => ({ ...p, date: formatDate(p.date) }));
|
||||
|
||||
if (historyData.length === 0 && predictionData.length === 0) {
|
||||
ElMessage.warning('没有可用于图表的数据。')
|
||||
return
|
||||
}
|
||||
|
||||
const allLabels = [...new Set([...historyData.map(p => p.date), ...predictionData.map(p => p.date)])].sort()
|
||||
const simplifiedLabels = allLabels.map(date => date.split('-')[2]);
|
||||
|
||||
const historyMap = new Map(historyData.map(p => [p.date, p.sales]))
|
||||
const predictionMap = new Map(predictionData.map(p => [p.date, p.predicted_sales]))
|
||||
|
||||
const alignedHistorySales = allLabels.map(label => historyMap.get(label) ?? null)
|
||||
const alignedPredictionSales = allLabels.map(label => predictionMap.get(label) ?? null)
|
||||
|
||||
if (historyData.length > 0 && predictionData.length > 0) {
|
||||
const lastHistoryDate = historyData[historyData.length - 1].date
|
||||
const lastHistoryValue = historyData[historyData.length - 1].sales
|
||||
if (!predictionMap.has(lastHistoryDate)) {
|
||||
alignedPredictionSales[allLabels.indexOf(lastHistoryDate)] = lastHistoryValue
|
||||
}
|
||||
}
|
||||
|
||||
let subtitleText = '';
|
||||
if (historyData.length > 0) {
|
||||
subtitleText += `历史数据: ${historyData[0].date} ~ ${historyData[historyData.length - 1].date}`;
|
||||
}
|
||||
if (predictionData.length > 0) {
|
||||
if (subtitleText) subtitleText += ' | ';
|
||||
subtitleText += `预测数据: ${predictionData[0].date} ~ ${predictionData[predictionData.length - 1].date}`;
|
||||
}
|
||||
|
||||
chart = new Chart(chartCanvas.value, {
|
||||
type: 'line',
|
||||
data: {
|
||||
labels,
|
||||
datasets: [{
|
||||
label: '预测销量',
|
||||
data,
|
||||
borderColor: '#409EFF',
|
||||
backgroundColor: 'rgba(64, 158, 255, 0.1)',
|
||||
tension: 0.4,
|
||||
fill: true
|
||||
}]
|
||||
labels: simplifiedLabels,
|
||||
datasets: [
|
||||
{
|
||||
label: '历史销量',
|
||||
data: alignedHistorySales,
|
||||
borderColor: '#67C23A',
|
||||
backgroundColor: 'rgba(103, 194, 58, 0.2)',
|
||||
tension: 0.4,
|
||||
fill: true,
|
||||
spanGaps: false,
|
||||
},
|
||||
{
|
||||
label: '预测销量',
|
||||
data: alignedPredictionSales,
|
||||
borderColor: '#409EFF',
|
||||
backgroundColor: 'rgba(64, 158, 255, 0.2)',
|
||||
tension: 0.4,
|
||||
fill: true,
|
||||
borderDash: [5, 5],
|
||||
}
|
||||
]
|
||||
},
|
||||
options: {
|
||||
responsive: true,
|
||||
maintainAspectRatio: false,
|
||||
plugins: {
|
||||
title: {
|
||||
display: true,
|
||||
text: '销量预测趋势图'
|
||||
text: `${predictionResult.value.product_name} - 销量预测趋势图`,
|
||||
color: '#ffffff',
|
||||
font: {
|
||||
size: 20,
|
||||
weight: 'bold',
|
||||
}
|
||||
},
|
||||
subtitle: {
|
||||
display: true,
|
||||
text: subtitleText,
|
||||
color: '#6c757d',
|
||||
font: {
|
||||
size: 14,
|
||||
},
|
||||
padding: {
|
||||
bottom: 20
|
||||
}
|
||||
}
|
||||
},
|
||||
scales: {
|
||||
x: {
|
||||
title: {
|
||||
display: true,
|
||||
text: '日期 (日)'
|
||||
},
|
||||
grid: {
|
||||
display: false
|
||||
}
|
||||
},
|
||||
y: {
|
||||
title: {
|
||||
display: true,
|
||||
text: '销量'
|
||||
},
|
||||
grid: {
|
||||
color: '#e9e9e9',
|
||||
drawBorder: false,
|
||||
},
|
||||
beginAtZero: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -244,12 +244,7 @@
|
||||
prop="version"
|
||||
label="版本"
|
||||
width="80"
|
||||
>
|
||||
<template #default="{ row }">
|
||||
<el-tag v-if="row.version" type="primary" size="small">v{{ row.version }}</el-tag>
|
||||
<span v-else>-</span>
|
||||
</template>
|
||||
</el-table-column>
|
||||
/>
|
||||
<el-table-column prop="status" label="状态" width="100">
|
||||
<template #default="{ row }">
|
||||
<el-tag :type="statusTag(row.status)">
|
||||
@ -271,11 +266,11 @@
|
||||
<div v-if="row.status === 'completed'">
|
||||
<h4>评估指标</h4>
|
||||
<pre>{{ JSON.stringify(row.metrics, null, 2) }}</pre>
|
||||
<!-- <div v-if="row.version">
|
||||
<div v-if="row.version">
|
||||
<h4>版本信息</h4>
|
||||
<p><strong>版本:</strong> {{ row.version }}</p>
|
||||
<p><strong>模型路径:</strong> {{ row.model_path }}</p>
|
||||
</div> -->
|
||||
</div>
|
||||
</div>
|
||||
<div v-if="row.status === 'failed'">
|
||||
<h4>错误信息</h4>
|
||||
|
@ -213,12 +213,7 @@
|
||||
prop="version"
|
||||
label="版本"
|
||||
width="80"
|
||||
>
|
||||
<template #default="{ row }">
|
||||
<el-tag v-if="row.version" type="primary" size="small">v{{ row.version }}</el-tag>
|
||||
<span v-else>-</span>
|
||||
</template>
|
||||
</el-table-column>
|
||||
/>
|
||||
<el-table-column prop="status" label="状态" width="100">
|
||||
<template #default="{ row }">
|
||||
<el-tag :type="statusTag(row.status)">
|
||||
@ -240,11 +235,11 @@
|
||||
<div v-if="row.status === 'completed'">
|
||||
<h4>评估指标</h4>
|
||||
<pre>{{ JSON.stringify(row.metrics, null, 2) }}</pre>
|
||||
<!-- <div v-if="row.version">
|
||||
<div v-if="row.version">
|
||||
<h4>版本信息</h4>
|
||||
<p><strong>版本:</strong> {{ row.version }}</p>
|
||||
<p><strong>模型路径:</strong> {{ row.model_path }}</p>
|
||||
</div> -->
|
||||
</div>
|
||||
</div>
|
||||
<div v-if="row.status === 'failed'">
|
||||
<h4>错误信息</h4>
|
||||
@ -433,8 +428,8 @@ const initWebSocket = () => {
|
||||
};
|
||||
}
|
||||
|
||||
// 刷新任务列表 (注释掉,因为WebSocket已经提供了最新数据)
|
||||
// fetchTrainingTasks();
|
||||
// 刷新任务列表
|
||||
fetchTrainingTasks();
|
||||
});
|
||||
|
||||
socket.on("disconnect", () => {
|
||||
@ -568,7 +563,7 @@ const startTraining = async () => {
|
||||
product_id: form.product_id,
|
||||
store_id: form.data_scope === 'global' ? null : form.store_id,
|
||||
model_type: form.model_type,
|
||||
version: response.data.path_info?.version || response.data.new_version || "v1",
|
||||
version: response.data.new_version || "v1",
|
||||
status: "starting",
|
||||
progress: 0,
|
||||
message: "正在启动药品训练...",
|
||||
|
@ -228,12 +228,7 @@
|
||||
prop="version"
|
||||
label="版本"
|
||||
width="80"
|
||||
>
|
||||
<template #default="{ row }">
|
||||
<el-tag v-if="row.version" type="primary" size="small">v{{ row.version }}</el-tag>
|
||||
<span v-else>-</span>
|
||||
</template>
|
||||
</el-table-column>
|
||||
/>
|
||||
<el-table-column prop="status" label="状态" width="100">
|
||||
<template #default="{ row }">
|
||||
<el-tag :type="statusTag(row.status)">
|
||||
@ -255,11 +250,11 @@
|
||||
<div v-if="row.status === 'completed'">
|
||||
<h4>评估指标</h4>
|
||||
<pre>{{ JSON.stringify(row.metrics, null, 2) }}</pre>
|
||||
<!-- <div v-if="row.version">
|
||||
<div v-if="row.version">
|
||||
<h4>版本信息</h4>
|
||||
<p><strong>版本:</strong> {{ row.version }}</p>
|
||||
<p><strong>模型路径:</strong> {{ row.model_path }}</p>
|
||||
</div> -->
|
||||
</div>
|
||||
</div>
|
||||
<div v-if="row.status === 'failed'">
|
||||
<h4>错误信息</h4>
|
||||
|
101
feature_branch_workflow.md
Normal file
101
feature_branch_workflow.md
Normal file
@ -0,0 +1,101 @@
|
||||
# 功能分支开发与合并标准流程
|
||||
|
||||
本文档旨在说明一个标准、安全的功能开发流程,涵盖从创建分支到最终合并的完整步骤。
|
||||
|
||||
## 流程概述
|
||||
|
||||
1. **创建功能分支**:基于主开发分支(如 `lyf-dev`)在远程仓库创建一个新的功能分支(如 `lyf-dev-req0001`)。
|
||||
2. **同步到本地**:将远程的新分支同步到本地,并切换到该分支进行开发。
|
||||
3. **开发与提交**:在功能分支上进行代码开发,并频繁提交改动。
|
||||
4. **推送到远程**:定期将本地的提交推送到远程功能分支,用于备份和协作。
|
||||
5. **合并回主分支**:当功能开发和测试完成后,将功能分支合并回主开发分支。
|
||||
|
||||
---
|
||||
|
||||
## 详细操作步骤
|
||||
|
||||
### 第一步:同步并切换到功能分支
|
||||
|
||||
当远程仓库已经创建了新的功能分支后(例如 `lyf-dev-req0001`),本地需要执行以下命令来同步和切换。
|
||||
|
||||
1. **获取远程最新信息**:
|
||||
```bash
|
||||
git fetch
|
||||
```
|
||||
这个命令会拉取远程仓库的所有最新信息,包括新建的分支。
|
||||
|
||||
2. **创建并切换到本地分支**:
|
||||
```bash
|
||||
git checkout lyf-dev-req0001
|
||||
```
|
||||
Git 会自动检测到远程存在一个同名分支,并为您创建一个本地分支来跟踪它。
|
||||
|
||||
### 第二步:在功能分支上开发和提交
|
||||
|
||||
现在您可以在 `lyf-dev-req0001` 分支上安全地进行开发。
|
||||
|
||||
1. **进行代码修改**:添加、修改或删除文件以实现新功能。
|
||||
|
||||
2. **提交代码改动**:
|
||||
```bash
|
||||
# 添加所有修改过的文件到暂存区
|
||||
git add .
|
||||
|
||||
# 提交改动到本地仓库,并附上有意义的说明
|
||||
git commit -m "feat: 完成用户认证模块"
|
||||
```
|
||||
> **最佳实践**:保持提交的粒度小且描述清晰,方便代码审查和问题回溯。
|
||||
|
||||
### 第三步:推送功能分支到远程
|
||||
|
||||
为了备份代码和进行团队协作,需要将本地的提交推送到远程仓库。
|
||||
|
||||
```bash
|
||||
# 将当前分支 (lyf-dev-req0001) 的提交推送到远程同名分支
|
||||
git push origin lyf-dev-req0001
|
||||
```
|
||||
|
||||
### 第四步:合并功能到主开发分支 (`lyf-dev`)
|
||||
|
||||
当功能开发完毕并通过测试后,就可以准备将其合并回 `lyf-dev` 分支。
|
||||
|
||||
1. **切换到主开发分支**:
|
||||
```bash
|
||||
git checkout lyf-dev
|
||||
```
|
||||
|
||||
2. **确保主开发分支是最新版本**:
|
||||
在合并前,务必先拉取远程 `lyf-dev` 的最新代码,以减少冲突的可能性。
|
||||
```bash
|
||||
git pull origin lyf-dev
|
||||
```
|
||||
|
||||
3. **合并功能分支**:
|
||||
将 `lyf-dev-req0001` 的所有改动合并到当前的 `lyf-dev` 分支。
|
||||
```bash
|
||||
git merge lyf-dev-req0001
|
||||
```
|
||||
* **如果出现冲突 (Conflict)**:Git 会提示您哪些文件存在冲突。您需要手动打开这些文件,解决冲突部分,然后再次执行 `git add .` 和 `git commit` 来完成合并提交。
|
||||
* **如果没有冲突**:Git 会自动创建一个合并提交。
|
||||
|
||||
4. **将合并后的主分支推送到远程**:
|
||||
```bash
|
||||
git push origin lyf-dev
|
||||
```
|
||||
|
||||
### 第五步:清理(可选)
|
||||
|
||||
当功能分支确认不再需要后,可以删除它以保持仓库整洁。
|
||||
|
||||
1. **删除远程分支**:
|
||||
```bash
|
||||
git push origin --delete lyf-dev-req0001
|
||||
```
|
||||
|
||||
2. **删除本地分支**:
|
||||
```bash
|
||||
git branch -d lyf-dev-req0001
|
||||
```
|
||||
|
||||
---
|
||||
遵循以上流程可以确保您的开发工作流程清晰、安全且高效。
|
326
lyf开发日志记录文档.md
Normal file
326
lyf开发日志记录文档.md
Normal file
@ -0,0 +1,326 @@
|
||||
# 开发日志记录
|
||||
|
||||
本文档记录了项目开发过程中的主要修改、问题修复和重要决策。
|
||||
|
||||
---
|
||||
|
||||
## 2025-07-13:早期后端修复与重构
|
||||
**开发者**: lyf
|
||||
|
||||
### 13:30 - 修复数据加载路径问题
|
||||
- **任务目标**: 解决模型训练时因数据文件路径错误导致的数据加载失败问题。
|
||||
- **核心问题**: `server/core/predictor.py` 中的 `PharmacyPredictor` 类初始化时,硬编码了错误的默认数据文件路径。
|
||||
- **修复方案**: 将默认数据路径更正为 `'data/timeseries_training_data_sample_10s50p.parquet'`,并同步更新了所有训练器。
|
||||
|
||||
### 14:00 - 数据流重构
|
||||
- **任务目标**: 解决因数据处理流程中断导致关键特征丢失,从而引发模型训练失败的根本问题。
|
||||
- **核心问题**: `predictor.py` 未将预处理好的数据向下传递,导致各训练器重复加载并错误处理数据。
|
||||
- **修复方案**: 重构了核心数据流,确保数据在 `predictor.py` 中被统一加载和预处理,然后作为一个DataFrame显式传递给所有下游的训练器函数。
|
||||
|
||||
---
|
||||
|
||||
## 2025-07-14:模型训练与并发问题集中攻坚
|
||||
**开发者**: lyf
|
||||
|
||||
### 10:16 - 修复训练器层 `KeyError`
|
||||
- **问题**: 所有模型训练均因 `KeyError: "['sales', 'price'] not in index"` 失败。
|
||||
- **分析**: 训练器硬编码的特征列表中包含了数据源中不存在的 `'price'` 列。
|
||||
- **修复**: 从所有四个训练器 (`mlstm`, `transformer`, `tcn`, `kan`) 的 `features` 列表中移除了对不存在的 `'price'` 列的依赖。
|
||||
|
||||
### 10:38 - 修复数据标准化层 `KeyError`
|
||||
- **问题**: 修复后出现新错误 `KeyError: "['sales'] not in index"`。
|
||||
- **分析**: `server/utils/multi_store_data_utils.py` 中的 `standardize_column_names` 函数列名映射错误,且缺少最终列选择机制。
|
||||
- **修复**: 修正了列名映射,并增加了列选择机制,确保函数返回的 `DataFrame` 结构统一且包含 `sales` 列。
|
||||
|
||||
### 11:04 - 修复JSON序列化失败问题
|
||||
- **问题**: 训练完成后,因 `Object of type float32 is not JSON serializable` 导致前后端通信失败。
|
||||
- **分析**: 训练产生的评估指标是NumPy的 `float32` 类型,无法被标准 `json` 库序列化。
|
||||
- **修复**: 在 `server/utils/training_process_manager.py` 中增加了 `convert_numpy_types` 辅助函数,在通过WebSocket或API返回数据前,将所有NumPy数值类型转换为Python原生类型,从根源上解决了所有序列化问题。
|
||||
|
||||
### 11:15 - 修复MAPE计算错误
|
||||
- **问题**: 训练日志显示 `MAPE: nan%` 并伴有 `RuntimeWarning: Mean of empty slice.`。
|
||||
- **分析**: 当测试集中的所有真实值都为0时,计算MAPE会导致对空数组求平均值。
|
||||
- **修复**: 在 `server/analysis/metrics.py` 中增加条件判断,若不存在非零真实值,则直接将MAPE设为0。
|
||||
|
||||
### 11:41 - 修复“按店铺训练”页面列表加载失败
|
||||
- **问题**: “选择店铺”的下拉列表为空。
|
||||
- **分析**: `standardize_column_names` 函数错误地移除了包括店铺元数据在内的非训练必需列。
|
||||
- **修复**: 将列筛选的逻辑从通用的 `standardize_column_names` 函数中移出,精确地应用到仅为模型训练准备数据的函数中。
|
||||
|
||||
### 13:00 - 修复“按店铺训练-所有药品”模式
|
||||
- **问题**: 选择“所有药品”训练时,因 `product_id` 被错误地处理为字符串 `"unknown"` 而失败。
|
||||
- **修复**: 在 `server/core/predictor.py` 中拦截 `"unknown"` ID,并将其意图正确地转换为“聚合此店铺的所有产品数据”。同时扩展了 `aggregate_multi_store_data` 函数,使其支持按店铺ID进行聚合。
|
||||
|
||||
### 14:19 - 修复并发训练中的稳定性问题
|
||||
- **问题**: 并发训练时出现 `API列表排序错误` 和 `WebSocket连接错误`。
|
||||
- **修复**:
|
||||
1. **排序**: 在 `api.py` 中为 `None` 类型的 `start_time` 提供了默认值,解决了 `TypeError`。
|
||||
2. **连接**: 在 `socketio.run()` 调用时增加了 `allow_unsafe_werkzeug=True` 参数,解决了调试模式下Socket.IO与Werkzeug的冲突。
|
||||
|
||||
### 15:30 - 根治模型训练中的维度不匹配问题
|
||||
- **问题**: 所有模型训练完成后,评估指标 `R²` 始终为0.0。
|
||||
- **根本原因**: `server/utils/data_utils.py` 的 `create_dataset` 函数在创建目标数据集 `dataY` 时,错误地保留了一个多余的维度。同时,模型文件 (`mlstm_model.py`, `transformer_model.py`) 的输出也存在维度问题。
|
||||
- **最终修复**:
|
||||
1. **数据层**: 在 `create_dataset` 中使用 `.flatten()` 修正了 `y` 标签的维度。
|
||||
2. **模型层**: 在所有模型的 `forward` 方法最后增加了 `.squeeze(-1)`,确保模型输出维度正确。
|
||||
3. **训练器层**: 撤销了所有为解决此问题而做的临时性维度调整,恢复了最直接的损失计算。
|
||||
|
||||
### 16:10 - 修复“全局模型训练-所有药品”模式
|
||||
- **问题**: 与“按店铺训练”类似,全局训练的“所有药品”模式也因 `product_id="unknown"` 而失败。
|
||||
- **修复**: 采用了与店铺训练完全相同的修复模式。在 `predictor.py` 中拦截 `"unknown"` 并将其意图转换为真正的全局聚合(`product_id=None`),并扩展 `aggregate_multi_store_data` 函数以支持此功能。
|
||||
|
||||
---
|
||||
|
||||
## 2025-07-15:端到端修复“按药品预测”图表功能
|
||||
**开发者**: lyf
|
||||
|
||||
### 10:00 - 阶段一:修复数据库写入失败 (`sqlite3.IntegrityError`)
|
||||
- **问题**: 后端日志显示 `datatype mismatch`。
|
||||
- **分析**: `save_prediction_result` 函数试图将复杂Python对象直接存入数据库。
|
||||
- **修复**: 在 `server/api.py` 中,执行数据库插入前,使用 `json.dumps()` 将复杂对象序列化为JSON字符串。
|
||||
|
||||
### 10:30 - 阶段二:修复API响应结构与前端不匹配
|
||||
- **问题**: 图表依然无法渲染。
|
||||
- **分析**: 前端期望 `history_data` 在顶层,而后端将其封装在 `data` 子对象中。
|
||||
- **修复**: 修改 `server/api.py` 的 `predict` 函数,将关键数据提升到响应的根级别。
|
||||
|
||||
### 11:00 - 阶段三:修复历史数据与预测数据时间不连续
|
||||
- **问题**: 图表数据在时间上完全脱节。
|
||||
- **分析**: 获取历史数据的逻辑总是取整个数据集的最后30条,而非预测起始日期之前的30条。
|
||||
- **修复**: 在 `server/api.py` 中增加了正确的日期筛选逻辑。
|
||||
|
||||
### 14:00 - 阶段四:重构数据源,根治数据不一致问题
|
||||
- **问题**: 历史数据(绿线)与预测数据(蓝线)的口径完全不同。
|
||||
- **根本原因**: API层独立加载**原始数据**画图,而预测器使用**聚合后数据**预测。
|
||||
- **修复 (重构)**:
|
||||
1. 修改 `server/predictors/model_predictor.py`,使其返回预测结果的同时,也返回其所使用的、口径一致的历史数据。
|
||||
2. 彻底删除了 `server/api.py` 中所有独立加载历史数据的冗余代码,确保了数据源的唯一性。
|
||||
|
||||
### 15:00 - 阶段五:修复图表X轴日期格式问题
|
||||
- **问题**: X轴显示为混乱的GMT格式时间戳。
|
||||
- **分析**: `history_data` 中的 `Timestamp` 对象未被正确格式化。
|
||||
- **修复**: 在 `server/api.py` 中,为 `history_data` 增加了 `.strftime('%Y-%m-%d')` 的格式化处理。
|
||||
|
||||
### 16:00 - 阶段六:修复模型“学不会”的根本原因 (超参数传递中断)
|
||||
- **问题**: 即便流程正确,所有模型的预测结果依然是无法学习的直线。
|
||||
- **根本原因**: `server/core/predictor.py` 在调用训练器时,**没有将 `sequence_length` 等关键超参数传递下去**,导致所有模型都使用了错误的默认值。
|
||||
- **修复**:
|
||||
1. 修改 `server/core/predictor.py`,在调用中加入超参数的传递。
|
||||
2. 修改所有四个训练器文件,使其能接收并使用这些参数。
|
||||
|
||||
---
|
||||
|
||||
## 2025-07-16:最终验证与项目总结
|
||||
**开发者**: lyf
|
||||
|
||||
### 10:00 - 阶段七:最终验证与结论
|
||||
- **问题**: 在修复所有代码问题后,对特定日期的预测结果依然是平线。
|
||||
- **分析**: 通过编写临时数据分析脚本 (`temp_check_parquet.py`) 最终确认,这是**数据本身**的问题。我们选择的预测日期在样本数据集中恰好处于一个“零销量”的空白期。
|
||||
- **最终结论**: 系统代码已完全修复。图表上显示的平线,是模型对“零销量”历史做出的**正确且符合逻辑**的反应。
|
||||
|
||||
### 11:45 - 项目总结与文档归档
|
||||
- **任务**: 根据用户要求,回顾整个调试过程,将所有问题、解决方案、优化思路和最终结论,按照日期和时间顺序,整理并更新到本开发日志中,形成一份高质量的技术档案。
|
||||
- **结果**: 本文档已更新完成。
|
||||
|
||||
|
||||
### 13:15 - 最终修复:根治模型标识符不一致问题
|
||||
- **问题**: 经过再次测试和日志分析,发现即便是修正后,店铺模型的 `model_identifier` 在训练时依然被错误地构建为 `01010023_store_01010023`。
|
||||
- **根本原因**: `server/core/predictor.py` 的 `train_model` 方法中,在 `training_mode == 'store'` 的分支下,构建 `model_identifier` 的逻辑存在冗余和错误。
|
||||
- **最终解决方案**: 删除了错误的拼接逻辑 `model_identifier = f"{store_id}_{product_id}"`,直接使用在之前步骤中已经被正确赋值为 `f"store_{store_id}"` 的 `product_id` 变量作为 `model_identifier`。这确保了从训练、保存到最终API查询,店铺模型的唯一标识符始终保持一致。
|
||||
|
||||
|
||||
### 13:30 - 最终修复(第二轮):根治模型保存路径错误
|
||||
- **问题**: 即便修复了标识符,模型版本依然无法加载。
|
||||
- **根本原因**: 通过分析训练日志,发现所有训练器(`transformer_trainer.py`, `mlstm_trainer.py`, `tcn_trainer.py`)中的 `save_checkpoint` 函数,都会强制在 `saved_models` 目录下创建一个 `checkpoints` 子目录,并将所有模型文件保存在其中。而负责查找模型的 `get_model_versions` 函数只在根目录查找,导致模型永远无法被发现。
|
||||
- **最终解决方案**: 逐一修改了所有相关训练器文件中的 `save_checkpoint` 函数,移除了创建和使用 `checkpoints` 子目录的逻辑,确保所有模型都直接保存在 `saved_models` 根目录下。
|
||||
- **结论**: 至此,模型保存的路径与查找的路径完全统一,从根本上解决了模型版本无法加载的问题。
|
||||
|
||||
|
||||
### 13:40 - 最终修复(第三轮):统一所有训练器的模型保存逻辑
|
||||
- **问题**: 在修复了 `transformer_trainer.py` 后,发现 `mlstm_trainer.py` 和 `tcn_trainer.py` 存在完全相同的路径和命名错误,导致问题依旧。
|
||||
- **根本原因**: `save_checkpoint` 函数在所有训练器中都被错误地实现,它们都强制创建了 `checkpoints` 子目录,并使用了错误的逻辑来拼接文件名。
|
||||
- **最终解决方案**:
|
||||
1. **逐一修复**: 逐一修改了 `transformer_trainer.py`, `mlstm_trainer.py`, 和 `tcn_trainer.py` 中的 `save_checkpoint` 函数。
|
||||
2. **路径修复**: 移除了创建和使用 `checkpoints` 子目录的逻辑,确保模型直接保存在 `model_dir` (即 `saved_models`) 的根目录下。
|
||||
3. **文件名修复**: 简化并修正了文件名的生成逻辑,直接使用 `product_id` 参数作为唯一标识符(该参数已由上游逻辑正确赋值为 `药品ID` 或 `store_{店铺ID}`),不再进行任何额外的、错误的拼接。
|
||||
- **结论**: 至此,所有训练器的模型保存逻辑完全统一,模型保存的路径和文件名与API的查找逻辑完全匹配,从根本上解决了模型版本无法加载的问题。
|
||||
|
||||
|
||||
---
|
||||
|
||||
## 2025-07-16 (续):端到端修复“店铺预测”图表功能
|
||||
**开发者**: lyf
|
||||
|
||||
### 15:30 - 最终修复(第四轮):打通店铺预测的数据流
|
||||
- **问题**: 在解决了模型加载问题后,“店铺预测”功能虽然可以成功执行,但前端图表依然空白,不显示历史数据和预测数据。
|
||||
- **根本原因**: 参数传递在调用链中出现断裂。
|
||||
1. `server/api.py` 在调用 `run_prediction` 时,没有传递 `training_mode`。
|
||||
2. `server/core/predictor.py` 在调用 `load_model_and_predict` 时,没有传递 `store_id` 和 `training_mode`。
|
||||
3. `server/predictors/model_predictor.py` 内部的数据加载逻辑,在处理店铺预测时,错误地使用了模型标识符(`store_{id}`)作为产品ID来过滤数据,导致无法加载到任何历史数据。
|
||||
- **最终解决方案 (三步修复)**:
|
||||
1. **修复 `model_predictor.py`**: 修改 `load_model_and_predict` 函数,使其能够根据 `training_mode` 参数智能地加载数据。当模式为 `'store'` 时,它会正确地聚合该店铺的所有销售数据作为历史数据,这与训练时的数据准备方式完全一致。
|
||||
2. **修复 `predictor.py`**: 修改 `predict` 方法,将 `store_id` 和 `training_mode` 参数正确地传递给底层的 `load_model_and_predict` 函数。
|
||||
3. **修复 `api.py`**: 修改 `predict` 路由和 `run_prediction` 辅助函数,确保 `training_mode` 参数在整个调用链中被完整传递。
|
||||
- **结论**: 通过以上修复,我们确保了从API接口到最底层数据加载器的参数传递是完整和正确的。现在,无论是药品预测还是店铺预测,系统都能够加载正确的历史数据用于图表绘制,彻底解决了图表显示空白的问题。
|
||||
|
||||
### 16:16 - 项目状态更新
|
||||
- **状态**: **所有已知问题已修复**。
|
||||
- **确认**: 用户已确认“现在药品和店铺预测流程通了。
|
||||
- **后续**: 将本次修复过程归档至本文档。
|
||||
|
||||
|
||||
---
|
||||
|
||||
### 2025年7月16日 18:38 - 全模型预测功能通用性修复
|
||||
|
||||
**问题现象**:
|
||||
在解决了 `Transformer` 模型的预测问题后,发现一个更深层次的系统性问题:在所有预测模式(按药品、按店铺、全局)中,只有 `Transformer` 算法可以成功预测并显示图表,而其他四种模型(`mLSTM`, `KAN`, `优化版KAN`, `TCN`)虽然能成功训练,但在预测时均会失败,并提示“没有可用于图表的数据”。
|
||||
|
||||
**根本原因深度分析**:
|
||||
这个问题的核心在于**模型配置的持久化不完整且不统一**。
|
||||
|
||||
1. **Transformer 的“幸存”**: `Transformer` 模型的实现恰好不依赖于那些在保存时被遗漏的特定超参数,因此它能“幸存”下来。
|
||||
2. **其他模型的“共性缺陷”**: 其他所有模型 (`mLSTM`, `TCN`, `KAN`) 在它们的构造函数中,都依赖于一些在训练时定义、但在保存到检查点文件 (`.pth`) 时**被遗漏的**关键结构性参数。
|
||||
* **mLSTM**: 缺少 `mlstm_layers`, `embed_dim`, `dense_dim` 等参数。
|
||||
* **TCN**: 缺少 `num_channels`, `kernel_size` 等参数。
|
||||
* **KAN**: 缺少 `hidden_sizes` 列表。
|
||||
3. **连锁失败**:
|
||||
* 当 `server/predictors/model_predictor.py` 尝试加载这些模型的检查点文件时,它从 `checkpoint['config']` 中找不到实例化模型所必需的全部参数。
|
||||
* 模型实例化失败,抛出 `KeyError` 或 `TypeError`。
|
||||
* 这个异常导致 `load_model_and_predict` 函数提前返回 `None`,最终导致返回给前端的响应中缺少 `history_data`,前端因此无法渲染图表。
|
||||
|
||||
**系统性、可扩展的解决方案**:
|
||||
为了彻底解决这个问题,并为未来平稳地加入新算法,我们对所有非 Transformer 的训练器进行了标准化的、彻底的修复。
|
||||
|
||||
1. **修复 `mlstm_trainer.py`**: 在 `config` 字典中补全了 `mlstm_layers`, `embed_dim`, `dense_dim` 等所有缺失的参数。
|
||||
2. **修复 `tcn_trainer.py`**: 在 `config` 字典中补全了 `num_channels`, `kernel_size` 等所有缺失的参数。
|
||||
3. **修复 `kan_trainer.py`**: 在 `config` 字典中补全了 `hidden_sizes` 列表。
|
||||
|
||||
**结果**:
|
||||
通过这次系统性的修复,我们确保了所有训练器在保存模型时,都会将完整的、可用于重新实例化模型的配置信息写入检查点文件。这从根本上解决了所有模型算法的预测失败问题,使得整个系统在处理不同算法时具有了通用性和健壮性。
|
||||
|
||||
|
||||
---
|
||||
|
||||
## 2025-07-17:系统性链路疏通与规范化
|
||||
**开发者**: lyf
|
||||
|
||||
### 15:00 - 创建技术文档与上手指南
|
||||
- **任务**: 为了便于新成员理解和未来维护,创建了两份核心技术文档。
|
||||
- **产出**:
|
||||
1. **`系统调用逻辑与核心代码分析.md`**: 一份深入代码细节的端到端调用链路分析文档,详细描述了从前端交互到后端处理,再到模型训练和预测的完整流程。
|
||||
2. **`项目快速上手指南.md`**: 一份面向新成员(特别是Java背景)的高层次指南,通过技术栈类比、架构分层图和清晰的开发流程,帮助新成员快速建立对项目的宏观理解。
|
||||
|
||||
### 16:00 - 修复 `mLSTM` 模型加载链路
|
||||
- **问题**: `mLSTM` 模型在预测时因参数名不一致而加载失败。
|
||||
- **分析**:
|
||||
- 第一次失败: 加载器需要 `num_layers`,但训练器保存的是 `mlstm_layers`。
|
||||
- 第二次失败: 加载器需要 `dropout`,但训练器保存的是 `dropout_rate`。
|
||||
- **修复**: 遵循“保存方决定命名”的原则,修改了 `server/predictors/model_predictor.py`,将加载时使用的参数名统一为 `mlstm_layers` 和 `dropout_rate`,与训练器保持一致。
|
||||
|
||||
### 16:45 - 修复 `mLSTM` 模型算法缺陷
|
||||
- **问题**: `mLSTM` 模型修复加载问题后,预测结果为一条无效的直线。
|
||||
- **根本原因**: `server/models/mlstm_model.py` 中的模型架构存在设计缺陷。其解码器逻辑错误地将输入序列的最后一个时间步复制多份作为预测,导致模型无法学习时间序列的变化。
|
||||
- **修复**: 重构了 `MLSTMTransformer` 类的 `forward` 方法,移除了有问题的解码器逻辑,改为直接使用编码器最终的隐藏状态通过一个线性层进行预测,从根本上修正了算法的实现。
|
||||
|
||||
### 17:00 - 修复 `TCN` 模型加载链路
|
||||
- **问题**: `TCN` 模型在预测加载时存在硬编码参数,是一个潜在的崩溃点。
|
||||
- **分析**: `server/predictors/model_predictor.py` 在创建 `TCNForecaster` 实例时,硬编码了 `kernel_size=3`,而没有从模型配置中读取。
|
||||
- **修复**: 修改了 `model_predictor.py`,使其从 `config['kernel_size']` 中动态读取该参数,确保了配置的完整性和一致性。
|
||||
|
||||
### 17:15 - 修复 `KAN` 模型版本发现问题
|
||||
- **问题**: `KAN` 和 `优化版KAN` 训练成功后,在预测页面无法找到任何模型版本。
|
||||
- **根本原因**: **保存**和**搜索**逻辑不匹配。`kan_trainer.py` 使用 `model_manager.py` 以 `..._product_...` 格式保存模型,而 `server/core/config.py` 中的 `get_model_versions` 函数却只按 `..._epoch_...` 的格式进行搜索。
|
||||
- **修复**: 扩展了 `config.py` 中的 `get_model_versions` 函数,使其能够兼容并搜索多种命名格式,包括 `KAN` 模型使用的 `..._product_...` 格式。
|
||||
|
||||
### 17:25 - 修复 `KAN` 模型文件路径生成问题
|
||||
- **问题**: 修复版本发现问题后,点击预测依然失败,提示“未找到模型文件”。
|
||||
- **根本原因**: 只修复了**版本发现**逻辑,但未同步修复**文件路径生成**逻辑。`config.py` 中的 `get_model_file_path` 函数在为 `KAN` 模型生成路径时,依然错误地使用了 `_epoch_` 格式。
|
||||
- **修复**: 修改了 `get_model_file_path` 函数,为 `kan` 和 `optimized_kan` 模型增加了特殊处理,确保在生成其文件路径时使用正确的 `_product_` 命名格式。
|
||||
|
||||
### 17:40 - 升级 `KAN` 训练器的版本管理功能
|
||||
- **问题**: `KAN` 模型只有一个静态的 `'v1'` 版本,与其他模型(有 `best`, `final_epoch_...` 等版本)不一致。
|
||||
- **根本原因**: `kan_trainer.py` 的实现逻辑过于简单,缺少在训练过程中动态评估并保存多个版本的功能,仅在最后硬编码保存为 `'v1'`。
|
||||
- **修复 (功能升级)**: 重构了 `server/trainers/kan_trainer.py`,为其增加了与其他训练器完全一致的动态版本管理功能。现在它可以在训练时自动追踪并保存性能最佳的 `best` 版本,并在训练结束后保存 `final_epoch_...` 版本。
|
||||
|
||||
### 17:58 - 最终结论
|
||||
- **状态**: **所有已知问题已修复**。
|
||||
- **成果**:
|
||||
1. 所有模型的 **“数据 -> 训练 -> 保存 -> 加载 -> 预测 -> 可视化”** 执行链路已全面打通和验证。
|
||||
2. 统一并修复了所有模型在配置持久化和加载过程中的参数不一致问题。
|
||||
3. 将所有模型的版本管理逻辑和工程实现标准完全对齐。
|
||||
4. 创建并完善了核心技术文档,固化了开发规范。
|
||||
- **项目状态**: 系统现在处于一个健壮、一致且可扩展的稳定状态。
|
||||
|
||||
---
|
||||
|
||||
## 2025-07-18: 系统性重构模型版本管理机制
|
||||
**开发者**: lyf
|
||||
|
||||
### 14:00 - 根治版本混乱与模型加载失败问题
|
||||
- **问题现象**: `KAN` 及其他算法在训练后,预测时出现版本号混乱(如出现裸数字 `1`、`3` 或 `best` 等无效版本)、版本重复、以及因版本不匹配导致的“模型文件未找到”的 `404` 错误。
|
||||
- **根本原因深度分析**:
|
||||
1. **逻辑分散**: 版本生成的逻辑分散在各个训练器 (`trainer`) 中,而版本发现的逻辑在 `config.py` 中,两者标准不一,充满冲突的正则表达式和硬编码规则。
|
||||
2. **命名不统一**: `KAN` 训练器使用 `model_manager` 保存,而其他训练器使用本地的 `save_checkpoint` 函数,导致了 `..._product_..._v1.pth` 和 `..._epoch_best.pth` 等多种不兼容的命名格式并存。
|
||||
3. **提取错误**: `config.py` 中的 `get_model_versions` 函数因其过于宽泛和冲突的匹配规则,会从文件名中错误地提取出无效的版本号,是导致前端下拉框内容混乱的直接原因。
|
||||
- **系统性重构解决方案**:
|
||||
1. **确立单一权威**: 将 [`server/utils/model_manager.py`](server/utils/model_manager.py:1) 确立为系统中唯一负责版本管理、模型命名和文件IO的组件。
|
||||
2. **实现自动版本控制**: 在 `ModelManager` 中增加了 `_get_next_version` 内部方法,使其能够自动扫描现有文件,并安全地生成下一个递增的、带 `v` 前缀的版本号(如 `v3`)。
|
||||
3. **统一所有训练器**: 全面重构了 `kan_trainer.py`, `mlstm_trainer.py`, `tcn_trainer.py`, 和 `transformer_trainer.py`。现在,所有训练器在保存最终模型时,都调用 `model_manager.save_model` 并且**不再自行决定版本号**,完全由 `ModelManager` 自动生成。对于训练过程中的最佳模型,则统一显式保存为 `best` 版本。
|
||||
4. **清理与加固**: 废弃并删除了 `config.py` 中所有旧的、有问题的版本管理函数,并重写了 `get_model_versions`,使其只使用严格的正则表达式来查找和解析符合新命名规范的模型版本。
|
||||
5. **优化API**: 更新了 `api.py`,使其完全与新的 `ModelManager` 对接,并改进了预测失败时的错误信息反馈。
|
||||
- **结论**: 通过这次重构,系统的版本管理机制从一个分散、混乱、充满硬编码的状态,升级为了一个集中的、统一的、自动化的健壮系统。所有已知相关的bug已被从根本上解决。
|
||||
|
||||
---
|
||||
|
||||
## 2025-07-18 (续): 实现“按店铺”AI闭环及连锁Bug修复
|
||||
**开发者**: lyf
|
||||
|
||||
### 15:00 - 架构升级:实现“按店铺”训练与预测功能
|
||||
- **任务目标**: 在现有“按药品”模式基础上,增加并打通“按店铺”维度的完整AI闭环。
|
||||
- **核心挑战**: 需要对数据处理、模型标识、训练流程和API调用进行系统性改造,以支持新的训练模式。
|
||||
- **解决方案 (四步重构)**:
|
||||
1. **升级 `ModelManager`**: 重新设计了模型命名规则,为店铺和全局模型提供了清晰、无歧义的标识(如 `transformer_store_S001_v1.pth`),并同步更新了解析逻辑。
|
||||
2. **修正核心预测器**: 修复了 `predictor.py` 中的关键逻辑缺陷,确保在店铺模式下,系统能生成并使用正确的 `model_identifier`(如 `store_S001`),并强制调用数据聚合函数。
|
||||
3. **适配API层**: 调整了 `api.py` 中的训练和预测接口,使其能够兼容和正确处理新的店铺模式请求。
|
||||
4. **统一所有训练器**: 对全部四个训练器文件进行了统一修改,确保它们在保存模型时,都正确地使用了新的 `model_identifier`。
|
||||
|
||||
### 15:30 - 连锁Bug修复第一环:解决店铺模型版本加载失败
|
||||
- **问题现象**: “按店铺预测”页面的模型版本下拉框为空。
|
||||
- **根本原因**: `api.py` 中负责获取店铺模型版本的接口 `get_store_model_versions_api` 仍在使用旧的、不兼容新命名规范的函数来查找模型。
|
||||
- **修复**: 重写了该接口,使其放弃旧函数,转而使用 `ModelManager` 来进行统一、可靠的模型查找。
|
||||
|
||||
### 15:40 - 连锁Bug修复第二环:解决店铺预测 `404` 失败
|
||||
- **问题现象**: 版本列表加载正常后,点击“开始预测”返回 `404` 错误。
|
||||
- **根本原因**: 后端预测接口 `predict()` 内部的执行函数 `load_model_and_predict` 存在一段过时的、手动的模型文件查找逻辑,它完全绕过了 `ModelManager`,并错误地构建了文件路径。
|
||||
- **修复 (联合重构)**:
|
||||
1. **改造 `model_predictor.py`**: 彻底移除了 `load_model_and_predict` 函数内部所有过时的文件查找代码,并修改其函数签名,使其直接接收一个明确的 `model_path` 参数。
|
||||
2. **改造 `api.py`**: 修改了 `predict` 接口,将在API层通过 `ModelManager` 找到的正确模型路径,一路传递到最底层的 `load_model_and_predict` 函数中,确保了调用链的逻辑一致性。
|
||||
|
||||
### 15:50 - 连锁Bug修复第三环:解决服务启动 `NameError`
|
||||
- **问题现象**: 在修复预测逻辑后,API服务无法启动,报错 `NameError: name 'Optional' is not defined`。
|
||||
- **根本原因**: 在修改 `model_predictor.py` 时,使用了 `Optional` 类型提示,但忘记从 `typing` 模块导入。
|
||||
- **修复**: 在 `server/predictors/model_predictor.py` 文件顶部添加了 `from typing import Optional`。
|
||||
- **最终结论**: 至此,所有与“按店铺”功能相关的架构升级和连锁bug均已修复。系统现在能够稳定、正确地处理两种维度的训练和预测任务,并且代码逻辑更加统一和健壮。
|
||||
|
||||
|
||||
---
|
||||
|
||||
## 2025-07-21:前后端联合调试与UI修复
|
||||
**开发者**: lyf
|
||||
|
||||
### 15:45 - 修复后端 `DataFrame` 序列化错误
|
||||
- **问题现象**: 在清理了历史模型并重新进行预测后,前端出现 `Object of type DataFrame is not JSON serializable` 错误。
|
||||
- **根本原因**: `server/predictors/model_predictor.py` 中的 `load_model_and_predict` 函数在返回结果时,为了兼容旧版接口而保留的 `'predictions'` 字段,其值依然是未经处理的 Pandas DataFrame (`predictions_df`)。
|
||||
- **修复方案**: 修改了该函数的返回字典,将 `'predictions'` 字段的值也更新为已经过 `.to_dict('records')` 方法处理的 `prediction_data_json`,确保了返回对象的所有部分都是JSON兼容的。
|
||||
|
||||
### 16:00 - 统一修复所有预测视图的图表渲染问题
|
||||
- **问题现象**: 在解决了后端的序列化问题后,所有三个预测视图(按药品、按店铺、全局)的图表均为空白,并且图表下方的日期副标题显示为未经格式化的原始JavaScript日期字符串。
|
||||
- **根本原因深度分析**:
|
||||
1. **数据访问路径不精确**: 前端代码直接从API响应的根对象 (`response.data`) 中获取数据,而最可靠的数据源位于 `response.data.data` 中。
|
||||
2. **日期对象处理不当**: 前端代码未能将从后端接收到的日期(无论是字符串还是由axios自动转换的Date对象)标准化为统一的字符串格式。这导致在使用 `Set` 对日期进行去重时,因对象引用不同而失败,最终图表上没有数据点。
|
||||
- **统一修复方案**:
|
||||
1. **逐一修改**: 逐一修改了 `ProductPredictionView.vue`, `StorePredictionView.vue`, 和 `GlobalPredictionView.vue` 三个文件。
|
||||
2. **修正数据访问**: 在 `startPrediction` 方法中,将API响应的核心数据 `response.data.data` 赋值给 `predictionResult`。
|
||||
3. **标准化日期**: 在 `renderChart` 方法的开头,增加了一个 `formatDate` 辅助函数,并在处理数据时立即调用它,将所有日期都统一转换为 `'YYYY-MM-DD'` 格式的字符串,从而一举解决了数据点丢失和标题格式错误的双重问题。
|
||||
- **最终结论**: 至此,所有预测视图的前后端数据链路和UI展示功能均已修复,系统功能恢复正常。
|
Binary file not shown.
@ -56,4 +56,6 @@ tzdata==2025.2
|
||||
werkzeug==3.1.3
|
||||
win32-setctime==1.2.0
|
||||
wsproto==1.2.0
|
||||
python-dateutil
|
||||
xgboost
|
||||
scikit-learn
|
||||
|
BIN
sales_trends.png
BIN
sales_trends.png
Binary file not shown.
Before Width: | Height: | Size: 348 KiB |
983
server/api.py
983
server/api.py
File diff suppressed because it is too large
Load Diff
@ -36,7 +36,7 @@ DEVICE = get_device()
|
||||
# 使用 os.path.join 构造跨平台的路径
|
||||
DEFAULT_DATA_PATH = os.path.join(PROJECT_ROOT, 'data', 'timeseries_training_data_sample_10s50p.parquet')
|
||||
DEFAULT_MODEL_DIR = os.path.join(PROJECT_ROOT, 'saved_models')
|
||||
DEFAULT_FEATURES = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||
DEFAULT_FEATURES = ['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||
|
||||
# 时间序列参数
|
||||
LOOK_BACK = 5 # 使用过去5天数据(适应小数据集)
|
||||
@ -58,7 +58,9 @@ HIDDEN_SIZE = 64 # 隐藏层大小
|
||||
NUM_LAYERS = 2 # 层数
|
||||
|
||||
# 支持的模型类型
|
||||
SUPPORTED_MODELS = ['mlstm', 'kan', 'transformer', 'tcn', 'optimized_kan']
|
||||
# 支持的模型类型 (v2 - 动态加载)
|
||||
from models.model_registry import TRAINER_REGISTRY
|
||||
SUPPORTED_MODELS = list(TRAINER_REGISTRY.keys())
|
||||
|
||||
# 版本管理配置
|
||||
MODEL_VERSION_PREFIX = 'v' # 版本前缀
|
||||
@ -71,6 +73,154 @@ TRAINING_UPDATE_INTERVAL = 1 # 训练进度更新间隔(秒)
|
||||
# 创建模型保存目录
|
||||
os.makedirs(DEFAULT_MODEL_DIR, exist_ok=True)
|
||||
|
||||
# 注意:所有与模型路径、版本管理相关的函数(如 get_next_model_version, get_model_file_path 等)
|
||||
# 已被移除,因为这些功能现在由 server.utils.file_save.ModelPathManager 统一处理。
|
||||
# 这种集中化管理确保了整个应用程序遵循统一的、基于规范的扁平化文件保存策略。
|
||||
|
||||
def get_model_file_path(product_id: str, model_type: str, version: str) -> str:
|
||||
"""
|
||||
根据产品ID、模型类型和版本号,生成模型文件的准确路径。
|
||||
|
||||
Args:
|
||||
product_id: 产品ID (纯数字)
|
||||
model_type: 模型类型
|
||||
version: 版本字符串 (例如 'best', 'final_epoch_50', 'v1_legacy')
|
||||
|
||||
Returns:
|
||||
模型文件的完整路径
|
||||
"""
|
||||
# 处理历史遗留的 "v1" 格式
|
||||
if version == "v1_legacy":
|
||||
filename = f"{model_type}_model_product_{product_id}.pth"
|
||||
return os.path.join(DEFAULT_MODEL_DIR, filename)
|
||||
|
||||
# 修正:直接使用唯一的product_id(它可能包含store_前缀)来构建文件名
|
||||
# 文件名示例: transformer_17002608_epoch_best.pth 或 transformer_store_01010023_epoch_best.pth
|
||||
# 针对 KAN 和 optimized_kan,使用 model_manager 的命名约定
|
||||
# 统一所有模型的命名格式
|
||||
filename = f"{model_type}_product_{product_id}_{version}.pth"
|
||||
# 修正:直接在根模型目录查找,不再使用checkpoints子目录
|
||||
return os.path.join(DEFAULT_MODEL_DIR, filename)
|
||||
|
||||
def get_model_versions(product_id: str, model_type: str) -> list:
|
||||
"""
|
||||
获取指定产品和模型类型的所有版本
|
||||
|
||||
Args:
|
||||
product_id: 产品ID (现在应该是纯数字ID)
|
||||
model_type: 模型类型
|
||||
|
||||
Returns:
|
||||
版本列表,按版本号排序
|
||||
"""
|
||||
# 统一使用新的命名约定进行搜索
|
||||
pattern = os.path.join(DEFAULT_MODEL_DIR, f"{model_type}_product_{product_id}_*.pth")
|
||||
existing_files = glob.glob(pattern)
|
||||
|
||||
versions = set()
|
||||
|
||||
for file_path in existing_files:
|
||||
filename = os.path.basename(file_path)
|
||||
|
||||
# 严格匹配 _v<number> 或 'best'
|
||||
match = re.search(r'_(v\d+|best)\.pth$', filename)
|
||||
if match:
|
||||
versions.add(match.group(1))
|
||||
|
||||
# 按数字版本降序排序,'best'始终在最前
|
||||
def sort_key(v):
|
||||
if v == 'best':
|
||||
return -1 # 'best' is always first
|
||||
if v.startswith('v'):
|
||||
return int(v[1:])
|
||||
return float('inf') # Should not happen
|
||||
|
||||
sorted_versions = sorted(list(versions), key=sort_key, reverse=True)
|
||||
|
||||
return sorted_versions
|
||||
|
||||
|
||||
def save_model_version_info(product_id: str, model_type: str, version: str, file_path: str, metrics: dict = None):
|
||||
"""
|
||||
保存模型版本信息到数据库
|
||||
|
||||
Args:
|
||||
product_id: 产品ID
|
||||
model_type: 模型类型
|
||||
version: 版本号
|
||||
file_path: 模型文件路径
|
||||
metrics: 模型性能指标
|
||||
"""
|
||||
import sqlite3
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
try:
|
||||
conn = sqlite3.connect('prediction_history.db')
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 插入模型版本记录
|
||||
cursor.execute('''
|
||||
INSERT INTO model_versions (
|
||||
product_id, model_type, version, file_path, created_at, metrics, is_active
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
''', (
|
||||
product_id,
|
||||
model_type,
|
||||
version,
|
||||
file_path,
|
||||
datetime.now().isoformat(),
|
||||
json.dumps(metrics) if metrics else None,
|
||||
1 # 新模型默认为激活状态
|
||||
))
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
print(f"已保存模型版本信息: {product_id}_{model_type}_{version}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"保存模型版本信息失败: {str(e)}")
|
||||
|
||||
def get_model_version_info(product_id: str, model_type: str, version: str = None):
|
||||
"""
|
||||
从数据库获取模型版本信息
|
||||
|
||||
Args:
|
||||
product_id: 产品ID
|
||||
model_type: 模型类型
|
||||
version: 版本号,如果为None则获取最新版本
|
||||
|
||||
Returns:
|
||||
模型版本信息字典
|
||||
"""
|
||||
import sqlite3
|
||||
import json
|
||||
|
||||
try:
|
||||
conn = sqlite3.connect('prediction_history.db')
|
||||
conn.row_factory = sqlite3.Row
|
||||
cursor = conn.cursor()
|
||||
|
||||
if version:
|
||||
cursor.execute('''
|
||||
SELECT * FROM model_versions
|
||||
WHERE product_id = ? AND model_type = ? AND version = ?
|
||||
ORDER BY created_at DESC LIMIT 1
|
||||
''', (product_id, model_type, version))
|
||||
else:
|
||||
cursor.execute('''
|
||||
SELECT * FROM model_versions
|
||||
WHERE product_id = ? AND model_type = ?
|
||||
ORDER BY created_at DESC LIMIT 1
|
||||
''', (product_id, model_type))
|
||||
|
||||
row = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
if row:
|
||||
result = dict(row)
|
||||
if result['metrics']:
|
||||
result['metrics'] = json.loads(result['metrics'])
|
||||
return result
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取模型版本信息失败: {str(e)}")
|
||||
return None
|
@ -11,13 +11,13 @@ import time
|
||||
import matplotlib.pyplot as plt
|
||||
from datetime import datetime
|
||||
|
||||
from trainers import (
|
||||
train_product_model_with_mlstm,
|
||||
train_product_model_with_kan,
|
||||
train_product_model_with_tcn,
|
||||
train_product_model_with_transformer,
|
||||
train_product_model_with_xgboost
|
||||
)
|
||||
# from trainers import (
|
||||
# train_product_model_with_mlstm,
|
||||
# train_product_model_with_kan,
|
||||
# train_product_model_with_tcn,
|
||||
# train_product_model_with_transformer
|
||||
# )
|
||||
# 上述导入已不再需要,因为我们现在通过模型注册表动态获取训练器
|
||||
from predictors.model_predictor import load_model_and_predict
|
||||
from utils.data_utils import prepare_data, prepare_sequences
|
||||
from utils.multi_store_data_utils import (
|
||||
@ -65,9 +65,8 @@ class PharmacyPredictor:
|
||||
learning_rate=0.001, sequence_length=30, forecast_horizon=7,
|
||||
hidden_size=64, num_layers=2, dropout=0.1, use_optimized=False,
|
||||
store_id=None, training_mode='product', aggregation_method='sum',
|
||||
product_scope='all', product_ids=None,
|
||||
socketio=None, task_id=None, version=None, continue_training=False,
|
||||
progress_callback=None, path_info=None):
|
||||
progress_callback=None):
|
||||
"""
|
||||
训练预测模型 - 支持多店铺训练
|
||||
|
||||
@ -125,38 +124,29 @@ class PharmacyPredictor:
|
||||
return None
|
||||
|
||||
# 如果product_id是'unknown',则表示为店铺所有商品训练一个聚合模型
|
||||
if product_scope == 'specific' and product_ids:
|
||||
# 为店铺的指定产品列表训练
|
||||
try:
|
||||
# 从该店铺的数据中筛选出指定的产品
|
||||
store_data = self.data[self.data['store_id'] == store_id]
|
||||
product_data = store_data[store_data['product_id'].isin(product_ids)].copy()
|
||||
log_message(f"按店铺-指定药品训练: 店铺 {store_id}, {len(product_ids)}种药品, 数据量: {len(product_data)}")
|
||||
except Exception as e:
|
||||
log_message(f"获取店铺指定药品数据失败: {e}", 'error')
|
||||
return None
|
||||
elif product_id == 'unknown' or product_scope == 'all':
|
||||
# 为店铺所有商品训练一个聚合模型
|
||||
if product_id == 'unknown':
|
||||
try:
|
||||
# 使用新的聚合函数,按店铺聚合
|
||||
product_data = aggregate_multi_store_data(
|
||||
store_id=store_id,
|
||||
aggregation_method=aggregation_method,
|
||||
file_path=self.data_path
|
||||
)
|
||||
log_message(f"按店铺聚合训练: 店铺 {store_id}, 聚合方法 {aggregation_method}, 数据量: {len(product_data)}")
|
||||
product_id = store_id # 使用店铺ID作为模型标识
|
||||
# 将product_id设置为'store_{store_id}',与API查找逻辑保持一致
|
||||
product_id = f"store_{store_id}"
|
||||
except Exception as e:
|
||||
log_message(f"聚合店铺 {store_id} 数据失败: {e}", 'error')
|
||||
return None
|
||||
else:
|
||||
# 为店铺的单个特定产品训练(兼容旧逻辑)
|
||||
# 为店铺的单个特定产品训练
|
||||
try:
|
||||
product_data = get_store_product_sales_data(
|
||||
store_id=store_id,
|
||||
product_id=product_id,
|
||||
file_path=self.data_path
|
||||
)
|
||||
log_message(f"按店铺-单个产品训练: 店铺 {store_id}, 产品 {product_id}, 数据量: {len(product_data)}")
|
||||
log_message(f"按店铺-产品训练: 店铺 {store_id}, 产品 {product_id}, 数据量: {len(product_data)}")
|
||||
except Exception as e:
|
||||
log_message(f"获取店铺产品数据失败: {e}", 'error')
|
||||
return None
|
||||
@ -188,95 +178,59 @@ class PharmacyPredictor:
|
||||
log_message(f"不支持的训练模式: {training_mode}", 'error')
|
||||
return None
|
||||
|
||||
# 根据训练模式构建模型标识符
|
||||
# 根据训练模式构建模型标识符 (v2 修正)
|
||||
if training_mode == 'store':
|
||||
model_identifier = f"{store_id}_{product_id}"
|
||||
# 店铺模型的标识符只应基于店铺ID
|
||||
model_identifier = f"store_{store_id}"
|
||||
elif training_mode == 'global':
|
||||
model_identifier = f"global_{product_id}_{aggregation_method}"
|
||||
else:
|
||||
# 全局模型的标识符不应依赖于单个product_id
|
||||
model_identifier = f"global_{aggregation_method}"
|
||||
else: # product mode
|
||||
model_identifier = product_id
|
||||
|
||||
# 调用相应的训练函数
|
||||
# 调用相应的训练函数 (重构为使用注册表)
|
||||
try:
|
||||
log_message(f"🤖 开始调用 {model_type} 训练器")
|
||||
if model_type == 'transformer':
|
||||
model_result, metrics, actual_version, _ = train_product_model_with_transformer(
|
||||
product_id=product_id,
|
||||
product_df=product_data,
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
epochs=epochs,
|
||||
model_dir=self.model_dir,
|
||||
version=version,
|
||||
socketio=socketio,
|
||||
task_id=task_id,
|
||||
continue_training=continue_training,
|
||||
path_info=path_info
|
||||
)
|
||||
log_message(f"✅ {model_type} 训练器返回: metrics={type(metrics)}, version={actual_version}", 'success')
|
||||
elif model_type == 'mlstm':
|
||||
model_result, metrics, _, _ = train_product_model_with_mlstm(
|
||||
product_id=product_id,
|
||||
product_df=product_data,
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
epochs=epochs,
|
||||
model_dir=self.model_dir,
|
||||
socketio=socketio,
|
||||
task_id=task_id,
|
||||
progress_callback=progress_callback,
|
||||
path_info=path_info
|
||||
)
|
||||
elif model_type == 'kan':
|
||||
_, metrics = train_product_model_with_kan(
|
||||
product_id=product_id,
|
||||
product_df=product_data,
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
epochs=epochs,
|
||||
use_optimized=use_optimized,
|
||||
path_info=path_info
|
||||
)
|
||||
elif model_type == 'optimized_kan':
|
||||
_, metrics = train_product_model_with_kan(
|
||||
product_id=product_id,
|
||||
product_df=product_data,
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
epochs=epochs,
|
||||
use_optimized=True,
|
||||
path_info=path_info
|
||||
)
|
||||
elif model_type == 'tcn':
|
||||
model_result, metrics, _, _ = train_product_model_with_tcn(
|
||||
product_id=product_id,
|
||||
product_df=product_data,
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
epochs=epochs,
|
||||
model_dir=self.model_dir,
|
||||
socketio=socketio,
|
||||
task_id=task_id,
|
||||
path_info=path_info
|
||||
)
|
||||
elif model_type == 'xgboost':
|
||||
metrics, _ = train_product_model_with_xgboost(
|
||||
product_id=product_id,
|
||||
store_id=store_id,
|
||||
epochs=epochs,
|
||||
socketio=socketio,
|
||||
task_id=task_id,
|
||||
version=version,
|
||||
path_info=path_info
|
||||
)
|
||||
from models.model_registry import get_trainer
|
||||
log_message(f"🤖 正在从注册表获取 '{model_type}' 训练器...")
|
||||
trainer_function = get_trainer(model_type)
|
||||
log_message(f"✅ 成功获取训练器: {trainer_function.__name__}")
|
||||
|
||||
# 准备通用参数
|
||||
trainer_args = {
|
||||
'product_id': product_id,
|
||||
'model_identifier': model_identifier,
|
||||
'product_df': product_data,
|
||||
'store_id': store_id,
|
||||
'training_mode': training_mode,
|
||||
'aggregation_method': aggregation_method,
|
||||
'epochs': epochs,
|
||||
'sequence_length': sequence_length,
|
||||
'forecast_horizon': forecast_horizon,
|
||||
'model_dir': self.model_dir,
|
||||
'socketio': socketio,
|
||||
'task_id': task_id,
|
||||
'progress_callback': progress_callback,
|
||||
'version': version,
|
||||
'continue_training': continue_training,
|
||||
'use_optimized': use_optimized # KAN模型需要
|
||||
}
|
||||
|
||||
# 动态调用训练函数 (v2 - 智能参数过滤)
|
||||
import inspect
|
||||
sig = inspect.signature(trainer_function)
|
||||
valid_args = {k: v for k, v in trainer_args.items() if k in sig.parameters}
|
||||
|
||||
log_message(f"🔍 准备调用 {trainer_function.__name__},有效参数: {list(valid_args.keys())}")
|
||||
|
||||
result = trainer_function(**valid_args)
|
||||
|
||||
# 根据返回值的数量解析metrics
|
||||
if isinstance(result, tuple) and len(result) >= 2:
|
||||
metrics = result[1] # 通常第二个返回值是metrics
|
||||
else:
|
||||
log_message(f"不支持的模型类型: {model_type}", 'error')
|
||||
return None
|
||||
log_message(f"⚠️ 训练器返回格式未知,无法直接提取metrics: {type(result)}", 'warning')
|
||||
metrics = None
|
||||
|
||||
|
||||
# 检查和打印返回的metrics
|
||||
log_message(f"📊 训练完成,检查返回的metrics: {metrics}")
|
||||
@ -320,21 +274,24 @@ class PharmacyPredictor:
|
||||
返回:
|
||||
预测结果和分析(如果analyze_result为True)
|
||||
"""
|
||||
# 根据训练模式构建模型标识符
|
||||
# 根据训练模式构建模型标识符 (v2 修正)
|
||||
if training_mode == 'store' and store_id:
|
||||
model_identifier = f"{store_id}_{product_id}"
|
||||
model_identifier = f"store_{store_id}"
|
||||
elif training_mode == 'global':
|
||||
model_identifier = f"global_{product_id}_{aggregation_method}"
|
||||
else:
|
||||
# 全局模型的标识符不应依赖于单个product_id
|
||||
model_identifier = f"global_{aggregation_method}"
|
||||
else: # product mode
|
||||
model_identifier = product_id
|
||||
|
||||
return load_model_and_predict(
|
||||
model_identifier,
|
||||
model_type,
|
||||
future_days=future_days,
|
||||
start_date=start_date,
|
||||
model_identifier,
|
||||
model_type,
|
||||
store_id=store_id,
|
||||
future_days=future_days,
|
||||
start_date=start_date,
|
||||
analyze_result=analyze_result,
|
||||
version=version
|
||||
version=version,
|
||||
training_mode=training_mode
|
||||
)
|
||||
|
||||
def train_optimized_kan_model(self, product_id, epochs=100, batch_size=32,
|
||||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
102
server/models/cnn_bilstm_attention.py
Normal file
102
server/models/cnn_bilstm_attention.py
Normal file
@ -0,0 +1,102 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
CNN-BiLSTM-Attention 模型定义,适配药店销售预测系统。
|
||||
原始代码来源: python机器学习回归全家桶
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# 注意:由于原始代码使用了 TensorFlow/Keras 的层,我们将在这里创建一个 PyTorch 的等效实现。
|
||||
# 这是一个更健壮、更符合现有系统架构的做法。
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""
|
||||
PyTorch 实现的注意力机制。
|
||||
"""
|
||||
def __init__(self, feature_dim, step_dim, bias=True, **kwargs):
|
||||
super(Attention, self).__init__(**kwargs)
|
||||
|
||||
self.supports_masking = True
|
||||
self.bias = bias
|
||||
self.feature_dim = feature_dim
|
||||
self.step_dim = step_dim
|
||||
self.features_dim = 0
|
||||
|
||||
weight = torch.zeros(feature_dim, 1)
|
||||
nn.init.xavier_uniform_(weight)
|
||||
self.weight = nn.Parameter(weight)
|
||||
|
||||
if bias:
|
||||
self.b = nn.Parameter(torch.zeros(step_dim))
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
feature_dim = self.feature_dim
|
||||
step_dim = self.step_dim
|
||||
|
||||
eij = torch.mm(
|
||||
x.contiguous().view(-1, feature_dim),
|
||||
self.weight
|
||||
).view(-1, step_dim)
|
||||
|
||||
if self.bias:
|
||||
eij = eij + self.b
|
||||
|
||||
eij = torch.tanh(eij)
|
||||
a = torch.exp(eij)
|
||||
|
||||
if mask is not None:
|
||||
a = a * mask
|
||||
|
||||
a = a / (torch.sum(a, 1, keepdim=True) + 1e-10)
|
||||
|
||||
weighted_input = x * torch.unsqueeze(a, -1)
|
||||
return torch.sum(weighted_input, 1)
|
||||
|
||||
|
||||
class CnnBiLstmAttention(nn.Module):
|
||||
"""
|
||||
CNN-BiLSTM-Attention 模型的 PyTorch 实现。
|
||||
"""
|
||||
def __init__(self, input_dim, output_dim, sequence_length, cnn_filters=64, cnn_kernel_size=1, lstm_units=128):
|
||||
super(CnnBiLstmAttention, self).__init__()
|
||||
self.sequence_length = sequence_length
|
||||
self.cnn_filters = cnn_filters
|
||||
self.lstm_units = lstm_units
|
||||
|
||||
# CNN 层
|
||||
self.conv1d = nn.Conv1d(in_channels=input_dim, out_channels=cnn_filters, kernel_size=cnn_kernel_size)
|
||||
self.relu = nn.ReLU()
|
||||
self.maxpool = nn.MaxPool1d(kernel_size=1)
|
||||
|
||||
# BiLSTM 层
|
||||
self.bilstm = nn.LSTM(input_size=cnn_filters, hidden_size=lstm_units, num_layers=1, batch_first=True, bidirectional=True)
|
||||
|
||||
# Attention 层
|
||||
self.attention = Attention(feature_dim=lstm_units * 2, step_dim=sequence_length)
|
||||
|
||||
# 全连接输出层
|
||||
self.dense = nn.Linear(lstm_units * 2, output_dim)
|
||||
|
||||
def forward(self, x):
|
||||
# 输入 x 的形状: (batch_size, sequence_length, input_dim)
|
||||
|
||||
# CNN 处理
|
||||
x = x.permute(0, 2, 1) # 转换为 (batch_size, input_dim, sequence_length) 以适应 Conv1d
|
||||
x = self.conv1d(x)
|
||||
x = self.relu(x)
|
||||
x = x.permute(0, 2, 1) # 转换回 (batch_size, sequence_length, cnn_filters)
|
||||
|
||||
# BiLSTM 处理
|
||||
lstm_out, _ = self.bilstm(x) # lstm_out 形状: (batch_size, sequence_length, lstm_units * 2)
|
||||
|
||||
# Attention 处理
|
||||
# 注意:这里的 Attention 实现可能需要根据具体任务微调
|
||||
# 一个简化的方法是直接使用 LSTM 的最终隐藏状态或输出
|
||||
# 这里我们先用一个简化的逻辑:直接展平 LSTM 输出
|
||||
attention_out = self.attention(lstm_out)
|
||||
|
||||
# 全连接层输出
|
||||
output = self.dense(attention_out)
|
||||
|
||||
return output
|
@ -14,31 +14,695 @@ from .kan_model import KANForecaster
|
||||
|
||||
class ModelManager:
|
||||
"""
|
||||
模型管理类:此类现在主要负责提供模型类的映射。
|
||||
注意:所有与文件系统交互的逻辑(保存、加载、删除等)已被移除,
|
||||
并由 server.utils.file_save.ModelPathManager 统一处理,
|
||||
以遵循新的扁平化文件存储规范。
|
||||
模型管理类:负责模型的保存、加载、列出和删除等操作
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, models_dir='models'):
|
||||
"""
|
||||
初始化模型管理器
|
||||
|
||||
参数:
|
||||
models_dir: 模型存储目录
|
||||
"""
|
||||
# 模型类型到其对应类的映射
|
||||
self.models_dir = models_dir
|
||||
self._ensure_model_dir()
|
||||
|
||||
# 模型类型映射
|
||||
self.model_types = {
|
||||
'mlstm': MLSTMTransformer,
|
||||
'transformer': TimeSeriesTransformer,
|
||||
'kan': KANForecaster
|
||||
}
|
||||
|
||||
def get_model_class(self, model_type: str):
|
||||
|
||||
def _ensure_model_dir(self):
|
||||
"""确保模型目录存在"""
|
||||
if not os.path.exists(self.models_dir):
|
||||
try:
|
||||
os.makedirs(self.models_dir, exist_ok=True)
|
||||
print(f"创建模型目录: {os.path.abspath(self.models_dir)}")
|
||||
except Exception as e:
|
||||
print(f"创建模型目录失败: {str(e)}")
|
||||
raise
|
||||
|
||||
def save_model(self, model, model_type, product_id, optimizer=None,
|
||||
train_loss=None, test_loss=None, scaler_X=None,
|
||||
scaler_y=None, features=None, look_back=None, T=None,
|
||||
metrics=None, version=None):
|
||||
"""
|
||||
根据模型类型字符串获取模型类。
|
||||
|
||||
Args:
|
||||
model_type (str): 模型类型 (e.g., 'mlstm', 'kan')。
|
||||
|
||||
Returns:
|
||||
模型类,如果不存在则返回 None。
|
||||
保存模型及其相关信息
|
||||
|
||||
参数:
|
||||
model: 训练好的模型
|
||||
model_type: 模型类型 ('mlstm', 'transformer', 'kan')
|
||||
product_id: 产品ID
|
||||
optimizer: 优化器
|
||||
train_loss: 训练损失历史
|
||||
test_loss: 测试损失历史
|
||||
scaler_X: 特征缩放器
|
||||
scaler_y: 目标缩放器
|
||||
features: 使用的特征列表
|
||||
look_back: 回看天数
|
||||
T: 预测天数
|
||||
metrics: 模型评估指标
|
||||
version: 模型版本(可选),如果不提供则使用时间戳
|
||||
"""
|
||||
return self.model_types.get(model_type)
|
||||
self._ensure_model_dir()
|
||||
|
||||
# 设置版本
|
||||
if version is None:
|
||||
version = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
# 设置文件名
|
||||
model_filename = f"{product_id}_{model_type}_model_v{version}.pt"
|
||||
model_path = os.path.join(self.models_dir, model_filename)
|
||||
|
||||
# 准备要保存的数据
|
||||
save_dict = {
|
||||
'model_state_dict': model.state_dict(),
|
||||
'model_type': model_type,
|
||||
'product_id': product_id,
|
||||
'version': version,
|
||||
'created_at': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
'features': features,
|
||||
'look_back': look_back,
|
||||
'T': T
|
||||
}
|
||||
|
||||
# 添加可选数据
|
||||
if optimizer is not None:
|
||||
save_dict['optimizer_state_dict'] = optimizer.state_dict()
|
||||
if train_loss is not None:
|
||||
save_dict['train_loss'] = train_loss
|
||||
if test_loss is not None:
|
||||
save_dict['test_loss'] = test_loss
|
||||
if scaler_X is not None:
|
||||
save_dict['scaler_X'] = scaler_X
|
||||
if scaler_y is not None:
|
||||
save_dict['scaler_y'] = scaler_y
|
||||
if metrics is not None:
|
||||
save_dict['metrics'] = metrics
|
||||
|
||||
try:
|
||||
# 保存模型
|
||||
torch.save(save_dict, model_path)
|
||||
print(f"模型已成功保存到 {os.path.abspath(model_path)}")
|
||||
|
||||
# 保存模型的元数据到JSON文件,便于查询
|
||||
meta_path = os.path.join(self.models_dir, f"{product_id}_{model_type}_meta_v{version}.json")
|
||||
meta_dict = {k: str(v) if not isinstance(v, (int, float, bool, list, dict, type(None))) else v
|
||||
for k, v in save_dict.items() if k != 'model_state_dict' and
|
||||
k != 'optimizer_state_dict' and k != 'scaler_X' and k != 'scaler_y'}
|
||||
|
||||
# 如果有评估指标,添加到元数据
|
||||
if metrics is not None:
|
||||
meta_dict['metrics'] = metrics
|
||||
|
||||
with open(meta_path, 'w') as f:
|
||||
json.dump(meta_dict, f, indent=4)
|
||||
|
||||
return model_path
|
||||
except Exception as e:
|
||||
print(f"保存模型时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def load_model(self, product_id, model_type='mlstm', version=None, device=None):
|
||||
"""
|
||||
加载指定的模型
|
||||
|
||||
参数:
|
||||
product_id: 产品ID
|
||||
model_type: 模型类型 ('mlstm', 'transformer', 'kan')
|
||||
version: 模型版本,如果不指定则加载最新版本
|
||||
device: 设备 (cuda/cpu)
|
||||
|
||||
返回:
|
||||
model: 加载的模型
|
||||
checkpoint: 包含模型信息的字典
|
||||
"""
|
||||
if device is None:
|
||||
device = get_device()
|
||||
|
||||
# 查找匹配的模型文件
|
||||
if version is None:
|
||||
# 查找最新版本
|
||||
pattern = os.path.join(self.models_dir, f"{product_id}_{model_type}_model_v*.pt")
|
||||
model_files = glob.glob(pattern)
|
||||
|
||||
if not model_files:
|
||||
print(f"错误: 未找到产品 {product_id} 的 {model_type} 模型文件")
|
||||
return None, None
|
||||
|
||||
# 按照文件修改时间排序,获取最新的
|
||||
model_path = max(model_files, key=os.path.getmtime)
|
||||
else:
|
||||
# 指定版本
|
||||
model_path = os.path.join(self.models_dir, f"{product_id}_{model_type}_model_v{version}.pt")
|
||||
if not os.path.exists(model_path):
|
||||
print(f"错误: 未找到产品 {product_id} 的 {model_type} 模型版本 {version}")
|
||||
return None, None
|
||||
|
||||
try:
|
||||
# 加载模型
|
||||
checkpoint = torch.load(model_path, map_location=device)
|
||||
|
||||
# 创建模型实例
|
||||
if model_type == 'mlstm':
|
||||
model = MLSTMTransformer(
|
||||
num_features=len(checkpoint['features']),
|
||||
hidden_size=128,
|
||||
mlstm_layers=1,
|
||||
embed_dim=32,
|
||||
dense_dim=32,
|
||||
num_heads=4,
|
||||
dropout_rate=0.1,
|
||||
num_blocks=3,
|
||||
output_sequence_length=checkpoint['T']
|
||||
)
|
||||
elif model_type == 'transformer':
|
||||
model = TimeSeriesTransformer(
|
||||
num_features=len(checkpoint['features']),
|
||||
d_model=32,
|
||||
nhead=4,
|
||||
num_encoder_layers=3,
|
||||
dim_feedforward=32,
|
||||
dropout=0.1,
|
||||
output_sequence_length=checkpoint['T']
|
||||
)
|
||||
elif model_type == 'kan':
|
||||
model = KANForecaster(
|
||||
input_features=len(checkpoint['features']),
|
||||
hidden_sizes=[64, 128, 64],
|
||||
output_size=1,
|
||||
grid_size=5,
|
||||
spline_order=3,
|
||||
dropout_rate=0.1,
|
||||
output_sequence_length=checkpoint['T']
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"不支持的模型类型: {model_type}")
|
||||
|
||||
# 加载模型参数
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
|
||||
print(f"模型已从 {os.path.abspath(model_path)} 成功加载")
|
||||
return model, checkpoint
|
||||
except Exception as e:
|
||||
print(f"加载模型时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def list_models(self, product_id=None, model_type=None):
|
||||
"""
|
||||
列出所有保存的模型
|
||||
|
||||
参数:
|
||||
product_id: 按产品ID筛选 (可选)
|
||||
model_type: 按模型类型筛选 (可选)
|
||||
|
||||
返回:
|
||||
models_list: 模型信息列表
|
||||
"""
|
||||
self._ensure_model_dir()
|
||||
|
||||
# 构建搜索模式
|
||||
if product_id and model_type:
|
||||
pattern = os.path.join(self.models_dir, f"{product_id}_{model_type}_model_v*.pt")
|
||||
elif product_id:
|
||||
pattern = os.path.join(self.models_dir, f"{product_id}_*_model_v*.pt")
|
||||
elif model_type:
|
||||
pattern = os.path.join(self.models_dir, f"*_{model_type}_model_v*.pt")
|
||||
else:
|
||||
pattern = os.path.join(self.models_dir, "*_model_v*.pt")
|
||||
|
||||
model_files = glob.glob(pattern)
|
||||
|
||||
if not model_files:
|
||||
print("未找到匹配的模型文件")
|
||||
return []
|
||||
|
||||
# 收集模型信息
|
||||
models_list = []
|
||||
for model_path in model_files:
|
||||
try:
|
||||
# 从文件名解析信息
|
||||
filename = os.path.basename(model_path)
|
||||
parts = filename.split('_')
|
||||
if len(parts) < 4:
|
||||
continue
|
||||
|
||||
product_id = parts[0]
|
||||
model_type = parts[1]
|
||||
version = parts[-1].replace('model_v', '').replace('.pt', '')
|
||||
|
||||
# 查找对应的元数据文件
|
||||
meta_path = os.path.join(self.models_dir, f"{product_id}_{model_type}_meta_v{version}.json")
|
||||
|
||||
model_info = {
|
||||
'product_id': product_id,
|
||||
'model_type': model_type,
|
||||
'version': version,
|
||||
'file_path': model_path,
|
||||
'created_at': datetime.fromtimestamp(os.path.getctime(model_path)).strftime("%Y-%m-%d %H:%M:%S"),
|
||||
'file_size': f"{os.path.getsize(model_path) / (1024 * 1024):.2f} MB"
|
||||
}
|
||||
|
||||
# 如果有元数据文件,添加更多信息
|
||||
if os.path.exists(meta_path):
|
||||
with open(meta_path, 'r') as f:
|
||||
meta = json.load(f)
|
||||
model_info.update(meta)
|
||||
|
||||
models_list.append(model_info)
|
||||
except Exception as e:
|
||||
print(f"解析模型文件 {model_path} 时出错: {str(e)}")
|
||||
|
||||
# 按创建时间排序
|
||||
models_list.sort(key=lambda x: x['created_at'], reverse=True)
|
||||
|
||||
return models_list
|
||||
|
||||
def delete_model(self, product_id, model_type, version=None):
|
||||
"""
|
||||
删除指定的模型
|
||||
|
||||
参数:
|
||||
product_id: 产品ID
|
||||
model_type: 模型类型
|
||||
version: 模型版本,如果不指定则删除所有版本
|
||||
|
||||
返回:
|
||||
success: 是否成功删除
|
||||
"""
|
||||
self._ensure_model_dir()
|
||||
|
||||
if version:
|
||||
# 删除特定版本
|
||||
model_path = os.path.join(self.models_dir, f"{product_id}_{model_type}_model_v{version}.pt")
|
||||
meta_path = os.path.join(self.models_dir, f"{product_id}_{model_type}_meta_v{version}.json")
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
print(f"错误: 未找到产品 {product_id} 的 {model_type} 模型版本 {version}")
|
||||
return False
|
||||
|
||||
try:
|
||||
os.remove(model_path)
|
||||
if os.path.exists(meta_path):
|
||||
os.remove(meta_path)
|
||||
print(f"已删除产品 {product_id} 的 {model_type} 模型版本 {version}")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"删除模型时出错: {str(e)}")
|
||||
return False
|
||||
else:
|
||||
# 删除所有版本
|
||||
pattern = os.path.join(self.models_dir, f"{product_id}_{model_type}_model_v*.pt")
|
||||
meta_pattern = os.path.join(self.models_dir, f"{product_id}_{model_type}_meta_v*.json")
|
||||
|
||||
model_files = glob.glob(pattern)
|
||||
meta_files = glob.glob(meta_pattern)
|
||||
|
||||
if not model_files:
|
||||
print(f"错误: 未找到产品 {product_id} 的 {model_type} 模型文件")
|
||||
return False
|
||||
|
||||
try:
|
||||
for file_path in model_files:
|
||||
os.remove(file_path)
|
||||
|
||||
for file_path in meta_files:
|
||||
os.remove(file_path)
|
||||
|
||||
print(f"已删除产品 {product_id} 的所有 {model_type} 模型")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"删除模型时出错: {str(e)}")
|
||||
return False
|
||||
|
||||
def get_model_details(self, product_id, model_type, version=None):
|
||||
"""
|
||||
获取模型的详细信息
|
||||
|
||||
参数:
|
||||
product_id: 产品ID
|
||||
model_type: 模型类型
|
||||
version: 模型版本,如果不指定则获取最新版本
|
||||
|
||||
返回:
|
||||
details: 模型详细信息字典
|
||||
"""
|
||||
# 查找匹配的模型文件
|
||||
if version is None:
|
||||
# 查找最新版本
|
||||
pattern = os.path.join(self.models_dir, f"{product_id}_{model_type}_model_v*.pt")
|
||||
model_files = glob.glob(pattern)
|
||||
|
||||
if not model_files:
|
||||
print(f"错误: 未找到产品 {product_id} 的 {model_type} 模型文件")
|
||||
return None
|
||||
|
||||
# 按照文件修改时间排序,获取最新的
|
||||
model_path = max(model_files, key=os.path.getmtime)
|
||||
# 从文件名解析版本
|
||||
filename = os.path.basename(model_path)
|
||||
version = filename.split('_')[-1].replace('model_v', '').replace('.pt', '')
|
||||
|
||||
# 查找元数据文件
|
||||
meta_path = os.path.join(self.models_dir, f"{product_id}_{model_type}_meta_v{version}.json")
|
||||
|
||||
if not os.path.exists(meta_path):
|
||||
print(f"错误: 未找到产品 {product_id} 的 {model_type} 模型版本 {version} 的元数据")
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(meta_path, 'r') as f:
|
||||
details = json.load(f)
|
||||
|
||||
# 添加文件路径
|
||||
model_path = os.path.join(self.models_dir, f"{product_id}_{model_type}_model_v{version}.pt")
|
||||
details['file_path'] = model_path
|
||||
details['file_size'] = f"{os.path.getsize(model_path) / (1024 * 1024):.2f} MB"
|
||||
|
||||
return details
|
||||
except Exception as e:
|
||||
print(f"获取模型详情时出错: {str(e)}")
|
||||
return None
|
||||
|
||||
def predict_with_model(self, product_id, model_type='mlstm', version=None, future_days=7,
|
||||
product_df=None, features=None, visualize=True, save_results=True):
|
||||
"""
|
||||
使用指定的模型进行预测
|
||||
|
||||
参数:
|
||||
product_id: 产品ID
|
||||
model_type: 模型类型 ('mlstm', 'transformer', 'kan')
|
||||
version: 模型版本,如果不指定则使用最新版本
|
||||
future_days: 要预测的未来天数
|
||||
product_df: 产品数据DataFrame
|
||||
features: 特征列表
|
||||
visualize: 是否可视化结果
|
||||
save_results: 是否保存结果
|
||||
|
||||
返回:
|
||||
predictions_df: 预测结果DataFrame
|
||||
"""
|
||||
# 获取设备
|
||||
device = get_device()
|
||||
print(f"使用设备: {device} 进行预测")
|
||||
|
||||
# 加载模型
|
||||
model, checkpoint = self.load_model(product_id, model_type, version, device)
|
||||
|
||||
if model is None or checkpoint is None:
|
||||
return None
|
||||
|
||||
# 如果没有提供产品数据,则从Excel文件加载
|
||||
if product_df is None:
|
||||
try:
|
||||
df = pd.read_excel('pharmacy_sales.xlsx')
|
||||
product_df = df[df['product_id'] == product_id].sort_values('date')
|
||||
except Exception as e:
|
||||
print(f"加载产品数据时出错: {str(e)}")
|
||||
return None
|
||||
|
||||
product_name = product_df['product_name'].iloc[0]
|
||||
|
||||
# 获取模型参数
|
||||
features = checkpoint['features']
|
||||
look_back = checkpoint['look_back']
|
||||
T = checkpoint['T']
|
||||
scaler_X = checkpoint['scaler_X']
|
||||
scaler_y = checkpoint['scaler_y']
|
||||
|
||||
# 获取最近的look_back天数据
|
||||
last_data = product_df[features].values[-look_back:]
|
||||
last_data_scaled = scaler_X.transform(last_data)
|
||||
|
||||
# 准备输入数据
|
||||
X_input = torch.Tensor(last_data_scaled).unsqueeze(0) # 添加批次维度
|
||||
X_input = X_input.to(device) # 移动到设备上
|
||||
|
||||
# 进行预测
|
||||
with torch.no_grad():
|
||||
y_pred_scaled = model(X_input).squeeze(0).cpu().numpy() # 返回到CPU并转换为numpy
|
||||
|
||||
# 反归一化预测结果
|
||||
y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).flatten()
|
||||
|
||||
# 创建预测日期范围
|
||||
last_date = product_df['date'].iloc[-1]
|
||||
future_dates = pd.date_range(start=last_date + pd.Timedelta(days=1), periods=T, freq='D')
|
||||
|
||||
# 创建预测结果DataFrame
|
||||
predictions_df = pd.DataFrame({
|
||||
'date': future_dates,
|
||||
'product_id': product_id,
|
||||
'product_name': product_name,
|
||||
'predicted_sales': y_pred
|
||||
})
|
||||
|
||||
print(f"\n{product_name} 未来 {T} 天销售预测 (使用{model_type.upper()}模型):")
|
||||
print(predictions_df[['date', 'predicted_sales']])
|
||||
|
||||
# 可视化预测结果
|
||||
if visualize:
|
||||
plt.figure(figsize=(12, 6))
|
||||
|
||||
# 显示历史数据和预测数据
|
||||
history_days = 30 # 显示最近30天的历史数据
|
||||
history_dates = product_df['date'].iloc[-history_days:].values
|
||||
history_sales = product_df['sales'].iloc[-history_days:].values
|
||||
|
||||
plt.plot(history_dates, history_sales, 'b-', label='历史销量')
|
||||
plt.plot(future_dates, y_pred, 'r--', label=f'{model_type.upper()}预测销量')
|
||||
|
||||
plt.title(f'{product_name} - {model_type.upper()}销量预测 (未来{T}天)')
|
||||
plt.xlabel('日期')
|
||||
plt.ylabel('销量')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
plt.xticks(rotation=45)
|
||||
plt.tight_layout()
|
||||
|
||||
# 保存和显示图表
|
||||
forecast_chart = f'{product_id}_{model_type}_forecast.png'
|
||||
plt.savefig(forecast_chart)
|
||||
print(f"预测图表已保存为: {forecast_chart}")
|
||||
|
||||
# 保存预测结果到CSV
|
||||
if save_results:
|
||||
forecast_csv = f'{product_id}_{model_type}_forecast.csv'
|
||||
predictions_df.to_csv(forecast_csv, index=False)
|
||||
print(f"预测结果已保存到: {forecast_csv}")
|
||||
|
||||
return predictions_df
|
||||
|
||||
def compare_models(self, product_id, model_types=None, versions=None, product_df=None, visualize=True):
|
||||
"""
|
||||
比较不同模型的预测结果
|
||||
|
||||
参数:
|
||||
product_id: 产品ID
|
||||
model_types: 要比较的模型类型列表
|
||||
versions: 对应的模型版本列表,如果不指定则使用最新版本
|
||||
product_df: 产品数据DataFrame
|
||||
visualize: 是否可视化结果
|
||||
|
||||
返回:
|
||||
比较结果DataFrame
|
||||
"""
|
||||
if model_types is None:
|
||||
model_types = ['mlstm', 'transformer', 'kan']
|
||||
|
||||
if versions is None:
|
||||
versions = [None] * len(model_types)
|
||||
|
||||
if len(versions) != len(model_types):
|
||||
print("错误: 模型类型和版本列表长度不匹配")
|
||||
return None
|
||||
|
||||
# 如果没有提供产品数据,则从Excel文件加载
|
||||
if product_df is None:
|
||||
try:
|
||||
df = pd.read_excel('pharmacy_sales.xlsx')
|
||||
product_df = df[df['product_id'] == product_id].sort_values('date')
|
||||
except Exception as e:
|
||||
print(f"加载产品数据时出错: {str(e)}")
|
||||
return None
|
||||
|
||||
product_name = product_df['product_name'].iloc[0]
|
||||
|
||||
# 存储所有模型的预测结果
|
||||
predictions = {}
|
||||
|
||||
# 对每个模型进行预测
|
||||
for i, model_type in enumerate(model_types):
|
||||
version = versions[i]
|
||||
|
||||
try:
|
||||
pred_df = self.predict_with_model(
|
||||
product_id,
|
||||
model_type=model_type,
|
||||
version=version,
|
||||
product_df=product_df,
|
||||
visualize=False,
|
||||
save_results=False
|
||||
)
|
||||
|
||||
if pred_df is not None:
|
||||
predictions[model_type] = pred_df
|
||||
except Exception as e:
|
||||
print(f"{model_type} 模型预测出错: {str(e)}")
|
||||
|
||||
if not predictions:
|
||||
print("没有成功的预测结果")
|
||||
return None
|
||||
|
||||
# 合并预测结果
|
||||
result_df = predictions[list(predictions.keys())[0]][['date', 'product_id', 'product_name']].copy()
|
||||
|
||||
for model_type, pred_df in predictions.items():
|
||||
result_df[f'{model_type}_prediction'] = pred_df['predicted_sales'].values
|
||||
|
||||
# 可视化比较结果
|
||||
if visualize and len(predictions) > 0:
|
||||
plt.figure(figsize=(12, 6))
|
||||
|
||||
# 显示历史数据
|
||||
history_days = 30 # 显示最近30天的历史数据
|
||||
history_dates = product_df['date'].iloc[-history_days:].values
|
||||
history_sales = product_df['sales'].iloc[-history_days:].values
|
||||
|
||||
plt.plot(history_dates, history_sales, 'k-', label='历史销量')
|
||||
|
||||
# 显示预测数据
|
||||
colors = ['r', 'g', 'b', 'c', 'm', 'y']
|
||||
future_dates = result_df['date'].values
|
||||
|
||||
for i, (model_type, pred_df) in enumerate(predictions.items()):
|
||||
color = colors[i % len(colors)]
|
||||
plt.plot(future_dates, pred_df['predicted_sales'].values,
|
||||
f'{color}--', label=f'{model_type.upper()}预测')
|
||||
|
||||
plt.title(f'{product_name} - 不同模型预测结果比较')
|
||||
plt.xlabel('日期')
|
||||
plt.ylabel('销量')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
plt.xticks(rotation=45)
|
||||
plt.tight_layout()
|
||||
|
||||
# 保存和显示图表
|
||||
compare_chart = f'{product_id}_model_comparison.png'
|
||||
plt.savefig(compare_chart)
|
||||
print(f"比较图表已保存为: {compare_chart}")
|
||||
|
||||
# 保存比较结果到CSV
|
||||
compare_csv = f'{product_id}_model_comparison.csv'
|
||||
result_df.to_csv(compare_csv, index=False)
|
||||
print(f"比较结果已保存到: {compare_csv}")
|
||||
|
||||
return result_df
|
||||
|
||||
def export_model(self, product_id, model_type, version=None, export_dir='exported_models'):
|
||||
"""
|
||||
导出模型到指定目录
|
||||
|
||||
参数:
|
||||
product_id: 产品ID
|
||||
model_type: 模型类型
|
||||
version: 模型版本,如果不指定则导出最新版本
|
||||
export_dir: 导出目录
|
||||
|
||||
返回:
|
||||
export_path: 导出的文件路径
|
||||
"""
|
||||
# 确保导出目录存在
|
||||
if not os.path.exists(export_dir):
|
||||
os.makedirs(export_dir, exist_ok=True)
|
||||
|
||||
# 查找匹配的模型文件
|
||||
if version is None:
|
||||
# 查找最新版本
|
||||
pattern = os.path.join(self.models_dir, f"{product_id}_{model_type}_model_v*.pt")
|
||||
model_files = glob.glob(pattern)
|
||||
|
||||
if not model_files:
|
||||
print(f"错误: 未找到产品 {product_id} 的 {model_type} 模型文件")
|
||||
return None
|
||||
|
||||
# 按照文件修改时间排序,获取最新的
|
||||
model_path = max(model_files, key=os.path.getmtime)
|
||||
# 从文件名解析版本
|
||||
filename = os.path.basename(model_path)
|
||||
version = filename.split('_')[-1].replace('model_v', '').replace('.pt', '')
|
||||
else:
|
||||
model_path = os.path.join(self.models_dir, f"{product_id}_{model_type}_model_v{version}.pt")
|
||||
if not os.path.exists(model_path):
|
||||
print(f"错误: 未找到产品 {product_id} 的 {model_type} 模型版本 {version}")
|
||||
return None
|
||||
|
||||
# 元数据文件
|
||||
meta_path = os.path.join(self.models_dir, f"{product_id}_{model_type}_meta_v{version}.json")
|
||||
|
||||
# 导出路径
|
||||
export_model_path = os.path.join(export_dir, f"{product_id}_{model_type}_model_v{version}.pt")
|
||||
export_meta_path = os.path.join(export_dir, f"{product_id}_{model_type}_meta_v{version}.json")
|
||||
|
||||
try:
|
||||
# 复制文件
|
||||
shutil.copy2(model_path, export_model_path)
|
||||
if os.path.exists(meta_path):
|
||||
shutil.copy2(meta_path, export_meta_path)
|
||||
|
||||
print(f"模型已导出到 {os.path.abspath(export_model_path)}")
|
||||
return export_model_path
|
||||
except Exception as e:
|
||||
print(f"导出模型时出错: {str(e)}")
|
||||
return None
|
||||
|
||||
def import_model(self, import_file, overwrite=False):
|
||||
"""
|
||||
导入模型文件
|
||||
|
||||
参数:
|
||||
import_file: 要导入的模型文件路径
|
||||
overwrite: 如果存在同名文件是否覆盖
|
||||
|
||||
返回:
|
||||
import_path: 导入后的文件路径
|
||||
"""
|
||||
self._ensure_model_dir()
|
||||
|
||||
if not os.path.exists(import_file):
|
||||
print(f"错误: 导入文件 {import_file} 不存在")
|
||||
return None
|
||||
|
||||
# 获取文件名
|
||||
filename = os.path.basename(import_file)
|
||||
|
||||
# 目标路径
|
||||
target_path = os.path.join(self.models_dir, filename)
|
||||
|
||||
# 检查是否存在同名文件
|
||||
if os.path.exists(target_path) and not overwrite:
|
||||
print(f"错误: 目标文件 {target_path} 已存在,如需覆盖请设置overwrite=True")
|
||||
return None
|
||||
|
||||
try:
|
||||
# 复制文件
|
||||
shutil.copy2(import_file, target_path)
|
||||
|
||||
# 如果有对应的元数据文件,也一并导入
|
||||
meta_filename = filename.replace('_model_v', '_meta_v')
|
||||
meta_import_file = import_file.replace('_model_v', '_meta_v').replace('.pt', '.json')
|
||||
meta_target_path = os.path.join(self.models_dir, meta_filename.replace('.pt', '.json'))
|
||||
|
||||
if os.path.exists(meta_import_file):
|
||||
shutil.copy2(meta_import_file, meta_target_path)
|
||||
|
||||
print(f"模型已导入到 {os.path.abspath(target_path)}")
|
||||
return target_path
|
||||
except Exception as e:
|
||||
print(f"导入模型时出错: {str(e)}")
|
||||
return None
|
64
server/models/model_registry.py
Normal file
64
server/models/model_registry.py
Normal file
@ -0,0 +1,64 @@
|
||||
"""
|
||||
模型注册表
|
||||
用于解耦模型的调用和实现,支持插件式扩展新模型。
|
||||
"""
|
||||
|
||||
# 训练器注册表
|
||||
TRAINER_REGISTRY = {}
|
||||
|
||||
def register_trainer(name, func):
|
||||
"""
|
||||
注册一个模型训练器。
|
||||
|
||||
参数:
|
||||
name (str): 模型类型名称 (e.g., 'xgboost')
|
||||
func (function): 对应的训练函数
|
||||
"""
|
||||
if name in TRAINER_REGISTRY:
|
||||
print(f"警告: 模型训练器 '{name}' 已被覆盖注册。")
|
||||
TRAINER_REGISTRY[name] = func
|
||||
print(f"✅ 已注册训练器: {name}")
|
||||
|
||||
def get_trainer(name):
|
||||
"""
|
||||
根据模型类型名称获取一个已注册的训练器。
|
||||
"""
|
||||
if name not in TRAINER_REGISTRY:
|
||||
# 在打印可用训练器之前,确保它们已经被加载
|
||||
from trainers import discover_trainers
|
||||
discover_trainers()
|
||||
if name not in TRAINER_REGISTRY:
|
||||
raise ValueError(f"未注册的模型训练器: '{name}'. 可用: {list(TRAINER_REGISTRY.keys())}")
|
||||
return TRAINER_REGISTRY[name]
|
||||
|
||||
# --- 预测器注册表 ---
|
||||
|
||||
# 预测器函数需要一个统一的接口,例如:
|
||||
# def predictor_function(model, checkpoint, **kwargs): -> predictions
|
||||
|
||||
PREDICTOR_REGISTRY = {}
|
||||
|
||||
def register_predictor(name, func):
|
||||
"""
|
||||
注册一个模型预测器。
|
||||
"""
|
||||
if name in PREDICTOR_REGISTRY:
|
||||
print(f"警告: 模型预测器 '{name}' 已被覆盖注册。")
|
||||
PREDICTOR_REGISTRY[name] = func
|
||||
|
||||
def get_predictor(name):
|
||||
"""
|
||||
根据模型类型名称获取一个已注册的预测器。
|
||||
如果找不到特定预测器,可以返回一个默认的。
|
||||
"""
|
||||
return PREDICTOR_REGISTRY.get(name, PREDICTOR_REGISTRY.get('default'))
|
||||
|
||||
# 默认的PyTorch预测逻辑可以被注册为 'default'
|
||||
def register_default_predictors():
|
||||
from predictors.model_predictor import default_pytorch_predictor
|
||||
register_predictor('default', default_pytorch_predictor)
|
||||
# 如果其他PyTorch模型有特殊预测逻辑,也可以在这里注册
|
||||
# register_predictor('kan', kan_predictor_func)
|
||||
|
||||
# 注意:这个函数的调用时机很重要,需要在应用启动时执行一次。
|
||||
# 我们可以暂时在 model_predictor.py 导入注册表后调用它。
|
Binary file not shown.
@ -8,10 +8,9 @@ import pandas as pd
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
import matplotlib.pyplot as plt
|
||||
import xgboost as xgb
|
||||
from sklearn.preprocessing import MinMaxScaler
|
||||
import sklearn.preprocessing._data # 添加这一行以支持MinMaxScaler的反序列化
|
||||
import joblib
|
||||
from typing import Optional
|
||||
|
||||
from models.transformer_model import TimeSeriesTransformer
|
||||
from models.slstm_model import sLSTM as ScalarLSTM
|
||||
@ -19,395 +18,170 @@ from models.mlstm_model import MLSTMTransformer as MatrixLSTM
|
||||
from models.kan_model import KANForecaster
|
||||
from models.tcn_model import TCNForecaster
|
||||
from models.optimized_kan_forecaster import OptimizedKANForecaster
|
||||
from models.cnn_bilstm_attention import CnnBiLstmAttention
|
||||
import xgboost as xgb
|
||||
|
||||
from analysis.trend_analysis import analyze_prediction_result
|
||||
from utils.visualization import plot_prediction_results
|
||||
from utils.multi_store_data_utils import get_store_product_sales_data, aggregate_multi_store_data
|
||||
from core.config import DEVICE
|
||||
from utils.file_save import ModelPathManager
|
||||
from core.config import DEVICE, get_model_file_path, DEFAULT_DATA_PATH
|
||||
from models.model_registry import get_predictor, register_predictor
|
||||
|
||||
def load_model_and_predict(product_id, model_type, model_path=None, store_id=None, future_days=7, start_date=None, analyze_result=False, version=None, training_mode='product', **kwargs):
|
||||
def default_pytorch_predictor(model, checkpoint, product_df, future_days, start_date, history_lookback_days):
|
||||
"""
|
||||
加载已训练的模型并进行预测
|
||||
|
||||
参数:
|
||||
product_id: 产品ID
|
||||
model_type: 模型类型 ('transformer', 'mlstm', 'kan', 'tcn', 'optimized_kan', 'xgboost')
|
||||
model_path: 模型的完整文件路径
|
||||
store_id: 店铺ID,为None时使用全局模型
|
||||
future_days: 预测未来天数
|
||||
start_date: 预测起始日期,如果为None则使用最后一个已知日期
|
||||
analyze_result: 是否分析预测结果
|
||||
version: 模型版本
|
||||
|
||||
返回:
|
||||
预测结果和分析(如果analyze_result为True)
|
||||
默认的PyTorch模型预测逻辑,支持自动回归。
|
||||
"""
|
||||
config = checkpoint['config']
|
||||
scaler_X = checkpoint['scaler_X']
|
||||
scaler_y = checkpoint['scaler_y']
|
||||
features = config.get('features', ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'])
|
||||
sequence_length = config['sequence_length']
|
||||
|
||||
if start_date:
|
||||
start_date_dt = pd.to_datetime(start_date)
|
||||
prediction_input_df = product_df[product_df['date'] < start_date_dt].tail(sequence_length)
|
||||
else:
|
||||
prediction_input_df = product_df.tail(sequence_length)
|
||||
start_date_dt = product_df['date'].iloc[-1] + timedelta(days=1)
|
||||
|
||||
if len(prediction_input_df) < sequence_length:
|
||||
raise ValueError(f"预测所需的历史数据不足。需要 {sequence_length} 天, 但只有 {len(prediction_input_df)} 天。")
|
||||
|
||||
history_for_chart_df = product_df[product_df['date'] < start_date_dt].tail(history_lookback_days)
|
||||
|
||||
all_predictions = []
|
||||
current_sequence_df = prediction_input_df.copy()
|
||||
|
||||
for _ in range(future_days):
|
||||
X_current_scaled = scaler_X.transform(current_sequence_df[features].values)
|
||||
# **核心改进**: 智能判断模型类型并调用相应的预测方法
|
||||
if isinstance(model, xgb.Booster):
|
||||
# XGBoost 模型预测路径
|
||||
X_input_reshaped = X_current_scaled.reshape(1, -1)
|
||||
d_input = xgb.DMatrix(X_input_reshaped)
|
||||
# **关键修复**: 使用 best_iteration 进行预测,以匹配早停策略
|
||||
y_pred_scaled = model.predict(d_input, iteration_range=(0, model.best_iteration))
|
||||
next_step_pred_scaled = y_pred_scaled.reshape(1, -1)
|
||||
else:
|
||||
# 默认 PyTorch 模型预测路径
|
||||
X_input = torch.tensor(X_current_scaled.reshape(1, sequence_length, -1), dtype=torch.float32).to(DEVICE)
|
||||
with torch.no_grad():
|
||||
y_pred_scaled = model(X_input).cpu().numpy()
|
||||
next_step_pred_scaled = y_pred_scaled[0, 0].reshape(1, -1)
|
||||
next_step_pred_unscaled = float(max(0, scaler_y.inverse_transform(next_step_pred_scaled)[0][0]))
|
||||
|
||||
next_date = current_sequence_df['date'].iloc[-1] + timedelta(days=1)
|
||||
all_predictions.append({'date': next_date, 'predicted_sales': next_step_pred_unscaled})
|
||||
|
||||
new_row = {'date': next_date, 'sales': next_step_pred_unscaled, 'weekday': next_date.weekday(), 'month': next_date.month, 'is_holiday': 0, 'is_weekend': 1 if next_date.weekday() >= 5 else 0, 'is_promotion': 0, 'temperature': current_sequence_df['temperature'].iloc[-1]}
|
||||
new_row_df = pd.DataFrame([new_row])
|
||||
current_sequence_df = pd.concat([current_sequence_df.iloc[1:], new_row_df], ignore_index=True)
|
||||
|
||||
predictions_df = pd.DataFrame(all_predictions)
|
||||
return predictions_df, history_for_chart_df, prediction_input_df
|
||||
|
||||
# 注册默认的PyTorch预测器
|
||||
register_predictor('default', default_pytorch_predictor)
|
||||
# 将增强后的默认预测器也注册给xgboost
|
||||
register_predictor('xgboost', default_pytorch_predictor)
|
||||
# 将新模型也注册给默认预测器
|
||||
register_predictor('cnn_bilstm_attention', default_pytorch_predictor)
|
||||
|
||||
|
||||
def load_model_and_predict(model_path: str, product_id: str, model_type: str, store_id: Optional[str] = None, future_days: int = 7, start_date: Optional[str] = None, analyze_result: bool = False, version: Optional[str] = None, training_mode: str = 'product', history_lookback_days: int = 30):
|
||||
"""
|
||||
加载已训练的模型并进行预测 (v4版 - 插件式架构)
|
||||
"""
|
||||
try:
|
||||
print(f"尝试加载模型文件: {model_path}")
|
||||
|
||||
# 如果没有提供 model_path,则使用 ModelPathManager 动态生成
|
||||
if not model_path:
|
||||
if version is None:
|
||||
raise ValueError("使用动态路径加载时必须提供 'version'。")
|
||||
|
||||
path_manager = ModelPathManager()
|
||||
# 传递所有必要的参数以重构路径
|
||||
path_params = {
|
||||
'product_id': product_id,
|
||||
'store_id': store_id,
|
||||
**kwargs
|
||||
}
|
||||
model_path = path_manager.get_model_path_for_prediction(
|
||||
training_mode=training_mode,
|
||||
model_type=model_type,
|
||||
version=version,
|
||||
**path_params
|
||||
)
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(f"模型文件 {model_path} 不存在")
|
||||
|
||||
# --- 数据加载部分保持不变 ---
|
||||
from utils.multi_store_data_utils import aggregate_multi_store_data
|
||||
if training_mode == 'store' and store_id:
|
||||
from utils.multi_store_data_utils import load_multi_store_data
|
||||
store_df_for_name = load_multi_store_data(store_id=store_id)
|
||||
product_name = store_df_for_name['store_name'].iloc[0] if not store_df_for_name.empty else f"店铺 {store_id}"
|
||||
product_df = aggregate_multi_store_data(store_id=store_id, aggregation_method='sum', file_path=DEFAULT_DATA_PATH)
|
||||
elif training_mode == 'global':
|
||||
product_df = aggregate_multi_store_data(aggregation_method='sum', file_path=DEFAULT_DATA_PATH)
|
||||
product_name = "全局销售数据"
|
||||
else:
|
||||
product_df = aggregate_multi_store_data(product_id=product_id, aggregation_method='sum', file_path=DEFAULT_DATA_PATH)
|
||||
product_name = product_df['product_name'].iloc[0] if not product_df.empty else product_id
|
||||
|
||||
if not model_path or not os.path.exists(model_path):
|
||||
print(f"模型文件 {model_path} 不存在或无法生成。")
|
||||
return None
|
||||
|
||||
# 加载销售数据(支持多店铺)
|
||||
try:
|
||||
if store_id:
|
||||
# 加载特定店铺的数据
|
||||
product_df = get_store_product_sales_data(
|
||||
store_id,
|
||||
product_id,
|
||||
None # 使用默认数据路径
|
||||
)
|
||||
store_name = product_df['store_name'].iloc[0] if 'store_name' in product_df.columns else f"店铺{store_id}"
|
||||
prediction_scope = f"店铺 '{store_name}' ({store_id})"
|
||||
else:
|
||||
# 聚合所有店铺的数据进行预测
|
||||
product_df = aggregate_multi_store_data(
|
||||
product_id,
|
||||
aggregation_method='sum',
|
||||
file_path=None # 使用默认数据路径
|
||||
)
|
||||
prediction_scope = "全部店铺(聚合数据)"
|
||||
except Exception as e:
|
||||
print(f"多店铺数据加载失败,尝试使用原始数据格式: {e}")
|
||||
# 后向兼容:尝试加载原始数据格式
|
||||
try:
|
||||
from core.config import DEFAULT_DATA_PATH
|
||||
from utils.multi_store_data_utils import load_multi_store_data
|
||||
df = load_multi_store_data(DEFAULT_DATA_PATH)
|
||||
product_df = df[df['product_id'] == product_id].sort_values('date')
|
||||
if store_id:
|
||||
print(f"警告:原始数据不支持店铺过滤,将使用所有数据预测")
|
||||
prediction_scope = "默认数据"
|
||||
except Exception as e2:
|
||||
print(f"加载产品数据失败: {str(e2)}")
|
||||
return None
|
||||
|
||||
if product_df.empty:
|
||||
print(f"产品 {product_id} 没有销售数据")
|
||||
return None
|
||||
|
||||
product_name = product_df['product_name'].iloc[0]
|
||||
print(f"使用 {model_type} 模型预测产品 '{product_name}' (ID: {product_id}) 的未来 {future_days} 天销量")
|
||||
print(f"预测范围: {prediction_scope}")
|
||||
|
||||
# 添加安全的全局变量以支持MinMaxScaler的反序列化
|
||||
raise ValueError(f"产品 {product_id} 或店铺 {store_id} 没有销售数据")
|
||||
|
||||
# --- 模型加载与实例化 (重构) ---
|
||||
try:
|
||||
torch.serialization.add_safe_globals([sklearn.preprocessing._data.MinMaxScaler])
|
||||
except Exception as e:
|
||||
print(f"添加安全全局变量失败,但这可能不影响模型加载: {str(e)}")
|
||||
except Exception: pass
|
||||
|
||||
# 加载模型和配置
|
||||
try:
|
||||
# 首先尝试使用weights_only=False加载
|
||||
if model_type == 'xgboost':
|
||||
if not os.path.exists(model_path):
|
||||
print(f"XGBoost模型文件不存在: {model_path}")
|
||||
return None
|
||||
# 加载元数据
|
||||
metadata = joblib.load(model_path)
|
||||
model_file_path = metadata['model_file']
|
||||
|
||||
if not os.path.exists(model_file_path):
|
||||
print(f"引用的XGBoost模型文件不存在: {model_file_path}")
|
||||
return None
|
||||
|
||||
# 加载原生Booster模型
|
||||
model = xgb.Booster()
|
||||
model.load_model(model_file_path)
|
||||
|
||||
config = metadata['config']
|
||||
metrics = metadata['metrics']
|
||||
scaler_X = metadata['scaler_X']
|
||||
scaler_y = metadata['scaler_y']
|
||||
print("XGBoost原生模型及元数据加载成功")
|
||||
checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False)
|
||||
config = checkpoint.get('config', {})
|
||||
loaded_model_type = config.get('model_type', model_type) # 优先使用模型内保存的类型
|
||||
|
||||
# 根据模型类型决定如何获取模型实例
|
||||
if loaded_model_type == 'xgboost':
|
||||
# 对于XGBoost, 模型对象直接保存在'model_state_dict'键中
|
||||
model = checkpoint['model_state_dict']
|
||||
else:
|
||||
# 对于PyTorch模型, 需要重新构建实例并加载state_dict
|
||||
if loaded_model_type == 'transformer':
|
||||
model = TimeSeriesTransformer(num_features=config['input_dim'], d_model=config['hidden_size'], nhead=config['num_heads'], num_encoder_layers=config['num_layers'], dim_feedforward=config['hidden_size'] * 2, dropout=config['dropout'], output_sequence_length=config['output_dim'], seq_length=config['sequence_length'], batch_size=32).to(DEVICE)
|
||||
elif loaded_model_type == 'mlstm':
|
||||
model = MatrixLSTM(num_features=config['input_dim'], hidden_size=config['hidden_size'], mlstm_layers=config['mlstm_layers'], embed_dim=config.get('embed_dim', 32), dense_dim=config.get('dense_dim', 32), num_heads=config.get('num_heads', 4), dropout_rate=config['dropout_rate'], num_blocks=config.get('num_blocks', 3), output_sequence_length=config['output_dim']).to(DEVICE)
|
||||
elif loaded_model_type == 'kan':
|
||||
model = KANForecaster(input_features=config['input_dim'], hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']], output_sequence_length=config['output_dim']).to(DEVICE)
|
||||
elif loaded_model_type == 'optimized_kan':
|
||||
model = OptimizedKANForecaster(input_features=config['input_dim'], hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']], output_sequence_length=config['output_dim']).to(DEVICE)
|
||||
elif loaded_model_type == 'tcn':
|
||||
model = TCNForecaster(num_features=config['input_dim'], output_sequence_length=config['output_dim'], num_channels=[config['hidden_size']] * config['num_layers'], kernel_size=config['kernel_size'], dropout=config['dropout']).to(DEVICE)
|
||||
elif loaded_model_type == 'cnn_bilstm_attention':
|
||||
model = CnnBiLstmAttention(
|
||||
input_dim=config['input_dim'],
|
||||
output_dim=config['output_dim'],
|
||||
sequence_length=config['sequence_length']
|
||||
).to(DEVICE)
|
||||
else:
|
||||
try:
|
||||
print("尝试使用 weights_only=False 加载模型")
|
||||
checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False)
|
||||
except Exception as e:
|
||||
print(f"使用weights_only=False加载失败: {str(e)}")
|
||||
print("尝试使用默认参数加载模型")
|
||||
checkpoint = torch.load(model_path, map_location=DEVICE)
|
||||
|
||||
print(f"模型加载成功,检查checkpoint类型: {type(checkpoint)}")
|
||||
if isinstance(checkpoint, dict):
|
||||
print(f"checkpoint包含的键: {list(checkpoint.keys())}")
|
||||
else:
|
||||
print(f"checkpoint不是字典类型,而是: {type(checkpoint)}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"加载模型失败: {str(e)}")
|
||||
return None
|
||||
|
||||
# XGBoost有不同的处理逻辑
|
||||
if model_type == 'xgboost':
|
||||
look_back = config['look_back']
|
||||
features = config['features']
|
||||
raise ValueError(f"不支持的模型类型: {loaded_model_type}")
|
||||
|
||||
# 准备输入数据
|
||||
recent_data = product_df.iloc[-look_back:].copy()
|
||||
|
||||
predictions = []
|
||||
current_input_df = recent_data[features].copy()
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
model.eval()
|
||||
|
||||
for _ in range(future_days):
|
||||
# 归一化输入数据并展平
|
||||
input_scaled = scaler_X.transform(current_input_df.values)
|
||||
input_vector = input_scaled.flatten().reshape(1, -1)
|
||||
|
||||
# 预测缩放后的值
|
||||
dpredict = xgb.DMatrix(input_vector)
|
||||
prediction_scaled = model.predict(dpredict)
|
||||
|
||||
# 反归一化得到真实预测值
|
||||
prediction = scaler_y.inverse_transform(prediction_scaled.reshape(-1, 1)).flatten()[0]
|
||||
predictions.append(prediction)
|
||||
|
||||
# 更新输入窗口以进行下一次预测
|
||||
# 创建新的一行,包含真实的预测值
|
||||
new_row_values = current_input_df.iloc[-1].copy()
|
||||
new_row_values['sales'] = prediction
|
||||
# 可以在这里添加更复杂的未来特征生成逻辑(例如,根据新日期更新weekday, month等)
|
||||
|
||||
new_row_df = pd.DataFrame([new_row_values], columns=features)
|
||||
|
||||
# 滚动窗口
|
||||
current_input_df = pd.concat([current_input_df.iloc[1:], new_row_df], ignore_index=True)
|
||||
# --- 动态调用预测器 ---
|
||||
predictor_function = get_predictor(loaded_model_type)
|
||||
if not predictor_function:
|
||||
raise ValueError(f"找不到模型类型 '{loaded_model_type}' 的预测器实现")
|
||||
|
||||
# 生成预测日期
|
||||
last_date = recent_data['date'].iloc[-1]
|
||||
pred_dates = [last_date + timedelta(days=i+1) for i in range(future_days)]
|
||||
|
||||
y_pred = np.array(predictions)
|
||||
predictions_df, history_for_chart_df, prediction_input_df = predictor_function(
|
||||
model=model,
|
||||
checkpoint=checkpoint,
|
||||
product_df=product_df,
|
||||
future_days=future_days,
|
||||
start_date=start_date,
|
||||
history_lookback_days=history_lookback_days
|
||||
)
|
||||
|
||||
else: # 原有的PyTorch模型逻辑
|
||||
# 检查并获取配置
|
||||
if 'config' not in checkpoint:
|
||||
print("模型文件中没有配置信息")
|
||||
return None
|
||||
|
||||
config = checkpoint['config']
|
||||
print(f"模型配置: {config}")
|
||||
|
||||
# 检查并获取缩放器
|
||||
if 'scaler_X' not in checkpoint or 'scaler_y' not in checkpoint:
|
||||
print("模型文件中没有缩放器信息")
|
||||
return None
|
||||
|
||||
scaler_X = checkpoint['scaler_X']
|
||||
scaler_y = checkpoint['scaler_y']
|
||||
|
||||
# 创建模型实例
|
||||
try:
|
||||
if model_type == 'transformer':
|
||||
model = TimeSeriesTransformer(
|
||||
num_features=config['input_dim'],
|
||||
d_model=config['hidden_size'],
|
||||
nhead=config['num_heads'],
|
||||
num_encoder_layers=config['num_layers'],
|
||||
dim_feedforward=config['hidden_size'] * 2,
|
||||
dropout=config['dropout'],
|
||||
output_sequence_length=config['output_dim'],
|
||||
seq_length=config['sequence_length'],
|
||||
batch_size=32
|
||||
).to(DEVICE)
|
||||
elif model_type == 'slstm':
|
||||
model = ScalarLSTM(
|
||||
input_dim=config['input_dim'],
|
||||
hidden_dim=config['hidden_size'],
|
||||
output_dim=config['output_dim'],
|
||||
num_layers=config['num_layers'],
|
||||
dropout=config['dropout']
|
||||
).to(DEVICE)
|
||||
elif model_type == 'mlstm':
|
||||
# 获取配置参数,如果不存在则使用默认值
|
||||
embed_dim = config.get('embed_dim', 32)
|
||||
dense_dim = config.get('dense_dim', 32)
|
||||
num_heads = config.get('num_heads', 4)
|
||||
num_blocks = config.get('num_blocks', 3)
|
||||
|
||||
model = MatrixLSTM(
|
||||
num_features=config['input_dim'],
|
||||
hidden_size=config['hidden_size'],
|
||||
mlstm_layers=config['num_layers'],
|
||||
embed_dim=embed_dim,
|
||||
dense_dim=dense_dim,
|
||||
num_heads=num_heads,
|
||||
dropout_rate=config['dropout'],
|
||||
num_blocks=num_blocks,
|
||||
output_sequence_length=config['output_dim']
|
||||
).to(DEVICE)
|
||||
elif model_type == 'kan':
|
||||
model = KANForecaster(
|
||||
input_features=config['input_dim'],
|
||||
hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']],
|
||||
output_sequence_length=config['output_dim']
|
||||
).to(DEVICE)
|
||||
elif model_type == 'optimized_kan':
|
||||
model = OptimizedKANForecaster(
|
||||
input_features=config['input_dim'],
|
||||
hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']],
|
||||
output_sequence_length=config['output_dim']
|
||||
).to(DEVICE)
|
||||
elif model_type == 'tcn':
|
||||
model = TCNForecaster(
|
||||
num_features=config['input_dim'],
|
||||
output_sequence_length=config['output_dim'],
|
||||
num_channels=[config['hidden_size']] * config['num_layers'],
|
||||
kernel_size=3,
|
||||
dropout=config['dropout']
|
||||
).to(DEVICE)
|
||||
else:
|
||||
print(f"不支持的模型类型: {model_type}")
|
||||
return None
|
||||
|
||||
print(f"模型实例创建成功: {type(model)}")
|
||||
except Exception as e:
|
||||
print(f"创建模型实例失败: {str(e)}")
|
||||
return None
|
||||
|
||||
# 加载模型参数
|
||||
try:
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
model.eval()
|
||||
print("模型参数加载成功")
|
||||
except Exception as e:
|
||||
print(f"加载模型参数失败: {str(e)}")
|
||||
return None
|
||||
|
||||
# 准备输入数据
|
||||
try:
|
||||
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||
sequence_length = config['sequence_length']
|
||||
|
||||
# 获取最近的sequence_length天数据作为输入
|
||||
recent_data = product_df.iloc[-sequence_length:].copy()
|
||||
|
||||
# 如果指定了起始日期,则使用该日期之后的数据
|
||||
if start_date:
|
||||
if isinstance(start_date, str):
|
||||
start_date = datetime.strptime(start_date, '%Y-%m-%d')
|
||||
recent_data = product_df[product_df['date'] >= start_date].iloc[:sequence_length].copy()
|
||||
if len(recent_data) < sequence_length:
|
||||
print(f"警告: 从指定日期 {start_date} 开始的数据少于所需的 {sequence_length} 天")
|
||||
# 补充数据
|
||||
missing_days = sequence_length - len(recent_data)
|
||||
additional_data = product_df[product_df['date'] < start_date].iloc[-missing_days:].copy()
|
||||
recent_data = pd.concat([additional_data, recent_data]).reset_index(drop=True)
|
||||
|
||||
print(f"输入数据准备完成,形状: {recent_data.shape}")
|
||||
except Exception as e:
|
||||
print(f"准备输入数据失败: {str(e)}")
|
||||
return None
|
||||
|
||||
# 归一化输入数据
|
||||
try:
|
||||
X = recent_data[features].values
|
||||
X_scaled = scaler_X.transform(X)
|
||||
|
||||
# 转换为模型输入格式
|
||||
X_input = torch.tensor(X_scaled.reshape(1, sequence_length, -1), dtype=torch.float32).to(DEVICE)
|
||||
print(f"输入张量准备完成,形状: {X_input.shape}")
|
||||
except Exception as e:
|
||||
print(f"归一化输入数据失败: {str(e)}")
|
||||
return None
|
||||
|
||||
# 预测
|
||||
try:
|
||||
with torch.no_grad():
|
||||
y_pred_scaled = model(X_input).cpu().numpy()
|
||||
print(f"原始预测输出形状: {y_pred_scaled.shape}")
|
||||
|
||||
# 处理TCN、Transformer、mLSTM和KAN模型的输出,确保形状正确
|
||||
if model_type in ['tcn', 'transformer', 'mlstm', 'kan', 'optimized_kan'] and len(y_pred_scaled.shape) == 3:
|
||||
y_pred_scaled = y_pred_scaled.squeeze(-1)
|
||||
print(f"处理后的预测输出形状: {y_pred_scaled.shape}")
|
||||
|
||||
# 反归一化预测结果
|
||||
y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).flatten()
|
||||
print(f"反归一化后的预测结果: {y_pred}")
|
||||
|
||||
# 生成预测日期
|
||||
last_date = recent_data['date'].iloc[-1]
|
||||
pred_dates = [(last_date + timedelta(days=i+1)) for i in range(len(y_pred))]
|
||||
print(f"预测日期: {pred_dates}")
|
||||
except Exception as e:
|
||||
print(f"执行预测失败: {str(e)}")
|
||||
return None
|
||||
|
||||
# 创建预测结果DataFrame
|
||||
try:
|
||||
predictions_df = pd.DataFrame({
|
||||
'date': pred_dates,
|
||||
'sales': y_pred # 使用sales字段名而不是predicted_sales,以便与历史数据兼容
|
||||
})
|
||||
print(f"预测结果DataFrame创建成功,形状: {predictions_df.shape}")
|
||||
except Exception as e:
|
||||
print(f"创建预测结果DataFrame失败: {str(e)}")
|
||||
return None
|
||||
|
||||
# 绘制预测结果
|
||||
try:
|
||||
plt.figure(figsize=(12, 6))
|
||||
plt.plot(product_df['date'], product_df['sales'], 'b-', label='历史销量')
|
||||
plt.plot(predictions_df['date'], predictions_df['sales'], 'r--', label='预测销量')
|
||||
plt.title(f'{product_name} - {model_type}模型销量预测')
|
||||
plt.xlabel('日期')
|
||||
plt.ylabel('销量')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
plt.xticks(rotation=45)
|
||||
plt.tight_layout()
|
||||
|
||||
# 保存图像
|
||||
plt.savefig(f'{product_id}_{model_type}_prediction.png')
|
||||
plt.close()
|
||||
|
||||
print(f"预测结果已保存到 {product_id}_{model_type}_prediction.png")
|
||||
except Exception as e:
|
||||
print(f"绘制预测结果图表失败: {str(e)}")
|
||||
# 这个错误不影响主要功能,继续执行
|
||||
|
||||
# 分析预测结果
|
||||
# --- 分析与返回部分保持不变 ---
|
||||
analysis = None
|
||||
if analyze_result:
|
||||
try:
|
||||
analysis = analyze_prediction_result(product_id, model_type, y_pred, X)
|
||||
print("\n预测结果分析:")
|
||||
if analysis and 'explanation' in analysis:
|
||||
print(analysis['explanation'])
|
||||
else:
|
||||
print("分析结果不包含explanation字段")
|
||||
analysis = analyze_prediction_result(product_id, loaded_model_type, predictions_df['predicted_sales'].values, prediction_input_df[config.get('features')].values)
|
||||
except Exception as e:
|
||||
print(f"分析预测结果失败: {str(e)}")
|
||||
# 分析失败不影响主要功能,继续执行
|
||||
|
||||
|
||||
history_data_json = history_for_chart_df.to_dict('records') if not history_for_chart_df.empty else []
|
||||
prediction_data_json = predictions_df.to_dict('records') if not predictions_df.empty else []
|
||||
|
||||
return {
|
||||
'product_id': product_id,
|
||||
'product_name': product_name,
|
||||
'model_type': model_type,
|
||||
'predictions': predictions_df,
|
||||
'model_type': loaded_model_type,
|
||||
'predictions': prediction_data_json,
|
||||
'prediction_data': prediction_data_json,
|
||||
'history_data': history_data_json,
|
||||
'analysis': analysis
|
||||
}
|
||||
except Exception as e:
|
||||
|
@ -2,20 +2,44 @@
|
||||
药店销售预测系统 - 模型训练模块
|
||||
"""
|
||||
|
||||
from .mlstm_trainer import train_product_model_with_mlstm
|
||||
from .kan_trainer import train_product_model_with_kan
|
||||
from .tcn_trainer import train_product_model_with_tcn
|
||||
from .transformer_trainer import train_product_model_with_transformer
|
||||
from .xgboost_trainer import train_product_model_with_xgboost
|
||||
import os
|
||||
import glob
|
||||
import importlib
|
||||
|
||||
# 默认训练函数
|
||||
from .mlstm_trainer import train_product_model_with_mlstm as train_product_model
|
||||
_TRAINERS_LOADED = False
|
||||
|
||||
def discover_trainers():
|
||||
"""
|
||||
自动发现并加载所有训练器插件。
|
||||
使用一个标志位确保这个过程只执行一次。
|
||||
"""
|
||||
global _TRAINERS_LOADED
|
||||
if _TRAINERS_LOADED:
|
||||
return
|
||||
|
||||
print("🚀 开始发现并加载训练器插件...")
|
||||
|
||||
package_dir = os.path.dirname(__file__)
|
||||
module_name = __name__
|
||||
|
||||
trainer_files = glob.glob(os.path.join(package_dir, "*_trainer.py"))
|
||||
|
||||
for f in trainer_files:
|
||||
base_name = os.path.basename(f)
|
||||
if base_name.startswith('__'):
|
||||
continue
|
||||
|
||||
module_stem = base_name.replace('.py', '')
|
||||
|
||||
try:
|
||||
# 动态导入模块以触发自注册
|
||||
importlib.import_module(f".{module_stem}", package=module_name)
|
||||
except ImportError as e:
|
||||
print(f"⚠️ 加载训练器 {module_stem} 失败: {e}")
|
||||
|
||||
_TRAINERS_LOADED = True
|
||||
print("✅ 所有训练器插件加载完成。")
|
||||
|
||||
# 在包被首次导入时,自动执行发现过程
|
||||
discover_trainers()
|
||||
|
||||
__all__ = [
|
||||
'train_product_model',
|
||||
'train_product_model_with_mlstm',
|
||||
'train_product_model_with_kan',
|
||||
'train_product_model_with_tcn',
|
||||
'train_product_model_with_transformer',
|
||||
'train_product_model_with_xgboost'
|
||||
]
|
||||
|
118
server/trainers/cnn_bilstm_attention_trainer.py
Normal file
118
server/trainers/cnn_bilstm_attention_trainer.py
Normal file
@ -0,0 +1,118 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
CNN-BiLSTM-Attention 模型训练器
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
import numpy as np
|
||||
|
||||
from models.model_registry import register_trainer
|
||||
from utils.model_manager import model_manager
|
||||
from analysis.metrics import evaluate_model
|
||||
from utils.data_utils import create_dataset
|
||||
from sklearn.preprocessing import MinMaxScaler
|
||||
|
||||
# 导入新创建的模型
|
||||
from models.cnn_bilstm_attention import CnnBiLstmAttention
|
||||
|
||||
def train_with_cnn_bilstm_attention(product_id, model_identifier, product_df, store_id, training_mode, aggregation_method, epochs, sequence_length, forecast_horizon, model_dir, **kwargs):
|
||||
"""
|
||||
使用 CNN-BiLSTM-Attention 模型进行训练。
|
||||
函数签名遵循系统标准。
|
||||
"""
|
||||
print(f"🚀 CNN-BiLSTM-Attention 训练器启动: model_identifier='{model_identifier}'")
|
||||
|
||||
# --- 1. 数据准备 ---
|
||||
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||
|
||||
X = product_df[features].values
|
||||
y = product_df[['sales']].values
|
||||
|
||||
scaler_X = MinMaxScaler(feature_range=(0, 1))
|
||||
scaler_y = MinMaxScaler(feature_range=(0, 1))
|
||||
|
||||
X_scaled = scaler_X.fit_transform(X)
|
||||
y_scaled = scaler_y.fit_transform(y)
|
||||
|
||||
train_size = int(len(X_scaled) * 0.8)
|
||||
X_train_raw, X_test_raw = X_scaled[:train_size], X_scaled[train_size:]
|
||||
y_train_raw, y_test_raw = y_scaled[:train_size], y_scaled[train_size:]
|
||||
|
||||
trainX, trainY = create_dataset(X_train_raw, y_train_raw, sequence_length, forecast_horizon)
|
||||
testX, testY = create_dataset(X_test_raw, y_test_raw, sequence_length, forecast_horizon)
|
||||
|
||||
# 转换为 PyTorch Tensors
|
||||
trainX = torch.from_numpy(trainX).float()
|
||||
trainY = torch.from_numpy(trainY).float()
|
||||
testX = torch.from_numpy(testX).float()
|
||||
testY = torch.from_numpy(testY).float()
|
||||
|
||||
# --- 2. 实例化模型和优化器 ---
|
||||
input_dim = trainX.shape[2]
|
||||
|
||||
model = CnnBiLstmAttention(
|
||||
input_dim=input_dim,
|
||||
output_dim=forecast_horizon,
|
||||
sequence_length=sequence_length
|
||||
)
|
||||
|
||||
optimizer = optim.Adam(model.parameters(), lr=kwargs.get('learning_rate', 0.001))
|
||||
criterion = torch.nn.MSELoss()
|
||||
|
||||
# --- 3. 训练循环 ---
|
||||
print("开始训练 CNN-BiLSTM-Attention 模型...")
|
||||
for epoch in range(epochs):
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
|
||||
outputs = model(trainX)
|
||||
loss = criterion(outputs, trainY.squeeze(-1)) # 确保目标维度匹配
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if (epoch + 1) % 10 == 0:
|
||||
print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')
|
||||
|
||||
# --- 4. 模型评估 ---
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
test_pred_scaled = model(testX)
|
||||
|
||||
test_pred_unscaled = scaler_y.inverse_transform(test_pred_scaled.numpy())
|
||||
test_true_unscaled = scaler_y.inverse_transform(testY.squeeze(-1).numpy())
|
||||
|
||||
metrics = evaluate_model(test_true_unscaled.flatten(), test_pred_unscaled.flatten())
|
||||
print(f"模型评估完成: RMSE={metrics['rmse']:.4f}")
|
||||
|
||||
# --- 5. 模型保存 ---
|
||||
model_data = {
|
||||
'model_state_dict': model.state_dict(),
|
||||
'scaler_X': scaler_X,
|
||||
'scaler_y': scaler_y,
|
||||
'config': {
|
||||
'model_type': 'cnn_bilstm_attention',
|
||||
'input_dim': input_dim,
|
||||
'output_dim': forecast_horizon,
|
||||
'sequence_length': sequence_length,
|
||||
'features': features
|
||||
},
|
||||
'metrics': metrics
|
||||
}
|
||||
|
||||
final_model_path, final_version = model_manager.save_model(
|
||||
model_data=model_data,
|
||||
product_id=product_id,
|
||||
model_type='cnn_bilstm_attention',
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
product_name=product_df['product_name'].iloc[0]
|
||||
)
|
||||
|
||||
print(f"✅ CNN-BiLSTM-Attention 模型已保存,版本: {final_version}")
|
||||
return model, metrics, final_version, final_model_path
|
||||
|
||||
# --- 关键步骤: 将训练器注册到系统中 ---
|
||||
register_trainer('cnn_bilstm_attention', train_with_cnn_bilstm_attention)
|
@ -21,7 +21,7 @@ from utils.visualization import plot_loss_curve
|
||||
from analysis.metrics import evaluate_model
|
||||
from core.config import DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON
|
||||
|
||||
def train_product_model_with_kan(product_id, product_df=None, store_id=None, training_mode='product', aggregation_method='sum', epochs=50, use_optimized=False, path_info=None, **kwargs):
|
||||
def train_product_model_with_kan(product_id, model_identifier, product_df=None, store_id=None, training_mode='product', aggregation_method='sum', epochs=50, sequence_length=LOOK_BACK, forecast_horizon=FORECAST_HORIZON, use_optimized=False, model_dir=DEFAULT_MODEL_DIR):
|
||||
"""
|
||||
使用KAN模型训练产品销售预测模型
|
||||
|
||||
@ -29,14 +29,12 @@ def train_product_model_with_kan(product_id, product_df=None, store_id=None, tra
|
||||
product_id: 产品ID
|
||||
epochs: 训练轮次
|
||||
use_optimized: 是否使用优化版KAN
|
||||
path_info: 包含所有路径信息的字典
|
||||
model_dir: 模型保存目录,默认使用配置中的DEFAULT_MODEL_DIR
|
||||
|
||||
返回:
|
||||
model: 训练好的模型
|
||||
metrics: 模型评估指标
|
||||
"""
|
||||
if not path_info:
|
||||
raise ValueError("train_product_model_with_kan 需要 'path_info' 参数。")
|
||||
# 如果没有传入product_df,则根据训练模式加载数据
|
||||
if product_df is None:
|
||||
from utils.multi_store_data_utils import load_multi_store_data, get_store_product_sales_data, aggregate_multi_store_data
|
||||
@ -81,11 +79,11 @@ def train_product_model_with_kan(product_id, product_df=None, store_id=None, tra
|
||||
raise ValueError(f"产品 {product_id} 没有可用的销售数据")
|
||||
|
||||
# 数据量检查
|
||||
min_required_samples = LOOK_BACK + FORECAST_HORIZON
|
||||
min_required_samples = sequence_length + forecast_horizon
|
||||
if len(product_df) < min_required_samples:
|
||||
error_msg = (
|
||||
f"❌ 训练数据不足错误\n"
|
||||
f"当前配置需要: {min_required_samples} 天数据 (LOOK_BACK={LOOK_BACK} + FORECAST_HORIZON={FORECAST_HORIZON})\n"
|
||||
f"当前配置需要: {min_required_samples} 天数据 (LOOK_BACK={sequence_length} + FORECAST_HORIZON={forecast_horizon})\n"
|
||||
f"实际数据量: {len(product_df)} 天\n"
|
||||
f"产品ID: {product_id}, 训练模式: {training_mode}\n"
|
||||
f"建议解决方案:\n"
|
||||
@ -103,7 +101,7 @@ def train_product_model_with_kan(product_id, product_df=None, store_id=None, tra
|
||||
print(f"使用{model_type}模型训练产品 '{product_name}' (ID: {product_id}) 的销售预测模型")
|
||||
print(f"训练范围: {training_scope}")
|
||||
print(f"使用设备: {DEVICE}")
|
||||
print(f"模型将保存到: {path_info['base_dir']}")
|
||||
print(f"模型将保存到目录: {model_dir}")
|
||||
|
||||
# 创建特征和目标变量
|
||||
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||
@ -125,8 +123,8 @@ def train_product_model_with_kan(product_id, product_df=None, store_id=None, tra
|
||||
y_train, y_test = y_scaled[:train_size], y_scaled[train_size:]
|
||||
|
||||
# 创建时间序列数据
|
||||
trainX, trainY = create_dataset(X_train, y_train, LOOK_BACK, FORECAST_HORIZON)
|
||||
testX, testY = create_dataset(X_test, y_test, LOOK_BACK, FORECAST_HORIZON)
|
||||
trainX, trainY = create_dataset(X_train, y_train, sequence_length, forecast_horizon)
|
||||
testX, testY = create_dataset(X_test, y_test, sequence_length, forecast_horizon)
|
||||
|
||||
# 转换为PyTorch的Tensor
|
||||
trainX_tensor = torch.Tensor(trainX)
|
||||
@ -144,7 +142,7 @@ def train_product_model_with_kan(product_id, product_df=None, store_id=None, tra
|
||||
|
||||
# 初始化KAN模型
|
||||
input_dim = X_train.shape[1]
|
||||
output_dim = FORECAST_HORIZON
|
||||
output_dim = forecast_horizon
|
||||
hidden_size = 64
|
||||
|
||||
if use_optimized:
|
||||
@ -170,6 +168,7 @@ def train_product_model_with_kan(product_id, product_df=None, store_id=None, tra
|
||||
train_losses = []
|
||||
test_losses = []
|
||||
start_time = time.time()
|
||||
best_loss = float('inf')
|
||||
|
||||
for epoch in range(epochs):
|
||||
model.train()
|
||||
@ -227,6 +226,43 @@ def train_product_model_with_kan(product_id, product_df=None, store_id=None, tra
|
||||
|
||||
test_loss = test_loss / len(test_loader)
|
||||
test_losses.append(test_loss)
|
||||
|
||||
# 检查是否为最佳模型
|
||||
model_type_name = 'optimized_kan' if use_optimized else 'kan'
|
||||
if test_loss < best_loss:
|
||||
best_loss = test_loss
|
||||
print(f"🎉 新的最佳模型发现在 epoch {epoch+1},测试损失: {test_loss:.4f}")
|
||||
|
||||
# 为保存最佳模型准备数据
|
||||
best_model_data = {
|
||||
'model_state_dict': model.state_dict(),
|
||||
'scaler_X': scaler_X,
|
||||
'scaler_y': scaler_y,
|
||||
'config': {
|
||||
'input_dim': input_dim,
|
||||
'output_dim': output_dim,
|
||||
'hidden_size': hidden_size,
|
||||
'hidden_sizes': [hidden_size, hidden_size * 2, hidden_size],
|
||||
'sequence_length': sequence_length,
|
||||
'forecast_horizon': forecast_horizon,
|
||||
'model_type': model_type_name,
|
||||
'use_optimized': use_optimized
|
||||
},
|
||||
'epoch': epoch + 1
|
||||
}
|
||||
|
||||
# 使用模型管理器保存 'best' 版本
|
||||
from utils.model_manager import model_manager
|
||||
model_manager.save_model(
|
||||
model_data=best_model_data,
|
||||
product_id=model_identifier, # 修正:使用唯一的标识符
|
||||
model_type=model_type_name,
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
product_name=product_name,
|
||||
version='best' # 显式覆盖版本为'best'
|
||||
)
|
||||
|
||||
if (epoch + 1) % 10 == 0:
|
||||
print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")
|
||||
@ -236,13 +272,12 @@ def train_product_model_with_kan(product_id, product_df=None, store_id=None, tra
|
||||
|
||||
# 绘制损失曲线并保存到模型目录
|
||||
model_name = 'optimized_kan' if use_optimized else 'kan'
|
||||
loss_curve_path = path_info['loss_curve_path']
|
||||
plot_loss_curve(
|
||||
train_losses,
|
||||
test_losses,
|
||||
product_name,
|
||||
model_type,
|
||||
save_path=loss_curve_path
|
||||
loss_curve_path = plot_loss_curve(
|
||||
train_losses,
|
||||
test_losses,
|
||||
product_name,
|
||||
model_type,
|
||||
model_dir=model_dir
|
||||
)
|
||||
print(f"损失曲线已保存到: {loss_curve_path}")
|
||||
|
||||
@ -272,6 +307,9 @@ def train_product_model_with_kan(product_id, product_df=None, store_id=None, tra
|
||||
print(f"MAPE: {metrics['mape']:.2f}%")
|
||||
print(f"训练时间: {training_time:.2f}秒")
|
||||
|
||||
# 使用统一模型管理器保存模型
|
||||
from utils.model_manager import model_manager
|
||||
|
||||
model_type_name = 'optimized_kan' if use_optimized else 'kan'
|
||||
|
||||
model_data = {
|
||||
@ -282,9 +320,9 @@ def train_product_model_with_kan(product_id, product_df=None, store_id=None, tra
|
||||
'input_dim': input_dim,
|
||||
'output_dim': output_dim,
|
||||
'hidden_size': hidden_size,
|
||||
'hidden_sizes': [hidden_size, hidden_size*2, hidden_size],
|
||||
'sequence_length': LOOK_BACK,
|
||||
'forecast_horizon': FORECAST_HORIZON,
|
||||
'hidden_sizes': [hidden_size, hidden_size * 2, hidden_size],
|
||||
'sequence_length': sequence_length,
|
||||
'forecast_horizon': forecast_horizon,
|
||||
'model_type': model_type_name,
|
||||
'use_optimized': use_optimized
|
||||
},
|
||||
@ -297,14 +335,23 @@ def train_product_model_with_kan(product_id, product_df=None, store_id=None, tra
|
||||
'loss_curve_path': loss_curve_path
|
||||
}
|
||||
|
||||
# 检查模型性能是否达标
|
||||
# 移除R2检查,始终保存模型
|
||||
if metrics:
|
||||
# 使用 path_info 中的路径保存模型
|
||||
model_path = path_info['model_path']
|
||||
torch.save(model_data, model_path)
|
||||
print(f"模型已保存到: {model_path}")
|
||||
else:
|
||||
print(f"训练过程中未生成评估指标,不保存最终模型。")
|
||||
# 保存最终模型,让 model_manager 自动处理版本号
|
||||
final_model_path, final_version = model_manager.save_model(
|
||||
model_data=model_data,
|
||||
product_id=model_identifier, # 修正:使用唯一的标识符
|
||||
model_type=model_type_name,
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
product_name=product_name
|
||||
# 注意:此处不传递version参数,由管理器自动生成
|
||||
)
|
||||
|
||||
return model, metrics
|
||||
print(f"最终模型已保存,版本: {final_version}, 路径: {final_model_path}")
|
||||
|
||||
return model, metrics
|
||||
|
||||
# --- 将此训练器注册到系统中 ---
|
||||
from models.model_registry import register_trainer
|
||||
register_trainer('kan', train_product_model_with_kan)
|
||||
register_trainer('optimized_kan', train_product_model_with_kan)
|
@ -23,46 +23,24 @@ from core.config import (
|
||||
DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON
|
||||
)
|
||||
from utils.training_progress import progress_manager
|
||||
|
||||
def save_checkpoint(checkpoint_data: dict, epoch_or_label, path_info: dict):
|
||||
"""
|
||||
保存训练检查点 (已适配扁平化路径规范)
|
||||
|
||||
Args:
|
||||
checkpoint_data: 检查点数据
|
||||
epoch_or_label: epoch编号或标签(如'best'或整数)
|
||||
path_info (dict): 包含所有路径信息的字典
|
||||
"""
|
||||
if epoch_or_label == 'best':
|
||||
# 使用由 ModelPathManager 直接提供的最佳检查点路径
|
||||
checkpoint_path = path_info['best_checkpoint_path']
|
||||
else:
|
||||
# 使用 epoch 检查点模板生成路径
|
||||
template = path_info.get('epoch_checkpoint_template')
|
||||
if not template:
|
||||
raise ValueError("路径信息 'path_info' 中缺少 'epoch_checkpoint_template'。")
|
||||
checkpoint_path = template.format(N=epoch_or_label)
|
||||
|
||||
# 保存检查点
|
||||
torch.save(checkpoint_data, checkpoint_path)
|
||||
print(f"[mLSTM] 检查点已保存: {checkpoint_path}", flush=True)
|
||||
|
||||
return checkpoint_path
|
||||
from utils.model_manager import model_manager
|
||||
|
||||
def train_product_model_with_mlstm(
|
||||
product_id,
|
||||
model_identifier,
|
||||
product_df,
|
||||
store_id=None,
|
||||
training_mode='product',
|
||||
aggregation_method='sum',
|
||||
epochs=50,
|
||||
model_dir=DEFAULT_MODEL_DIR, # 将被 path_info 替代
|
||||
version=None, # 将被 path_info 替代
|
||||
sequence_length=LOOK_BACK,
|
||||
forecast_horizon=FORECAST_HORIZON,
|
||||
model_dir=DEFAULT_MODEL_DIR,
|
||||
version=None,
|
||||
socketio=None,
|
||||
task_id=None,
|
||||
continue_training=False,
|
||||
progress_callback=None,
|
||||
path_info=None, # 新增参数
|
||||
patience=10,
|
||||
learning_rate=0.001,
|
||||
clip_norm=1.0
|
||||
@ -84,12 +62,6 @@ def train_product_model_with_mlstm(
|
||||
progress_callback: 进度回调函数,用于多进程训练
|
||||
"""
|
||||
|
||||
# 验证 path_info 是否提供
|
||||
if not path_info:
|
||||
raise ValueError("train_product_model_with_mlstm 需要 'path_info' 参数。")
|
||||
|
||||
version = path_info['version']
|
||||
|
||||
# 创建WebSocket进度反馈函数,支持多进程
|
||||
def emit_progress(message, progress=None, metrics=None):
|
||||
"""发送训练进度到前端"""
|
||||
@ -123,7 +95,12 @@ def train_product_model_with_mlstm(
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
|
||||
emit_progress(f"开始训练 mLSTM 模型版本 v{version}")
|
||||
emit_progress("开始mLSTM模型训练...")
|
||||
|
||||
# 确定版本号
|
||||
emit_progress(f"开始训练 mLSTM 模型")
|
||||
if version:
|
||||
emit_progress(f"使用指定版本: {version}")
|
||||
|
||||
# 初始化训练进度管理器(如果还未初始化)
|
||||
if socketio and task_id:
|
||||
@ -156,11 +133,11 @@ def train_product_model_with_mlstm(
|
||||
training_scope = "所有店铺"
|
||||
|
||||
# 数据量检查
|
||||
min_required_samples = LOOK_BACK + FORECAST_HORIZON
|
||||
min_required_samples = sequence_length + forecast_horizon
|
||||
if len(product_df) < min_required_samples:
|
||||
error_msg = (
|
||||
f"❌ 训练数据不足错误\n"
|
||||
f"当前配置需要: {min_required_samples} 天数据 (LOOK_BACK={LOOK_BACK} + FORECAST_HORIZON={FORECAST_HORIZON})\n"
|
||||
f"当前配置需要: {min_required_samples} 天数据 (LOOK_BACK={sequence_length} + FORECAST_HORIZON={forecast_horizon})\n"
|
||||
f"实际数据量: {len(product_df)} 天\n"
|
||||
f"产品ID: {product_id}, 训练模式: {training_mode}\n"
|
||||
f"建议解决方案:\n"
|
||||
@ -176,9 +153,9 @@ def train_product_model_with_mlstm(
|
||||
|
||||
print(f"[mLSTM] 使用mLSTM模型训练产品 '{product_name}' (ID: {product_id}) 的销售预测模型", flush=True)
|
||||
print(f"[mLSTM] 训练范围: {training_scope}", flush=True)
|
||||
print(f"[mLSTM] 版本: v{version}", flush=True)
|
||||
# print(f"[mLSTM] 版本: {version}", flush=True) # Version is now handled by model_manager
|
||||
print(f"[mLSTM] 使用设备: {DEVICE}", flush=True)
|
||||
print(f"[mLSTM] 模型将保存到: {path_info['base_dir']}", flush=True)
|
||||
print(f"[mLSTM] 模型将保存到目录: {model_dir}", flush=True)
|
||||
print(f"[mLSTM] 数据量: {len(product_df)} 条记录", flush=True)
|
||||
|
||||
emit_progress(f"训练产品: {product_name} (ID: {product_id}) - {training_scope}")
|
||||
@ -210,8 +187,8 @@ def train_product_model_with_mlstm(
|
||||
y_train, y_test = y_scaled[:train_size], y_scaled[train_size:]
|
||||
|
||||
# 创建时间序列数据
|
||||
trainX, trainY = create_dataset(X_train, y_train, LOOK_BACK, FORECAST_HORIZON)
|
||||
testX, testY = create_dataset(X_test, y_test, LOOK_BACK, FORECAST_HORIZON)
|
||||
trainX, trainY = create_dataset(X_train, y_train, sequence_length, forecast_horizon)
|
||||
testX, testY = create_dataset(X_test, y_test, sequence_length, forecast_horizon)
|
||||
|
||||
# 转换为PyTorch的Tensor
|
||||
trainX_tensor = torch.Tensor(trainX)
|
||||
@ -236,7 +213,7 @@ def train_product_model_with_mlstm(
|
||||
|
||||
# 初始化mLSTM结合Transformer模型
|
||||
input_dim = X_train.shape[1]
|
||||
output_dim = FORECAST_HORIZON
|
||||
output_dim = forecast_horizon
|
||||
hidden_size = 128
|
||||
num_heads = 4
|
||||
dropout_rate = 0.1
|
||||
@ -265,9 +242,8 @@ def train_product_model_with_mlstm(
|
||||
|
||||
# 如果是继续训练,加载现有模型
|
||||
if continue_training and version != 'v1':
|
||||
# TODO: 继续训练的逻辑需要调整以适应新的路径结构
|
||||
# 例如,加载上一个版本的 best checkpoint
|
||||
emit_progress("继续训练功能待适配新路径结构,暂时作为新训练开始。")
|
||||
# TODO: Implement continue_training logic with the new model_manager
|
||||
pass
|
||||
|
||||
# 将模型移动到设备上
|
||||
model = model.to(DEVICE)
|
||||
@ -366,12 +342,13 @@ def train_product_model_with_mlstm(
|
||||
'output_dim': output_dim,
|
||||
'hidden_size': hidden_size,
|
||||
'num_heads': num_heads,
|
||||
'dropout': dropout_rate,
|
||||
'dropout_rate': dropout_rate,
|
||||
'num_blocks': num_blocks,
|
||||
'embed_dim': embed_dim,
|
||||
'dense_dim': dense_dim,
|
||||
'sequence_length': LOOK_BACK,
|
||||
'forecast_horizon': FORECAST_HORIZON,
|
||||
'mlstm_layers': 2, # 确保这个参数被保存
|
||||
'sequence_length': sequence_length,
|
||||
'forecast_horizon': forecast_horizon,
|
||||
'model_type': 'mlstm'
|
||||
},
|
||||
'training_info': {
|
||||
@ -385,19 +362,23 @@ def train_product_model_with_mlstm(
|
||||
}
|
||||
}
|
||||
|
||||
# 保存检查点
|
||||
save_checkpoint(checkpoint_data, epoch + 1, path_info)
|
||||
|
||||
# 如果是最佳模型,额外保存一份
|
||||
if test_loss < best_loss:
|
||||
best_loss = test_loss
|
||||
save_checkpoint(checkpoint_data, 'best', path_info)
|
||||
model_manager.save_model(
|
||||
model_data=checkpoint_data,
|
||||
product_id=model_identifier, # 修正:使用唯一的标识符
|
||||
model_type='mlstm',
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
product_name=product_name,
|
||||
version='best'
|
||||
)
|
||||
emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})")
|
||||
epochs_no_improve = 0
|
||||
else:
|
||||
epochs_no_improve += 1
|
||||
|
||||
emit_progress(f"💾 保存训练检查点 epoch_{epoch+1}")
|
||||
|
||||
if (epoch + 1) % 10 == 0:
|
||||
print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}", flush=True)
|
||||
@ -412,15 +393,26 @@ def train_product_model_with_mlstm(
|
||||
|
||||
emit_progress("生成损失曲线...", progress=95)
|
||||
|
||||
# 从 path_info 获取损失曲线保存路径
|
||||
loss_curve_path = path_info['loss_curve_path']
|
||||
# 确定模型保存目录(支持多店铺)
|
||||
if store_id:
|
||||
# 为特定店铺创建子目录
|
||||
store_model_dir = os.path.join(model_dir, 'mlstm', store_id)
|
||||
os.makedirs(store_model_dir, exist_ok=True)
|
||||
loss_curve_filename = f"{product_id}_mlstm_{version}_loss_curve.png"
|
||||
loss_curve_path = os.path.join(store_model_dir, loss_curve_filename)
|
||||
else:
|
||||
# 全局模型保存在global目录
|
||||
global_model_dir = os.path.join(model_dir, 'mlstm', 'global')
|
||||
os.makedirs(global_model_dir, exist_ok=True)
|
||||
loss_curve_filename = f"{product_id}_mlstm_{version}_global_loss_curve.png"
|
||||
loss_curve_path = os.path.join(global_model_dir, loss_curve_filename)
|
||||
|
||||
# 绘制损失曲线并保存到模型目录
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.plot(train_losses, label='Training Loss')
|
||||
plt.plot(test_losses, label='Test Loss')
|
||||
title_suffix = f" - {training_scope}" if store_id else " - 全局模型"
|
||||
plt.title(f'mLSTM 模型训练损失曲线 - {product_name} (v{version}){title_suffix}')
|
||||
plt.title(f'mLSTM 模型训练损失曲线 - {product_name} ({version}){title_suffix}')
|
||||
plt.xlabel('Epoch')
|
||||
plt.ylabel('Loss')
|
||||
plt.legend()
|
||||
@ -445,7 +437,6 @@ def train_product_model_with_mlstm(
|
||||
# 计算评估指标
|
||||
metrics = evaluate_model(test_true_inv, test_pred_inv)
|
||||
metrics['training_time'] = training_time
|
||||
metrics['version'] = version
|
||||
|
||||
# 打印评估指标
|
||||
print("\n模型评估指标:")
|
||||
@ -474,12 +465,13 @@ def train_product_model_with_mlstm(
|
||||
'output_dim': output_dim,
|
||||
'hidden_size': hidden_size,
|
||||
'num_heads': num_heads,
|
||||
'dropout': dropout_rate,
|
||||
'dropout_rate': dropout_rate,
|
||||
'num_blocks': num_blocks,
|
||||
'embed_dim': embed_dim,
|
||||
'dense_dim': dense_dim,
|
||||
'sequence_length': LOOK_BACK,
|
||||
'forecast_horizon': FORECAST_HORIZON,
|
||||
'mlstm_layers': 2, # 确保这个参数被保存
|
||||
'sequence_length': sequence_length,
|
||||
'forecast_horizon': forecast_horizon,
|
||||
'model_type': 'mlstm'
|
||||
},
|
||||
'metrics': metrics,
|
||||
@ -496,17 +488,17 @@ def train_product_model_with_mlstm(
|
||||
}
|
||||
}
|
||||
|
||||
# 检查模型性能是否达标
|
||||
# 移除R2检查,始终保存模型
|
||||
if metrics:
|
||||
# 保存最终模型到 model.pth
|
||||
final_model_path = path_info['model_path']
|
||||
torch.save(final_model_data, final_model_path)
|
||||
print(f"[mLSTM] 最终模型已保存: {final_model_path}", flush=True)
|
||||
else:
|
||||
final_model_path = None
|
||||
print(f"[mLSTM] 训练过程中未生成评估指标,不保存最终模型。", flush=True)
|
||||
|
||||
# 保存最终模型,让 model_manager 自动处理版本号
|
||||
final_model_path, final_version = model_manager.save_model(
|
||||
model_data=final_model_data,
|
||||
product_id=model_identifier, # 修正:使用唯一的标识符
|
||||
model_type='mlstm',
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
product_name=product_name
|
||||
)
|
||||
|
||||
# 发送训练完成消息
|
||||
final_metrics = {
|
||||
'mse': metrics['mse'],
|
||||
@ -516,12 +508,14 @@ def train_product_model_with_mlstm(
|
||||
'mape': metrics['mape'],
|
||||
'training_time': training_time,
|
||||
'final_epoch': epochs,
|
||||
'model_path': final_model_path
|
||||
'model_path': final_model_path,
|
||||
'version': final_version
|
||||
}
|
||||
|
||||
if final_model_path:
|
||||
emit_progress(f"✅ mLSTM模型训练完成!最终epoch: {epochs} 已保存", progress=100, metrics=final_metrics)
|
||||
else:
|
||||
emit_progress(f"❌ mLSTM模型训练失败:性能不达标", progress=100, metrics={'error': '模型性能不佳'})
|
||||
emit_progress(f"✅ mLSTM模型训练完成!版本 {final_version} 已保存", progress=100, metrics=final_metrics)
|
||||
|
||||
return model, metrics, epochs, final_model_path
|
||||
|
||||
return model, metrics, epochs, final_model_path
|
||||
# --- 将此训练器注册到系统中 ---
|
||||
from models.model_registry import register_trainer
|
||||
register_trainer('mlstm', train_product_model_with_mlstm)
|
@ -20,71 +20,28 @@ from utils.visualization import plot_loss_curve
|
||||
from analysis.metrics import evaluate_model
|
||||
from core.config import DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON
|
||||
from utils.training_progress import progress_manager
|
||||
|
||||
def save_checkpoint(checkpoint_data: dict, epoch_or_label, path_info: dict):
|
||||
"""
|
||||
保存训练检查点
|
||||
|
||||
Args:
|
||||
checkpoint_data: 检查点数据
|
||||
epoch_or_label: epoch编号或标签(如'best', 'final', 50)
|
||||
path_info (dict): 包含所有路径信息的字典
|
||||
"""
|
||||
if epoch_or_label == 'best':
|
||||
# 使用由 ModelPathManager 直接提供的最佳检查点路径
|
||||
checkpoint_path = path_info['best_checkpoint_path']
|
||||
else:
|
||||
# 使用 epoch 检查点模板生成路径
|
||||
template = path_info.get('epoch_checkpoint_template')
|
||||
if not template:
|
||||
raise ValueError("路径信息 'path_info' 中缺少 'epoch_checkpoint_template'。")
|
||||
checkpoint_path = template.format(N=epoch_or_label)
|
||||
|
||||
# 保存检查点
|
||||
torch.save(checkpoint_data, checkpoint_path)
|
||||
print(f"[TCN] 检查点已保存: {checkpoint_path}", flush=True)
|
||||
|
||||
return checkpoint_path
|
||||
from utils.model_manager import model_manager
|
||||
|
||||
def train_product_model_with_tcn(
|
||||
product_id,
|
||||
model_identifier,
|
||||
product_df=None,
|
||||
store_id=None,
|
||||
training_mode='product',
|
||||
aggregation_method='sum',
|
||||
epochs=50,
|
||||
model_dir=DEFAULT_MODEL_DIR, # 将被 path_info 替代
|
||||
version=None, # 将被 path_info 替代
|
||||
sequence_length=LOOK_BACK,
|
||||
forecast_horizon=FORECAST_HORIZON,
|
||||
model_dir=DEFAULT_MODEL_DIR,
|
||||
version=None,
|
||||
socketio=None,
|
||||
task_id=None,
|
||||
continue_training=False,
|
||||
path_info=None, # 新增参数
|
||||
**kwargs
|
||||
continue_training=False
|
||||
):
|
||||
"""
|
||||
使用TCN模型训练产品销售预测模型
|
||||
|
||||
参数:
|
||||
product_id: 产品ID
|
||||
epochs: 训练轮次
|
||||
model_dir: 模型保存目录,默认使用配置中的DEFAULT_MODEL_DIR
|
||||
version: 指定版本号,如果为None则自动生成
|
||||
socketio: WebSocket对象,用于实时反馈
|
||||
task_id: 训练任务ID
|
||||
continue_training: 是否继续训练现有模型
|
||||
|
||||
返回:
|
||||
model: 训练好的模型
|
||||
metrics: 模型评估指标
|
||||
version: 实际使用的版本号
|
||||
model_path: 模型文件路径
|
||||
"""
|
||||
|
||||
if not path_info:
|
||||
raise ValueError("train_product_model_with_tcn 需要 'path_info' 参数。")
|
||||
|
||||
version = path_info['version']
|
||||
|
||||
def emit_progress(message, progress=None, metrics=None):
|
||||
"""发送训练进度到前端"""
|
||||
if socketio and task_id:
|
||||
@ -99,63 +56,28 @@ def train_product_model_with_tcn(
|
||||
data['metrics'] = metrics
|
||||
socketio.emit('training_progress', data, namespace='/training')
|
||||
|
||||
emit_progress(f"开始训练 TCN 模型版本 v{version}")
|
||||
emit_progress(f"开始训练 TCN 模型")
|
||||
|
||||
# 如果没有传入product_df,则根据训练模式加载数据
|
||||
if product_df is None:
|
||||
from utils.multi_store_data_utils import load_multi_store_data, get_store_product_sales_data, aggregate_multi_store_data
|
||||
|
||||
try:
|
||||
if training_mode == 'store' and store_id:
|
||||
# 加载特定店铺的数据
|
||||
product_df = get_store_product_sales_data(
|
||||
store_id,
|
||||
product_id,
|
||||
'pharmacy_sales_multi_store.csv'
|
||||
)
|
||||
training_scope = f"店铺 {store_id}"
|
||||
elif training_mode == 'global':
|
||||
# 聚合所有店铺的数据
|
||||
product_df = aggregate_multi_store_data(
|
||||
product_id,
|
||||
aggregation_method=aggregation_method,
|
||||
file_path='pharmacy_sales_multi_store.csv'
|
||||
)
|
||||
training_scope = f"全局聚合({aggregation_method})"
|
||||
else:
|
||||
# 默认:加载所有店铺的产品数据
|
||||
product_df = load_multi_store_data('pharmacy_sales_multi_store.csv', product_id=product_id)
|
||||
training_scope = "所有店铺"
|
||||
except Exception as e:
|
||||
print(f"多店铺数据加载失败: {e}")
|
||||
# 后备方案:尝试原始数据
|
||||
df = pd.read_excel('pharmacy_sales.xlsx')
|
||||
product_df = df[df['product_id'] == product_id].sort_values('date')
|
||||
training_scope = "原始数据"
|
||||
from utils.multi_store_data_utils import aggregate_multi_store_data
|
||||
product_df = aggregate_multi_store_data(
|
||||
product_id=product_id,
|
||||
aggregation_method=aggregation_method
|
||||
)
|
||||
training_scope = f"全局聚合({aggregation_method})"
|
||||
else:
|
||||
# 如果传入了product_df,直接使用
|
||||
if training_mode == 'store' and store_id:
|
||||
training_scope = f"店铺 {store_id}"
|
||||
elif training_mode == 'global':
|
||||
training_scope = f"全局聚合({aggregation_method})"
|
||||
else:
|
||||
training_scope = "所有店铺"
|
||||
training_scope = "所有店铺"
|
||||
|
||||
if product_df.empty:
|
||||
raise ValueError(f"产品 {product_id} 没有可用的销售数据")
|
||||
|
||||
# 数据量检查
|
||||
min_required_samples = LOOK_BACK + FORECAST_HORIZON
|
||||
min_required_samples = sequence_length + forecast_horizon
|
||||
if len(product_df) < min_required_samples:
|
||||
error_msg = (
|
||||
f"❌ 训练数据不足错误\n"
|
||||
f"当前配置需要: {min_required_samples} 天数据 (LOOK_BACK={LOOK_BACK} + FORECAST_HORIZON={FORECAST_HORIZON})\n"
|
||||
f"当前配置需要: {min_required_samples} 天数据 (LOOK_BACK={sequence_length} + FORECAST_HORIZON={forecast_horizon})\n"
|
||||
f"实际数据量: {len(product_df)} 天\n"
|
||||
f"产品ID: {product_id}, 训练模式: {training_mode}\n"
|
||||
f"建议解决方案:\n"
|
||||
f"1. 生成更多数据: uv run generate_multi_store_data.py\n"
|
||||
f"2. 调整配置参数: 减小 LOOK_BACK 或 FORECAST_HORIZON\n"
|
||||
f"3. 使用全局训练模式聚合更多数据"
|
||||
)
|
||||
print(error_msg)
|
||||
emit_progress(f"训练失败:数据不足 ({len(product_df)}/{min_required_samples} 天)")
|
||||
@ -166,48 +88,39 @@ def train_product_model_with_tcn(
|
||||
|
||||
print(f"使用TCN模型训练产品 '{product_name}' (ID: {product_id}) 的销售预测模型")
|
||||
print(f"训练范围: {training_scope}")
|
||||
print(f"版本: v{version}")
|
||||
print(f"使用设备: {DEVICE}")
|
||||
print(f"模型将保存到: {path_info['base_dir']}")
|
||||
print(f"模型将保存到目录: {model_dir}")
|
||||
|
||||
emit_progress(f"训练产品: {product_name} (ID: {product_id})")
|
||||
|
||||
# 创建特征和目标变量
|
||||
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||
|
||||
# 预处理数据
|
||||
X = product_df[features].values
|
||||
y = product_df[['sales']].values # 保持为二维数组
|
||||
y = product_df[['sales']].values
|
||||
|
||||
# 设置数据预处理阶段
|
||||
progress_manager.set_stage("data_preprocessing", 0)
|
||||
emit_progress("数据预处理中...")
|
||||
|
||||
# 归一化数据
|
||||
scaler_X = MinMaxScaler(feature_range=(0, 1))
|
||||
scaler_y = MinMaxScaler(feature_range=(0, 1))
|
||||
|
||||
X_scaled = scaler_X.fit_transform(X)
|
||||
y_scaled = scaler_y.fit_transform(y)
|
||||
|
||||
# 划分训练集和测试集(80% 训练,20% 测试)
|
||||
train_size = int(len(X_scaled) * 0.8)
|
||||
X_train, X_test = X_scaled[:train_size], X_scaled[train_size:]
|
||||
y_train, y_test = y_scaled[:train_size], y_scaled[train_size:]
|
||||
|
||||
progress_manager.set_stage("data_preprocessing", 50)
|
||||
|
||||
# 创建时间序列数据
|
||||
trainX, trainY = create_dataset(X_train, y_train, LOOK_BACK, FORECAST_HORIZON)
|
||||
testX, testY = create_dataset(X_test, y_test, LOOK_BACK, FORECAST_HORIZON)
|
||||
trainX, trainY = create_dataset(X_train, y_train, sequence_length, forecast_horizon)
|
||||
testX, testY = create_dataset(X_test, y_test, sequence_length, forecast_horizon)
|
||||
|
||||
# 转换为PyTorch的Tensor
|
||||
trainX_tensor = torch.Tensor(trainX)
|
||||
trainY_tensor = torch.Tensor(trainY)
|
||||
testX_tensor = torch.Tensor(testX)
|
||||
testY_tensor = torch.Tensor(testY)
|
||||
|
||||
# 创建数据加载器
|
||||
train_dataset = PharmacyDataset(trainX_tensor, trainY_tensor)
|
||||
test_dataset = PharmacyDataset(testX_tensor, testY_tensor)
|
||||
|
||||
@ -215,7 +128,6 @@ def train_product_model_with_tcn(
|
||||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
||||
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
|
||||
|
||||
# 更新进度管理器的批次信息
|
||||
total_batches = len(train_loader)
|
||||
total_samples = len(train_dataset)
|
||||
progress_manager.total_batches_per_epoch = total_batches
|
||||
@ -224,9 +136,8 @@ def train_product_model_with_tcn(
|
||||
|
||||
progress_manager.set_stage("data_preprocessing", 100)
|
||||
|
||||
# 初始化TCN模型
|
||||
input_dim = X_train.shape[1]
|
||||
output_dim = FORECAST_HORIZON
|
||||
output_dim = forecast_horizon
|
||||
hidden_size = 64
|
||||
num_layers = 3
|
||||
kernel_size = 3
|
||||
@ -240,12 +151,8 @@ def train_product_model_with_tcn(
|
||||
dropout=dropout_rate
|
||||
)
|
||||
|
||||
# 如果是继续训练,加载现有模型
|
||||
if continue_training:
|
||||
# TODO: 继续训练的逻辑需要调整以适应新的路径结构
|
||||
emit_progress("继续训练功能待适配新路径结构,暂时作为新训练开始。")
|
||||
# TODO: Implement continue_training logic with the new model_manager
|
||||
|
||||
# 将模型移动到设备上
|
||||
model = model.to(DEVICE)
|
||||
|
||||
criterion = nn.MSELoss()
|
||||
@ -253,20 +160,17 @@ def train_product_model_with_tcn(
|
||||
|
||||
emit_progress("开始模型训练...")
|
||||
|
||||
# 训练模型
|
||||
train_losses = []
|
||||
test_losses = []
|
||||
start_time = time.time()
|
||||
|
||||
# 配置检查点保存
|
||||
checkpoint_interval = max(1, epochs // 10) # 每10%进度保存一次,最少每1个epoch
|
||||
checkpoint_interval = max(1, epochs // 10)
|
||||
best_loss = float('inf')
|
||||
|
||||
progress_manager.set_stage("model_training", 0)
|
||||
emit_progress(f"开始训练 - 总epoch: {epochs}, 检查点间隔: {checkpoint_interval}")
|
||||
|
||||
for epoch in range(epochs):
|
||||
# 开始新的轮次
|
||||
progress_manager.start_epoch(epoch)
|
||||
|
||||
model.train()
|
||||
@ -275,43 +179,34 @@ def train_product_model_with_tcn(
|
||||
for batch_idx, (X_batch, y_batch) in enumerate(train_loader):
|
||||
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
|
||||
|
||||
# 确保目标张量有正确的形状 (batch_size, forecast_horizon, 1)
|
||||
if y_batch.dim() == 2:
|
||||
y_batch = y_batch.unsqueeze(-1)
|
||||
|
||||
# 前向传播
|
||||
outputs = model(X_batch)
|
||||
|
||||
# 确保输出和目标形状匹配
|
||||
loss = criterion(outputs, y_batch)
|
||||
|
||||
# 反向传播和优化
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
epoch_loss += loss.item()
|
||||
|
||||
# 更新批次进度(每10个批次更新一次)
|
||||
if batch_idx % 10 == 0 or batch_idx == len(train_loader) - 1:
|
||||
current_lr = optimizer.param_groups[0]['lr']
|
||||
progress_manager.update_batch(batch_idx, loss.item(), current_lr)
|
||||
|
||||
# 计算训练损失
|
||||
train_loss = epoch_loss / len(train_loader)
|
||||
train_losses.append(train_loss)
|
||||
|
||||
# 设置验证阶段
|
||||
progress_manager.set_stage("validation", 0)
|
||||
|
||||
# 在测试集上评估
|
||||
model.eval()
|
||||
test_loss = 0
|
||||
with torch.no_grad():
|
||||
for batch_idx, (X_batch, y_batch) in enumerate(test_loader):
|
||||
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
|
||||
|
||||
# 确保目标张量有正确的形状
|
||||
if y_batch.dim() == 2:
|
||||
y_batch = y_batch.unsqueeze(-1)
|
||||
|
||||
@ -319,7 +214,6 @@ def train_product_model_with_tcn(
|
||||
loss = criterion(outputs, y_batch)
|
||||
test_loss += loss.item()
|
||||
|
||||
# 更新验证进度
|
||||
if batch_idx % 5 == 0 or batch_idx == len(test_loader) - 1:
|
||||
val_progress = (batch_idx / len(test_loader)) * 100
|
||||
progress_manager.set_stage("validation", val_progress)
|
||||
@ -327,10 +221,8 @@ def train_product_model_with_tcn(
|
||||
test_loss = test_loss / len(test_loader)
|
||||
test_losses.append(test_loss)
|
||||
|
||||
# 完成当前轮次
|
||||
progress_manager.finish_epoch(train_loss, test_loss)
|
||||
|
||||
# 发送训练进度(保持与旧系统的兼容性)
|
||||
if (epoch + 1) % 5 == 0 or epoch == epochs - 1:
|
||||
progress = ((epoch + 1) / epochs) * 100
|
||||
current_metrics = {
|
||||
@ -342,7 +234,6 @@ def train_product_model_with_tcn(
|
||||
emit_progress(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}",
|
||||
progress=progress, metrics=current_metrics)
|
||||
|
||||
# 定期保存检查点
|
||||
if (epoch + 1) % checkpoint_interval == 0 or epoch == epochs - 1:
|
||||
checkpoint_data = {
|
||||
'epoch': epoch + 1,
|
||||
@ -359,10 +250,11 @@ def train_product_model_with_tcn(
|
||||
'output_dim': output_dim,
|
||||
'hidden_size': hidden_size,
|
||||
'num_layers': num_layers,
|
||||
'num_channels': [hidden_size] * num_layers,
|
||||
'dropout': dropout_rate,
|
||||
'kernel_size': kernel_size,
|
||||
'sequence_length': LOOK_BACK,
|
||||
'forecast_horizon': FORECAST_HORIZON,
|
||||
'sequence_length': sequence_length,
|
||||
'forecast_horizon': forecast_horizon,
|
||||
'model_type': 'tcn'
|
||||
},
|
||||
'training_info': {
|
||||
@ -375,55 +267,48 @@ def train_product_model_with_tcn(
|
||||
}
|
||||
}
|
||||
|
||||
# 保存检查点
|
||||
save_checkpoint(checkpoint_data, epoch + 1, path_info)
|
||||
|
||||
# 如果是最佳模型,额外保存一份
|
||||
if test_loss < best_loss:
|
||||
best_loss = test_loss
|
||||
save_checkpoint(checkpoint_data, 'best', path_info)
|
||||
model_manager.save_model(
|
||||
model_data=checkpoint_data,
|
||||
product_id=model_identifier, # 修正:使用唯一的标识符
|
||||
model_type='tcn',
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
product_name=product_name,
|
||||
version='best'
|
||||
)
|
||||
emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})")
|
||||
|
||||
emit_progress(f"💾 保存训练检查点 epoch_{epoch+1}")
|
||||
|
||||
if (epoch + 1) % 10 == 0:
|
||||
print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")
|
||||
|
||||
# 计算训练时间
|
||||
training_time = time.time() - start_time
|
||||
|
||||
# 设置模型保存阶段
|
||||
progress_manager.set_stage("model_saving", 0)
|
||||
emit_progress("训练完成,正在保存模型...")
|
||||
|
||||
# 绘制损失曲线并保存到模型目录
|
||||
loss_curve_path = path_info['loss_curve_path']
|
||||
plot_loss_curve(
|
||||
train_losses,
|
||||
test_losses,
|
||||
product_name,
|
||||
'TCN',
|
||||
save_path=loss_curve_path
|
||||
loss_curve_path = plot_loss_curve(
|
||||
train_losses,
|
||||
test_losses,
|
||||
product_name,
|
||||
'TCN',
|
||||
model_dir=model_dir
|
||||
)
|
||||
print(f"损失曲线已保存到: {loss_curve_path}")
|
||||
|
||||
# 评估模型
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
# 确保测试数据的形状正确
|
||||
test_pred = model(testX_tensor.to(DEVICE))
|
||||
# 将输出转换为二维数组 [samples, forecast_horizon]
|
||||
test_pred = test_pred.squeeze(-1).cpu().numpy()
|
||||
|
||||
# 反归一化预测结果和真实值
|
||||
test_pred_inv = scaler_y.inverse_transform(test_pred.reshape(-1, 1)).flatten()
|
||||
test_true_inv = scaler_y.inverse_transform(testY.reshape(-1, 1)).flatten()
|
||||
|
||||
# 计算评估指标
|
||||
metrics = evaluate_model(test_true_inv, test_pred_inv)
|
||||
metrics['training_time'] = training_time
|
||||
|
||||
# 打印评估指标
|
||||
print("\n模型评估指标:")
|
||||
print(f"MSE: {metrics['mse']:.4f}")
|
||||
print(f"RMSE: {metrics['rmse']:.4f}")
|
||||
@ -432,9 +317,8 @@ def train_product_model_with_tcn(
|
||||
print(f"MAPE: {metrics['mape']:.2f}%")
|
||||
print(f"训练时间: {training_time:.2f}秒")
|
||||
|
||||
# 保存最终训练完成的模型(基于最终epoch)
|
||||
final_model_data = {
|
||||
'epoch': epochs, # 最终epoch
|
||||
'epoch': epochs,
|
||||
'model_state_dict': model.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'train_loss': train_losses[-1],
|
||||
@ -448,10 +332,11 @@ def train_product_model_with_tcn(
|
||||
'output_dim': output_dim,
|
||||
'hidden_size': hidden_size,
|
||||
'num_layers': num_layers,
|
||||
'num_channels': [hidden_size] * num_layers,
|
||||
'dropout': dropout_rate,
|
||||
'kernel_size': kernel_size,
|
||||
'sequence_length': LOOK_BACK,
|
||||
'forecast_horizon': FORECAST_HORIZON,
|
||||
'sequence_length': sequence_length,
|
||||
'forecast_horizon': forecast_horizon,
|
||||
'model_type': 'tcn'
|
||||
},
|
||||
'metrics': metrics,
|
||||
@ -469,17 +354,17 @@ def train_product_model_with_tcn(
|
||||
|
||||
progress_manager.set_stage("model_saving", 50)
|
||||
|
||||
# 检查模型性能是否达标
|
||||
# 移除R2检查,始终保存模型
|
||||
if metrics:
|
||||
# 保存最终模型
|
||||
final_model_path = path_info['model_path']
|
||||
torch.save(final_model_data, final_model_path)
|
||||
print(f"[TCN] 最终模型已保存: {final_model_path}", flush=True)
|
||||
progress_manager.set_stage("model_saving", 100)
|
||||
else:
|
||||
final_model_path = None
|
||||
print(f"[TCN] 训练过程中未生成评估指标,不保存最终模型。", flush=True)
|
||||
final_model_path, final_version = model_manager.save_model(
|
||||
model_data=final_model_data,
|
||||
product_id=model_identifier, # 修正:使用唯一的标识符
|
||||
model_type='tcn',
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
product_name=product_name
|
||||
)
|
||||
|
||||
progress_manager.set_stage("model_saving", 100)
|
||||
|
||||
final_metrics = {
|
||||
'mse': metrics['mse'],
|
||||
@ -488,12 +373,14 @@ def train_product_model_with_tcn(
|
||||
'r2': metrics['r2'],
|
||||
'mape': metrics['mape'],
|
||||
'training_time': training_time,
|
||||
'final_epoch': epochs
|
||||
'final_epoch': epochs,
|
||||
'version': final_version
|
||||
}
|
||||
|
||||
if final_model_path:
|
||||
emit_progress(f"模型训练完成!最终epoch: {epochs}", progress=100, metrics=final_metrics)
|
||||
else:
|
||||
emit_progress(f"❌ TCN模型训练失败:性能不达标", progress=100, metrics={'error': '模型性能不佳'})
|
||||
emit_progress(f"模型训练完成!版本 {final_version} 已保存", progress=100, metrics=final_metrics)
|
||||
|
||||
return model, metrics, epochs, final_model_path
|
||||
return model, metrics, epochs, final_model_path
|
||||
|
||||
# --- 将此训练器注册到系统中 ---
|
||||
from models.model_registry import register_trainer
|
||||
register_trainer('tcn', train_product_model_with_tcn)
|
@ -26,72 +26,29 @@ from core.config import (
|
||||
from utils.training_progress import progress_manager
|
||||
from utils.model_manager import model_manager
|
||||
|
||||
def save_checkpoint(checkpoint_data: dict, epoch_or_label, path_info: dict):
|
||||
"""
|
||||
保存训练检查点
|
||||
|
||||
Args:
|
||||
checkpoint_data: 检查点数据
|
||||
epoch_or_label: epoch编号或标签(如'best', 'final', 50)
|
||||
path_info (dict): 包含所有路径信息的字典
|
||||
"""
|
||||
if epoch_or_label == 'best':
|
||||
# 使用由 ModelPathManager 直接提供的最佳检查点路径
|
||||
checkpoint_path = path_info['best_checkpoint_path']
|
||||
else:
|
||||
# 使用 epoch 检查点模板生成路径
|
||||
template = path_info.get('epoch_checkpoint_template')
|
||||
if not template:
|
||||
raise ValueError("路径信息 'path_info' 中缺少 'epoch_checkpoint_template'。")
|
||||
checkpoint_path = template.format(N=epoch_or_label)
|
||||
|
||||
# 保存检查点
|
||||
torch.save(checkpoint_data, checkpoint_path)
|
||||
print(f"[Transformer] 检查点已保存: {checkpoint_path}", flush=True)
|
||||
|
||||
return checkpoint_path
|
||||
|
||||
def train_product_model_with_transformer(
|
||||
product_id,
|
||||
model_identifier,
|
||||
product_df=None,
|
||||
store_id=None,
|
||||
training_mode='product',
|
||||
aggregation_method='sum',
|
||||
epochs=50,
|
||||
model_dir=DEFAULT_MODEL_DIR, # 将被 path_info 替代
|
||||
version=None, # 将被 path_info 替代
|
||||
sequence_length=LOOK_BACK,
|
||||
forecast_horizon=FORECAST_HORIZON,
|
||||
model_dir=DEFAULT_MODEL_DIR,
|
||||
version=None,
|
||||
socketio=None,
|
||||
task_id=None,
|
||||
continue_training=False,
|
||||
path_info=None, # 新增参数
|
||||
patience=10,
|
||||
learning_rate=0.001,
|
||||
clip_norm=1.0
|
||||
):
|
||||
"""
|
||||
使用Transformer模型训练产品销售预测模型
|
||||
|
||||
参数:
|
||||
product_id: 产品ID
|
||||
epochs: 训练轮次
|
||||
model_dir: 模型保存目录,默认使用配置中的DEFAULT_MODEL_DIR
|
||||
version: 指定版本号,如果为None则自动生成
|
||||
socketio: WebSocket对象,用于实时反馈
|
||||
task_id: 训练任务ID
|
||||
continue_training: 是否继续训练现有模型
|
||||
|
||||
返回:
|
||||
model: 训练好的模型
|
||||
metrics: 模型评估指标
|
||||
version: 实际使用的版本号
|
||||
"""
|
||||
|
||||
if not path_info:
|
||||
raise ValueError("train_product_model_with_transformer 需要 'path_info' 参数。")
|
||||
|
||||
version = path_info['version']
|
||||
|
||||
# WebSocket进度反馈函数
|
||||
def emit_progress(message, progress=None, metrics=None):
|
||||
"""发送训练进度到前端"""
|
||||
if socketio and task_id:
|
||||
@ -106,18 +63,15 @@ def train_product_model_with_transformer(
|
||||
data['metrics'] = metrics
|
||||
socketio.emit('training_progress', data, namespace='/training')
|
||||
print(f"[{time.strftime('%H:%M:%S')}] {message}", flush=True)
|
||||
# 强制刷新输出缓冲区
|
||||
import sys
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
|
||||
emit_progress(f"开始Transformer模型训练... 版本 v{version}")
|
||||
emit_progress("开始Transformer模型训练...")
|
||||
|
||||
# 获取训练进度管理器实例
|
||||
try:
|
||||
from utils.training_progress import progress_manager
|
||||
except ImportError:
|
||||
# 如果无法导入,创建一个空的管理器以避免错误
|
||||
class DummyProgressManager:
|
||||
def set_stage(self, *args, **kwargs): pass
|
||||
def start_training(self, *args, **kwargs): pass
|
||||
@ -127,75 +81,26 @@ def train_product_model_with_transformer(
|
||||
def finish_training(self, *args, **kwargs): pass
|
||||
progress_manager = DummyProgressManager()
|
||||
|
||||
# 如果没有传入product_df,则根据训练模式加载数据
|
||||
if product_df is None:
|
||||
from utils.multi_store_data_utils import load_multi_store_data, get_store_product_sales_data, aggregate_multi_store_data
|
||||
|
||||
try:
|
||||
if training_mode == 'store' and store_id:
|
||||
# 加载特定店铺的数据
|
||||
product_df = get_store_product_sales_data(
|
||||
store_id,
|
||||
product_id,
|
||||
'pharmacy_sales_multi_store.csv'
|
||||
)
|
||||
training_scope = f"店铺 {store_id}"
|
||||
elif training_mode == 'global':
|
||||
# 聚合所有店铺的数据
|
||||
product_df = aggregate_multi_store_data(
|
||||
product_id,
|
||||
aggregation_method=aggregation_method,
|
||||
file_path='pharmacy_sales_multi_store.csv'
|
||||
)
|
||||
training_scope = f"全局聚合({aggregation_method})"
|
||||
else:
|
||||
# 默认:加载所有店铺的产品数据
|
||||
product_df = load_multi_store_data('pharmacy_sales_multi_store.csv', product_id=product_id)
|
||||
training_scope = "所有店铺"
|
||||
except ValueError as e:
|
||||
if "No objects to concatenate" in str(e):
|
||||
err_msg = f"聚合数据失败 (product: {product_id}, store: {store_id}, mode: {training_mode}): 没有找到可聚合的数据。"
|
||||
emit_progress(err_msg)
|
||||
# 在这种情况下,我们不能继续,所以抛出异常
|
||||
raise ValueError(err_msg) from e
|
||||
# 对于其他 ValueError,也打印并重新抛出
|
||||
emit_progress(f"数据加载时发生值错误: {e}")
|
||||
raise e
|
||||
except Exception as e:
|
||||
emit_progress(f"多店铺数据加载失败: {e}, 尝试后备方案...")
|
||||
# 后备方案:尝试原始数据
|
||||
try:
|
||||
df = pd.read_excel('pharmacy_sales.xlsx')
|
||||
product_df = df[df['product_id'] == product_id].sort_values('date')
|
||||
training_scope = "原始数据"
|
||||
emit_progress("成功从 'pharmacy_sales.xlsx' 加载后备数据。")
|
||||
except Exception as fallback_e:
|
||||
emit_progress(f"后备数据加载失败: {fallback_e}")
|
||||
raise fallback_e from e
|
||||
from utils.multi_store_data_utils import aggregate_multi_store_data
|
||||
product_df = aggregate_multi_store_data(
|
||||
product_id=product_id,
|
||||
aggregation_method=aggregation_method
|
||||
)
|
||||
training_scope = f"全局聚合({aggregation_method})"
|
||||
else:
|
||||
# 如果传入了product_df,直接使用
|
||||
if training_mode == 'store' and store_id:
|
||||
training_scope = f"店铺 {store_id}"
|
||||
elif training_mode == 'global':
|
||||
training_scope = f"全局聚合({aggregation_method})"
|
||||
else:
|
||||
training_scope = "所有店铺"
|
||||
training_scope = "所有店铺"
|
||||
|
||||
if product_df.empty:
|
||||
raise ValueError(f"产品 {product_id} 没有可用的销售数据")
|
||||
|
||||
# 数据量检查
|
||||
min_required_samples = LOOK_BACK + FORECAST_HORIZON
|
||||
min_required_samples = sequence_length + forecast_horizon
|
||||
if len(product_df) < min_required_samples:
|
||||
error_msg = (
|
||||
f"❌ 训练数据不足错误\n"
|
||||
f"当前配置需要: {min_required_samples} 天数据 (LOOK_BACK={LOOK_BACK} + FORECAST_HORIZON={FORECAST_HORIZON})\n"
|
||||
f"当前配置需要: {min_required_samples} 天数据 (LOOK_BACK={sequence_length} + FORECAST_HORIZON={forecast_horizon})\n"
|
||||
f"实际数据量: {len(product_df)} 天\n"
|
||||
f"产品ID: {product_id}, 训练模式: {training_mode}\n"
|
||||
f"建议解决方案:\n"
|
||||
f"1. 生成更多数据: uv run generate_multi_store_data.py\n"
|
||||
f"2. 调整配置参数: 减小 LOOK_BACK 或 FORECAST_HORIZON\n"
|
||||
f"3. 使用全局训练模式聚合更多数据"
|
||||
)
|
||||
print(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
@ -205,20 +110,16 @@ def train_product_model_with_transformer(
|
||||
|
||||
print(f"[Transformer] 训练产品 '{product_name}' (ID: {product_id}) 的销售预测模型", flush=True)
|
||||
print(f"[Device] 使用设备: {DEVICE}", flush=True)
|
||||
print(f"[Model] 模型将保存到: {path_info['base_dir']}", flush=True)
|
||||
print(f"[Model] 模型将保存到目录: {model_dir}", flush=True)
|
||||
|
||||
# 创建特征和目标变量
|
||||
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||
|
||||
# 设置数据预处理阶段
|
||||
progress_manager.set_stage("data_preprocessing", 0)
|
||||
emit_progress("数据预处理中...")
|
||||
|
||||
# 预处理数据
|
||||
X = product_df[features].values
|
||||
y = product_df[['sales']].values # 保持为二维数组
|
||||
y = product_df[['sales']].values
|
||||
|
||||
# 归一化数据
|
||||
scaler_X = MinMaxScaler(feature_range=(0, 1))
|
||||
scaler_y = MinMaxScaler(feature_range=(0, 1))
|
||||
|
||||
@ -227,24 +128,20 @@ def train_product_model_with_transformer(
|
||||
|
||||
progress_manager.set_stage("data_preprocessing", 40)
|
||||
|
||||
# 划分训练集和测试集(80% 训练,20% 测试)
|
||||
train_size = int(len(X_scaled) * 0.8)
|
||||
X_train, X_test = X_scaled[:train_size], X_scaled[train_size:]
|
||||
y_train, y_test = y_scaled[:train_size], y_scaled[train_size:]
|
||||
|
||||
# 创建时间序列数据
|
||||
trainX, trainY = create_dataset(X_train, y_train, LOOK_BACK, FORECAST_HORIZON)
|
||||
testX, testY = create_dataset(X_test, y_test, LOOK_BACK, FORECAST_HORIZON)
|
||||
trainX, trainY = create_dataset(X_train, y_train, sequence_length, forecast_horizon)
|
||||
testX, testY = create_dataset(X_test, y_test, sequence_length, forecast_horizon)
|
||||
|
||||
progress_manager.set_stage("data_preprocessing", 70)
|
||||
|
||||
# 转换为PyTorch的Tensor
|
||||
trainX_tensor = torch.Tensor(trainX)
|
||||
trainY_tensor = torch.Tensor(trainY)
|
||||
testX_tensor = torch.Tensor(testX)
|
||||
testY_tensor = torch.Tensor(testY)
|
||||
|
||||
# 创建数据加载器
|
||||
train_dataset = PharmacyDataset(trainX_tensor, trainY_tensor)
|
||||
test_dataset = PharmacyDataset(testX_tensor, testY_tensor)
|
||||
|
||||
@ -252,7 +149,6 @@ def train_product_model_with_transformer(
|
||||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
||||
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
|
||||
|
||||
# 更新进度管理器的批次信息
|
||||
total_batches = len(train_loader)
|
||||
total_samples = len(train_dataset)
|
||||
progress_manager.total_batches_per_epoch = total_batches
|
||||
@ -262,9 +158,8 @@ def train_product_model_with_transformer(
|
||||
progress_manager.set_stage("data_preprocessing", 100)
|
||||
emit_progress("数据预处理完成,开始模型训练...")
|
||||
|
||||
# 初始化Transformer模型
|
||||
input_dim = X_train.shape[1]
|
||||
output_dim = FORECAST_HORIZON
|
||||
output_dim = forecast_horizon
|
||||
hidden_size = 64
|
||||
num_heads = 4
|
||||
dropout_rate = 0.1
|
||||
@ -278,24 +173,21 @@ def train_product_model_with_transformer(
|
||||
dim_feedforward=hidden_size * 2,
|
||||
dropout=dropout_rate,
|
||||
output_sequence_length=output_dim,
|
||||
seq_length=LOOK_BACK,
|
||||
seq_length=sequence_length,
|
||||
batch_size=batch_size
|
||||
)
|
||||
|
||||
# 将模型移动到设备上
|
||||
model = model.to(DEVICE)
|
||||
|
||||
criterion = nn.MSELoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=patience // 2, factor=0.5)
|
||||
|
||||
# 训练模型
|
||||
train_losses = []
|
||||
test_losses = []
|
||||
start_time = time.time()
|
||||
|
||||
# 配置检查点保存
|
||||
checkpoint_interval = max(1, epochs // 10) # 每10%进度保存一次,最少每1个epoch
|
||||
checkpoint_interval = max(1, epochs // 10)
|
||||
best_loss = float('inf')
|
||||
epochs_no_improve = 0
|
||||
|
||||
@ -303,7 +195,6 @@ def train_product_model_with_transformer(
|
||||
emit_progress(f"开始训练 - 总epoch: {epochs}, 检查点间隔: {checkpoint_interval}, 耐心值: {patience}")
|
||||
|
||||
for epoch in range(epochs):
|
||||
# 开始新的轮次
|
||||
progress_manager.start_epoch(epoch)
|
||||
|
||||
model.train()
|
||||
@ -312,12 +203,9 @@ def train_product_model_with_transformer(
|
||||
for batch_idx, (X_batch, y_batch) in enumerate(train_loader):
|
||||
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
|
||||
|
||||
# 确保目标张量有正确的形状
|
||||
# 前向传播
|
||||
outputs = model(X_batch)
|
||||
loss = criterion(outputs, y_batch)
|
||||
|
||||
# 反向传播和优化
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
if clip_norm:
|
||||
@ -326,31 +214,25 @@ def train_product_model_with_transformer(
|
||||
|
||||
epoch_loss += loss.item()
|
||||
|
||||
# 更新批次进度
|
||||
if batch_idx % 5 == 0 or batch_idx == len(train_loader) - 1:
|
||||
current_lr = optimizer.param_groups[0]['lr']
|
||||
progress_manager.update_batch(batch_idx, loss.item(), current_lr)
|
||||
|
||||
# 计算训练损失
|
||||
train_loss = epoch_loss / len(train_loader)
|
||||
train_losses.append(train_loss)
|
||||
|
||||
# 设置验证阶段
|
||||
progress_manager.set_stage("validation", 0)
|
||||
|
||||
# 在测试集上评估
|
||||
model.eval()
|
||||
test_loss = 0
|
||||
with torch.no_grad():
|
||||
for batch_idx, (X_batch, y_batch) in enumerate(test_loader):
|
||||
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
|
||||
|
||||
# 确保目标张量有正确的形状
|
||||
outputs = model(X_batch)
|
||||
loss = criterion(outputs, y_batch)
|
||||
test_loss += loss.item()
|
||||
|
||||
# 更新验证进度
|
||||
if batch_idx % 3 == 0 or batch_idx == len(test_loader) - 1:
|
||||
val_progress = (batch_idx / len(test_loader)) * 100
|
||||
progress_manager.set_stage("validation", val_progress)
|
||||
@ -358,13 +240,10 @@ def train_product_model_with_transformer(
|
||||
test_loss = test_loss / len(test_loader)
|
||||
test_losses.append(test_loss)
|
||||
|
||||
# 更新学习率
|
||||
scheduler.step(test_loss)
|
||||
|
||||
# 完成当前轮次
|
||||
progress_manager.finish_epoch(train_loss, test_loss)
|
||||
|
||||
# 发送训练进度
|
||||
if (epoch + 1) % 5 == 0 or epoch == epochs - 1:
|
||||
progress = ((epoch + 1) / epochs) * 100
|
||||
current_metrics = {
|
||||
@ -376,7 +255,6 @@ def train_product_model_with_transformer(
|
||||
emit_progress(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}",
|
||||
progress=progress, metrics=current_metrics)
|
||||
|
||||
# 定期保存检查点
|
||||
if (epoch + 1) % checkpoint_interval == 0 or epoch == epochs - 1:
|
||||
checkpoint_data = {
|
||||
'epoch': epoch + 1,
|
||||
@ -395,8 +273,8 @@ def train_product_model_with_transformer(
|
||||
'num_heads': num_heads,
|
||||
'dropout': dropout_rate,
|
||||
'num_layers': num_layers,
|
||||
'sequence_length': LOOK_BACK,
|
||||
'forecast_horizon': FORECAST_HORIZON,
|
||||
'sequence_length': sequence_length,
|
||||
'forecast_horizon': forecast_horizon,
|
||||
'model_type': 'transformer'
|
||||
},
|
||||
'training_info': {
|
||||
@ -409,62 +287,55 @@ def train_product_model_with_transformer(
|
||||
}
|
||||
}
|
||||
|
||||
# 检查是否为最佳模型
|
||||
is_best = test_loss < best_loss
|
||||
if is_best:
|
||||
if test_loss < best_loss:
|
||||
best_loss = test_loss
|
||||
epochs_no_improve = 0
|
||||
# 保存最佳模型检查点
|
||||
save_checkpoint(checkpoint_data, 'best', path_info)
|
||||
model_manager.save_model(
|
||||
model_data=checkpoint_data,
|
||||
product_id=model_identifier, # 修正:使用唯一的标识符
|
||||
model_type='transformer',
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
product_name=product_name,
|
||||
version='best'
|
||||
)
|
||||
emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})")
|
||||
epochs_no_improve = 0
|
||||
else:
|
||||
epochs_no_improve += 1
|
||||
|
||||
# 保存定期的epoch检查点(如果不是最佳模型,或者即时是最佳也保存一份epoch版本)
|
||||
save_checkpoint(checkpoint_data, epoch + 1, path_info)
|
||||
emit_progress(f"💾 保存训练检查点 epoch_{epoch+1}")
|
||||
|
||||
if (epoch + 1) % 10 == 0:
|
||||
print(f"📊 Epoch {epoch+1}/{epochs}, 训练损失: {train_loss:.4f}, 测试损失: {test_loss:.4f}", flush=True)
|
||||
|
||||
# 提前停止逻辑
|
||||
if epochs_no_improve >= patience:
|
||||
emit_progress(f"连续 {patience} 个epoch测试损失未改善,提前停止训练。")
|
||||
break
|
||||
|
||||
# 计算训练时间
|
||||
training_time = time.time() - start_time
|
||||
|
||||
# 设置模型保存阶段
|
||||
progress_manager.set_stage("model_saving", 0)
|
||||
emit_progress("训练完成,正在保存模型...")
|
||||
|
||||
# 绘制损失曲线并保存到模型目录
|
||||
loss_curve_path = path_info['loss_curve_path']
|
||||
plot_loss_curve(
|
||||
train_losses,
|
||||
test_losses,
|
||||
product_name,
|
||||
'Transformer',
|
||||
save_path=loss_curve_path
|
||||
loss_curve_path = plot_loss_curve(
|
||||
train_losses,
|
||||
test_losses,
|
||||
product_name,
|
||||
'Transformer',
|
||||
model_dir=model_dir
|
||||
)
|
||||
print(f"📈 损失曲线已保存到: {loss_curve_path}", flush=True)
|
||||
|
||||
# 评估模型
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
test_pred = model(testX_tensor.to(DEVICE)).cpu().numpy()
|
||||
test_true = testY
|
||||
|
||||
# 反归一化预测结果和真实值
|
||||
test_pred_inv = scaler_y.inverse_transform(test_pred)
|
||||
test_true_inv = scaler_y.inverse_transform(test_true)
|
||||
|
||||
# 计算评估指标
|
||||
metrics = evaluate_model(test_true_inv, test_pred_inv)
|
||||
metrics['training_time'] = training_time
|
||||
|
||||
# 打印评估指标
|
||||
print(f"\n📊 模型评估指标:", flush=True)
|
||||
print(f" MSE: {metrics['mse']:.4f}", flush=True)
|
||||
print(f" RMSE: {metrics['rmse']:.4f}", flush=True)
|
||||
@ -473,9 +344,8 @@ def train_product_model_with_transformer(
|
||||
print(f" MAPE: {metrics['mape']:.2f}%", flush=True)
|
||||
print(f" ⏱️ 训练时间: {training_time:.2f}秒", flush=True)
|
||||
|
||||
# 保存最终训练完成的模型(基于最终epoch)
|
||||
final_model_data = {
|
||||
'epoch': epochs, # 最终epoch
|
||||
'epoch': epochs,
|
||||
'model_state_dict': model.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'train_loss': train_losses[-1],
|
||||
@ -491,8 +361,8 @@ def train_product_model_with_transformer(
|
||||
'num_heads': num_heads,
|
||||
'dropout': dropout_rate,
|
||||
'num_layers': num_layers,
|
||||
'sequence_length': LOOK_BACK,
|
||||
'forecast_horizon': FORECAST_HORIZON,
|
||||
'sequence_length': sequence_length,
|
||||
'forecast_horizon': forecast_horizon,
|
||||
'model_type': 'transformer'
|
||||
},
|
||||
'metrics': metrics,
|
||||
@ -510,20 +380,21 @@ def train_product_model_with_transformer(
|
||||
|
||||
progress_manager.set_stage("model_saving", 50)
|
||||
|
||||
# 检查模型性能是否达标
|
||||
# 移除R2检查,始终保存模型
|
||||
if metrics:
|
||||
# 保存最终模型
|
||||
final_model_path = path_info['model_path']
|
||||
torch.save(final_model_data, final_model_path)
|
||||
progress_manager.set_stage("model_saving", 100)
|
||||
emit_progress(f"模型已保存到 {final_model_path}")
|
||||
print(f"💾 模型已保存到 {final_model_path}", flush=True)
|
||||
else:
|
||||
final_model_path = None
|
||||
print(f"[Transformer] 训练过程中未生成评估指标,不保存最终模型。", flush=True)
|
||||
final_model_path, final_version = model_manager.save_model(
|
||||
model_data=final_model_data,
|
||||
product_id=model_identifier, # 修正:使用唯一的标识符
|
||||
model_type='transformer',
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
product_name=product_name
|
||||
)
|
||||
|
||||
progress_manager.set_stage("model_saving", 100)
|
||||
emit_progress(f"模型已保存到 {final_model_path}")
|
||||
|
||||
print(f"💾 模型已保存到 {final_model_path}", flush=True)
|
||||
|
||||
# 准备最终返回的指标
|
||||
final_metrics = {
|
||||
'mse': metrics['mse'],
|
||||
'rmse': metrics['rmse'],
|
||||
@ -531,12 +402,12 @@ def train_product_model_with_transformer(
|
||||
'r2': metrics['r2'],
|
||||
'mape': metrics['mape'],
|
||||
'training_time': training_time,
|
||||
'final_epoch': epochs
|
||||
'final_epoch': epochs,
|
||||
'version': final_version
|
||||
}
|
||||
|
||||
if final_model_path:
|
||||
emit_progress(f"✅ Transformer模型训练完成!", progress=100, metrics=final_metrics)
|
||||
else:
|
||||
emit_progress(f"❌ Transformer模型训练失败:性能不达标", progress=100, metrics={'error': '模型性能不佳'})
|
||||
|
||||
return model, metrics, epochs, final_model_path
|
||||
return model, final_metrics, epochs
|
||||
|
||||
# --- 将此训练器注册到系统中 ---
|
||||
from models.model_registry import register_trainer
|
||||
register_trainer('transformer', train_product_model_with_transformer)
|
@ -1,296 +1,142 @@
|
||||
import xgboost as xgb
|
||||
import numpy as np
|
||||
"""
|
||||
药店销售预测系统 - XGBoost 模型训练器 (插件式)
|
||||
"""
|
||||
|
||||
import time
|
||||
import pandas as pd
|
||||
import os
|
||||
import joblib
|
||||
import numpy as np
|
||||
import xgboost as xgb
|
||||
from xgboost.callback import EarlyStopping
|
||||
import matplotlib.pyplot as plt
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
|
||||
from sklearn.preprocessing import MinMaxScaler
|
||||
from xgboost.callback import EarlyStopping
|
||||
|
||||
# 从项目中导入正确的工具函数和配置
|
||||
from utils.multi_store_data_utils import load_multi_store_data
|
||||
from core.config import DEFAULT_DATA_PATH
|
||||
from utils.file_save import ModelPathManager
|
||||
# 导入核心工具
|
||||
from utils.data_utils import create_dataset
|
||||
from analysis.metrics import evaluate_model
|
||||
from utils.model_manager import model_manager
|
||||
from models.model_registry import register_trainer
|
||||
|
||||
# 重构后的原生API兼容回调
|
||||
class EpochCheckpointCallback(xgb.callback.TrainingCallback):
|
||||
def __init__(self, save_period, payload, base_path):
|
||||
super().__init__()
|
||||
self.save_period = save_period
|
||||
self.payload = payload
|
||||
self.base_path = base_path
|
||||
self.best_score = float('inf')
|
||||
|
||||
def _save_checkpoint(self, model, path_suffix):
|
||||
"""辅助函数,用于保存模型和元数据检查点"""
|
||||
metadata_path = self.base_path.replace('_model.pth', f'_{path_suffix}.pth')
|
||||
model_file_path = metadata_path.replace('.pth', '.xgb')
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(os.path.dirname(metadata_path), exist_ok=True)
|
||||
|
||||
# 保存原生Booster模型
|
||||
model.save_model(model_file_path)
|
||||
|
||||
# 更新payload中的模型文件引用
|
||||
self.payload['model_file'] = model_file_path
|
||||
joblib.dump(self.payload, metadata_path)
|
||||
|
||||
print(f"[Checkpoint] 已保存检查点到: {metadata_path}")
|
||||
|
||||
def after_iteration(self, model, epoch, evals_log):
|
||||
# 获取当前验证集的分数 (假设'test'是验证集)
|
||||
current_score = evals_log['test']['rmse'][-1]
|
||||
|
||||
# 保存最佳模型
|
||||
if current_score < self.best_score:
|
||||
self.best_score = current_score
|
||||
self._save_checkpoint(model, 'checkpoint_best')
|
||||
|
||||
# 保存周期性检查点
|
||||
if (epoch + 1) % self.save_period == 0:
|
||||
self._save_checkpoint(model, f'checkpoint_epoch_{epoch + 1}')
|
||||
|
||||
return False # 继续训练
|
||||
|
||||
def create_dataset(data, look_back=7):
|
||||
def train_product_model_with_xgboost(product_id, model_identifier, product_df, store_id, training_mode, aggregation_method, epochs, sequence_length, forecast_horizon, model_dir, **kwargs):
|
||||
"""
|
||||
将时间序列数据转换为监督学习格式。
|
||||
:param data: 输入的DataFrame,包含特征和目标。
|
||||
:param look_back: 用于预测的时间窗口大小。
|
||||
:return: X (特征), y (目标)
|
||||
使用 XGBoost 模型训练产品销售预测模型。
|
||||
此函数签名与其他训练器保持一致,以兼容注册表调用。
|
||||
"""
|
||||
X, y = [], []
|
||||
feature_columns = [col for col in data.columns if col != 'date']
|
||||
print(f"🚀 XGBoost训练器启动: model_identifier='{model_identifier}'")
|
||||
|
||||
for i in range(len(data) - look_back):
|
||||
# 展平look_back窗口内的所有特征
|
||||
features = data[feature_columns].iloc[i:(i + look_back)].values.flatten()
|
||||
X.append(features)
|
||||
# 目标是窗口后的第一个销售值
|
||||
y.append(data['sales'].iloc[i + look_back])
|
||||
|
||||
return np.array(X), np.array(y)
|
||||
# --- 1. 数据准备和验证 ---
|
||||
if product_df.empty:
|
||||
raise ValueError(f"产品 {product_id} 没有可用的销售数据")
|
||||
|
||||
def train_product_model_with_xgboost(
|
||||
product_id,
|
||||
store_id=None,
|
||||
epochs=100, # XGBoost中n_estimators更常用
|
||||
look_back=7,
|
||||
socketio=None,
|
||||
task_id=None,
|
||||
version='v1',
|
||||
path_info=None,
|
||||
**kwargs):
|
||||
"""
|
||||
使用XGBoost训练产品销售预测模型。
|
||||
"""
|
||||
min_required_samples = sequence_length + forecast_horizon
|
||||
if len(product_df) < min_required_samples:
|
||||
error_msg = (f"数据不足: 需要 {min_required_samples} 条, 实际 {len(product_df)} 条。")
|
||||
raise ValueError(error_msg)
|
||||
|
||||
product_df = product_df.sort_values('date')
|
||||
product_name = product_df['product_name'].iloc[0] if 'product_name' in product_df.columns else model_identifier
|
||||
|
||||
def emit_progress(message, progress=None):
|
||||
if socketio and task_id:
|
||||
payload = {'task_id': task_id, 'message': message}
|
||||
if progress is not None:
|
||||
payload['progress'] = progress
|
||||
socketio.emit('training_update', payload, namespace='/api/training', room=task_id)
|
||||
print(f"[{task_id}] {message}")
|
||||
# --- 2. 数据预处理和适配 ---
|
||||
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||
|
||||
X = product_df[features].values
|
||||
y = product_df[['sales']].values
|
||||
|
||||
try:
|
||||
model_path = None
|
||||
emit_progress("开始XGBoost模型训练...", 0)
|
||||
scaler_X = MinMaxScaler(feature_range=(0, 1))
|
||||
scaler_y = MinMaxScaler(feature_range=(0, 1))
|
||||
|
||||
X_scaled = scaler_X.fit_transform(X)
|
||||
y_scaled = scaler_y.fit_transform(y)
|
||||
|
||||
# 1. 加载数据
|
||||
# 使用正确的函数并从config导入路径
|
||||
full_df = load_multi_store_data(DEFAULT_DATA_PATH)
|
||||
|
||||
# 根据 store_id 和 product_id 筛选数据
|
||||
if store_id:
|
||||
df = full_df[(full_df['product_id'] == product_id) & (full_df['store_id'] == store_id)].copy()
|
||||
else:
|
||||
# 如果没有store_id,则聚合该产品在所有店铺的数据
|
||||
df = full_df[full_df['product_id'] == product_id].groupby('date').agg({
|
||||
'sales': 'sum',
|
||||
'weekday': 'first',
|
||||
'month': 'first',
|
||||
'is_holiday': 'max',
|
||||
'is_weekend': 'max',
|
||||
'is_promotion': 'max',
|
||||
'temperature': 'mean'
|
||||
}).reset_index()
|
||||
train_size = int(len(X_scaled) * 0.8)
|
||||
X_train_raw, X_test_raw = X_scaled[:train_size], X_scaled[train_size:]
|
||||
y_train_raw, y_test_raw = y_scaled[:train_size], y_scaled[train_size:]
|
||||
|
||||
if df.empty:
|
||||
raise ValueError(f"加载的数据为空 (product: {product_id}, store: {store_id}),无法进行训练。")
|
||||
|
||||
# 确保数据按日期排序
|
||||
df = df.sort_values('date').reset_index(drop=True)
|
||||
|
||||
emit_progress("数据加载完成。", 10)
|
||||
trainX, trainY = create_dataset(X_train_raw, y_train_raw, sequence_length, forecast_horizon)
|
||||
testX, testY = create_dataset(X_test_raw, y_test_raw, sequence_length, forecast_horizon)
|
||||
|
||||
# 2. 创建数据集
|
||||
features_to_use = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||
# 确保所有需要的特征都存在
|
||||
for col in features_to_use:
|
||||
if col not in df.columns:
|
||||
# 如果特征不存在,用0填充
|
||||
df[col] = 0
|
||||
|
||||
df_features = df[['date'] + features_to_use]
|
||||
# **关键适配步骤**: XGBoost 需要二维输入
|
||||
trainX = trainX.reshape(trainX.shape[0], -1)
|
||||
testX = testX.reshape(testX.shape[0], -1)
|
||||
|
||||
X, y = create_dataset(df_features, look_back)
|
||||
if X.shape[0] == 0:
|
||||
raise ValueError("创建数据集后样本数量为0,请检查数据量和look_back参数。")
|
||||
|
||||
emit_progress(f"数据集创建完成,样本数: {X.shape[0]}", 20)
|
||||
# **关键适配**: 转换为 XGBoost 核心 DMatrix 格式,以使用稳定的 xgb.train API
|
||||
dtrain = xgb.DMatrix(trainX, label=trainY)
|
||||
dtest = xgb.DMatrix(testX, label=testY)
|
||||
|
||||
# 3. 划分训练集和测试集
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
||||
|
||||
# 数据缩放
|
||||
scaler_X = MinMaxScaler()
|
||||
X_train_scaled = scaler_X.fit_transform(X_train)
|
||||
X_test_scaled = scaler_X.transform(X_test)
|
||||
|
||||
scaler_y = MinMaxScaler()
|
||||
y_train_scaled = scaler_y.fit_transform(y_train.reshape(-1, 1))
|
||||
# y_test is not scaled, used for metric calculation against inverse_transformed predictions
|
||||
|
||||
emit_progress("数据划分和缩放完成。", 30)
|
||||
# --- 3. 模型训练 (使用核心 xgb.train API) ---
|
||||
xgb_params = {
|
||||
'learning_rate': kwargs.get('learning_rate', 0.08),
|
||||
'subsample': kwargs.get('subsample', 0.75),
|
||||
'colsample_bytree': kwargs.get('colsample_bytree', 1),
|
||||
'max_depth': kwargs.get('max_depth', 7),
|
||||
'gamma': kwargs.get('gamma', 0),
|
||||
'objective': 'reg:squarederror',
|
||||
'eval_metric': 'rmse', # eval_metric 在这里是原生支持的
|
||||
'n_jobs': -1
|
||||
}
|
||||
n_estimators = kwargs.get('n_estimators', 500)
|
||||
|
||||
# 4. 切换到XGBoost原生API
|
||||
params = {
|
||||
'learning_rate': kwargs.get('learning_rate', 0.1),
|
||||
'max_depth': kwargs.get('max_depth', 5),
|
||||
'subsample': kwargs.get('subsample', 0.8),
|
||||
'colsample_bytree': kwargs.get('colsample_bytree', 0.8),
|
||||
'objective': 'reg:squarederror',
|
||||
'eval_metric': 'rmse',
|
||||
'random_state': 42
|
||||
}
|
||||
print("开始训练XGBoost模型 (使用核心xgb.train API)...")
|
||||
start_time = time.time()
|
||||
|
||||
evals_result = {}
|
||||
model = xgb.train(
|
||||
params=xgb_params,
|
||||
dtrain=dtrain,
|
||||
num_boost_round=n_estimators,
|
||||
evals=[(dtrain, 'train'), (dtest, 'test')],
|
||||
early_stopping_rounds=50, # early_stopping_rounds 在这里是原生支持的
|
||||
evals_result=evals_result,
|
||||
verbose_eval=False
|
||||
)
|
||||
|
||||
training_time = time.time() - start_time
|
||||
print(f"XGBoost模型训练完成,耗时: {training_time:.2f}秒")
|
||||
|
||||
dtrain = xgb.DMatrix(X_train_scaled, label=y_train_scaled.ravel())
|
||||
dtest = xgb.DMatrix(X_test_scaled, label=scaler_y.transform(y_test.reshape(-1, 1)).ravel())
|
||||
|
||||
emit_progress("开始模型训练...", 40)
|
||||
|
||||
# 定义验证集
|
||||
evals = [(dtrain, 'train'), (dtest, 'test')]
|
||||
|
||||
# 准备回调
|
||||
callbacks = []
|
||||
checkpoint_interval = kwargs.get('checkpoint_interval', 10) # 默认每10轮保存一次
|
||||
if path_info and path_info.get('model_path') and checkpoint_interval > 0:
|
||||
# 准备用于保存的payload,模型对象将在回调中动态更新
|
||||
checkpoint_payload = {
|
||||
'metrics': {}, # 检查点不保存最终指标
|
||||
'scaler_X': scaler_X,
|
||||
'scaler_y': scaler_y,
|
||||
'config': {
|
||||
'look_back': look_back,
|
||||
'features': features_to_use,
|
||||
'product_id': product_id,
|
||||
'store_id': store_id,
|
||||
'version': version
|
||||
}
|
||||
}
|
||||
checkpoint_callback = EpochCheckpointCallback(
|
||||
save_period=checkpoint_interval,
|
||||
payload=checkpoint_payload,
|
||||
base_path=path_info['model_path']
|
||||
)
|
||||
callbacks.append(checkpoint_callback)
|
||||
# --- 4. 模型评估 ---
|
||||
# 使用 model.best_iteration 获取最佳轮次的预测结果
|
||||
test_pred = model.predict(dtest, iteration_range=(0, model.best_iteration))
|
||||
|
||||
test_pred_inv = scaler_y.inverse_transform(test_pred.reshape(-1, forecast_horizon))
|
||||
test_true_inv = scaler_y.inverse_transform(testY.reshape(-1, forecast_horizon))
|
||||
|
||||
# 添加早停回调 (移除save_best)
|
||||
callbacks.append(EarlyStopping(rounds=10))
|
||||
metrics = evaluate_model(test_true_inv.flatten(), test_pred_inv.flatten())
|
||||
metrics['training_time'] = training_time
|
||||
|
||||
# 用于存储评估结果
|
||||
evals_result = {}
|
||||
print("\n模型评估指标:")
|
||||
print(f"MSE: {metrics['mse']:.4f}, RMSE: {metrics['rmse']:.4f}, MAE: {metrics['mae']:.4f}, R²: {metrics['r2']:.4f}, MAPE: {metrics['mape']:.2f}%")
|
||||
|
||||
model = xgb.train(
|
||||
params=params,
|
||||
dtrain=dtrain,
|
||||
num_boost_round=epochs,
|
||||
evals=evals,
|
||||
callbacks=callbacks,
|
||||
evals_result=evals_result,
|
||||
verbose_eval=False
|
||||
)
|
||||
emit_progress("模型训练完成。", 80)
|
||||
# --- 5. 模型保存 (借道 utils.model_manager) ---
|
||||
# **关键适配点**: 我们将完整的XGBoost模型对象存入字典
|
||||
# torch.save 可以序列化多种Python对象,包括sklearn模型
|
||||
model_data = {
|
||||
'model_state_dict': model, # 直接保存模型对象
|
||||
'scaler_X': scaler_X,
|
||||
'scaler_y': scaler_y,
|
||||
'config': {
|
||||
'sequence_length': sequence_length,
|
||||
'forecast_horizon': forecast_horizon,
|
||||
'model_type': 'xgboost',
|
||||
'features': features,
|
||||
'xgb_params': xgb_params
|
||||
},
|
||||
'metrics': metrics,
|
||||
'loss_history': evals_result
|
||||
}
|
||||
|
||||
# 绘制并保存损失曲线
|
||||
if path_info and path_info.get('model_path'):
|
||||
try:
|
||||
loss_curve_path = path_info['model_path'].replace('_model.pth', '_loss_curve.png')
|
||||
results = evals_result
|
||||
train_rmse = results['train']['rmse']
|
||||
test_rmse = results['test']['rmse']
|
||||
num_epochs = len(train_rmse)
|
||||
x_axis = range(0, num_epochs)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 6))
|
||||
ax.plot(x_axis, train_rmse, label='Train')
|
||||
ax.plot(x_axis, test_rmse, label='Test')
|
||||
ax.legend()
|
||||
plt.ylabel('RMSE')
|
||||
plt.xlabel('Epoch')
|
||||
plt.title('XGBoost RMSE Loss Curve')
|
||||
plt.savefig(loss_curve_path)
|
||||
plt.close(fig)
|
||||
emit_progress(f"损失曲线图已保存到: {loss_curve_path}")
|
||||
except Exception as e:
|
||||
emit_progress(f"警告: 绘制损失曲线失败: {str(e)}")
|
||||
# 调用全局管理器进行保存,复用其命名和版本逻辑
|
||||
final_model_path, final_version = model_manager.save_model(
|
||||
model_data=model_data,
|
||||
product_id=product_id,
|
||||
model_type='xgboost',
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
product_name=product_name
|
||||
)
|
||||
|
||||
print(f"XGBoost模型已通过统一管理器保存,版本: {final_version}, 路径: {final_model_path}")
|
||||
|
||||
# 5. 评估模型
|
||||
dtest_pred = xgb.DMatrix(X_test_scaled)
|
||||
y_pred_scaled = model.predict(dtest_pred)
|
||||
y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).flatten()
|
||||
|
||||
metrics = {
|
||||
'RMSE': np.sqrt(mean_squared_error(y_test, y_pred)),
|
||||
'MAE': mean_absolute_error(y_test, y_pred),
|
||||
'R2': r2_score(y_test, y_pred)
|
||||
}
|
||||
emit_progress(f"模型评估完成: {metrics}", 90)
|
||||
# 返回值遵循统一格式
|
||||
return model, metrics, final_version, final_model_path
|
||||
|
||||
# 6. 保存模型 (原生API方式)
|
||||
if path_info and path_info.get('model_path'):
|
||||
metadata_path = path_info['model_path']
|
||||
# 使用 .xgb 扩展名保存原生Booster模型
|
||||
model_file_path = metadata_path.replace('.pth', '.xgb')
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(os.path.dirname(metadata_path), exist_ok=True)
|
||||
|
||||
# 使用原生方法保存Booster模型
|
||||
model.save_model(model_file_path)
|
||||
emit_progress(f"原生XGBoost模型已保存到: {model_file_path}")
|
||||
|
||||
# 保存元数据(包括模型文件路径)
|
||||
metadata_payload = {
|
||||
'model_file': model_file_path, # 保存模型文件的引用
|
||||
'metrics': metrics,
|
||||
'scaler_X': scaler_X,
|
||||
'scaler_y': scaler_y,
|
||||
'config': {
|
||||
'look_back': look_back,
|
||||
'features': features_to_use,
|
||||
'product_id': product_id,
|
||||
'store_id': store_id,
|
||||
'version': version
|
||||
}
|
||||
}
|
||||
joblib.dump(metadata_payload, metadata_path)
|
||||
model_path = metadata_path # 确保model_path被赋值
|
||||
emit_progress(f"模型元数据已保存到: {metadata_path}", 100)
|
||||
else:
|
||||
emit_progress("警告: 未提供path_info,模型未保存。", 100)
|
||||
|
||||
return metrics, model_path
|
||||
|
||||
except Exception as e:
|
||||
emit_progress(f"XGBoost训练失败: {str(e)}", 100)
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return {'error': str(e)}, None
|
||||
# --- 将此训练器注册到系统中 ---
|
||||
register_trainer('xgboost', train_product_model_with_xgboost)
|
@ -60,7 +60,7 @@ def prepare_data(product_data, sequence_length=30, forecast_horizon=7):
|
||||
scaler_X, scaler_y: 特征和目标的归一化器
|
||||
"""
|
||||
# 创建特征和目标变量
|
||||
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||
features = ['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||
|
||||
# 预处理数据
|
||||
X_raw = product_data[features].values
|
||||
|
@ -1,256 +0,0 @@
|
||||
import os
|
||||
import json
|
||||
import hashlib
|
||||
from threading import Lock
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
class ModelPathManager:
|
||||
"""
|
||||
根据定义的规则管理模型训练产物的保存路径。
|
||||
此类旨在集中处理所有与文件系统交互的路径生成逻辑,
|
||||
确保整个应用程序遵循统一的模型保存标准。
|
||||
"""
|
||||
def __init__(self, base_dir: str = 'saved_models'):
|
||||
"""
|
||||
初始化路径管理器。
|
||||
|
||||
Args:
|
||||
base_dir (str): 所有模型保存的根目录。
|
||||
"""
|
||||
# 始终使用相对于项目根目录的相对路径
|
||||
self.base_dir = base_dir
|
||||
self.versions_file = os.path.join(self.base_dir, 'versions.json')
|
||||
self.lock = Lock()
|
||||
|
||||
# 确保根目录存在
|
||||
os.makedirs(self.base_dir, exist_ok=True)
|
||||
|
||||
def _hash_ids(self, ids: List[str]) -> str:
|
||||
"""
|
||||
对ID列表进行排序和哈希,生成一个稳定的、简短的哈希值。
|
||||
|
||||
Args:
|
||||
ids (List[str]): 需要哈希的ID列表。
|
||||
|
||||
Returns:
|
||||
str: 代表该ID集合的10位短哈希字符串。
|
||||
"""
|
||||
if not ids:
|
||||
return 'none'
|
||||
# 排序以确保对于相同集合的ID,即使顺序不同,结果也一样
|
||||
sorted_ids = sorted([str(i) for i in ids])
|
||||
id_string = ",".join(sorted_ids)
|
||||
|
||||
# 使用SHA256生成哈希值并截取前10位
|
||||
return hashlib.sha256(id_string.encode('utf-8')).hexdigest()[:10]
|
||||
|
||||
def _generate_identifier(self, training_mode: str, **kwargs: Any) -> str:
|
||||
"""
|
||||
根据训练模式和参数生成模型的唯一标识符 (identifier)。
|
||||
这个标识符将作为版本文件中的key,并用于构建目录路径。
|
||||
|
||||
Args:
|
||||
training_mode (str): 训练模式 ('product', 'store', 'global')。
|
||||
**kwargs: 从API请求中传递的参数字典。
|
||||
|
||||
Returns:
|
||||
str: 模型的唯一标识符。
|
||||
|
||||
Raises:
|
||||
ValueError: 如果缺少必要的参数。
|
||||
"""
|
||||
if training_mode == 'product':
|
||||
product_id = kwargs.get('product_id')
|
||||
if not product_id:
|
||||
raise ValueError("按药品训练模式需要 'product_id'。")
|
||||
# 对于药品训练,数据范围由 store_id 定义
|
||||
store_id = kwargs.get('store_id')
|
||||
scope = store_id if store_id is not None else 'all'
|
||||
return f"product_{product_id}_scope_{scope}"
|
||||
|
||||
elif training_mode == 'store':
|
||||
store_id = kwargs.get('store_id')
|
||||
if not store_id:
|
||||
raise ValueError("按店铺训练模式需要 'store_id'。")
|
||||
|
||||
product_scope = kwargs.get('product_scope', 'all')
|
||||
if product_scope == 'specific':
|
||||
product_ids = kwargs.get('product_ids')
|
||||
if not product_ids:
|
||||
raise ValueError("店铺训练选择 specific 范围时需要 'product_ids'。")
|
||||
# 如果只有一个ID,直接使用ID;否则使用哈希
|
||||
scope = product_ids[0] if len(product_ids) == 1 else self._hash_ids(product_ids)
|
||||
else:
|
||||
scope = 'all'
|
||||
return f"store_{store_id}_products_{scope}"
|
||||
|
||||
elif training_mode == 'global':
|
||||
training_scope = kwargs.get('training_scope', 'all')
|
||||
|
||||
if training_scope in ['all', 'all_stores_all_products']:
|
||||
scope_part = 'all'
|
||||
elif training_scope == 'selected_stores':
|
||||
store_ids = kwargs.get('store_ids')
|
||||
if not store_ids:
|
||||
raise ValueError("全局训练选择 selected_stores 范围时需要 'store_ids'。")
|
||||
# 如果只有一个ID,直接使用ID;否则使用哈希
|
||||
scope_id = store_ids[0] if len(store_ids) == 1 else self._hash_ids(store_ids)
|
||||
scope_part = f"stores_{scope_id}"
|
||||
elif training_scope == 'selected_products':
|
||||
product_ids = kwargs.get('product_ids')
|
||||
if not product_ids:
|
||||
raise ValueError("全局训练选择 selected_products 范围时需要 'product_ids'。")
|
||||
# 如果只有一个ID,直接使用ID;否则使用哈希
|
||||
scope_id = product_ids[0] if len(product_ids) == 1 else self._hash_ids(product_ids)
|
||||
scope_part = f"products_{scope_id}"
|
||||
elif training_scope == 'custom':
|
||||
store_ids = kwargs.get('store_ids')
|
||||
product_ids = kwargs.get('product_ids')
|
||||
if not store_ids or not product_ids:
|
||||
raise ValueError("全局训练选择 custom 范围时需要 'store_ids' 和 'product_ids'。")
|
||||
s_id = store_ids[0] if len(store_ids) == 1 else self._hash_ids(store_ids)
|
||||
p_id = product_ids[0] if len(product_ids) == 1 else self._hash_ids(product_ids)
|
||||
scope_part = f"custom_s_{s_id}_p_{p_id}"
|
||||
else:
|
||||
raise ValueError(f"未知的全局训练范围: {training_scope}")
|
||||
|
||||
aggregation_method = kwargs.get('aggregation_method', 'sum')
|
||||
return f"global_{scope_part}_{aggregation_method}"
|
||||
|
||||
else:
|
||||
raise ValueError(f"未知的训练模式: {training_mode}")
|
||||
|
||||
def get_next_version(self, identifier: str) -> int:
|
||||
"""
|
||||
获取指定标识符的下一个版本号。
|
||||
此方法是线程安全的。
|
||||
|
||||
Args:
|
||||
identifier (str): 模型的唯一标识符。
|
||||
|
||||
Returns:
|
||||
int: 下一个可用的版本号 (从1开始)。
|
||||
"""
|
||||
with self.lock:
|
||||
try:
|
||||
if os.path.exists(self.versions_file):
|
||||
with open(self.versions_file, 'r', encoding='utf-8') as f:
|
||||
versions_data = json.load(f)
|
||||
else:
|
||||
versions_data = {}
|
||||
|
||||
# 如果标识符不存在,当前版本为0,下一个版本即为1
|
||||
current_version = versions_data.get(identifier, 0)
|
||||
return current_version + 1
|
||||
except (IOError, json.JSONDecodeError) as e:
|
||||
# 如果文件损坏或读取失败,从0开始
|
||||
print(f"警告: 读取版本文件 '{self.versions_file}' 失败: {e}。将从版本1开始。")
|
||||
return 1
|
||||
|
||||
def save_version_info(self, identifier: str, new_version: int):
|
||||
"""
|
||||
训练成功后,更新版本文件。
|
||||
此方法是线程安全的。
|
||||
|
||||
Args:
|
||||
identifier (str): 模型的唯一标识符。
|
||||
new_version (int): 要保存的新的版本号。
|
||||
"""
|
||||
with self.lock:
|
||||
try:
|
||||
if os.path.exists(self.versions_file):
|
||||
with open(self.versions_file, 'r', encoding='utf-8') as f:
|
||||
versions_data = json.load(f)
|
||||
else:
|
||||
versions_data = {}
|
||||
|
||||
versions_data[identifier] = new_version
|
||||
|
||||
with open(self.versions_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(versions_data, f, indent=4, ensure_ascii=False)
|
||||
except (IOError, json.JSONDecodeError) as e:
|
||||
print(f"错误: 保存版本信息到 '{self.versions_file}' 失败: {e}")
|
||||
# 在这种情况下,可以选择抛出异常或采取其他恢复措施
|
||||
raise
|
||||
|
||||
def get_model_paths(self, training_mode: str, model_type: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
主入口函数:为一次新的训练获取所有相关路径和版本信息。
|
||||
此方法遵循扁平化文件存储规范,将逻辑路径编码到文件名中。
|
||||
|
||||
Args:
|
||||
training_mode (str): 训练模式 ('product', 'store', 'global')。
|
||||
model_type (str): 模型类型 (e.g., 'mlstm', 'kan')。
|
||||
**kwargs: 从API请求中传递的参数字典。
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 一个包含所有路径和关键信息的字典。
|
||||
"""
|
||||
# 1. 生成不含模型类型和版本的核心标识符,并将其中的分隔符替换为下划线
|
||||
# 例如:product/P001/all -> product_P001_all
|
||||
base_identifier = self._generate_identifier(training_mode, **kwargs)
|
||||
|
||||
# 规范化处理,将 'scope' 'products' 等关键字替换为更简洁的形式
|
||||
# 例如 product_P001_scope_all -> product_P001_all
|
||||
core_prefix = base_identifier.replace('_scope_', '_').replace('_products_', '_')
|
||||
|
||||
# 2. 构建用于版本控制的完整标识符 (不含版本号)
|
||||
# 例如: product_P001_all_mlstm
|
||||
version_control_identifier = f"{core_prefix}_{model_type}"
|
||||
|
||||
# 3. 获取下一个版本号
|
||||
next_version = self.get_next_version(version_control_identifier)
|
||||
version_str = f"v{next_version}"
|
||||
|
||||
# 4. 构建最终的文件名前缀,包含版本号
|
||||
# 例如: product_P001_all_mlstm_v2
|
||||
filename_prefix = f"{version_control_identifier}_{version_str}"
|
||||
|
||||
# 5. 确保 `saved_models` 和 `saved_models/checkpoints` 目录存在
|
||||
checkpoints_base_dir = os.path.join(self.base_dir, 'checkpoints')
|
||||
os.makedirs(self.base_dir, exist_ok=True)
|
||||
os.makedirs(checkpoints_base_dir, exist_ok=True)
|
||||
|
||||
# 6. 构建并返回包含所有扁平化路径和关键信息的字典
|
||||
return {
|
||||
"identifier": version_control_identifier, # 用于版本控制的key
|
||||
"filename_prefix": filename_prefix, # 用于数据库和文件查找
|
||||
"version": next_version,
|
||||
"base_dir": self.base_dir,
|
||||
"model_path": os.path.join(self.base_dir, f"{filename_prefix}_model.pth"),
|
||||
"metadata_path": os.path.join(self.base_dir, f"{filename_prefix}_metadata.json"),
|
||||
"loss_curve_path": os.path.join(self.base_dir, f"{filename_prefix}_loss_curve.png"),
|
||||
"checkpoint_dir": checkpoints_base_dir, # 指向公共的检查点目录
|
||||
"best_checkpoint_path": os.path.join(checkpoints_base_dir, f"{filename_prefix}_checkpoint_best.pth"),
|
||||
# 为动态epoch检查点提供一个格式化模板
|
||||
"epoch_checkpoint_template": os.path.join(checkpoints_base_dir, f"{filename_prefix}_checkpoint_epoch_{{N}}.pth")
|
||||
}
|
||||
|
||||
def get_model_path_for_prediction(self, training_mode: str, model_type: str, version: int, **kwargs: Any) -> Optional[str]:
|
||||
"""
|
||||
获取用于预测的已存在模型的完整路径 (遵循扁平化规范)。
|
||||
|
||||
Args:
|
||||
training_mode (str): 训练模式。
|
||||
model_type (str): 模型类型。
|
||||
version (int): 模型版本号。
|
||||
**kwargs: 其他用于定位模型的参数。
|
||||
|
||||
Returns:
|
||||
Optional[str]: 模型的完整路径,如果不存在则返回None。
|
||||
"""
|
||||
# 1. 生成不含模型类型和版本的核心标识符
|
||||
base_identifier = self._generate_identifier(training_mode, **kwargs)
|
||||
core_prefix = base_identifier.replace('_scope_', '_').replace('_products_', '_')
|
||||
|
||||
# 2. 构建用于版本控制的标识符
|
||||
version_control_identifier = f"{core_prefix}_{model_type}"
|
||||
|
||||
# 3. 构建完整的文件名前缀
|
||||
version_str = f"v{version}"
|
||||
filename_prefix = f"{version_control_identifier}_{version_str}"
|
||||
|
||||
# 4. 构建模型文件的完整路径
|
||||
model_path = os.path.join(self.base_dir, f"{filename_prefix}_model.pth")
|
||||
|
||||
return model_path if os.path.exists(model_path) else None
|
@ -8,6 +8,7 @@ import json
|
||||
import torch
|
||||
import glob
|
||||
from datetime import datetime
|
||||
import re
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
from core.config import DEFAULT_MODEL_DIR
|
||||
|
||||
@ -24,56 +25,91 @@ class ModelManager:
|
||||
if not os.path.exists(self.model_dir):
|
||||
os.makedirs(self.model_dir)
|
||||
|
||||
def generate_model_filename(self,
|
||||
product_id: str,
|
||||
model_type: str,
|
||||
def _get_next_version(self, model_type: str, product_id: Optional[str] = None, store_id: Optional[str] = None, training_mode: str = 'product', aggregation_method: Optional[str] = None) -> int:
|
||||
"""获取下一个模型版本号 (纯数字)"""
|
||||
search_pattern = self.generate_model_filename(
|
||||
model_type=model_type,
|
||||
version='v*',
|
||||
product_id=product_id,
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method
|
||||
)
|
||||
|
||||
full_search_path = os.path.join(self.model_dir, search_pattern)
|
||||
existing_files = glob.glob(full_search_path)
|
||||
|
||||
max_version = 0
|
||||
for f in existing_files:
|
||||
match = re.search(r'_v(\d+)\.pth$', os.path.basename(f))
|
||||
if match:
|
||||
max_version = max(max_version, int(match.group(1)))
|
||||
|
||||
return max_version + 1
|
||||
|
||||
def generate_model_filename(self,
|
||||
model_type: str,
|
||||
version: str,
|
||||
store_id: Optional[str] = None,
|
||||
training_mode: str = 'product',
|
||||
product_id: Optional[str] = None,
|
||||
store_id: Optional[str] = None,
|
||||
aggregation_method: Optional[str] = None) -> str:
|
||||
"""
|
||||
生成统一的模型文件名
|
||||
|
||||
格式规范:
|
||||
|
||||
格式规范 (v2):
|
||||
- 产品模式: {model_type}_product_{product_id}_{version}.pth
|
||||
- 店铺模式: {model_type}_store_{store_id}_{product_id}_{version}.pth
|
||||
- 全局模式: {model_type}_global_{product_id}_{aggregation_method}_{version}.pth
|
||||
- 店铺模式: {model_type}_store_{store_id}_{version}.pth
|
||||
- 全局模式: {model_type}_global_{aggregation_method}_{version}.pth
|
||||
"""
|
||||
if training_mode == 'store' and store_id:
|
||||
return f"{model_type}_store_{store_id}_{product_id}_{version}.pth"
|
||||
return f"{model_type}_store_{store_id}_{version}.pth"
|
||||
elif training_mode == 'global' and aggregation_method:
|
||||
return f"{model_type}_global_{product_id}_{aggregation_method}_{version}.pth"
|
||||
else:
|
||||
# 默认产品模式
|
||||
return f"{model_type}_global_{aggregation_method}_{version}.pth"
|
||||
elif training_mode == 'product' and product_id:
|
||||
return f"{model_type}_product_{product_id}_{version}.pth"
|
||||
else:
|
||||
# 提供一个后备或抛出错误,以避免生成无效文件名
|
||||
raise ValueError(f"无法为训练模式 '{training_mode}' 生成有效的文件名,缺少必需的ID。")
|
||||
|
||||
def save_model(self,
|
||||
def save_model(self,
|
||||
model_data: dict,
|
||||
product_id: str,
|
||||
model_type: str,
|
||||
version: str,
|
||||
model_type: str,
|
||||
store_id: Optional[str] = None,
|
||||
training_mode: str = 'product',
|
||||
aggregation_method: Optional[str] = None,
|
||||
product_name: Optional[str] = None) -> str:
|
||||
product_name: Optional[str] = None,
|
||||
version: Optional[str] = None) -> Tuple[str, str]:
|
||||
"""
|
||||
保存模型到统一位置
|
||||
保存模型到统一位置,并自动管理版本。
|
||||
|
||||
参数:
|
||||
model_data: 包含模型状态和配置的字典
|
||||
product_id: 产品ID
|
||||
model_type: 模型类型
|
||||
version: 版本号
|
||||
store_id: 店铺ID (可选)
|
||||
training_mode: 训练模式
|
||||
aggregation_method: 聚合方法 (可选)
|
||||
product_name: 产品名称 (可选)
|
||||
...
|
||||
version: (可选) 如果提供,则覆盖自动版本控制 (如 'best')。
|
||||
|
||||
返回:
|
||||
模型文件路径
|
||||
(模型文件路径, 使用的版本号)
|
||||
"""
|
||||
if version is None:
|
||||
next_version_num = self._get_next_version(
|
||||
model_type=model_type,
|
||||
product_id=product_id,
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method
|
||||
)
|
||||
version_str = f"v{next_version_num}"
|
||||
else:
|
||||
version_str = version
|
||||
|
||||
filename = self.generate_model_filename(
|
||||
product_id, model_type, version, store_id, training_mode, aggregation_method
|
||||
model_type=model_type,
|
||||
version=version_str,
|
||||
training_mode=training_mode,
|
||||
product_id=product_id,
|
||||
store_id=store_id,
|
||||
aggregation_method=aggregation_method
|
||||
)
|
||||
|
||||
# 统一保存到根目录,避免复杂的子目录结构
|
||||
@ -86,7 +122,7 @@ class ModelManager:
|
||||
'product_id': product_id,
|
||||
'product_name': product_name or product_id,
|
||||
'model_type': model_type,
|
||||
'version': version,
|
||||
'version': version_str,
|
||||
'store_id': store_id,
|
||||
'training_mode': training_mode,
|
||||
'aggregation_method': aggregation_method,
|
||||
@ -99,7 +135,7 @@ class ModelManager:
|
||||
torch.save(enhanced_model_data, model_path)
|
||||
|
||||
print(f"模型已保存: {model_path}")
|
||||
return model_path
|
||||
return model_path, version_str
|
||||
|
||||
def list_models(self,
|
||||
product_id: Optional[str] = None,
|
||||
@ -228,145 +264,58 @@ class ModelManager:
|
||||
|
||||
def parse_model_filename(self, filename: str) -> Optional[Dict]:
|
||||
"""
|
||||
解析模型文件名,提取模型信息
|
||||
|
||||
解析模型文件名,提取模型信息 (v2版)
|
||||
|
||||
支持的格式:
|
||||
- {model_type}_product_{product_id}_{version}.pth
|
||||
- {model_type}_store_{store_id}_{product_id}_{version}.pth
|
||||
- {model_type}_global_{product_id}_{aggregation_method}_{version}.pth
|
||||
- 旧格式兼容
|
||||
- 产品: {model_type}_product_{product_id}_{version}.pth
|
||||
- 店铺: {model_type}_store_{store_id}_{version}.pth
|
||||
- 全局: {model_type}_global_{aggregation_method}_{version}.pth
|
||||
"""
|
||||
if not filename.endswith('.pth'):
|
||||
return None
|
||||
|
||||
|
||||
base_name = filename.replace('.pth', '')
|
||||
|
||||
parts = base_name.split('_')
|
||||
|
||||
if len(parts) < 3:
|
||||
return None # 格式不符合基本要求
|
||||
|
||||
# **核心修复**: 采用更健壮的、从后往前的解析逻辑,以支持带下划线的模型名称
|
||||
try:
|
||||
# 新格式解析
|
||||
if '_product_' in base_name:
|
||||
# 产品模式: model_type_product_product_id_version
|
||||
parts = base_name.split('_product_')
|
||||
model_type = parts[0]
|
||||
rest = parts[1]
|
||||
|
||||
# 分离产品ID和版本
|
||||
if '_v' in rest:
|
||||
last_v_index = rest.rfind('_v')
|
||||
product_id = rest[:last_v_index]
|
||||
version = rest[last_v_index+1:]
|
||||
else:
|
||||
product_id = rest
|
||||
version = 'v1'
|
||||
|
||||
version = parts[-1]
|
||||
identifier = parts[-2]
|
||||
mode_candidate = parts[-3]
|
||||
|
||||
if mode_candidate == 'product':
|
||||
model_type = '_'.join(parts[:-3])
|
||||
return {
|
||||
'model_type': model_type,
|
||||
'product_id': product_id,
|
||||
'version': version,
|
||||
'training_mode': 'product',
|
||||
'store_id': None,
|
||||
'aggregation_method': None
|
||||
'product_id': identifier,
|
||||
'version': version,
|
||||
}
|
||||
|
||||
elif '_store_' in base_name:
|
||||
# 店铺模式: model_type_store_store_id_product_id_version
|
||||
parts = base_name.split('_store_')
|
||||
model_type = parts[0]
|
||||
rest = parts[1]
|
||||
|
||||
# 分离店铺ID、产品ID和版本
|
||||
rest_parts = rest.split('_')
|
||||
if len(rest_parts) >= 3:
|
||||
store_id = rest_parts[0]
|
||||
if rest_parts[-1].startswith('v'):
|
||||
# 最后一部分是版本号
|
||||
version = rest_parts[-1]
|
||||
product_id = '_'.join(rest_parts[1:-1])
|
||||
else:
|
||||
version = 'v1'
|
||||
product_id = '_'.join(rest_parts[1:])
|
||||
|
||||
return {
|
||||
'model_type': model_type,
|
||||
'product_id': product_id,
|
||||
'version': version,
|
||||
'training_mode': 'store',
|
||||
'store_id': store_id,
|
||||
'aggregation_method': None
|
||||
}
|
||||
|
||||
elif '_global_' in base_name:
|
||||
# 全局模式: model_type_global_product_id_aggregation_method_version
|
||||
parts = base_name.split('_global_')
|
||||
model_type = parts[0]
|
||||
rest = parts[1]
|
||||
|
||||
rest_parts = rest.split('_')
|
||||
if len(rest_parts) >= 3:
|
||||
if rest_parts[-1].startswith('v'):
|
||||
# 最后一部分是版本号
|
||||
version = rest_parts[-1]
|
||||
aggregation_method = rest_parts[-2]
|
||||
product_id = '_'.join(rest_parts[:-2])
|
||||
else:
|
||||
version = 'v1'
|
||||
aggregation_method = rest_parts[-1]
|
||||
product_id = '_'.join(rest_parts[:-1])
|
||||
|
||||
return {
|
||||
'model_type': model_type,
|
||||
'product_id': product_id,
|
||||
'version': version,
|
||||
'training_mode': 'global',
|
||||
'store_id': None,
|
||||
'aggregation_method': aggregation_method
|
||||
}
|
||||
|
||||
# 兼容以 _model.pth 结尾的格式
|
||||
elif base_name.endswith('_model'):
|
||||
name_part = base_name.rsplit('_model', 1)[0]
|
||||
parts = name_part.split('_')
|
||||
# 假设格式为 {product_id}_{...}_{model_type}_{version}
|
||||
if len(parts) >= 3:
|
||||
version = parts[-1]
|
||||
model_type = parts[-2]
|
||||
product_id = '_'.join(parts[:-2]) # The rest is product_id + scope
|
||||
return {
|
||||
'model_type': model_type,
|
||||
'product_id': product_id,
|
||||
'version': version,
|
||||
'training_mode': 'product', # Assumption
|
||||
'store_id': None,
|
||||
'aggregation_method': None
|
||||
}
|
||||
|
||||
# 兼容旧格式
|
||||
else:
|
||||
# 尝试解析其他格式
|
||||
if 'model_product_' in base_name:
|
||||
parts = base_name.split('_model_product_')
|
||||
model_type = parts[0]
|
||||
product_part = parts[1]
|
||||
|
||||
if '_v' in product_part:
|
||||
last_v_index = product_part.rfind('_v')
|
||||
product_id = product_part[:last_v_index]
|
||||
version = product_part[last_v_index+1:]
|
||||
else:
|
||||
product_id = product_part
|
||||
version = 'v1'
|
||||
|
||||
return {
|
||||
'model_type': model_type,
|
||||
'product_id': product_id,
|
||||
'version': version,
|
||||
'training_mode': 'product',
|
||||
'store_id': None,
|
||||
'aggregation_method': None
|
||||
}
|
||||
|
||||
elif mode_candidate == 'store':
|
||||
model_type = '_'.join(parts[:-3])
|
||||
return {
|
||||
'model_type': model_type,
|
||||
'training_mode': 'store',
|
||||
'store_id': identifier,
|
||||
'version': version,
|
||||
}
|
||||
elif mode_candidate == 'global':
|
||||
model_type = '_'.join(parts[:-3])
|
||||
return {
|
||||
'model_type': model_type,
|
||||
'training_mode': 'global',
|
||||
'aggregation_method': identifier,
|
||||
'version': version,
|
||||
}
|
||||
except IndexError:
|
||||
# 如果文件名部分不够,则解析失败
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f"解析文件名失败 {filename}: {e}")
|
||||
|
||||
|
||||
return None
|
||||
|
||||
def delete_model(self, model_file: str) -> bool:
|
||||
|
@ -268,7 +268,7 @@ def get_store_product_sales_data(store_id: str,
|
||||
|
||||
# 数据标准化已在load_multi_store_data中完成
|
||||
# 验证必要的列是否存在
|
||||
required_columns = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||
required_columns = ['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||
missing_columns = [col for col in required_columns if col not in df.columns]
|
||||
|
||||
if missing_columns:
|
||||
@ -324,21 +324,29 @@ def aggregate_multi_store_data(product_id: Optional[str] = None,
|
||||
grouping_entity = "所有产品"
|
||||
|
||||
# 按日期聚合(使用标准化后的列名)
|
||||
# 定义一个更健壮的聚合规范,以保留所有特征
|
||||
agg_spec = {
|
||||
'sales': aggregation_method,
|
||||
'sales_amount': aggregation_method,
|
||||
'price': 'mean',
|
||||
'weekday': 'first',
|
||||
'month': 'first',
|
||||
'is_holiday': 'first',
|
||||
'is_weekend': 'first',
|
||||
'is_promotion': 'first',
|
||||
'temperature': 'mean'
|
||||
}
|
||||
agg_dict = {}
|
||||
if aggregation_method == 'sum':
|
||||
agg_dict = {
|
||||
'sales': 'sum', # 标准化后的销量列
|
||||
'sales_amount': 'sum',
|
||||
'price': 'mean' # 标准化后的价格列,取平均值
|
||||
}
|
||||
elif aggregation_method == 'mean':
|
||||
agg_dict = {
|
||||
'sales': 'mean',
|
||||
'sales_amount': 'mean',
|
||||
'price': 'mean'
|
||||
}
|
||||
elif aggregation_method == 'median':
|
||||
agg_dict = {
|
||||
'sales': 'median',
|
||||
'sales_amount': 'median',
|
||||
'price': 'median'
|
||||
}
|
||||
|
||||
# 只聚合DataFrame中存在的列
|
||||
agg_dict = {k: v for k, v in agg_spec.items() if k in df.columns}
|
||||
# 确保列名存在
|
||||
available_cols = df.columns.tolist()
|
||||
agg_dict = {k: v for k, v in agg_dict.items() if k in available_cols}
|
||||
|
||||
# 聚合数据
|
||||
aggregated_df = df.groupby('date').agg(agg_dict).reset_index()
|
||||
|
@ -24,7 +24,6 @@ server_dir = os.path.dirname(current_dir)
|
||||
sys.path.append(server_dir)
|
||||
|
||||
from utils.logging_config import setup_api_logging, get_training_logger, log_training_progress
|
||||
from utils.file_save import ModelPathManager
|
||||
import numpy as np
|
||||
|
||||
def convert_numpy_types(obj):
|
||||
@ -45,9 +44,6 @@ class TrainingTask:
|
||||
model_type: str
|
||||
training_mode: str
|
||||
store_id: Optional[str] = None
|
||||
aggregation_method: Optional[str] = None # 新增:聚合方式
|
||||
product_scope: str = 'all'
|
||||
product_ids: Optional[list] = None
|
||||
epochs: int = 100
|
||||
status: str = "pending" # pending, running, completed, failed
|
||||
start_time: Optional[str] = None
|
||||
@ -57,8 +53,6 @@ class TrainingTask:
|
||||
error: Optional[str] = None
|
||||
metrics: Optional[Dict[str, Any]] = None
|
||||
process_id: Optional[int] = None
|
||||
path_info: Optional[Dict[str, Any]] = None # 新增字段
|
||||
version: Optional[int] = None # 新增版本字段
|
||||
|
||||
class TrainingWorker:
|
||||
"""训练工作进程"""
|
||||
@ -143,20 +137,16 @@ class TrainingWorker:
|
||||
except Exception as e:
|
||||
training_logger.error(f"进度回调失败: {e}")
|
||||
|
||||
# 执行真正的训练,传递进度回调和路径信息
|
||||
# 执行真正的训练,传递进度回调
|
||||
metrics = predictor.train_model(
|
||||
product_id=task.product_id,
|
||||
model_type=task.model_type,
|
||||
epochs=task.epochs,
|
||||
store_id=task.store_id,
|
||||
training_mode=task.training_mode,
|
||||
aggregation_method=task.aggregation_method, # 传递聚合方式
|
||||
product_scope=task.product_scope, # 传递药品范围
|
||||
product_ids=task.product_ids, # 传递药品ID列表
|
||||
socketio=None, # 子进程中不能直接使用socketio
|
||||
task_id=task.task_id,
|
||||
progress_callback=progress_callback, # 传递进度回调函数
|
||||
path_info=task.path_info # 传递路径信息
|
||||
progress_callback=progress_callback # 传递进度回调函数
|
||||
)
|
||||
|
||||
# 发送训练完成日志到主控制台
|
||||
@ -167,25 +157,11 @@ class TrainingWorker:
|
||||
})
|
||||
|
||||
if metrics:
|
||||
if 'error' in metrics:
|
||||
self.progress_queue.put({
|
||||
'task_id': task.task_id,
|
||||
'log_type': 'error',
|
||||
'message': f"❌ 训练返回错误: {metrics['error']}"
|
||||
})
|
||||
else:
|
||||
# 只有在没有错误时才格式化指标
|
||||
mse_val = metrics.get('mse', 'N/A')
|
||||
rmse_val = metrics.get('rmse', 'N/A')
|
||||
|
||||
mse_str = f"{mse_val:.4f}" if isinstance(mse_val, (int, float)) else mse_val
|
||||
rmse_str = f"{rmse_val:.4f}" if isinstance(rmse_val, (int, float)) else rmse_val
|
||||
|
||||
self.progress_queue.put({
|
||||
'task_id': task.task_id,
|
||||
'log_type': 'info',
|
||||
'message': f"📊 训练指标: MSE={mse_str}, RMSE={rmse_str}"
|
||||
})
|
||||
self.progress_queue.put({
|
||||
'task_id': task.task_id,
|
||||
'log_type': 'info',
|
||||
'message': f"📊 训练指标: MSE={metrics.get('mse', 'N/A'):.4f}, RMSE={metrics.get('rmse', 'N/A'):.4f}"
|
||||
})
|
||||
except ImportError as e:
|
||||
training_logger.error(f"❌ 导入训练器失败: {e}")
|
||||
# 返回模拟的训练结果用于测试
|
||||
@ -200,29 +176,18 @@ class TrainingWorker:
|
||||
}
|
||||
training_logger.warning("⚠️ 使用模拟训练结果")
|
||||
|
||||
# 检查训练是否成功
|
||||
# 训练完成
|
||||
task.status = "completed"
|
||||
task.end_time = time.strftime('%Y-%m-%d %H:%M:%S')
|
||||
task.progress = 100.0
|
||||
task.metrics = metrics
|
||||
task.message = "训练完成"
|
||||
|
||||
training_logger.success(f"✅ 训练任务完成 - 耗时: {task.end_time}")
|
||||
if metrics:
|
||||
# 训练成功
|
||||
task.status = "completed"
|
||||
task.end_time = time.strftime('%Y-%m-%d %H:%M:%S')
|
||||
task.progress = 100.0
|
||||
task.metrics = metrics
|
||||
task.message = "训练完成"
|
||||
|
||||
training_logger.success(f"✅ 训练任务完成 - 耗时: {task.end_time}")
|
||||
training_logger.info(f"📊 训练指标: {metrics}")
|
||||
|
||||
self.result_queue.put(('complete', asdict(task)))
|
||||
else:
|
||||
# 训练失败(性能不佳)
|
||||
# 即使性能不佳,也标记为完成,让用户决定是否使用
|
||||
task.status = "completed"
|
||||
task.end_time = time.strftime('%Y-%m-%d %H:%M:%S')
|
||||
task.metrics = metrics if metrics else {}
|
||||
task.message = "训练完成(性能可能不佳)"
|
||||
|
||||
training_logger.warning(f"⚠️ 训练完成,但性能可能不佳 (metrics: {metrics})")
|
||||
self.result_queue.put(('complete', asdict(task)))
|
||||
|
||||
self.result_queue.put(('complete', asdict(task)))
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
@ -270,7 +235,6 @@ class TrainingProcessManager:
|
||||
|
||||
# 设置日志
|
||||
self.logger = setup_api_logging()
|
||||
self.path_manager = ModelPathManager() # 实例化
|
||||
|
||||
def start(self):
|
||||
"""启动进程管理器"""
|
||||
@ -317,26 +281,18 @@ class TrainingProcessManager:
|
||||
|
||||
self.logger.info("✅ 训练进程管理器已停止")
|
||||
|
||||
def submit_task(self, training_params: Dict[str, Any], path_info: Dict[str, Any]) -> str:
|
||||
"""
|
||||
提交训练任务
|
||||
Args:
|
||||
training_params (Dict[str, Any]): 来自API请求的原始参数
|
||||
path_info (Dict[str, Any]): 由ModelPathManager生成的路径和版本信息
|
||||
"""
|
||||
def submit_task(self, product_id: str, model_type: str, training_mode: str = "product",
|
||||
store_id: str = None, epochs: int = 100, **kwargs) -> str:
|
||||
"""提交训练任务"""
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
task = TrainingTask(
|
||||
task_id=task_id,
|
||||
product_id=training_params.get('product_id'),
|
||||
model_type=training_params.get('model_type'),
|
||||
training_mode=training_params.get('training_mode', 'product'),
|
||||
store_id=training_params.get('store_id'),
|
||||
epochs=training_params.get('epochs', 100),
|
||||
aggregation_method=training_params.get('aggregation_method'), # 新增
|
||||
product_scope=training_params.get('product_scope', 'all'),
|
||||
product_ids=training_params.get('product_ids'),
|
||||
path_info=path_info # 存储路径信息
|
||||
product_id=product_id,
|
||||
model_type=model_type,
|
||||
training_mode=training_mode,
|
||||
store_id=store_id,
|
||||
epochs=epochs
|
||||
)
|
||||
|
||||
with self.lock:
|
||||
@ -345,7 +301,7 @@ class TrainingProcessManager:
|
||||
# 将任务放入队列
|
||||
self.task_queue.put(asdict(task))
|
||||
|
||||
self.logger.info(f"📋 训练任务已提交: {task_id[:8]} | {task.model_type} | {task.product_id}")
|
||||
self.logger.info(f"📋 训练任务已提交: {task_id[:8]} | {model_type} | {product_id}")
|
||||
return task_id
|
||||
|
||||
def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
|
||||
@ -385,41 +341,14 @@ class TrainingProcessManager:
|
||||
|
||||
with self.lock:
|
||||
if task_id in self.tasks:
|
||||
task = self.tasks[task_id]
|
||||
# 使用转换后的数据更新任务状态
|
||||
for key, value in serializable_task_data.items():
|
||||
if hasattr(task, key):
|
||||
setattr(task, key, value)
|
||||
|
||||
# 如果任务成功完成,则更新版本文件和任务对象中的版本号
|
||||
if action == 'complete':
|
||||
# 只有在训练成功(metrics有效)时才保存版本信息
|
||||
if task.metrics and task.metrics.get('r2', -1) >= 0:
|
||||
if task.path_info:
|
||||
# 确保使用正确的、经过规范化处理的标识符
|
||||
version_control_identifier = task.path_info.get('identifier')
|
||||
version = task.path_info.get('version')
|
||||
if version_control_identifier and version:
|
||||
try:
|
||||
self.path_manager.save_version_info(version_control_identifier, version)
|
||||
self.logger.info(f"✅ 版本信息已更新: identifier={version_control_identifier}, version={version}")
|
||||
task.version = version # 关键修复:将版本号保存到任务对象中
|
||||
except Exception as e:
|
||||
self.logger.error(f"❌ 更新版本文件失败: {e}")
|
||||
else:
|
||||
self.logger.warning(f"⚠️ 任务 {task_id} 训练性能不佳或失败,不保存版本信息。")
|
||||
setattr(self.tasks[task_id], key, value)
|
||||
|
||||
# WebSocket通知 - 使用已转换的数据
|
||||
if self.websocket_callback:
|
||||
try:
|
||||
if action == 'complete':
|
||||
# 从任务对象中获取权威的版本号
|
||||
version = None
|
||||
with self.lock:
|
||||
task = self.tasks.get(task_id)
|
||||
if task:
|
||||
version = task.version
|
||||
|
||||
# 训练完成 - 发送完成状态
|
||||
self.websocket_callback('training_update', {
|
||||
'task_id': task_id,
|
||||
@ -430,10 +359,7 @@ class TrainingProcessManager:
|
||||
'metrics': serializable_task_data.get('metrics'),
|
||||
'end_time': serializable_task_data.get('end_time'),
|
||||
'product_id': serializable_task_data.get('product_id'),
|
||||
'model_type': serializable_task_data.get('model_type'),
|
||||
'version': version, # 添加版本号
|
||||
'product_scope': serializable_task_data.get('product_scope'),
|
||||
'product_ids': serializable_task_data.get('product_ids')
|
||||
'model_type': serializable_task_data.get('model_type')
|
||||
})
|
||||
# 额外发送一个完成事件,确保前端能收到
|
||||
self.websocket_callback('training_completed', {
|
||||
@ -443,10 +369,7 @@ class TrainingProcessManager:
|
||||
'message': serializable_task_data.get('message', '训练完成'),
|
||||
'metrics': serializable_task_data.get('metrics'),
|
||||
'product_id': serializable_task_data.get('product_id'),
|
||||
'model_type': serializable_task_data.get('model_type'),
|
||||
'version': version, # 添加版本号
|
||||
'product_scope': serializable_task_data.get('product_scope'),
|
||||
'product_ids': serializable_task_data.get('product_ids')
|
||||
'model_type': serializable_task_data.get('model_type')
|
||||
})
|
||||
elif action == 'error':
|
||||
# 训练失败
|
||||
@ -458,9 +381,7 @@ class TrainingProcessManager:
|
||||
'message': serializable_task_data.get('message', '训练失败'),
|
||||
'error': serializable_task_data.get('error'),
|
||||
'product_id': serializable_task_data.get('product_id'),
|
||||
'model_type': serializable_task_data.get('model_type'),
|
||||
'product_scope': serializable_task_data.get('product_scope'),
|
||||
'product_ids': serializable_task_data.get('product_ids')
|
||||
'model_type': serializable_task_data.get('model_type')
|
||||
})
|
||||
else:
|
||||
# 状态更新
|
||||
@ -472,9 +393,7 @@ class TrainingProcessManager:
|
||||
'message': serializable_task_data.get('message', ''),
|
||||
'metrics': serializable_task_data.get('metrics'),
|
||||
'product_id': serializable_task_data.get('product_id'),
|
||||
'model_type': serializable_task_data.get('model_type'),
|
||||
'product_scope': serializable_task_data.get('product_scope'),
|
||||
'product_ids': serializable_task_data.get('product_ids')
|
||||
'model_type': serializable_task_data.get('model_type')
|
||||
})
|
||||
except Exception as e:
|
||||
self.logger.error(f"WebSocket通知失败: {e}")
|
||||
|
@ -1,259 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import shutil
|
||||
import json
|
||||
|
||||
# 将项目根目录添加到系统路径,以便导入server模块
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
from server.utils.file_save import ModelPathManager
|
||||
|
||||
def run_tests():
|
||||
"""执行所有路径生成逻辑的测试"""
|
||||
|
||||
# --- 测试设置 ---
|
||||
test_base_dir = 'test_saved_models'
|
||||
if os.path.exists(test_base_dir):
|
||||
shutil.rmtree(test_base_dir) # 清理旧的测试目录
|
||||
|
||||
path_manager = ModelPathManager(base_dir=test_base_dir)
|
||||
model_type = 'mlstm'
|
||||
|
||||
print("="*50)
|
||||
print("🚀 开始测试 ModelPathManager 路径生成逻辑...")
|
||||
print(f"测试根目录: {os.path.abspath(test_base_dir)}")
|
||||
print("="*50)
|
||||
|
||||
# --- 1. 按店铺训练 (Store Training) 测试 ---
|
||||
print("\n--- 🧪 1. 按店铺训练 (Store Training) ---")
|
||||
|
||||
# a) 店铺训练 - 所有药品
|
||||
print("\n[1a] 场景: 店铺训练 - 所有药品")
|
||||
store_payload_all = {
|
||||
'store_id': 'S001',
|
||||
'model_type': model_type,
|
||||
'training_mode': 'store',
|
||||
'product_scope': 'all'
|
||||
}
|
||||
payload = store_payload_all.copy()
|
||||
payload.pop('model_type', None)
|
||||
payload.pop('training_mode', None)
|
||||
paths_store_all = path_manager.get_model_paths(training_mode='store', model_type=model_type, **payload)
|
||||
print(f" - Identifier: {paths_store_all['identifier']}")
|
||||
print(f" - Version Dir: {paths_store_all['version_dir']}")
|
||||
assert f"store_S001_products_all_{model_type}" == paths_store_all['identifier']
|
||||
expected_path = os.path.join(test_base_dir, 'store', 'S001_all', model_type, 'v1')
|
||||
assert os.path.normpath(expected_path) == os.path.normpath(paths_store_all['version_dir'])
|
||||
|
||||
# b) 店铺训练 - 特定药品 (使用哈希)
|
||||
print("\n[1b] 场景: 店铺训练 - 特定药品 (使用哈希)")
|
||||
store_payload_specific = {
|
||||
'store_id': 'S002',
|
||||
'model_type': model_type,
|
||||
'training_mode': 'store',
|
||||
'product_scope': 'specific',
|
||||
'product_ids': ['P001', 'P005', 'P002']
|
||||
}
|
||||
payload = store_payload_specific.copy()
|
||||
payload.pop('model_type', None)
|
||||
payload.pop('training_mode', None)
|
||||
paths_store_specific = path_manager.get_model_paths(training_mode='store', model_type=model_type, **payload)
|
||||
hashed_ids = path_manager._hash_ids(['P001', 'P005', 'P002'])
|
||||
print(f" - Hashed IDs: {hashed_ids}")
|
||||
print(f" - Identifier: {paths_store_specific['identifier']}")
|
||||
print(f" - Version Dir: {paths_store_specific['version_dir']}")
|
||||
assert f"store_S002_products_{hashed_ids}_{model_type}" == paths_store_specific['identifier']
|
||||
expected_path = os.path.join(test_base_dir, 'store', f'S002_{hashed_ids}', model_type, 'v1')
|
||||
assert os.path.normpath(expected_path) == os.path.normpath(paths_store_specific['version_dir'])
|
||||
|
||||
# c) 店铺训练 - 单个指定药品
|
||||
print("\n[1c] 场景: 店铺训练 - 单个指定药品")
|
||||
store_payload_single_product = {
|
||||
'store_id': 'S003',
|
||||
'model_type': model_type,
|
||||
'training_mode': 'store',
|
||||
'product_scope': 'specific',
|
||||
'product_ids': ['P789']
|
||||
}
|
||||
payload = store_payload_single_product.copy()
|
||||
payload.pop('model_type', None)
|
||||
payload.pop('training_mode', None)
|
||||
paths_store_single_product = path_manager.get_model_paths(training_mode='store', model_type=model_type, **payload)
|
||||
print(f" - Identifier: {paths_store_single_product['identifier']}")
|
||||
print(f" - Version Dir: {paths_store_single_product['version_dir']}")
|
||||
assert f"store_S003_products_P789_{model_type}" == paths_store_single_product['identifier']
|
||||
expected_path = os.path.join(test_base_dir, 'store', 'S003_P789', model_type, 'v1')
|
||||
assert os.path.normpath(expected_path) == os.path.normpath(paths_store_single_product['version_dir'])
|
||||
|
||||
# --- 2. 按药品训练 (Product Training) 测试 ---
|
||||
print("\n--- 🧪 2. 按药品训练 (Product Training) ---")
|
||||
|
||||
# a) 药品训练 - 所有店铺
|
||||
print("\n[2a] 场景: 药品训练 - 所有店铺")
|
||||
product_payload_all = {
|
||||
'product_id': 'P123',
|
||||
'model_type': model_type,
|
||||
'training_mode': 'product',
|
||||
'store_id': None # 明确测试 None 的情况
|
||||
}
|
||||
payload = product_payload_all.copy()
|
||||
payload.pop('model_type', None)
|
||||
payload.pop('training_mode', None)
|
||||
paths_product_all = path_manager.get_model_paths(training_mode='product', model_type=model_type, **payload)
|
||||
print(f" - Identifier: {paths_product_all['identifier']}")
|
||||
print(f" - Version Dir: {paths_product_all['version_dir']}")
|
||||
assert f"product_P123_scope_all_{model_type}" == paths_product_all['identifier']
|
||||
expected_path = os.path.join(test_base_dir, 'product', 'P123_all', model_type, 'v1')
|
||||
assert os.path.normpath(expected_path) == os.path.normpath(paths_product_all['version_dir'])
|
||||
|
||||
# b) 药品训练 - 特定店铺
|
||||
print("\n[2b] 场景: 药品训练 - 特定店铺")
|
||||
product_payload_specific = {
|
||||
'product_id': 'P456',
|
||||
'store_id': 'S003',
|
||||
'model_type': model_type,
|
||||
'training_mode': 'product'
|
||||
}
|
||||
payload = product_payload_specific.copy()
|
||||
payload.pop('model_type', None)
|
||||
payload.pop('training_mode', None)
|
||||
paths_product_specific = path_manager.get_model_paths(training_mode='product', model_type=model_type, **payload)
|
||||
print(f" - Identifier: {paths_product_specific['identifier']}")
|
||||
print(f" - Version Dir: {paths_product_specific['version_dir']}")
|
||||
assert f"product_P456_scope_S003_{model_type}" == paths_product_specific['identifier']
|
||||
expected_path = os.path.join(test_base_dir, 'product', 'P456_S003', model_type, 'v1')
|
||||
assert os.path.normpath(expected_path) == os.path.normpath(paths_product_specific['version_dir'])
|
||||
|
||||
# --- 3. 全局训练 (Global Training) 测试 ---
|
||||
print("\n--- 🧪 3. 全局训练 (Global Training) ---")
|
||||
|
||||
# a) 全局训练 - 所有数据
|
||||
print("\n[3a] 场景: 全局训练 - 所有数据")
|
||||
global_payload_all = {
|
||||
'model_type': model_type,
|
||||
'training_mode': 'global',
|
||||
'training_scope': 'all',
|
||||
'aggregation_method': 'sum'
|
||||
}
|
||||
payload = global_payload_all.copy()
|
||||
payload.pop('model_type', None)
|
||||
payload.pop('training_mode', None)
|
||||
paths_global_all = path_manager.get_model_paths(training_mode='global', model_type=model_type, **payload)
|
||||
print(f" - Identifier: {paths_global_all['identifier']}")
|
||||
print(f" - Version Dir: {paths_global_all['version_dir']}")
|
||||
assert f"global_all_agg_sum_{model_type}" == paths_global_all['identifier']
|
||||
expected_path = os.path.join(test_base_dir, 'global', 'all', 'sum', model_type, 'v1')
|
||||
assert os.path.normpath(expected_path) == os.path.normpath(paths_global_all['version_dir'])
|
||||
|
||||
# a2) 全局训练 - 所有数据 (使用 all_stores_all_products)
|
||||
print("\n[3a2] 场景: 全局训练 - 所有数据 (使用 'all_stores_all_products')")
|
||||
global_payload_all_alt = {
|
||||
'model_type': model_type,
|
||||
'training_mode': 'global',
|
||||
'training_scope': 'all_stores_all_products',
|
||||
'aggregation_method': 'sum'
|
||||
}
|
||||
payload = global_payload_all_alt.copy()
|
||||
payload.pop('model_type', None)
|
||||
payload.pop('training_mode', None)
|
||||
paths_global_all_alt = path_manager.get_model_paths(training_mode='global', model_type=model_type, **payload)
|
||||
assert f"global_all_agg_sum_{model_type}" == paths_global_all_alt['identifier']
|
||||
assert os.path.normpath(expected_path) == os.path.normpath(paths_global_all_alt['version_dir'])
|
||||
|
||||
# b) 全局训练 - 自定义范围 (使用哈希)
|
||||
print("\n[3b] 场景: 全局训练 - 自定义范围 (使用哈希)")
|
||||
global_payload_custom = {
|
||||
'model_type': model_type,
|
||||
'training_mode': 'global',
|
||||
'training_scope': 'custom',
|
||||
'aggregation_method': 'mean',
|
||||
'store_ids': ['S001', 'S003'],
|
||||
'product_ids': ['P001', 'P002']
|
||||
}
|
||||
payload = global_payload_custom.copy()
|
||||
payload.pop('model_type', None)
|
||||
payload.pop('training_mode', None)
|
||||
paths_global_custom = path_manager.get_model_paths(training_mode='global', model_type=model_type, **payload)
|
||||
s_hash = path_manager._hash_ids(['S001', 'S003'])
|
||||
p_hash = path_manager._hash_ids(['P001', 'P002'])
|
||||
print(f" - Store Hash: {s_hash}, Product Hash: {p_hash}")
|
||||
print(f" - Identifier: {paths_global_custom['identifier']}")
|
||||
print(f" - Version Dir: {paths_global_custom['version_dir']}")
|
||||
assert f"global_custom_s_{s_hash}_p_{p_hash}_agg_mean_{model_type}" == paths_global_custom['identifier']
|
||||
expected_path = os.path.join(test_base_dir, 'global', 'custom', s_hash, p_hash, 'mean', model_type, 'v1')
|
||||
assert os.path.normpath(expected_path) == os.path.normpath(paths_global_custom['version_dir'])
|
||||
|
||||
# c) 全局训练 - 单个店铺
|
||||
print("\n[3c] 场景: 全局训练 - 单个店铺")
|
||||
global_payload_single_store = {
|
||||
'model_type': model_type,
|
||||
'training_mode': 'global',
|
||||
'training_scope': 'selected_stores',
|
||||
'aggregation_method': 'mean',
|
||||
'store_ids': ['S007']
|
||||
}
|
||||
payload = global_payload_single_store.copy()
|
||||
payload.pop('model_type', None)
|
||||
payload.pop('training_mode', None)
|
||||
paths_global_single_store = path_manager.get_model_paths(training_mode='global', model_type=model_type, **payload)
|
||||
print(f" - Identifier: {paths_global_single_store['identifier']}")
|
||||
print(f" - Version Dir: {paths_global_single_store['version_dir']}")
|
||||
assert f"global_stores_S007_agg_mean_{model_type}" == paths_global_single_store['identifier']
|
||||
expected_path = os.path.join(test_base_dir, 'global', 'stores', 'S007', 'mean', model_type, 'v1')
|
||||
assert os.path.normpath(expected_path) == os.path.normpath(paths_global_single_store['version_dir'])
|
||||
|
||||
# d) 全局训练 - 自定义范围 (单ID)
|
||||
print("\n[3d] 场景: 全局训练 - 自定义范围 (单ID)")
|
||||
global_payload_custom_single = {
|
||||
'model_type': model_type,
|
||||
'training_mode': 'global',
|
||||
'training_scope': 'custom',
|
||||
'aggregation_method': 'mean',
|
||||
'store_ids': ['S008'],
|
||||
'product_ids': ['P888']
|
||||
}
|
||||
payload = global_payload_custom_single.copy()
|
||||
payload.pop('model_type', None)
|
||||
payload.pop('training_mode', None)
|
||||
paths_global_custom_single = path_manager.get_model_paths(training_mode='global', model_type=model_type, **payload)
|
||||
print(f" - Identifier: {paths_global_custom_single['identifier']}")
|
||||
print(f" - Version Dir: {paths_global_custom_single['version_dir']}")
|
||||
assert f"global_custom_s_S008_p_P888_agg_mean_{model_type}" == paths_global_custom_single['identifier']
|
||||
expected_path = os.path.join(test_base_dir, 'global', 'custom', 'S008', 'P888', 'mean', model_type, 'v1')
|
||||
assert os.path.normpath(expected_path) == os.path.normpath(paths_global_custom_single['version_dir'])
|
||||
|
||||
# --- 4. 版本管理测试 ---
|
||||
print("\n--- 🧪 4. 版本管理测试 ---")
|
||||
print("\n[4a] 场景: 多次调用同一训练,版本号递增")
|
||||
|
||||
# 第一次训练
|
||||
path_manager.save_version_info(paths_store_all['identifier'], paths_store_all['version'])
|
||||
print(f" - 保存版本: {paths_store_all['identifier']} -> v{paths_store_all['version']}")
|
||||
|
||||
# 第二次训练
|
||||
payload = store_payload_all.copy()
|
||||
payload.pop('model_type', None)
|
||||
payload.pop('training_mode', None)
|
||||
paths_store_all_v2 = path_manager.get_model_paths(training_mode='store', model_type=model_type, **payload)
|
||||
print(f" - 获取新版本: {paths_store_all_v2['identifier']} -> v{paths_store_all_v2['version']}")
|
||||
assert paths_store_all_v2['version'] == 2
|
||||
expected_path = os.path.join(test_base_dir, 'store', 'S001_all', model_type, 'v2')
|
||||
assert os.path.normpath(expected_path) == os.path.normpath(paths_store_all_v2['version_dir'])
|
||||
|
||||
# 验证 versions.json 文件
|
||||
with open(path_manager.versions_file, 'r') as f:
|
||||
versions_data = json.load(f)
|
||||
print(f" - versions.json 内容: {versions_data}")
|
||||
assert versions_data[paths_store_all['identifier']] == 1
|
||||
|
||||
print("\n="*50)
|
||||
print("✅ 所有测试用例通过!")
|
||||
print("="*50)
|
||||
|
||||
# --- 清理 ---
|
||||
shutil.rmtree(test_base_dir)
|
||||
print(f"🗑️ 测试目录 '{test_base_dir}' 已清理。")
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
@ -1,118 +0,0 @@
|
||||
import unittest
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
# 将项目根目录添加到 sys.path,以解决模块导入问题
|
||||
# 这使得测试脚本可以直接运行,而无需复杂的路径配置
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
from server.utils.file_save import ModelPathManager
|
||||
|
||||
class TestModelPathManager(unittest.TestCase):
|
||||
"""
|
||||
测试 ModelPathManager 是否严格遵循扁平化文件存储规范。
|
||||
"""
|
||||
def setUp(self):
|
||||
"""在每个测试用例开始前,设置测试环境。"""
|
||||
self.test_base_dir = 'test_saved_models'
|
||||
# 清理之前的测试目录和文件
|
||||
if os.path.exists(self.test_base_dir):
|
||||
shutil.rmtree(self.test_base_dir)
|
||||
self.path_manager = ModelPathManager(base_dir=self.test_base_dir)
|
||||
|
||||
def tearDown(self):
|
||||
"""在每个测试用例结束后,清理测试环境。"""
|
||||
if os.path.exists(self.test_base_dir):
|
||||
shutil.rmtree(self.test_base_dir)
|
||||
|
||||
def test_product_mode_path_generation(self):
|
||||
"""测试 'product' 模式下的路径生成是否符合规范。"""
|
||||
print("\n--- 测试 'product' 模式 ---")
|
||||
params = {
|
||||
'training_mode': 'product',
|
||||
'model_type': 'mlstm',
|
||||
'product_id': 'P001',
|
||||
'store_id': 'all'
|
||||
}
|
||||
|
||||
# 第一次调用,版本应为 1
|
||||
paths_v1 = self.path_manager.get_model_paths(**params)
|
||||
|
||||
# 验证版本号
|
||||
self.assertEqual(paths_v1['version'], 1)
|
||||
|
||||
# 验证文件名前缀
|
||||
expected_prefix_v1 = 'product_P001_all_mlstm_v1'
|
||||
self.assertEqual(paths_v1['filename_prefix'], expected_prefix_v1)
|
||||
|
||||
# 验证各个文件的完整路径
|
||||
self.assertEqual(paths_v1['model_path'], os.path.join(self.test_base_dir, f'{expected_prefix_v1}_model.pth'))
|
||||
self.assertEqual(paths_v1['metadata_path'], os.path.join(self.test_base_dir, f'{expected_prefix_v1}_metadata.json'))
|
||||
self.assertEqual(paths_v1['loss_curve_path'], os.path.join(self.test_base_dir, f'{expected_prefix_v1}_loss_curve.png'))
|
||||
|
||||
# 验证检查点路径
|
||||
checkpoint_dir = os.path.join(self.test_base_dir, 'checkpoints')
|
||||
self.assertEqual(paths_v1['checkpoint_dir'], checkpoint_dir)
|
||||
self.assertEqual(paths_v1['best_checkpoint_path'], os.path.join(checkpoint_dir, f'{expected_prefix_v1}_checkpoint_best.pth'))
|
||||
self.assertEqual(paths_v1['epoch_checkpoint_template'], os.path.join(checkpoint_dir, f'{expected_prefix_v1}_checkpoint_epoch_{{N}}.pth'))
|
||||
|
||||
print(f"生成的文件名前缀: {paths_v1['filename_prefix']}")
|
||||
print(f"生成的模型路径: {paths_v1['model_path']}")
|
||||
print("验证通过!")
|
||||
|
||||
# 模拟一次成功的训练,以触发版本递增
|
||||
self.path_manager.save_version_info(paths_v1['identifier'], paths_v1['version'])
|
||||
|
||||
# 第二次调用,版本应为 2
|
||||
paths_v2 = self.path_manager.get_model_paths(**params)
|
||||
self.assertEqual(paths_v2['version'], 2)
|
||||
expected_prefix_v2 = 'product_P001_all_mlstm_v2'
|
||||
self.assertEqual(paths_v2['filename_prefix'], expected_prefix_v2)
|
||||
print(f"\n版本递增后,生成的文件名前缀: {paths_v2['filename_prefix']}")
|
||||
print("版本递增验证通过!")
|
||||
|
||||
def test_store_mode_path_generation_with_hash(self):
|
||||
"""测试 'store' 模式下使用哈希的路径生成。"""
|
||||
print("\n--- 测试 'store' 模式 (多药品ID哈希) ---")
|
||||
params = {
|
||||
'training_mode': 'store',
|
||||
'model_type': 'kan',
|
||||
'store_id': 'S008',
|
||||
'product_scope': 'specific',
|
||||
'product_ids': ['P002', 'P005', 'P003'] # 顺序故意打乱
|
||||
}
|
||||
|
||||
paths = self.path_manager.get_model_paths(**params)
|
||||
|
||||
# 哈希值应该是固定的,因为ID列表会先排序再哈希
|
||||
expected_hash = self.path_manager._hash_ids(sorted(['P002', 'P005', 'P003']))
|
||||
expected_prefix = f'store_S008_{expected_hash}_kan_v1'
|
||||
|
||||
self.assertEqual(paths['filename_prefix'], expected_prefix)
|
||||
self.assertEqual(paths['model_path'], os.path.join(self.test_base_dir, f'{expected_prefix}_model.pth'))
|
||||
print(f"生成的文件名前缀: {paths['filename_prefix']}")
|
||||
print("验证通过!")
|
||||
|
||||
def test_global_mode_path_generation(self):
|
||||
"""测试 'global' 模式下的路径生成。"""
|
||||
print("\n--- 测试 'global' 模式 ---")
|
||||
params = {
|
||||
'training_mode': 'global',
|
||||
'model_type': 'transformer',
|
||||
'training_scope': 'all',
|
||||
'aggregation_method': 'mean'
|
||||
}
|
||||
|
||||
paths = self.path_manager.get_model_paths(**params)
|
||||
|
||||
expected_prefix = 'global_all_agg_mean_transformer_v1'
|
||||
self.assertEqual(paths['filename_prefix'], expected_prefix)
|
||||
self.assertEqual(paths['model_path'], os.path.join(self.test_base_dir, f'{expected_prefix}_model.pth'))
|
||||
print(f"生成的文件名前缀: {paths['filename_prefix']}")
|
||||
print("验证通过!")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
334
xz修改记录日志和启动依赖.md
334
xz修改记录日志和启动依赖.md
@ -1,9 +1,10 @@
|
||||
### 根目录启动
|
||||
**1**:`uv venv`
|
||||
**2**:`uv pip install loguru numpy pandas torch matplotlib flask flask_cors flask_socketio flasgger scikit-learn tqdm pytorch_tcn pyarrow xgboost -i https://pypi.tuna.tsinghua.edu.cn/simple`
|
||||
**3**: `uv run .\server\api.py`
|
||||
`uv pip install loguru numpy pandas torch matplotlib flask flask_cors flask_socketio flasgger scikit-learn tqdm pytorch_tcn pyarrow xgboost`
|
||||
|
||||
### UI
|
||||
**1**:`npm install` `npm run dev`
|
||||
`npm install` `npm run dev`
|
||||
|
||||
|
||||
|
||||
# “预测分析”模块UI重构修改记录
|
||||
|
||||
@ -757,312 +758,31 @@
|
||||
通过以上步骤,您就可以在不改动项目其他任何部分的情况下,轻松地将数据源从本地文件切换到服务器数据库。
|
||||
|
||||
---
|
||||
**日期**: 2025-07-15 11:43
|
||||
**主题**: 修复因PyTorch版本不兼容导致的训练失败问题
|
||||
**日期**: 2025-07-15
|
||||
**主题**: 修复“按药品预测”功能并增强图表展示
|
||||
**开发者**: lyf
|
||||
|
||||
### 问题描述
|
||||
在修复了路径和依赖问题后,在某些机器上运行模型训练时,程序因 `TypeError: ReduceLROnPlateau.__init__() got an unexpected keyword argument 'verbose'` 而崩溃。但在本地开发机上运行正常。
|
||||
“预测分析” -> “按药品预测”页面无法正常使用。前端API调用地址错误,且图表渲染逻辑与后端返回的数据结构不匹配。
|
||||
|
||||
### 根本原因
|
||||
此问题是典型的**环境不一致**导致的兼容性错误。
|
||||
1. **PyTorch版本差异**: 本地开发环境安装了较旧版本的PyTorch,其学习率调度器 `ReduceLROnPlateau` 支持 `verbose` 参数(用于在学习率变化时打印日志)。
|
||||
2. **新环境**: 在其他计算机或新创建的虚拟环境中,安装了较新版本的PyTorch。在新版本中,`ReduceLROnPlateau` 的 `verbose` 参数已被移除。
|
||||
3. **代码问题**: `server/trainers/mlstm_trainer.py` 和 `server/trainers/transformer_trainer.py` 的代码中,在创建 `ReduceLROnPlateau` 实例时硬编码了 `verbose=True` 参数,导致在新版PyTorch环境下调用时出现 `TypeError`。
|
||||
### 解决方案
|
||||
对 `UI/src/views/prediction/ProductPredictionView.vue` 文件进行了以下修复和增强:
|
||||
|
||||
### 解决方案:移除已弃用的参数
|
||||
1. **全面排查**: 检查了项目中所有训练器文件 (`mlstm_trainer.py`, `transformer_trainer.py`, `kan_trainer.py`, `tcn_trainer.py`)。
|
||||
2. **精确定位**: 确认只有 `mlstm_trainer.py` 和 `transformer_trainer.py` 使用了 `ReduceLROnPlateau` 并传递了 `verbose` 参数。
|
||||
3. **执行修复**:
|
||||
* **文件**: `server/trainers/mlstm_trainer.py` 和 `server/trainers/transformer_trainer.py`
|
||||
* **位置**: `ReduceLROnPlateau` 的初始化调用处。
|
||||
* **操作**: 删除了 `verbose=True` 参数。
|
||||
```diff
|
||||
- scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', ..., verbose=True)
|
||||
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', ...)
|
||||
```
|
||||
* **原因**: 移除这个在新版PyTorch中已不存在的参数,可以从根本上解决 `TypeError`,并确保代码在不同版本的PyTorch环境中都能正常运行。此修改不影响学习率调度器的核心功能。
|
||||
1. **API端点修复**:
|
||||
* **位置**: `startPrediction` 函数。
|
||||
* **操作**: 将API请求地址从错误的 `/api/predict` 修正为正确的 `/api/prediction`。
|
||||
|
||||
2. **数据处理对齐**:
|
||||
* **位置**: `startPrediction` 和 `renderChart` 函数。
|
||||
* **操作**: 修改了数据接收逻辑,使其能够正确处理后端返回的 `history_data` 和 `prediction_data` 字段。
|
||||
|
||||
3. **图表功能增强**:
|
||||
* **位置**: `renderChart` 函数。
|
||||
* **操作**: 重构了图表渲染逻辑,现在可以同时展示历史销量(绿色实线)和预测销量(蓝色虚线),为用户提供更直观的对比分析。
|
||||
|
||||
4. **错误提示优化**:
|
||||
* **位置**: `startPrediction` 函数的 `catch` 块。
|
||||
* **操作**: 改进了错误处理,现在可以从响应中提取并显示来自后端的更具体的错误信息。
|
||||
|
||||
### 最终结果
|
||||
通过移除已弃用的 `verbose` 参数,彻底解决了由于环境差异导致的版本兼容性问题,确保了项目在所有目标机器上都能稳定地执行训练任务。
|
||||
|
||||
---
|
||||
**日期**: 2025-07-15 14:05
|
||||
**主题**: 仪表盘UI调整
|
||||
|
||||
### 描述
|
||||
根据用户请求,将仪表盘上的“数据管理”卡片替换为“店铺管理”。
|
||||
|
||||
### 主要改动
|
||||
* **文件**: `UI/src/views/DashboardView.vue`
|
||||
* **修改**:
|
||||
1. 在 `featureCards` 数组中,将原“数据管理”的对象修改为“店铺管理”。
|
||||
2. 更新了卡片的 `title`, `description`, `icon` 和 `path`,使其指向店铺管理页面 (`/store-management`)。
|
||||
3. 在脚本中导入了新的 `Shop` 图标。
|
||||
|
||||
### 结果
|
||||
仪表盘现在直接提供到“店铺管理”页面的快捷入口,提高了操作效率。
|
||||
|
||||
---
|
||||
**日期**: 2025-07-18
|
||||
**主题**: 模型保存逻辑重构与集中化管理
|
||||
|
||||
### 目标
|
||||
根据 `xz训练模型保存规则.md`,将系统中分散的模型文件保存逻辑统一重构,创建一个集中、健壮且可测试的路径管理系统。
|
||||
|
||||
### 核心成果
|
||||
1. **创建了 `server/utils/file_save.py` 模块**: 这个新模块现在是系统中处理模型文件保存路径的唯一权威来源。
|
||||
2. **实现了三种训练模式的路径生成**: 系统现在可以为“按店铺”、“按药品”和“全局”三种训练模式正确生成层级化的、可追溯的目录结构。
|
||||
3. **集成了智能ID处理**:
|
||||
* 对于包含**多个ID**的训练场景,系统会自动计算一个简短的哈希值作为目录名。
|
||||
* 对于全局训练中只包含**单个店铺或药品ID**的场景,系统会直接使用该ID作为目录名,增强了路径的可读性。
|
||||
4. **重构了整个训练流程**: 修改了API层、进程管理层以及所有模型训练器,使它们能够协同使用新的路径管理模块。
|
||||
5. **添加了自动化测试**: 创建了 `test/test_file_save_logic.py` 脚本,用于验证所有路径生成和版本管理逻辑的正确性。
|
||||
|
||||
### 详细文件修改记录
|
||||
|
||||
1. **`server/utils/file_save.py`**
|
||||
* **操作**: 创建
|
||||
* **内容**: 实现了 `ModelPathManager` 类,包含以下核心方法:
|
||||
* `_hash_ids`: 对ID列表进行排序和哈希。
|
||||
* `_generate_identifier`: 根据训练模式和参数生成唯一的模型标识符。
|
||||
* `get_next_version` / `save_version_info`: 线程安全地管理 `versions.json` 文件,实现版本号的获取和更新。
|
||||
* `get_model_paths`: 作为主入口,协调以上方法,生成包含所有产物路径的字典。
|
||||
|
||||
2. **`server/api.py`**
|
||||
* **操作**: 修改
|
||||
* **位置**: `start_training` 函数 (`/api/training` 端点)。
|
||||
* **内容**:
|
||||
* 导入并实例化 `ModelPathManager`。
|
||||
* 在接收到训练请求后,调用 `path_manager.get_model_paths()` 来获取所有路径信息。
|
||||
* 将获取到的 `path_info` 字典和原始请求参数 `training_params` 一并传递给后台训练任务管理器。
|
||||
* 修复了因重复传递关键字参数 (`model_type`, `training_mode`) 导致的 `TypeError`。
|
||||
* 修复了 `except` 块中因未导入 `traceback` 模块导致的 `UnboundLocalError`。
|
||||
|
||||
3. **`server/utils/training_process_manager.py`**
|
||||
* **操作**: 修改
|
||||
* **内容**:
|
||||
* 修改 `submit_task` 方法,使其能接收 `training_params` 和 `path_info` 字典。
|
||||
* 在 `TrainingTask` 数据类中增加了 `path_info` 字段来存储路径信息。
|
||||
* 在 `TrainingWorker` 中,将 `path_info` 传递给实际的训练函数。
|
||||
* 在 `_monitor_results` 方法中,当任务成功完成时,调用 `path_manager.save_version_info` 来更新 `versions.json`,完成版本管理的闭环。
|
||||
|
||||
4. **所有训练器文件** (`mlstm_trainer.py`, `kan_trainer.py`, `tcn_trainer.py`, `transformer_trainer.py`)
|
||||
* **操作**: 修改
|
||||
* **内容**:
|
||||
* 统一修改了主训练函数的签名,增加了 `path_info=None` 参数。
|
||||
* 移除了所有内部手动构建文件路径的逻辑。
|
||||
* 所有保存操作(最终模型、检查点、损失曲线图)现在都直接从传入的 `path_info` 字典中获取预先生成好的路径。
|
||||
* 简化了 `save_checkpoint` 辅助函数,使其也依赖 `path_info`。
|
||||
|
||||
5. **`test/test_file_save_logic.py`**
|
||||
* **操作**: 创建
|
||||
* **内容**:
|
||||
* 编写了一个独立的测试脚本,用于验证 `ModelPathManager` 的所有功能。
|
||||
* 覆盖了所有训练模式及其子场景(包括单ID和多ID哈希)。
|
||||
* 测试了版本号的正确递增和 `versions.json` 的写入。
|
||||
* 修复了测试脚本中因绝对/相对路径不匹配和重复关键字参数导致的多个 `AssertionError` 和 `TypeError`。
|
||||
|
||||
---
|
||||
**日期**: 2025-07-18 (后续修复)
|
||||
**主题**: 修复API层调用路径管理器时的 `TypeError`
|
||||
|
||||
### 问题描述
|
||||
在完成所有重构和测试后,实际运行API时,`POST /api/training` 端点在调用 `path_manager.get_model_paths` 时崩溃,并抛出 `TypeError: get_model_paths() got multiple values for keyword argument 'training_mode'`。
|
||||
|
||||
### 根本原因
|
||||
这是一个回归错误。在修复测试脚本 `test_file_save_logic.py` 中的类似问题时,我未能将相同的修复逻辑应用回 `server/api.py`。代码在调用 `get_model_paths` 时,既通过关键字参数 `training_mode=...` 明确传递了该参数,又通过 `**data` 将其再次传入,导致了冲突。
|
||||
|
||||
### 解决方案
|
||||
1. **文件**: `server/api.py`
|
||||
2. **位置**: `start_training` 函数。
|
||||
3. **操作**: 修改了对 `get_model_paths` 的调用逻辑。
|
||||
4. **内容**:
|
||||
```python
|
||||
# 移除 model_type 和 training_mode 以避免重复关键字参数错误
|
||||
data_for_path = data.copy()
|
||||
data_for_path.pop('model_type', None)
|
||||
data_for_path.pop('training_mode', None)
|
||||
path_info = path_manager.get_model_paths(
|
||||
training_mode=training_mode,
|
||||
model_type=model_type,
|
||||
**data_for_path # 传递剩余的payload
|
||||
)
|
||||
```
|
||||
5. **原因**: 在通过 `**` 解包传递参数之前,先从字典副本中移除了所有会被明确指定的关键字参数,从而确保了函数调用签名的正确性。
|
||||
|
||||
---
|
||||
**日期**: 2025-07-18 (最终修复)
|
||||
**主题**: 修复因中间层函数签名未更新导致的 `TypeError`
|
||||
|
||||
### 问题描述
|
||||
在完成所有重构后,实际运行API并触发训练任务时,程序在后台进程中因 `TypeError: train_model() got an unexpected keyword argument 'path_info'` 而崩溃。
|
||||
|
||||
### 根本原因
|
||||
这是一个典型的“中间人”遗漏错误。我成功地修改了调用链的两端(`api.py` -> `training_process_manager.py` 和 `*_trainer.py`),但忘记了修改它们之间的中间层——`server/core/predictor.py` 中的 `train_model` 方法。`training_process_manager` 尝试将 `path_info` 传递给 `predictor.train_model`,但后者的函数签名中并未包含这个新参数,导致了 `TypeError`。
|
||||
|
||||
### 解决方案
|
||||
1. **文件**: `server/core/predictor.py`
|
||||
2. **位置**: `train_model` 函数的定义处。
|
||||
3. **操作**: 在函数签名中增加了 `path_info=None` 参数。
|
||||
4. **内容**:
|
||||
```python
|
||||
def train_model(self, ..., progress_callback=None, path_info=None):
|
||||
# ...
|
||||
```
|
||||
5. **位置**: `train_model` 函数内部,对所有具体训练器(`train_product_model_with_mlstm`, `_with_kan`, etc.)的调用处。
|
||||
6. **操作**: 在所有调用中,将接收到的 `path_info` 参数透传下去。
|
||||
7. **内容**:
|
||||
```python
|
||||
# ...
|
||||
metrics = train_product_model_with_transformer(
|
||||
...,
|
||||
path_info=path_info
|
||||
)
|
||||
# ...
|
||||
```
|
||||
8. **原因**: 通过在中间层函数上“打通”`path_info` 参数的传递通道,确保了从API层到最终训练器层的完整数据流,解决了 `TypeError`。
|
||||
|
||||
---
|
||||
**日期**: 2025-07-18 (最终修复)
|
||||
**主题**: 修复“按药品训练-聚合所有店铺”模式下的路径生成错误
|
||||
|
||||
### 问题描述
|
||||
在实际运行中发现,当进行“按药品训练”并选择“聚合所有店铺”时,生成的模型保存路径中包含了错误的后缀 `_None`,而不是预期的 `_all` (例如 `.../17002608_None/...`)。
|
||||
|
||||
### 根本原因
|
||||
在 `server/utils/file_save.py` 的 `_generate_identifier` 和 `get_model_paths` 方法中,当 `store_id` 从前端传来为 `None` 时,代码 `scope = store_id if store_id else 'all'` 会因为 `store_id` 是 `None` 而正确地将 `scope` 设为 `'all'`。然而,在 `get_model_paths` 方法中,我错误地使用了 `kwargs.get('store_id', 'all')`,这在 `store_id` 键存在但值为 `None` 时,仍然会返回 `None`,导致了路径拼接错误。
|
||||
|
||||
### 解决方案
|
||||
1. **文件**: `server/utils/file_save.py`
|
||||
2. **位置**: `_generate_identifier` 和 `get_model_paths` 方法中处理 `product` 训练模式的部分。
|
||||
3. **操作**: 将逻辑从 `scope = kwargs.get('store_id', 'all')` 修改为更严谨的 `scope = store_id if store_id is not None else 'all'`。
|
||||
4. **内容**:
|
||||
```python
|
||||
# in _generate_identifier
|
||||
scope = store_id if store_id is not None else 'all'
|
||||
|
||||
# in get_model_paths
|
||||
store_id = kwargs.get('store_id')
|
||||
scope = store_id if store_id is not None else 'all'
|
||||
scope_folder = f"{product_id}_{scope}"
|
||||
```
|
||||
5. **原因**: 这种写法能正确处理 `store_id` 键不存在、或键存在但值为 `None` 的两种情况,确保在这两种情况下 `scope` 都被正确地设置为 `'all'`,从而生成符合规范的路径。
|
||||
|
||||
---
|
||||
**日期**: 2025-07-18 (最终修复)
|
||||
**主题**: 修复 `KeyError: 'price'` 和单ID哈希错误
|
||||
|
||||
### 问题描述
|
||||
在完成大规模重构后,实际运行时发现了两个隐藏的bug:
|
||||
1. 在“按店铺训练”模式下,训练因 `KeyError: 'price'` 而失败。
|
||||
2. 在“按店铺训练”模式下,当只选择一个“指定药品”时,系统仍然错误地对该药品的ID进行了哈希处理,而不是直接使用ID。
|
||||
|
||||
### 根本原因
|
||||
1. **`KeyError`**: `server/utils/multi_store_data_utils.py` 中的 `get_store_product_sales_data` 函数包含了一个硬编码的列校验,该校验要求 `price` 列必须存在,但这与当前的数据源不符。
|
||||
2. **哈希错误**: `server/utils/file_save.py` 中的 `get_model_paths` 方法在处理 `store` 训练模式时,没有复用 `_generate_identifier` 中已经写好的单ID判断逻辑,导致了逻辑不一致。
|
||||
|
||||
### 解决方案
|
||||
1. **修复 `KeyError`**:
|
||||
* **文件**: `server/utils/multi_store_data_utils.py`
|
||||
* **位置**: `get_store_product_sales_data` 函数。
|
||||
* **操作**: 从 `required_columns` 列表中移除了 `'price'`,根除了这个硬性依赖。
|
||||
2. **修复哈希逻辑**:
|
||||
* **文件**: `server/utils/file_save.py`
|
||||
* **位置**: `_generate_identifier` 和 `get_model_paths` 方法中处理 `store` 训练模式的部分。
|
||||
* **操作**: 统一了逻辑,确保在这两个地方都使用了 `scope = product_ids[0] if len(product_ids) == 1 else self._hash_ids(product_ids)` 的判断,从而在只选择一个药品时直接使用其ID。
|
||||
3. **更新测试**:
|
||||
* **文件**: `test/test_file_save_logic.py`
|
||||
* **操作**: 增加了新的测试用例,专门验证“按店铺训练-单个指定药品”场景下的路径生成是否正确。
|
||||
|
||||
---
|
||||
**日期**: 2025-07-18 (最终修复)
|
||||
**主题**: 修复全局训练范围值不匹配导致的 `ValueError`
|
||||
|
||||
### 问题描述
|
||||
在完成所有重构后,实际运行API并触发“全局训练-所有店铺所有药品”时,程序因 `ValueError: 未知的全局训练范围: all_stores_all_products` 而崩溃。
|
||||
|
||||
### 根本原因
|
||||
前端传递的 `training_scope` 值为 `all_stores_all_products`,而 `server/utils/file_save.py` 中的 `_generate_identifier` 和 `get_model_paths` 方法只处理了 `all` 这个值,未能兼容前端传递的具体字符串,导致逻辑判断失败。
|
||||
|
||||
### 解决方案
|
||||
1. **文件**: `server/utils/file_save.py`
|
||||
2. **位置**: `_generate_identifier` 和 `get_model_paths` 方法中处理 `global` 训练模式的部分。
|
||||
3. **操作**: 将逻辑判断从 `if training_scope == 'all':` 修改为 `if training_scope in ['all', 'all_stores_all_products']:`。
|
||||
4. **原因**: 使代码能够同时兼容两种表示“所有范围”的字符串,确保了前端请求的正确处理。
|
||||
5. **更新测试**:
|
||||
* **文件**: `test/test_file_save_logic.py`
|
||||
* **操作**: 增加了新的测试用例,专门验证 `training_scope` 为 `all_stores_all_products` 时的路径生成是否正确。
|
||||
|
||||
---
|
||||
**日期**: 2025-07-18 (最终优化)
|
||||
**主题**: 优化全局训练自定义模式下的单ID路径生成
|
||||
|
||||
### 问题描述
|
||||
根据用户反馈,希望在全局训练的“自定义范围”模式下,如果只选择单个店铺和/或单个药品,路径中应直接使用ID而不是哈希值,以增强可读性。
|
||||
|
||||
### 解决方案
|
||||
1. **文件**: `server/utils/file_save.py`
|
||||
2. **位置**: `_generate_identifier` 和 `get_model_paths` 方法中处理 `global` 训练模式 `custom` 范围的部分。
|
||||
3. **操作**: 为 `store_ids` 和 `product_ids` 分别增加了单ID判断逻辑。
|
||||
4. **内容**:
|
||||
```python
|
||||
# in _generate_identifier
|
||||
s_id = store_ids[0] if len(store_ids) == 1 else self._hash_ids(store_ids)
|
||||
p_id = product_ids[0] if len(product_ids) == 1 else self._hash_ids(product_ids)
|
||||
scope_part = f"custom_s_{s_id}_p_{p_id}"
|
||||
|
||||
# in get_model_paths
|
||||
store_ids = kwargs.get('store_ids', [])
|
||||
product_ids = kwargs.get('product_ids', [])
|
||||
s_id = store_ids[0] if len(store_ids) == 1 else self._hash_ids(store_ids)
|
||||
p_id = product_ids[0] if len(product_ids) == 1 else self._hash_ids(product_ids)
|
||||
scope_parts.extend(['custom', s_id, p_id])
|
||||
```
|
||||
5. **原因**: 使 `custom` 模式下的路径生成逻辑与 `selected_stores` 和 `selected_products` 模式保持一致,在只选择一个ID时优先使用ID本身,提高了路径的可读性和一致性。
|
||||
6. **更新测试**:
|
||||
* **文件**: `test/test_file_save_logic.py`
|
||||
* **操作**: 增加了新的测试用例,专门验证“全局训练-自定义范围-单ID”场景下的路径生成是否正确。
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
####
|
||||
---
|
||||
**日期**: 2025-07-18
|
||||
**主题**: 统一训练页面UI显示并修复后端数据传递
|
||||
|
||||
### 问题描述
|
||||
1. 在“按店铺训练”和“全局模型训练”页面的任务列表中,模型版本号前缺少 'v' 前缀,与“按品训练”页面不一致。
|
||||
2. 在“全局模型训练”页面的任务列表中,“聚合方式”一列始终为空,无法显示数据。
|
||||
|
||||
### 根本原因
|
||||
1. **UI层面**: `UI/src/views/StoreTrainingView.vue` 和 `UI/src/views/training/GlobalTrainingView.vue` 在渲染版本号时,没有像 `ProductTrainingView.vue` 一样添加 'v' 前缀的模板。
|
||||
2. **后端层面**: `server/utils/training_process_manager.py` 中的 `TrainingTask` 数据类缺少 `aggregation_method` 字段,导致从任务提交到数据返回的整个流程中,该信息都丢失了。
|
||||
|
||||
### 解决方案
|
||||
1. **修复前端UI**:
|
||||
* **文件**: `UI/src/views/StoreTrainingView.vue`, `UI/src/views/training/GlobalTrainingView.vue`
|
||||
* **操作**: 修改了 `el-table-column` for `version`,为其添加了 `<template>`,使用 `<el-tag>v{{ row.version }}</el-tag>` 来渲染版本号,确保了显示格式的统一。
|
||||
|
||||
2. **修复后端数据流**:
|
||||
* **文件**: `server/utils/training_process_manager.py`
|
||||
* **操作**:
|
||||
1. 在 `TrainingTask` 数据类中增加了 `aggregation_method: Optional[str] = None` 字段。
|
||||
2. 修改 `submit_task` 方法,使其在创建 `TrainingTask` 对象时能接收并设置 `aggregation_method`。
|
||||
3. 修改 `run_training_task` 方法,在调用 `predictor.train_model` 时,将 `task.aggregation_method` 传递下去。
|
||||
|
||||
### 最终结果
|
||||
通过前后端的协同修复,现在所有训练页面的UI表现完全一致,并且全局训练的“聚合方式”能够被正确记录和显示。
|
||||
“按药品预测”功能已与后端成功对接,可以正常使用,并且提供了更丰富、更健壮的可视化体验。
|
222
xz新模型添加流程.md
Normal file
222
xz新模型添加流程.md
Normal file
@ -0,0 +1,222 @@
|
||||
# 如何向系统添加新模型
|
||||
|
||||
本指南详细说明了如何向本预测系统添加一个全新的预测模型。系统采用灵活的插件式架构,集成新模型的过程非常模块化,主要围绕 **模型(Model)**、**训练器(Trainer)** 和 **预测器(Predictor)** 这三个核心组件进行。
|
||||
|
||||
## 核心理念
|
||||
|
||||
系统的核心是 `models/model_registry.py`,它维护了两个独立的注册表:一个用于训练函数,另一个用于预测函数。添加新模型的本质就是:
|
||||
|
||||
1. **定义模型**:创建模型的架构。
|
||||
2. **创建训练器**:编写一个函数来训练这个模型,并将其注册到训练器注册表。
|
||||
3. **集成预测器**:确保系统知道如何加载模型并用它来预测,然后将预测逻辑注册到预测器注册表。
|
||||
|
||||
---
|
||||
|
||||
## 第 1 步:定义模型架构
|
||||
|
||||
首先,您需要在 `ShopTRAINING/server/models/` 目录下创建一个新的 Python 文件来定义您的模型。
|
||||
|
||||
**示例:创建 `ShopTRAINING/server/models/my_new_model.py`**
|
||||
|
||||
如果您的新模型是基于 PyTorch 的,它应该是一个继承自 `torch.nn.Module` 的类。
|
||||
|
||||
```python
|
||||
# file: ShopTRAINING/server/models/my_new_model.py
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class MyNewModel(nn.Module):
|
||||
def __init__(self, input_features, hidden_size, output_sequence_length):
|
||||
"""
|
||||
定义模型的层和结构。
|
||||
"""
|
||||
super(MyNewModel, self).__init__()
|
||||
self.layer1 = nn.Linear(input_features, hidden_size)
|
||||
self.relu = nn.ReLU()
|
||||
self.layer2 = nn.Linear(hidden_size, output_sequence_length)
|
||||
# ... 可添加更复杂的结构
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
定义数据通过模型的前向传播路径。
|
||||
x 的形状通常是 (batch_size, sequence_length, num_features)
|
||||
"""
|
||||
# 确保输入是正确的形状
|
||||
# 例如,对于简单的线性层,可能需要展平
|
||||
batch_size, seq_len, features = x.shape
|
||||
x = x.view(batch_size * seq_len, features) # 展平
|
||||
|
||||
out = self.layer1(x)
|
||||
out = self.relu(out)
|
||||
out = self.layer2(out)
|
||||
|
||||
# 恢复形状以匹配输出
|
||||
out = out.view(batch_size, seq_len, -1)
|
||||
# 通常我们只关心序列的最后一个预测
|
||||
return out[:, -1, :]
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 第 2 步:创建模型训练器
|
||||
|
||||
接下来,在 `ShopTRAINING/server/trainers/` 目录下创建一个新的训练器文件。这个文件负责模型的整个训练、评估和保存流程。
|
||||
|
||||
**示例:创建 `ShopTRAINING/server/trainers/my_new_model_trainer.py`**
|
||||
|
||||
这个训练函数需要遵循系统中其他训练器(如 `xgboost_trainer.py`)的统一函数签名,并使用 `@register_trainer` 装饰器或在文件末尾调用 `register_trainer` 函数。
|
||||
|
||||
```python
|
||||
# file: ShopTRAINING/server/trainers/my_new_model_trainer.py
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
from models.model_registry import register_trainer
|
||||
from utils.model_manager import model_manager
|
||||
from analysis.metrics import evaluate_model
|
||||
from models.my_new_model import MyNewModel # 导入您的新模型
|
||||
|
||||
# 遵循系统的标准函数签名
|
||||
def train_with_mynewmodel(product_id, model_identifier, product_df, store_id, training_mode, aggregation_method, epochs, sequence_length, forecast_horizon, model_dir, **kwargs):
|
||||
print(f"🚀 MyNewModel 训练器启动: model_identifier='{model_identifier}'")
|
||||
|
||||
# --- 1. 数据准备 ---
|
||||
# (此处省略了数据加载、标准化和创建数据集的详细代码,
|
||||
# 您可以参考 xgboost_trainer.py 或其他训练器中的实现)
|
||||
# ...
|
||||
# 假设您已准备好 trainX, trainY, testX, testY, scaler_y 等变量
|
||||
# trainX = ...
|
||||
# trainY = ...
|
||||
# testX = ...
|
||||
# testY = ...
|
||||
# scaler_y = ...
|
||||
# features = [...]
|
||||
|
||||
# --- 2. 实例化模型和优化器 ---
|
||||
input_dim = trainX.shape[2] # 获取特征数量
|
||||
hidden_size = 64 # 示例超参数
|
||||
|
||||
model = MyNewModel(
|
||||
input_features=input_dim,
|
||||
hidden_size=hidden_size,
|
||||
output_sequence_length=forecast_horizon
|
||||
)
|
||||
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||
criterion = torch.nn.MSELoss()
|
||||
|
||||
# --- 3. 训练循环 ---
|
||||
print("开始训练 MyNewModel...")
|
||||
for epoch in range(epochs):
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
outputs = model(trainX)
|
||||
loss = criterion(outputs, trainY)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
if (epoch + 1) % 10 == 0:
|
||||
print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')
|
||||
|
||||
# --- 4. 模型评估 ---
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
test_pred_scaled = model(testX)
|
||||
|
||||
# 反标准化并计算指标
|
||||
# ... (参考其他训练器)
|
||||
metrics = {'rmse': 0.0, 'mae': 0.0, 'r2': 0.0, 'mape': 0.0} # 示例
|
||||
|
||||
# --- 5. 模型保存 ---
|
||||
model_data = {
|
||||
'model_state_dict': model.state_dict(),
|
||||
'scaler_X': None, # 替换为您的 scaler_X
|
||||
'scaler_y': scaler_y,
|
||||
'config': {
|
||||
'model_type': 'mynewmodel', # **关键**: 使用唯一的模型名称
|
||||
'input_dim': input_dim,
|
||||
'hidden_size': hidden_size,
|
||||
'sequence_length': sequence_length,
|
||||
'forecast_horizon': forecast_horizon,
|
||||
'features': features
|
||||
},
|
||||
'metrics': metrics
|
||||
}
|
||||
|
||||
model_manager.save_model(
|
||||
model_data=model_data,
|
||||
product_id=product_id,
|
||||
model_type='mynewmodel', # **关键**: 再次确认模型名称
|
||||
# ... 其他参数
|
||||
)
|
||||
|
||||
print("✅ MyNewModel 模型训练并保存完成!")
|
||||
return model, metrics, "v1", "path/to/model" # 返回值遵循统一格式
|
||||
|
||||
# --- 关键步骤: 将训练器注册到系统中 ---
|
||||
register_trainer('mynewmodel', train_with_mynewmodel)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 第 3 步:集成模型预测器
|
||||
|
||||
最后,您需要让系统知道如何加载和使用您的新模型进行预测。这需要在 `ShopTRAINING/server/predictors/model_predictor.py` 中进行两处修改。
|
||||
|
||||
**文件: `ShopTRAINING/server/predictors/model_predictor.py`**
|
||||
|
||||
1. **让系统知道如何构建您的模型实例**
|
||||
|
||||
在 `load_model_and_predict` 函数中,有一个 `if/elif` 结构用于根据模型类型实例化不同的模型。您需要为 `MyNewModel` 添加一个新的分支。
|
||||
|
||||
```python
|
||||
# 在 model_predictor.py 中
|
||||
|
||||
# 首先,导入您的新模型类
|
||||
from models.my_new_model import MyNewModel
|
||||
|
||||
# ... 在 load_model_and_predict 函数内部 ...
|
||||
|
||||
# ... 其他模型的 elif 分支 ...
|
||||
elif loaded_model_type == 'tcn':
|
||||
model = TCNForecaster(...)
|
||||
|
||||
# vvv 添加这个新的分支 vvv
|
||||
elif loaded_model_type == 'mynewmodel':
|
||||
model = MyNewModel(
|
||||
input_features=config['input_dim'],
|
||||
hidden_size=config['hidden_size'],
|
||||
output_sequence_length=config['forecast_horizon']
|
||||
).to(DEVICE)
|
||||
# ^^^ 添加结束 ^^^
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的模型类型: {loaded_model_type}")
|
||||
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
model.eval()
|
||||
```
|
||||
|
||||
2. **注册预测逻辑**
|
||||
|
||||
如果您的模型是一个标准的 PyTorch 模型,并且其预测逻辑与现有的模型(如 Transformer, KAN)相同,您可以直接复用 `default_pytorch_predictor`。只需在文件末尾添加一行注册代码即可。
|
||||
|
||||
```python
|
||||
# 在 model_predictor.py 文件末尾
|
||||
|
||||
# ...
|
||||
# 将增强后的默认预测器也注册给xgboost
|
||||
register_predictor('xgboost', default_pytorch_predictor)
|
||||
|
||||
# vvv 添加这行代码 vvv
|
||||
# 让 'mynewmodel' 也使用通用的 PyTorch 预测器
|
||||
register_predictor('mynewmodel', default_pytorch_predictor)
|
||||
# ^^^ 添加结束 ^^^
|
||||
```
|
||||
|
||||
如果您的模型需要特殊的预测逻辑(例如,像 XGBoost 那样有不同的输入格式或调用方式),您可以复制 `default_pytorch_predictor` 创建一个新函数,修改其内部逻辑,然后将新函数注册给 `'mynewmodel'`。
|
||||
|
||||
---
|
||||
|
||||
## 总结
|
||||
|
||||
完成以上三个步骤后,您的新模型 `MyNewModel` 就已完全集成到系统中了。系统会自动在 `trainers` 目录中发现您的新训练器。当您通过 API 或界面选择 `mynewmodel` 进行训练和预测时,系统将自动调用您刚刚编写和注册的所有相应逻辑。
|
61
xz模型添加流程.md
61
xz模型添加流程.md
@ -1,61 +0,0 @@
|
||||
# 为系统添加新模型的标准流程
|
||||
|
||||
本文档总结了向本项目中添加一个新的预测模型(以XGBoost为例)的标准流程,旨在为未来的开发工作提供清晰、可复用的路线图。
|
||||
|
||||
---
|
||||
|
||||
### 第1步:创建模型训练器
|
||||
|
||||
这是最核心的一步,负责实现新模型的训练逻辑。
|
||||
|
||||
1. **创建新文件**:在 [`server/trainers/`](server/trainers/) 目录下,创建一个新的Python文件,例如 `new_model_trainer.py`。
|
||||
|
||||
2. **定义训练函数**:在该文件中,定义一个核心的训练函数,遵循项目的标准签名,接收 `product_id`, `store_id`, `epochs`, `path_info` 等参数。
|
||||
|
||||
3. **实现函数内部逻辑**:
|
||||
* **数据加载**:使用 [`utils.multi_store_data_utils.load_multi_store_data`](server/utils/multi_store_data_utils.py) 加载数据,并根据 `product_id` 和 `store_id` 进行筛选。
|
||||
* **数据预处理**:将时间序列数据转换为监督学习格式。对于像XGBoost这样的模型,这意味着创建一个“滑动窗口”(如我们实现的 `create_dataset` 函数)。
|
||||
* **数据缩放 (关键)**:**必须**使用 `sklearn.preprocessing.MinMaxScaler` 对特征 (`X`) 和目标 (`y`) 进行归一化。创建并训练 `scaler_X` 和 `scaler_y` 两个缩放器。
|
||||
* **模型训练**:初始化您的新模型,并使用**归一化后**的数据进行训练。
|
||||
* **生成损失曲线 (可选但推荐)**:如果模型支持,在训练过程中捕获训练集和验证集的损失,然后使用 `matplotlib` 绘制损失曲线图,并将其保存为 `..._loss_curve.png`。
|
||||
* **保存检查点 (可选但推荐)**:如果模型支持回调(Callbacks),可以实现一个自定义回调函数,用于按指定轮次间隔保存模型检查点 (`..._checkpoint_epoch_{N}.pth`)。
|
||||
* **模型评估**:使用**反归一化后**的预测结果来计算评估指标(RMSE, R2等)。
|
||||
* **模型保存 (关键)**:
|
||||
* 创建一个字典(payload),**必须**包含以下内容:`'model'` (训练好的模型对象), `'config'` (训练配置), `'scaler_X'` (特征缩放器), 和 `'scaler_y'` (目标缩放器)。
|
||||
* 使用正确的库(PyTorch模型用 `torch.save`,其他模型如XGBoost用 `joblib.dump`)将这个字典保存到 `path_info['model_path']` 指定的路径。**文件名统一使用 `.pth` 扩展名**。
|
||||
|
||||
---
|
||||
|
||||
### 第2步:将训练器集成到系统中
|
||||
|
||||
1. **注册训练器**:打开 [`server/trainers/__init__.py`](server/trainers/__init__.py)。
|
||||
* 在文件顶部,从您的新训练器文件中导入训练函数,例如 `from .new_model_trainer import train_product_model_with_new_model`。
|
||||
* 在文件底部的 `__all__` 列表中,添加您的新训练函数名。
|
||||
|
||||
2. **添加调度逻辑**:打开 [`server/core/predictor.py`](server/core/predictor.py)。
|
||||
* 在 `train_model` 方法中,找到 `if/elif` 逻辑块,为您的新模型添加一个新的 `elif model_type == 'new_model':` 分支,并在此分支中调用您的新训练函数。
|
||||
|
||||
---
|
||||
|
||||
### 第3步:实现预测逻辑
|
||||
|
||||
1. **修改预测器**:打开 [`server/predictors/model_predictor.py`](server/predictors/model_predictor.py)。
|
||||
2. **添加预测分支**:在 `load_model_and_predict` 函数中,找到 `if/elif` 逻辑块,为您的新模型添加一个新的 `elif model_type == 'new_model':` 分支。
|
||||
3. **实现分支内部逻辑**:
|
||||
* 使用与保存时相同的库(例如 `joblib.load`)加载 `.pth` 模型文件。
|
||||
* 从加载的字典中,**必须**提取出 `model`, `config`, `scaler_X`, 和 `scaler_y`。
|
||||
* 准备用于预测的输入数据(例如,最近N天的数据)。
|
||||
* 在进行预测时,**必须**先用 `scaler_X.transform` 对输入数据进行归一化。
|
||||
* 得到模型的预测结果后,**必须**用 `scaler_y.inverse_transform` 将结果反归一化,以得到真实的预测值。
|
||||
|
||||
---
|
||||
|
||||
### 第4步:更新API和依赖项
|
||||
|
||||
1. **更新API端点**:打开 [`server/api.py`](server/api.py)。
|
||||
* 在 `/api/training` 路由(`start_training` 函数)的 `valid_model_types` 列表中,添加您的新模型ID(例如 `'new_model'`)。
|
||||
* 在 `/api/model_types` 路由(`get_model_types` 函数)返回的列表中,添加您新模型的描述信息,以便它能显示在前端界面。
|
||||
|
||||
2. **更新依赖**:打开 [`requirements.txt`](requirements.txt) 文件,添加您的新模型所需要的Python库(例如 `xgboost`)。
|
||||
|
||||
遵循以上四个步骤,您就可以将任何新的预测模型一致、健壮地集成到现有系统中。
|
50
xz模型预测修改.md
50
xz模型预测修改.md
@ -1,50 +0,0 @@
|
||||
# 模型预测路径修复记录
|
||||
|
||||
**修改时间**: 2025-07-18 18:43:50
|
||||
|
||||
## 1. 问题背景
|
||||
|
||||
系统在进行模型预测时出现“文件未找到”的错误。经分析,根本原因是模型加载逻辑(预测时)与模型保存逻辑(训练时)遵循了不一致的路径规则。
|
||||
|
||||
- **保存规则 (新)**: 遵循 `xz训练模型保存规则.md`,将模型保存在结构化的层级目录中,例如 `saved_models/product/{product_id}_all/mlstm/v1/model.pth`。
|
||||
- **加载逻辑 (旧)**: 代码中硬编码了扁平化的文件路径查找方式,例如在 `saved_models` 根目录下直接查找名为 `{product_id}_{model_type}_v1.pth` 的文件。
|
||||
|
||||
这种不匹配导致预测功能无法定位到已经训练好的模型。
|
||||
|
||||
## 2. 修复方案
|
||||
|
||||
为了解决此问题,我们采取了集中化路径管理的策略,确保整个应用程序都通过一个统一的管理器来生成和获取模型路径。
|
||||
|
||||
## 3. 代码修改详情
|
||||
|
||||
### 第一处修改:增强路径管理器
|
||||
|
||||
- **文件**: [`server/utils/file_save.py`](server/utils/file_save.py)
|
||||
- **操作**: 在 `ModelPathManager` 类中新增了 `get_model_path_for_prediction` 方法。
|
||||
- **目的**:
|
||||
- 提供一个专门用于**预测时**获取模型路径的函数。
|
||||
- 该函数严格按照 `xz训练模型保存规则.md` 中定义的层级结构来构建模型文件的完整路径。
|
||||
- 这使得路径生成逻辑被集中管理,避免了代码各处的硬编码。
|
||||
|
||||
### 第二处修改:修复API预测接口
|
||||
|
||||
- **文件**: [`server/api.py`](server/api.py)
|
||||
- **操作**:
|
||||
1. 修改了 `/api/prediction` 接口 (`predict` 函数) 的内部逻辑。
|
||||
2. 修改了辅助函数 `run_prediction` 的定义和实现。
|
||||
- **目的**:
|
||||
- **`predict` 函数**: 移除了所有旧的、手動拼接模型文件名的错误逻辑。转而实例化 `ModelPathManager` 并调用其新的 `get_model_path_for_prediction` 方法来获取准确、唯一的模型路径。
|
||||
- **`run_prediction` 函数**: 更新了其函数签名,增加了 `model_path` 参数,使其能够接收并向下传递由 `predict` 函数获取到的正确路径。同时,简化了其内部逻辑,直接调用 `load_model_and_predict`。
|
||||
|
||||
### 第三处修改:修复模型加载器
|
||||
|
||||
- **文件**: [`server/predictors/model_predictor.py`](server/predictors/model_predictor.py)
|
||||
- **操作**: 修改了 `load_model_and_predict` 函数。
|
||||
- **目的**:
|
||||
- 更新函数签名,添加了 `model_path` 参数。
|
||||
- **彻底移除了**函数内部所有复杂的、用于猜测模型文件位置的旧逻辑。
|
||||
- 函数现在完全依赖于从 `api.py` 传递过来的 `model_path` 参数来加载模型,确保了加载路径的准确性。
|
||||
|
||||
## 4. 结论
|
||||
|
||||
通过以上三处修改,我们打通了从API请求到模型文件加载的整个链路,确保了所有环节都遵循统一的、正确的结构化路径规则。这从根本上解决了因路径不匹配导致模型读取失败的问题。
|
@ -1,93 +0,0 @@
|
||||
# 扁平化模型数据处理规范 (最终版)
|
||||
|
||||
**版本**: 4.0 (最终版)
|
||||
**核心思想**: 逻辑路径被转换为文件名的一部分,实现极致扁平化的文件存储。
|
||||
|
||||
---
|
||||
|
||||
## 一、 文件保存规则
|
||||
|
||||
### 1.1. 核心原则
|
||||
|
||||
所有元数据都被编码到文件名中。一个逻辑上的层级路径(例如 `product/P001_all/mlstm/v2`)应该被转换为一个用下划线连接的文件名前缀(`product_P001_all_mlstm_v2`)。
|
||||
|
||||
### 1.2. 文件存储位置
|
||||
|
||||
- **最终产物**: 所有最终模型、元数据文件、损失图等,统一存放在 `saved_models/` 根目录下。
|
||||
- **过程文件**: 所有训练过程中的检查点文件,统一存放在 `saved_models/checkpoints/` 目录下。
|
||||
|
||||
### 1.3. 文件名生成规则
|
||||
|
||||
1. **构建逻辑路径**: 根据训练参数(模式、范围、类型、版本)确定逻辑路径。
|
||||
- *示例*: `product/P001_all/mlstm/v2`
|
||||
|
||||
2. **生成文件名前缀**: 将逻辑路径中的所有 `/` 替换为 `_`。
|
||||
- *示例*: `product_P001_all_mlstm_v2`
|
||||
|
||||
3. **拼接文件后缀**: 在前缀后加上描述文件类型的后缀。
|
||||
- `_model.pth`
|
||||
- `_loss_curve.png`
|
||||
- `_checkpoint_best.pth`
|
||||
- `_checkpoint_epoch_{N}.pth`
|
||||
|
||||
#### **完整示例:**
|
||||
|
||||
- **最终模型**: `saved_models/product_P001_all_mlstm_v2_model.pth`
|
||||
- **最佳检查点**: `saved_models/checkpoints/product_P001_all_mlstm_v2_checkpoint_best.pth`
|
||||
- **Epoch 50 检查点**: `saved_models/checkpoints/product_P001_all_mlstm_v2_checkpoint_epoch_50.pth`
|
||||
|
||||
---
|
||||
|
||||
## 二、 文件读取规则
|
||||
|
||||
1. **确定模型元数据**: 根据需求确定要加载的模型的训练模式、范围、类型和版本。
|
||||
2. **构建文件名前缀**: 按照与保存时相同的逻辑,将元数据拼接成文件名前缀(例如 `product_P001_all_mlstm_v2`)。
|
||||
3. **定位文件**:
|
||||
- 要加载最终模型,查找文件: `saved_models/{prefix}_model.pth`。
|
||||
- 要加载最佳检查点,查找文件: `saved_models/checkpoints/{prefix}_checkpoint_best.pth`。
|
||||
|
||||
---
|
||||
|
||||
## 三、 数据库存储规则
|
||||
|
||||
数据库用于索引,应存储足以重构文件名前缀的关键元数据。
|
||||
|
||||
#### **`models` 表结构建议:**
|
||||
|
||||
| 字段名 | 类型 | 描述 | 示例 |
|
||||
| :--- | :--- | :--- | :--- |
|
||||
| `id` | INTEGER | 主键 | 1 |
|
||||
| `filename_prefix` | TEXT | **完整文件名前缀,可作为唯一标识** | `product_P001_all_mlstm_v2` |
|
||||
| `model_identifier`| TEXT | 用于版本控制的标识符 (不含版本) | `product_P001_all_mlstm` |
|
||||
| `version` | INTEGER | 版本号 | `2` |
|
||||
| `status` | TEXT | 模型状态 | `completed`, `training`, `failed` |
|
||||
| `created_at` | TEXT | 创建时间 | `2025-07-21 02:29:00` |
|
||||
| `metrics_summary`| TEXT | 关键性能指标的JSON字符串 | `{"rmse": 10.5, "r2": 0.89}` |
|
||||
|
||||
#### **保存逻辑:**
|
||||
- 训练完成后,向表中插入一条记录。`filename_prefix` 字段是查找与该次训练相关的所有文件的关键。
|
||||
|
||||
---
|
||||
|
||||
## 四、 版本记录规则
|
||||
|
||||
版本管理依赖于根目录下的 `versions.json` 文件,以实现原子化、线程安全的版本号递增。
|
||||
|
||||
- **文件名**: `versions.json`
|
||||
- **位置**: `saved_models/versions.json`
|
||||
- **结构**: 一个JSON对象,`key` 是不包含版本号的标识符,`value` 是该标识符下最新的版本号(整数)。
|
||||
- **Key**: `{prefix_core}_{model_type}` (例如: `product_P001_all_mlstm`)
|
||||
- **Value**: `Integer`
|
||||
|
||||
#### **`versions.json` 示例:**
|
||||
```json
|
||||
{
|
||||
"product_P001_all_mlstm": 2,
|
||||
"store_S001_P002_transformer": 1
|
||||
}
|
||||
```
|
||||
|
||||
#### **版本管理流程:**
|
||||
|
||||
1. **获取新版本**: 开始训练前,构建 `key`。读取 `versions.json`,找到对应 `key` 的 `value`。新版本号为 `value + 1` (若key不存在,则为 `1`)。
|
||||
2. **更新版本**: 训练成功后,将新的版本号写回到 `versions.json`。此过程**必须使用文件锁**以防止并发冲突。
|
150
新需求开发流程.md
Normal file
150
新需求开发流程.md
Normal file
@ -0,0 +1,150 @@
|
||||
# 新需求开发标准流程
|
||||
|
||||
本文档旨在提供一个标准、安全、高效的新功能开发工作流,涵盖从创建功能分支到最终合并回主开发分支的完整步骤,并融入日常开发的最佳实践。
|
||||
|
||||
## 核心开发理念
|
||||
|
||||
- **主分支保护**: `lyf-dev` 是团队的主开发分支,应始终保持稳定和可部署状态。所有新功能开发都必须在独立的功能分支中进行。
|
||||
- **功能分支**: 每个新需求(如 `req0001`)都对应一个功能分支(如 `lyf-dev-req0001`)。分支命名应清晰、有意义。
|
||||
- **小步快跑**: 频繁提交(Commit)、频繁推送(Push)、频繁与主线同步(`rebase` 或 `merge`)。这能有效减少后期合并的难度和风险。
|
||||
- **清晰的历史**: 保持 Git 提交历史的可读性,方便代码审查(Code Review)和问题追溯。
|
||||
|
||||
---
|
||||
|
||||
## 每日工作第一步:同步最新代码
|
||||
|
||||
**无论你昨天工作到哪里,每天开始新一天的工作时,请务必执行以下步骤。这是保证团队高效协作、避免合并冲突的基石。**
|
||||
|
||||
1. **更新主开发分支 `lyf-dev`**
|
||||
```bash
|
||||
# 切换到主开发分支
|
||||
git checkout lyf-dev
|
||||
|
||||
# 从远程拉取最新代码,--prune 会清理远程已删除的分支引用
|
||||
git pull origin lyf-dev --prune
|
||||
```
|
||||
|
||||
2. **同步你的功能分支 (团队选择一种方案)**
|
||||
将主分支的最新代码同步到你的功能分支,有两种主流方案,请团队根据偏好选择其一。
|
||||
|
||||
---
|
||||
### 方案一 (推荐): 使用 `rebase` 保持历史清爽
|
||||
|
||||
此方案会让你的分支提交历史保持为一条直线,非常清晰。
|
||||
|
||||
```bash
|
||||
# 切换回你正在开发的功能分支(例如 lyf-dev-req0001)
|
||||
git checkout lyf-dev-req0001
|
||||
|
||||
# 使用 rebase 将 lyf-dev 的最新更新同步到你的分支
|
||||
git rebase lyf-dev
|
||||
```
|
||||
- **优点**: 最终的提交历史非常干净、线性,便于代码审查和问题追溯。
|
||||
- **缺点**: 重写了提交历史,需要使用 `git push --force-with-lease` 强制推送。
|
||||
- **冲突解决**:
|
||||
1. 手动修改冲突文件。
|
||||
2. 执行 `git add <冲突文件>`。
|
||||
3. 执行 `git rebase --continue`。
|
||||
4. 若想中止,执行 `git rebase --abort`。
|
||||
|
||||
---
|
||||
### 方案二: 使用 `merge` 保留完整历史
|
||||
|
||||
此方案会忠实记录每一次合并操作,不修改历史提交。
|
||||
|
||||
```bash
|
||||
# 切换回你正在开发的功能分支(例如 lyf-dev-req0001)
|
||||
git checkout lyf-dev-req0001
|
||||
|
||||
# 将最新的 lyf-dev 合并到你当前的分支
|
||||
git merge lyf-dev
|
||||
```
|
||||
- **优点**: 操作安全,不修改历史,推送时无需强制。
|
||||
- **缺点**: 会在功能分支中产生额外的合并提交记录 (e.g., "Merge branch 'lyf-dev' into ..."),使历史记录变得复杂。
|
||||
- **冲突解决**:
|
||||
1. 手动修改冲突文件。
|
||||
2. 执行 `git add <冲突文件>`。
|
||||
3. 执行 `git commit` 完成合并。
|
||||
|
||||
---
|
||||
|
||||
## 完整开发流程
|
||||
|
||||
### 1. 开始新需求:创建功能分支
|
||||
|
||||
**当你需要开启一个全新的功能开发时:**
|
||||
|
||||
1. **确保 `lyf-dev` 已是最新**
|
||||
(此步骤已在“每日工作第一步”中完成,此处作为提醒)
|
||||
|
||||
2. **从 `lyf-dev` 创建并切换到新分支**
|
||||
假设新需求编号是 `req0002`:
|
||||
```bash
|
||||
# 这会从最新的 lyf-dev 创建 lyf-dev-req0002 分支并切换过去
|
||||
git checkout -b lyf-dev-req0002
|
||||
```
|
||||
|
||||
### 2. 日常开发:提交与推送
|
||||
|
||||
**在你的功能分支上(如 `lyf-dev-req0002`)进行开发:**
|
||||
|
||||
1. **编码与本地提交**
|
||||
完成一个小的、完整的功能点后,就进行一次提交。
|
||||
```bash
|
||||
# 查看修改状态
|
||||
git status
|
||||
# 添加所有相关文件到暂存区
|
||||
git add .
|
||||
# 提交并撰写清晰的说明(feat: 功能, fix: 修复, docs: 文档等)
|
||||
git commit -m "feat: 实现用户认证模块"
|
||||
```
|
||||
|
||||
2. **推送改动到远程备份**
|
||||
为了代码安全和方便团队协作,应频繁将本地提交推送到远程。
|
||||
```bash
|
||||
# -u 参数会设置本地分支跟踪远程分支,后续只需 git push 即可
|
||||
git push -u origin lyf-dev-req0002
|
||||
```
|
||||
|
||||
### 3. 功能完成:合并回主线
|
||||
|
||||
**当功能开发完成并通过测试后,将其合并回 `lyf-dev`:**
|
||||
|
||||
1. **最后一次同步**
|
||||
在正式合并前,做最后一次同步,确保分支包含了 `lyf-dev` 的所有最新内容。
|
||||
(重复“每日工作第一步”中的同步流程)
|
||||
|
||||
2. **切换到主分支并拉取最新代码**
|
||||
```bash
|
||||
git checkout lyf-dev
|
||||
git pull origin lyf-dev
|
||||
```
|
||||
|
||||
3. **合并功能分支**
|
||||
我们使用 `--no-ff` (No Fast-forward) 参数来创建合并提交,这样可以清晰地记录“合并了一个功能”这个行为。
|
||||
```bash
|
||||
# --no-ff 会创建一个新的合并提交,保留分支历史
|
||||
git merge --no-ff lyf-dev-req0002
|
||||
```
|
||||
如果同步工作做得好,这一步通常不会有冲突。
|
||||
|
||||
4. **推送合并后的主分支**
|
||||
```bash
|
||||
git push origin lyf-dev
|
||||
```
|
||||
|
||||
### 4. 清理工作
|
||||
|
||||
**合并完成后,功能分支的历史使命就完成了:**
|
||||
|
||||
1. **删除远程分支**
|
||||
```bash
|
||||
git push origin --delete lyf-dev-req0002
|
||||
```
|
||||
|
||||
2. **删除本地分支**
|
||||
```bash
|
||||
git branch -d lyf-dev-req0002
|
||||
```
|
||||
|
||||
遵循以上流程,可以确保团队的开发工作流清晰、安全且高效。
|
466
系统调用逻辑与核心代码分析.md
Normal file
466
系统调用逻辑与核心代码分析.md
Normal file
@ -0,0 +1,466 @@
|
||||
# 系统调用逻辑与核心代码分析
|
||||
|
||||
本文档旨在详细阐述本销售预测系统的端到端调用链路,从系统启动、前端交互、后端处理,到最终的模型训练、预测和图表展示。
|
||||
|
||||
## 1. 系统启动
|
||||
|
||||
系统由两部分组成:Vue.js前端和Flask后端。
|
||||
|
||||
### 1.1. 启动后端API服务
|
||||
|
||||
在项目根目录下,通过以下命令启动后端服务:
|
||||
|
||||
```bash
|
||||
python server/api.py
|
||||
```
|
||||
|
||||
该命令会启动一个Flask应用,监听在 `http://localhost:5000`,并提供所有API和WebSocket服务。
|
||||
|
||||
### 1.2. 启动前端开发服务器
|
||||
|
||||
进入 `UI` 目录,执行以下命令:
|
||||
|
||||
```bash
|
||||
cd UI
|
||||
npm install
|
||||
npm run dev
|
||||
```
|
||||
|
||||
这将启动Vite开发服务器,通常在 `http://localhost:5173`,并自动打开浏览器访问前端页面。
|
||||
|
||||
## 2. 核心调用链路概览
|
||||
|
||||
以最核心的 **“按药品训练 -> 按药品预测”** 流程为例,其高层调用链路如下:
|
||||
|
||||
**训练流程:**
|
||||
`前端UI` -> `POST /api/training` -> `api.py: start_training()` -> `TrainingManager` -> `后台进程` -> `predictor.py: train_model()` -> `[model]_trainer.py: train_product_model_with_*()` -> `保存模型.pth`
|
||||
|
||||
**预测流程:**
|
||||
`前端UI` -> `POST /api/prediction` -> `api.py: predict()` -> `predictor.py: predict()` -> `model_predictor.py: load_model_and_predict()` -> `加载模型.pth` -> `返回预测JSON` -> `前端图表渲染`
|
||||
|
||||
## 3. 详细流程:按药品训练
|
||||
|
||||
此流程的目标是为特定药品训练一个专用的预测模型。
|
||||
|
||||
### 3.1. 前端交互与API请求
|
||||
|
||||
1. **用户操作**: 用户在 **“按药品训练”** 页面 ([`UI/src/views/training/ProductTrainingView.vue`](UI/src/views/training/ProductTrainingView.vue:1)) 选择一个药品、一个模型类型(如Transformer)、设置训练轮次(Epochs),然后点击 **“启动药品训练”** 按钮。
|
||||
|
||||
2. **触发函数**: 点击事件调用 [`startTraining`](UI/src/views/training/ProductTrainingView.vue:521) 方法。
|
||||
|
||||
3. **构建Payload**: `startTraining` 方法构建一个包含训练参数的 `payload` 对象。关键字段是 `training_mode: 'product'`,用于告知后端这是针对特定产品的训练。
|
||||
|
||||
*核心代码 ([`UI/src/views/training/ProductTrainingView.vue`](UI/src/views/training/ProductTrainingView.vue:521))*
|
||||
```javascript
|
||||
const startTraining = async () => {
|
||||
// ... 表单验证 ...
|
||||
trainingLoading.value = true;
|
||||
try {
|
||||
const endpoint = "/api/training";
|
||||
|
||||
const payload = {
|
||||
product_id: form.product_id,
|
||||
store_id: form.data_scope === 'global' ? null : form.store_id,
|
||||
model_type: form.model_type,
|
||||
epochs: form.epochs,
|
||||
training_mode: 'product' // 标识这是药品训练模式
|
||||
};
|
||||
|
||||
const response = await axios.post(endpoint, payload);
|
||||
// ... 处理响应,启动WebSocket监听 ...
|
||||
}
|
||||
// ... 错误处理 ...
|
||||
};
|
||||
```
|
||||
|
||||
4. **API请求**: 使用 `axios` 向后端 `POST /api/training` 发送请求。
|
||||
|
||||
### 3.2. 后端API接收与任务分发
|
||||
|
||||
1. **路由处理**: 后端 [`server/api.py`](server/api.py:1) 中的 [`@app.route('/api/training', methods=['POST'])`](server/api.py:933) 装饰器捕获该请求,并由 [`start_training()`](server/api.py:971) 函数处理。
|
||||
|
||||
2. **任务提交**: `start_training()` 函数解析请求中的JSON数据,然后调用 `training_manager.submit_task()` 将训练任务提交到一个后台进程池中执行,以避免阻塞API主线程。这使得API可以立即返回一个任务ID,而训练在后台异步进行。
|
||||
|
||||
*核心代码 ([`server/api.py`](server/api.py:971))*
|
||||
```python
|
||||
@app.route('/api/training', methods=['POST'])
|
||||
def start_training():
|
||||
data = request.get_json()
|
||||
|
||||
training_mode = data.get('training_mode', 'product')
|
||||
model_type = data.get('model_type')
|
||||
epochs = data.get('epochs', 50)
|
||||
product_id = data.get('product_id')
|
||||
store_id = data.get('store_id')
|
||||
|
||||
if not model_type or (training_mode == 'product' and not product_id):
|
||||
return jsonify({'error': '缺少必要参数'}), 400
|
||||
|
||||
try:
|
||||
# 使用训练进程管理器提交任务
|
||||
task_id = training_manager.submit_task(
|
||||
product_id=product_id or "unknown",
|
||||
model_type=model_type,
|
||||
training_mode=training_mode,
|
||||
store_id=store_id,
|
||||
epochs=epochs
|
||||
)
|
||||
|
||||
logger.info(f"🚀 训练任务已提交到进程管理器: {task_id[:8]}")
|
||||
|
||||
return jsonify({
|
||||
'message': '模型训练已开始(使用独立进程)',
|
||||
'task_id': task_id,
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 提交训练任务失败: {str(e)}")
|
||||
return jsonify({'error': f'启动训练任务失败: {str(e)}'}), 500
|
||||
```
|
||||
|
||||
### 3.3. 核心训练逻辑
|
||||
|
||||
1. **调用核心预测器**: 后台进程最终会调用 [`server/core/predictor.py`](server/core/predictor.py:1) 中的 [`PharmacyPredictor.train_model()`](server/core/predictor.py:63) 方法。
|
||||
|
||||
2. **数据准备**: `train_model` 方法首先根据 `training_mode` (`'product'`) 和 `product_id` 从数据源加载并聚合所有店铺关于该药品的销售数据。
|
||||
|
||||
3. **分发到具体训练器**: 接着,它根据 `model_type` 调用相应的训练函数。例如,如果 `model_type` 是 `transformer`,它会调用 `train_product_model_with_transformer`。
|
||||
|
||||
*核心代码 ([`server/core/predictor.py`](server/core/predictor.py:63))*
|
||||
```python
|
||||
class PharmacyPredictor:
|
||||
def train_model(self, product_id, model_type='transformer', ..., training_mode='product', ...):
|
||||
# ...
|
||||
if training_mode == 'product':
|
||||
product_data = self.data[self.data['product_id'] == product_id].copy()
|
||||
# ...
|
||||
|
||||
# 根据训练模式构建模型标识符
|
||||
model_identifier = product_id
|
||||
|
||||
try:
|
||||
if model_type == 'transformer':
|
||||
model_result, metrics, actual_version = train_product_model_with_transformer(
|
||||
product_id=product_id,
|
||||
model_identifier=model_identifier,
|
||||
product_df=product_data,
|
||||
# ... 其他参数 ...
|
||||
)
|
||||
# ... 其他模型的elif分支 ...
|
||||
|
||||
return metrics
|
||||
except Exception as e:
|
||||
# ... 错误处理 ...
|
||||
return None
|
||||
```
|
||||
|
||||
### 3.4. 模型训练与保存
|
||||
|
||||
1. **具体训练器**: 以 [`server/trainers/transformer_trainer.py`](server/trainers/transformer_trainer.py:1) 为例,`train_product_model_with_transformer` 函数执行以下步骤:
|
||||
* **数据预处理**: 调用 `prepare_data` 和 `prepare_sequences` 将原始销售数据转换为模型可以理解的、带有时间序列特征的监督学习格式(输入序列和目标序列)。
|
||||
* **模型实例化**: 创建 `TimeSeriesTransformer` 模型实例。
|
||||
* **训练循环**: 执行指定的 `epochs` 次训练,计算损失并使用优化器更新模型权重。
|
||||
* **进度更新**: 在训练过程中,通过 `socketio.emit` 向前端发送 `training_progress` 事件,实时更新进度条和日志。
|
||||
* **模型保存**: 训练完成后,将模型权重 (`model.state_dict()`)、完整的模型配置 (`config`) 以及数据缩放器 (`scaler_X`, `scaler_y`) 打包成一个字典(checkpoint),并使用 `torch.save()` 保存到 `.pth` 文件中。文件名由 `get_model_file_path` 根据 `model_identifier`、`model_type` 和 `version` 统一生成。
|
||||
|
||||
*核心代码 ([`server/trainers/transformer_trainer.py`](server/trainers/transformer_trainer.py:33))*
|
||||
```python
|
||||
def train_product_model_with_transformer(...):
|
||||
# ... 数据准备 ...
|
||||
|
||||
# 定义模型配置
|
||||
config = {
|
||||
'input_dim': input_dim,
|
||||
'output_dim': forecast_horizon,
|
||||
'hidden_size': hidden_size,
|
||||
# ... 所有必要的超参数 ...
|
||||
'model_type': 'transformer'
|
||||
}
|
||||
|
||||
model = TimeSeriesTransformer(...)
|
||||
|
||||
# ... 训练循环 ...
|
||||
|
||||
# 保存模型
|
||||
checkpoint = {
|
||||
'model_state_dict': model.state_dict(),
|
||||
'config': config,
|
||||
'scaler_X': scaler_X,
|
||||
'scaler_y': scaler_y,
|
||||
'metrics': test_metrics
|
||||
}
|
||||
|
||||
model_path = get_model_file_path(model_identifier, 'transformer', version)
|
||||
torch.save(checkpoint, model_path)
|
||||
|
||||
return model, test_metrics, version
|
||||
```
|
||||
|
||||
## 4. 详细流程:按药品预测
|
||||
|
||||
训练完成后,用户可以使用已保存的模型进行预测。
|
||||
|
||||
### 4.1. 前端交互与API请求
|
||||
|
||||
1. **用户操作**: 用户在 **“按药品预测”** 页面 ([`UI/src/views/prediction/ProductPredictionView.vue`](UI/src/views/prediction/ProductPredictionView.vue:1)) 选择同一个药品、对应的模型和版本,然后点击 **“开始预测”**。
|
||||
|
||||
2. **触发函数**: 点击事件调用 [`startPrediction`](UI/src/views/prediction/ProductPredictionView.vue:202) 方法。
|
||||
|
||||
3. **构建Payload**: 该方法构建一个包含预测参数的 `payload`。
|
||||
|
||||
*核心代码 ([`UI/src/views/prediction/ProductPredictionView.vue`](UI/src/views/prediction/ProductPredictionView.vue:202))*
|
||||
```javascript
|
||||
const startPrediction = async () => {
|
||||
try {
|
||||
predicting.value = true
|
||||
const payload = {
|
||||
product_id: form.product_id,
|
||||
model_type: form.model_type,
|
||||
version: form.version,
|
||||
future_days: form.future_days,
|
||||
// training_mode is implicitly 'product' here
|
||||
}
|
||||
const response = await axios.post('/api/prediction', payload)
|
||||
if (response.data.status === 'success') {
|
||||
predictionResult.value = response.data
|
||||
await nextTick()
|
||||
renderChart()
|
||||
}
|
||||
// ... 错误处理 ...
|
||||
}
|
||||
// ...
|
||||
}
|
||||
```
|
||||
|
||||
4. **API请求**: 使用 `axios` 向后端 `POST /api/prediction` 发送请求。
|
||||
|
||||
### 4.2. 后端API接收与预测执行
|
||||
|
||||
1. **路由处理**: [`server/api.py`](server/api.py:1) 中的 [`@app.route('/api/prediction', methods=['POST'])`](server/api.py:1413) 捕获请求,由 [`predict()`](server/api.py:1469) 函数处理。
|
||||
|
||||
2. **调用核心预测器**: `predict()` 函数解析参数,然后调用 `run_prediction` 辅助函数,该函数内部再调用 [`server/core/predictor.py`](server/core/predictor.py:1) 中的 [`PharmacyPredictor.predict()`](server/core/predictor.py:295) 方法。
|
||||
|
||||
*核心代码 ([`server/api.py`](server/api.py:1469))*
|
||||
```python
|
||||
@app.route('/api/prediction', methods=['POST'])
|
||||
def predict():
|
||||
try:
|
||||
data = request.json
|
||||
# ... 解析参数 ...
|
||||
training_mode = data.get('training_mode', 'product')
|
||||
product_id = data.get('product_id')
|
||||
# ...
|
||||
|
||||
# 根据模式确定模型标识符
|
||||
if training_mode == 'product':
|
||||
model_identifier = product_id
|
||||
# ...
|
||||
|
||||
# 执行预测
|
||||
prediction_result = run_prediction(model_type, product_id, model_id, ...)
|
||||
|
||||
# ... 格式化响应 ...
|
||||
return jsonify(response_data)
|
||||
except Exception as e:
|
||||
# ... 错误处理 ...
|
||||
```
|
||||
|
||||
3. **分发到模型加载器**: [`PharmacyPredictor.predict()`](server/core/predictor.py:295) 方法的主要作用是再次根据 `training_mode` 和 `product_id` 确定 `model_identifier`,然后将所有参数传递给 [`server/predictors/model_predictor.py`](server/predictors/model_predictor.py:1) 中的 [`load_model_and_predict()`](server/predictors/model_predictor.py:26) 函数。
|
||||
|
||||
*核心代码 ([`server/core/predictor.py`](server/core/predictor.py:295))*
|
||||
```python
|
||||
class PharmacyPredictor:
|
||||
def predict(self, product_id, model_type, ..., training_mode='product', ...):
|
||||
if training_mode == 'product':
|
||||
model_identifier = product_id
|
||||
# ...
|
||||
|
||||
return load_model_and_predict(
|
||||
model_identifier,
|
||||
model_type,
|
||||
# ... 其他参数 ...
|
||||
)
|
||||
```
|
||||
|
||||
### 4.3. 模型加载与执行预测
|
||||
|
||||
[`load_model_and_predict()`](server/predictors/model_predictor.py:26) 是预测流程的核心,它执行以下步骤:
|
||||
|
||||
1. **定位模型文件**: 使用 `get_model_file_path` 根据 `product_id` (即 `model_identifier`), `model_type`, 和 `version` 找到之前保存的 `.pth` 模型文件。
|
||||
|
||||
2. **加载Checkpoint**: 使用 `torch.load()` 加载模型文件,得到包含 `model_state_dict`, `config`, 和 `scalers` 的字典。
|
||||
|
||||
3. **重建模型**: 根据加载的 `config` 中的超参数(如 `hidden_size`, `num_layers` 等),重新创建一个与训练时结构完全相同的模型实例。**这是我们之前修复的关键点,确保所有必要参数都被保存和加载。**
|
||||
|
||||
4. **加载权重**: 将加载的 `model_state_dict` 应用到新创建的模型实例上。
|
||||
|
||||
5. **准备输入数据**: 从数据源获取最新的 `sequence_length` 天的历史数据作为预测的输入。
|
||||
|
||||
6. **数据归一化**: 使用加载的 `scaler_X` 对输入数据进行归一化。
|
||||
|
||||
7. **执行预测**: 将归一化的数据输入模型 (`model(X_input)`),得到预测结果。
|
||||
|
||||
8. **反归一化**: 使用加载的 `scaler_y` 将模型的输出(预测值)反归一化,转换回原始的销售量尺度。
|
||||
|
||||
9. **构建结果**: 将预测值和对应的未来日期组合成一个DataFrame,并连同历史数据一起返回。
|
||||
|
||||
*核心代码 ([`server/predictors/model_predictor.py`](server/predictors/model_predictor.py:26))*
|
||||
```python
|
||||
def load_model_and_predict(...):
|
||||
# ... 找到模型文件路径 model_path ...
|
||||
|
||||
# 加载模型和配置
|
||||
checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False)
|
||||
config = checkpoint['config']
|
||||
scaler_X = checkpoint['scaler_X']
|
||||
scaler_y = checkpoint['scaler_y']
|
||||
|
||||
# 创建模型实例 (以Transformer为例)
|
||||
model = TimeSeriesTransformer(
|
||||
num_features=config['input_dim'],
|
||||
d_model=config['hidden_size'],
|
||||
# ... 使用config中的所有参数 ...
|
||||
).to(DEVICE)
|
||||
|
||||
# 加载模型参数
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
model.eval()
|
||||
|
||||
# ... 准备输入数据 ...
|
||||
|
||||
# 归一化输入数据
|
||||
X_scaled = scaler_X.transform(X)
|
||||
X_input = torch.tensor(X_scaled.reshape(1, sequence_length, -1), ...).to(DEVICE)
|
||||
|
||||
# 预测
|
||||
with torch.no_grad():
|
||||
y_pred_scaled = model(X_input).cpu().numpy()
|
||||
|
||||
# 反归一化
|
||||
y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).flatten()
|
||||
|
||||
# ... 构建返回结果 ...
|
||||
return {
|
||||
'predictions': predictions_df,
|
||||
'history_data': recent_history,
|
||||
# ...
|
||||
}
|
||||
```
|
||||
|
||||
### 4.4. 响应格式化与前端图表渲染
|
||||
|
||||
1. **API层格式化**: 在 [`server/api.py`](server/api.py:1) 的 [`predict()`](server/api.py:1469) 函数中,从 `load_model_and_predict` 返回的结果被精心格式化成前端期望的JSON结构,该结构在顶层同时包含 `history_data` 和 `prediction_data` 两个数组。
|
||||
|
||||
2. **前端接收数据**: 前端 [`ProductPredictionView.vue`](UI/src/views/prediction/ProductPredictionView.vue:1) 在 `startPrediction` 方法中接收到这个JSON响应,并将其存入 `predictionResult` ref。
|
||||
|
||||
3. **图表渲染**: [`renderChart()`](UI/src/views/prediction/ProductPredictionView.vue:232) 方法被调用。它从 `predictionResult.value` 中提取 `history_data` 和 `prediction_data`,然后使用Chart.js库将这两部分数据绘制在同一个 `<canvas>` 上,历史数据为实线,预测数据为虚线,从而形成一个连续的趋势图。
|
||||
|
||||
*核心代码 ([`UI/src/views/prediction/ProductPredictionView.vue`](UI/src/views/prediction/ProductPredictionView.vue:232))*
|
||||
```javascript
|
||||
const renderChart = () => {
|
||||
if (!chartCanvas.value || !predictionResult.value) return
|
||||
// ...
|
||||
|
||||
// 后端直接提供 history_data 和 prediction_data
|
||||
const historyData = predictionResult.value.history_data || []
|
||||
const predictionData = predictionResult.value.prediction_data || []
|
||||
|
||||
const historyLabels = historyData.map(p => p.date)
|
||||
const historySales = historyData.map(p => p.sales)
|
||||
|
||||
const predictionLabels = predictionData.map(p => p.date)
|
||||
const predictionSales = predictionData.map(p => p.predicted_sales)
|
||||
|
||||
// ... 组合标签和数据,对齐数据点 ...
|
||||
|
||||
chart = new Chart(chartCanvas.value, {
|
||||
type: 'line',
|
||||
data: {
|
||||
labels: allLabels,
|
||||
datasets: [
|
||||
{
|
||||
label: '历史销量',
|
||||
data: alignedHistorySales,
|
||||
// ... 样式 ...
|
||||
},
|
||||
{
|
||||
label: '预测销量',
|
||||
data: alignedPredictionSales,
|
||||
// ... 样式 ...
|
||||
}
|
||||
]
|
||||
},
|
||||
// ... Chart.js 配置 ...
|
||||
})
|
||||
}
|
||||
```
|
||||
|
||||
至此,一个完整的“训练->预测->展示”的调用链路就完成了。
|
||||
|
||||
## 5. 模型保存与版本管理核心逻辑 (重构后)
|
||||
|
||||
为了根治版本混乱和模型加载失败的问题,系统进行了一项重要的重构。现在,所有与模型保存、命名和版本管理相关的逻辑都已**统一集中**到 [`server/utils/model_manager.py`](server/utils/model_manager.py:1) 的 `ModelManager` 类中。
|
||||
|
||||
### 5.1. 统一管理者:`ModelManager`
|
||||
|
||||
- **单一职责**: `ModelManager` 是系统中唯一负责处理模型文件IO的组件。所有训练器 (`trainer`) 在需要保存模型时,都必须通过它来进行。
|
||||
- **核心功能**:
|
||||
1. **自动版本控制**: 自动生成和递增符合规范的版本号。
|
||||
2. **统一命名**: 根据模型的元数据(算法类型、训练模式、ID等)生成标准化的文件名。
|
||||
3. **安全保存**: 将模型数据和元数据一起打包保存到 `.pth` 文件中。
|
||||
4. **可靠检索**: 提供统一的接口来列出和查找模型。
|
||||
|
||||
### 5.2. 统一版本规范
|
||||
|
||||
所有模型版本现在都遵循一个严格的、可预测的格式:
|
||||
|
||||
- **数字版本**: `v{数字}`,例如 `v1`, `v2`, `v3`...
|
||||
- **生成**: 当一次训练**正常完成**时,`ModelManager` 会自动计算出当前模型的下一个可用版本号(例如,如果已存在 `v1` 和 `v2`,则新版本为 `v3`),并以此命名最终的模型文件。
|
||||
- **用途**: 代表一次完整的、稳定的训练产出。
|
||||
- **特殊版本**: `best`
|
||||
- **生成**: 在训练过程中,如果某个 `epoch` 产生的模型在验证集上的性能超过了之前所有 `epoch`,训练器会调用 `ModelManager` 将这个模型保存为 `best` 版本,覆盖掉旧的 `best` 模型。
|
||||
- **用途**: 始终指向该模型迄今为止性能最佳的一个版本,便于快速进行高质量的预测。
|
||||
|
||||
### 5.3. 统一命名约定 (v2版)
|
||||
|
||||
随着系统增加了“按店铺”和“全局”训练模式,`ModelManager` 的 `generate_model_filename` 方法也已升级,以支持更丰富的、无歧义的命名格式:
|
||||
|
||||
- **药品模型**: `{model_type}_product_{product_id}_{version}.pth`
|
||||
- *示例*: `transformer_product_17002608_best.pth`
|
||||
- **店铺模型**: `{model_type}_store_{store_id}_{version}.pth`
|
||||
- *示例*: `mlstm_store_01010023_v2.pth`
|
||||
- **全局模型**: `{model_type}_global_{aggregation_method}_{version}.pth`
|
||||
- *示例*: `tcn_global_sum_v1.pth`
|
||||
|
||||
这个新的命名系统确保了不同训练模式产出的模型可以清晰地被识别和管理。
|
||||
|
||||
### 5.4. Checkpoint文件内容 (结构不变)
|
||||
|
||||
每个 `.pth` 文件依然是一个包含模型权重、完整配置和数据缩放器的PyTorch Checkpoint。重构加强了**所有训练器都必须将完整的配置信息存入 `config` 字典**这一规则,确保了模型的完全可复现性。
|
||||
|
||||
### 5.5. 核心优势 (重构后)
|
||||
|
||||
- **逻辑集中**: 所有版本管理的复杂性都被封装在 `ModelManager` 内部,训练器只需调用 `save_model` 即可,无需关心版本号如何生成。
|
||||
- **数据一致性**: 由于版本的生成、保存和检索都由同一个组件以同一种逻辑处理,从根本上杜绝了因命名或版本格式不匹配导致“模型未找到”的问题。
|
||||
- **易于维护**: 未来如果需要修改版本策略或命名规则,只需修改 `ModelManager` 一个文件即可,无需改动所有训练器。
|
||||
|
||||
## 6. 核心流程的演进:支持店铺与全局模式
|
||||
|
||||
在最初的“按药品”流程基础上,系统已重构以支持“按店铺”和“全局”的完整AI闭环。这引入了一些关键的逻辑变化:
|
||||
|
||||
### 6.1. 训练流程的变化
|
||||
|
||||
- **统一入口**: 所有训练请求(药品、店铺、全局)都通过 `POST /api/training` 接口,由 `training_mode` 参数区分。
|
||||
- **数据聚合**: 在 [`predictor.py`](server/core/predictor.py:1) 的 `train_model` 方法中,会根据 `training_mode` 调用 `aggregate_multi_store_data` 函数,为店铺或全局模式准备正确的聚合时间序列数据。
|
||||
- **模型标识符**: `train_model` 方法现在会生成一个唯一的 `model_identifier`(例如 `product_17002608`, `store_01010023`, `global_sum`),并将其传递给所有下游训练器。这是确保模型被正确命名的关键。
|
||||
|
||||
### 6.2. 预测流程的重大修复
|
||||
|
||||
预测流程经过了重大修复,以解决之前因逻辑不统一导致的 `404` 错误。
|
||||
|
||||
- **废弃旧函数**: `core/config.py` 中的 `get_model_file_path` 和 `get_model_versions` 等旧的、有缺陷的辅助函数已被**完全废弃**。
|
||||
- **统一查找逻辑**: 现在,[`api.py`](server/api.py:1) 的 `predict` 函数**必须**使用 `model_manager.list_models()` 方法来查找模型。
|
||||
- **可靠的路径传递**: `predict` 函数找到正确的模型文件路径后,会将其作为一个参数,一路传递给 `run_prediction` 和最终的 `load_model_and_predict` 函数。
|
||||
- **根除缺陷**: `load_model_and_predict` 函数内部所有手动的、过时的文件查找逻辑已被**完全移除**。它现在只负责接收一个明确的路径并加载模型。
|
||||
|
||||
这个修复确保了整个预测链路都依赖于 `ModelManager` 这一个“单一事实来源”,从根本上解决了因路径不匹配导致的预测失败问题。
|
127
项目快速上手指南.md
Normal file
127
项目快速上手指南.md
Normal file
@ -0,0 +1,127 @@
|
||||
# 项目快速上手指南 (面向新开发者)
|
||||
|
||||
欢迎加入项目!本指南旨在帮助你快速理解项目的核心功能、技术架构和开发流程,特别是为你(一位Java背景的开发者)提供清晰的切入点。
|
||||
|
||||
## 1. 项目是做什么的?(实现了什么功能)
|
||||
|
||||
这是一个基于历史销售数据的 **智能销售预测系统**。
|
||||
|
||||
核心功能有三个,全部通过Web界面操作:
|
||||
1. **模型训练**: 用户可以选择某个**药品**、某个**店铺**或**全局**数据,然后选择一种机器学习算法(如Transformer、mLSTM等)进行训练,最终生成一个预测模型。
|
||||
2. **销售预测**: 使用已经训练好的模型,对未来的销量进行预测。
|
||||
3. **结果可视化**: 将历史销量和预测销量在同一个图表中展示出来,方便用户直观地看到趋势。
|
||||
|
||||
简单来说,它就是一个 **"数据 -> 训练 -> 模型 -> 预测 -> 可视化"** 的完整闭环应用。
|
||||
|
||||
## 2. 用了什么技术?(技术栈)
|
||||
|
||||
你可以将这个项目的技术栈与Java世界进行类比:
|
||||
|
||||
| 层面 | 本项目技术 | Java世界类比 | 说明 |
|
||||
| :--- | :--- | :--- | :--- |
|
||||
| **后端框架** | **Flask** | Spring Boot | 一个轻量级的Web框架,用于提供API接口。 |
|
||||
| **前端框架** | **Vue.js** | React / Angular | 用于构建用户交互界面的现代化JavaScript框架。 |
|
||||
| **核心算法库** | **PyTorch** | (无直接对应) | 类似于Java的Deeplearning4j,是实现深度学习算法的核心。 |
|
||||
| **数据处理** | **Pandas** | (无直接对应) | Python中用于数据分析和处理的“瑞士军刀”,可以看作是内存中的强大数据表格。 |
|
||||
| **构建/打包** | **Vite** (前端) | Maven / Gradle | 前端项目的构建和依赖管理工具。 |
|
||||
| **数据库** | **SQLite** | H2 / MySQL | 一个轻量级的本地文件数据库,用于记录预测历史等。 |
|
||||
| **实时通信** | **Socket.IO** | WebSocket / STOMP | 用于后端在训练时向前端实时推送进度。 |
|
||||
|
||||
## 3. 系统架构是怎样的?(架构层级和设计)
|
||||
|
||||
本项目是经典的前后端分离架构,可以分为四个主要层次:
|
||||
|
||||
```
|
||||
+------------------------------------------------------+
|
||||
| 用户 (Browser) |
|
||||
+------------------------------------------------------+
|
||||
|
|
||||
+------------------------------------------------------+
|
||||
| 1. 前端层 (Frontend - Vue.js) |
|
||||
| - Views (页面组件, e.g., ProductPredictionView.vue) |
|
||||
| - API Calls (使用axios与后端通信) |
|
||||
| - Charting (使用Chart.js进行图表渲染) |
|
||||
+------------------------------------------------------+
|
||||
| (HTTP/S, WebSocket)
|
||||
+------------------------------------------------------+
|
||||
| 2. 后端API层 (Backend API - Flask) |
|
||||
| - api.py (类似Controller, 定义RESTful接口) |
|
||||
| - 接收请求, 验证参数, 调用业务逻辑层 |
|
||||
+------------------------------------------------------+
|
||||
|
|
||||
+------------------------------------------------------+
|
||||
| 3. 业务逻辑层 (Business Logic - Python) |
|
||||
| - core/predictor.py (类似Service层) |
|
||||
| - 封装核心业务, 如“根据参数选择合适的训练器” |
|
||||
+------------------------------------------------------+
|
||||
|
|
||||
+------------------------------------------------------+
|
||||
| 4. 数据与模型层 (Data & Model - PyTorch/Pandas) |
|
||||
| - trainers/*.py (具体的算法实现和训练逻辑) |
|
||||
| - predictors/model_predictor.py (模型加载与预测逻辑) |
|
||||
| - saved_models/ (存放训练好的.pth模型文件) |
|
||||
| - data/ (存放原始数据.parquet文件) |
|
||||
+------------------------------------------------------+
|
||||
```
|
||||
|
||||
## 4. 关键执行流程
|
||||
|
||||
以最常见的“按药品预测”为例:
|
||||
|
||||
1. **前端**: 用户在页面上选择药品和模型,点击“预测”按钮。Vue组件通过`axios`向后端发送一个POST请求到 `/api/prediction`。
|
||||
2. **API层**: `api.py` 接收到请求,像一个Controller一样,解析出药品ID、模型类型等参数。
|
||||
3. **业务逻辑层**: `api.py` 调用 `core/predictor.py` 中的 `predict` 方法,将参数传递下去。这一层是业务的“调度中心”。
|
||||
4. **模型层**: `core/predictor.py` 最终调用 `predictors/model_predictor.py` 中的 `load_model_and_predict` 函数。
|
||||
5. **模型加载与执行**:
|
||||
* 根据参数在 `saved_models/` 目录下找到对应的模型文件(例如 `transformer_store_01010023_best.pth` 或 `mlstm_product_17002608_v3.pth`)。
|
||||
* 加载文件,从中恢复出 **模型结构**、**模型权重** 和 **数据缩放器**。
|
||||
* 准备最新的历史数据作为输入,执行预测。
|
||||
* 将预测结果返回。
|
||||
6. **返回与渲染**: 结果逐层返回到`api.py`,在这里被格式化为JSON,然后发送给前端。前端接收到JSON后,使用`Chart.js`将历史和预测数据画在图表上。
|
||||
|
||||
## 5. 如何添加一个新的算法?(开发者指南)
|
||||
|
||||
这是你最可能接触到的新功能开发。假设你要添加一个名为 `NewNet` 的新算法,你需要按以下步骤操作:
|
||||
|
||||
**目标**: 让 `NewNet` 出现在前端的“模型类型”下拉框中,并能成功训练和预测。
|
||||
|
||||
1. **创建训练器文件**:
|
||||
* 在 `server/trainers/` 目录下,复制一份现有的训练器文件(例如 `tcn_trainer.py`)并重命名为 `newnet_trainer.py`。
|
||||
* 在 `newnet_trainer.py` 中:
|
||||
* 定义你的 `NewNet` 模型类(继承自 `torch.nn.Module`)。
|
||||
* 修改 `train_..._with_tcn` 函数,将其重命名为 `train_..._with_newnet`。
|
||||
* 在这个新函数里,确保实例化的是你的 `NewNet` 模型。
|
||||
* **最关键的一步**: 在保存checkpoint时,确保 `config` 字典里包含了重建 `NewNet` 所需的所有超参数(比如层数、节点数等)。
|
||||
|
||||
* **重要开发规范:参数命名规则**
|
||||
为了防止在模型加载时出现参数不匹配的错误(例如 `KeyError: 'num_layers'`),我们制定了以下命名规范:
|
||||
> **规则:** 对于特定于某个算法的超参数,其在 `config` 字典中的键名(key)必须以该算法的名称作为前缀或唯一标识。
|
||||
|
||||
**示例:**
|
||||
* 对于 `mLSTM` 模型的层数,键名应为 `mlstm_layers`。
|
||||
* 对于 `TCN` 模型的通道数,键名可以是 `tcn_channels`。
|
||||
* 对于 `Transformer` 模型的编码器层数,键名可以是 `num_encoder_layers` (因为这在Transformer语境下是明确的)。
|
||||
|
||||
在 **加载模型时** ([`server/predictors/model_predictor.py`](server/predictors/model_predictor.py:1)),必须使用与保存时完全一致的键名来读取这些参数。遵循此规则可以从根本上杜绝因参数名不一致导致的模型加载失败问题。
|
||||
|
||||
2. **注册新模型**:
|
||||
* 打开 `server/core/config.py` 文件。
|
||||
* 找到 `SUPPORTED_MODELS` 列表。
|
||||
* 在列表中添加你的新模型标识符 `'newnet'`。
|
||||
|
||||
3. **接入业务逻辑层 (训练)**:
|
||||
* 打开 `server/core/predictor.py` 文件。
|
||||
* 在 `train_model` 方法中,找到 `if/elif` 模型选择逻辑。
|
||||
* 添加一个新的 `elif model_type == 'newnet':` 分支,让它调用你在第一步中创建的 `train_..._with_newnet` 函数。
|
||||
|
||||
4. **接入模型层 (预测)**:
|
||||
* 打开 `server/predictors/model_predictor.py` 文件。
|
||||
* 在 `load_model_and_predict` 函数中,找到 `if/elif` 模型实例化逻辑。
|
||||
* 添加一个新的 `elif model_type == 'newnet':` 分支,确保它能根据 `config` 正确地创建 `NewNet` 模型实例。
|
||||
|
||||
5. **更新前端界面**:
|
||||
* 打开 `UI/src/views/training/` 和 `UI/src/views/prediction/` 目录下的相关Vue文件(如 `ProductTrainingView.vue`)。
|
||||
* 找到定义模型选项的地方(通常是一个数组或对象)。
|
||||
* 添加 `{ label: '新网络模型 (NewNet)', value: 'newnet' }` 这样的新选项。
|
||||
|
||||
完成以上步骤后,重启服务,你就可以在界面上选择并使用你的新算法了。
|
Loading…
x
Reference in New Issue
Block a user