Compare commits
23 Commits
87df49f764
...
e1980b3755
Author | SHA1 | Date | |
---|---|---|---|
e1980b3755 | |||
751de9b548 | |||
038289ae32 | |||
0d3b89abf6 | |||
ec636896da | |||
244393670d | |||
e4d170d667 | |||
311d71b653 | |||
ca7dc432c6 | |||
ada4e8e108 | |||
120caba3cd | |||
c64343fe95 | |||
9d439c36ba | |||
54428c80ca | |||
6f3240c723 | |||
e437658b9d | |||
ee9ba299fa | |||
a1d9c60e61 | |||
a18c8dddf9 | |||
398e949935 | |||
cc30295f1d | |||
066a0429e5 | |||
6c11aff234 |
@ -248,41 +248,9 @@ import { ref, onMounted, reactive, watch, nextTick } from 'vue';
|
||||
import axios from 'axios';
|
||||
import { ElMessage, ElMessageBox } from 'element-plus';
|
||||
import { QuestionFilled, Search, View, Delete, ArrowUp, ArrowDown, Minus, Download } from '@element-plus/icons-vue';
|
||||
import * as echarts from 'echarts/core';
|
||||
import { LineChart, BarChart } from 'echarts/charts';
|
||||
import {
|
||||
TitleComponent,
|
||||
TooltipComponent,
|
||||
GridComponent,
|
||||
DatasetComponent,
|
||||
TransformComponent,
|
||||
LegendComponent,
|
||||
ToolboxComponent,
|
||||
MarkLineComponent,
|
||||
MarkPointComponent
|
||||
} from 'echarts/components';
|
||||
import { LabelLayout, UniversalTransition } from 'echarts/features';
|
||||
import { CanvasRenderer } from 'echarts/renderers';
|
||||
import Chart from 'chart.js/auto'; // << 关键改动:导入Chart.js
|
||||
import { computed, onUnmounted } from 'vue';
|
||||
|
||||
// 注册必须的组件
|
||||
echarts.use([
|
||||
TitleComponent,
|
||||
TooltipComponent,
|
||||
GridComponent,
|
||||
DatasetComponent,
|
||||
TransformComponent,
|
||||
LegendComponent,
|
||||
ToolboxComponent,
|
||||
MarkLineComponent,
|
||||
MarkPointComponent,
|
||||
LineChart,
|
||||
BarChart,
|
||||
LabelLayout,
|
||||
UniversalTransition,
|
||||
CanvasRenderer
|
||||
]);
|
||||
|
||||
const loading = ref(false);
|
||||
const history = ref([]);
|
||||
const products = ref([]);
|
||||
@ -292,8 +260,8 @@ const currentPrediction = ref(null);
|
||||
const rawResponseData = ref(null);
|
||||
const showRawDataFlag = ref(false);
|
||||
|
||||
const fullscreenPredictionChart = ref(null);
|
||||
const fullscreenHistoryChart = ref(null);
|
||||
let predictionChart = null; // << 关键改动:使用单个chart实例
|
||||
let historyChart = null;
|
||||
|
||||
const filters = reactive({
|
||||
product_id: '',
|
||||
@ -982,104 +950,133 @@ const getFactorsArray = computed(() => {
|
||||
watch(detailsVisible, (newVal) => {
|
||||
if (newVal && currentPrediction.value) {
|
||||
nextTick(() => {
|
||||
// Init Prediction Chart
|
||||
if (fullscreenPredictionChart.value) fullscreenPredictionChart.value.dispose();
|
||||
const predChartDom = document.getElementById('fullscreen-prediction-chart-history');
|
||||
if (predChartDom) {
|
||||
fullscreenPredictionChart.value = echarts.init(predChartDom);
|
||||
if (currentPrediction.value.chart_data) {
|
||||
updatePredictionChart(currentPrediction.value.chart_data, fullscreenPredictionChart.value, true);
|
||||
}
|
||||
}
|
||||
|
||||
// Init History Chart
|
||||
if (currentPrediction.value.analysis) {
|
||||
if (fullscreenHistoryChart.value) fullscreenHistoryChart.value.dispose();
|
||||
const histChartDom = document.getElementById('fullscreen-history-chart-history');
|
||||
if (histChartDom) {
|
||||
fullscreenHistoryChart.value = echarts.init(histChartDom);
|
||||
updateHistoryChart(currentPrediction.value.analysis, fullscreenHistoryChart.value, true);
|
||||
}
|
||||
}
|
||||
renderChart();
|
||||
// 可以在这里添加渲染第二个图表的逻辑
|
||||
// renderHistoryAnalysisChart();
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
const updatePredictionChart = (chartData, chart, isFullscreen = false) => {
|
||||
if (!chart || !chartData) return;
|
||||
chart.showLoading();
|
||||
const dates = chartData.dates || [];
|
||||
const sales = chartData.sales || [];
|
||||
const types = chartData.types || [];
|
||||
// << 关键改动:从ProductPredictionView.vue复制并适应的renderChart函数
|
||||
const renderChart = () => {
|
||||
const chartCanvas = document.getElementById('fullscreen-prediction-chart-history');
|
||||
if (!chartCanvas || !currentPrediction.value || !currentPrediction.value.data) return;
|
||||
|
||||
const combinedData = [];
|
||||
for (let i = 0; i < dates.length; i++) {
|
||||
combinedData.push({ date: dates[i], sales: sales[i], type: types[i] });
|
||||
if (predictionChart) {
|
||||
predictionChart.destroy();
|
||||
}
|
||||
combinedData.sort((a, b) => new Date(a.date) - new Date(b.date));
|
||||
|
||||
const allDates = combinedData.map(item => item.date);
|
||||
const historyDates = combinedData.filter(d => d.type === '历史销量').map(d => d.date);
|
||||
const historySales = combinedData.filter(d => d.type === '历史销量').map(d => d.sales);
|
||||
const predictionDates = combinedData.filter(d => d.type === '预测销量').map(d => d.date);
|
||||
const predictionSales = combinedData.filter(d => d.type === '预测销量').map(d => d.sales);
|
||||
|
||||
const allSales = [...historySales, ...predictionSales].filter(val => !isNaN(val));
|
||||
const minSale = Math.max(0, Math.floor(Math.min(...allSales) * 0.9));
|
||||
const maxSale = Math.ceil(Math.max(...allSales) * 1.1);
|
||||
|
||||
const option = {
|
||||
title: { text: '销量预测趋势图', left: 'center', textStyle: { fontSize: isFullscreen ? 18 : 16, fontWeight: 'bold', color: '#e0e6ff' } },
|
||||
tooltip: { trigger: 'axis', axisPointer: { type: 'cross' },
|
||||
formatter: function(params) {
|
||||
if (!params || params.length === 0) return '';
|
||||
const date = params[0].axisValue;
|
||||
let html = `<div style="font-weight:bold">${date}</div>`;
|
||||
params.forEach(item => {
|
||||
if (item.value !== '-') {
|
||||
html += `<div style="display:flex;justify-content:space-between;align-items:center;margin:5px 0;">
|
||||
<span style="display:inline-block;margin-right:5px;width:10px;height:10px;border-radius:50%;background-color:${item.color};"></span>
|
||||
<span>${item.seriesName}:</span>
|
||||
<span style="font-weight:bold;margin-left:5px;">${item.value.toFixed(2)}</span>
|
||||
</div>`;
|
||||
}
|
||||
});
|
||||
return html;
|
||||
}
|
||||
const formatDate = (date) => new Date(date).toISOString().split('T')[0];
|
||||
|
||||
const historyData = (currentPrediction.value.data.history_data || []).map(p => ({ ...p, date: formatDate(p.date) }));
|
||||
const predictionData = (currentPrediction.value.data.prediction_data || []).map(p => ({ ...p, date: formatDate(p.date) }));
|
||||
|
||||
if (historyData.length === 0 && predictionData.length === 0) {
|
||||
ElMessage.warning('没有可用于图表的数据。');
|
||||
return;
|
||||
}
|
||||
|
||||
const allLabels = [...new Set([...historyData.map(p => p.date), ...predictionData.map(p => p.date)])].sort();
|
||||
const simplifiedLabels = allLabels.map(date => date.split('-')[2]);
|
||||
|
||||
const historyMap = new Map(historyData.map(p => [p.date, p.sales]));
|
||||
// 注意:这里使用 'sales' 字段,因为后端已经统一了
|
||||
const predictionMap = new Map(predictionData.map(p => [p.date, p.sales]));
|
||||
|
||||
const alignedHistorySales = allLabels.map(label => historyMap.get(label) ?? null);
|
||||
const alignedPredictionSales = allLabels.map(label => predictionMap.get(label) ?? null);
|
||||
|
||||
if (historyData.length > 0 && predictionData.length > 0) {
|
||||
const lastHistoryDate = historyData[historyData.length - 1].date;
|
||||
const lastHistoryValue = historyData[historyData.length - 1].sales;
|
||||
if (!predictionMap.has(lastHistoryDate)) {
|
||||
alignedPredictionSales[allLabels.indexOf(lastHistoryDate)] = lastHistoryValue;
|
||||
}
|
||||
}
|
||||
|
||||
let subtitleText = '';
|
||||
if (historyData.length > 0) {
|
||||
subtitleText += `历史数据: ${historyData[0].date} ~ ${historyData[historyData.length - 1].date}`;
|
||||
}
|
||||
if (predictionData.length > 0) {
|
||||
if (subtitleText) subtitleText += ' | ';
|
||||
subtitleText += `预测数据: ${predictionData[0].date} ~ ${predictionData[predictionData.length - 1].date}`;
|
||||
}
|
||||
|
||||
predictionChart = new Chart(chartCanvas, {
|
||||
type: 'line',
|
||||
data: {
|
||||
labels: simplifiedLabels,
|
||||
datasets: [
|
||||
{
|
||||
label: '历史销量',
|
||||
data: alignedHistorySales,
|
||||
borderColor: '#67C23A',
|
||||
backgroundColor: 'rgba(103, 194, 58, 0.2)',
|
||||
tension: 0.4,
|
||||
fill: true,
|
||||
spanGaps: false,
|
||||
},
|
||||
{
|
||||
label: '预测销量',
|
||||
data: alignedPredictionSales,
|
||||
borderColor: '#409EFF',
|
||||
backgroundColor: 'rgba(64, 158, 255, 0.2)',
|
||||
tension: 0.4,
|
||||
fill: true,
|
||||
borderDash: [5, 5],
|
||||
}
|
||||
]
|
||||
},
|
||||
legend: { data: ['历史销量', '预测销量'], top: isFullscreen ? 40 : 30, textStyle: { color: '#e0e6ff' } },
|
||||
grid: { left: '3%', right: '4%', bottom: '3%', containLabel: true },
|
||||
toolbox: { feature: { saveAsImage: { title: '保存图片' } }, iconStyle: { borderColor: '#e0e6ff' } },
|
||||
xAxis: { type: 'category', boundaryGap: false, data: allDates, axisLabel: { color: '#e0e6ff' }, axisLine: { lineStyle: { color: 'rgba(224, 230, 255, 0.5)' } } },
|
||||
yAxis: { type: 'value', name: '销量', min: minSale, max: maxSale, axisLabel: { color: '#e0e6ff' }, nameTextStyle: { color: '#e0e6ff' }, axisLine: { lineStyle: { color: 'rgba(224, 230, 255, 0.5)' } }, splitLine: { lineStyle: { color: 'rgba(224, 230, 255, 0.1)' } } },
|
||||
series: [
|
||||
{ name: '历史销量', type: 'line', smooth: true, connectNulls: true, data: allDates.map(date => historyDates.includes(date) ? historySales[historyDates.indexOf(date)] : null), areaStyle: { color: new echarts.graphic.LinearGradient(0, 0, 0, 1, [{ offset: 0, color: 'rgba(64, 158, 255, 0.3)' }, { offset: 1, color: 'rgba(64, 158, 255, 0.1)' }]) }, lineStyle: { color: '#409EFF' } },
|
||||
{ name: '预测销量', type: 'line', smooth: true, connectNulls: true, data: allDates.map(date => predictionDates.includes(date) ? predictionSales[predictionDates.indexOf(date)] : null), lineStyle: { color: '#F56C6C' } }
|
||||
]
|
||||
};
|
||||
chart.hideLoading();
|
||||
chart.setOption(option, true);
|
||||
};
|
||||
|
||||
const updateHistoryChart = (analysisData, chart, isFullscreen = false) => {
|
||||
if (!chart || !analysisData || !analysisData.history_chart_data) return;
|
||||
chart.showLoading();
|
||||
const { dates, changes } = analysisData.history_chart_data;
|
||||
|
||||
const option = {
|
||||
title: { text: '销量日环比变化', left: 'center', textStyle: { fontSize: isFullscreen ? 18 : 16, fontWeight: 'bold', color: '#e0e6ff' } },
|
||||
tooltip: { trigger: 'axis', axisPointer: { type: 'shadow' }, formatter: p => `${p[0].axisValue}<br/>环比: ${p[0].value.toFixed(2)}%` },
|
||||
grid: { left: '3%', right: '4%', bottom: '3%', containLabel: true },
|
||||
toolbox: { feature: { saveAsImage: { title: '保存图片' } }, iconStyle: { borderColor: '#e0e6ff' } },
|
||||
xAxis: { type: 'category', data: dates.map(d => formatDate(d)), axisLabel: { color: '#e0e6ff' }, axisLine: { lineStyle: { color: 'rgba(224, 230, 255, 0.5)' } } },
|
||||
yAxis: { type: 'value', name: '环比变化(%)', axisLabel: { formatter: '{value}%', color: '#e0e6ff' }, nameTextStyle: { color: '#e0e6ff' }, axisLine: { lineStyle: { color: 'rgba(224, 230, 255, 0.5)' } }, splitLine: { lineStyle: { color: 'rgba(224, 230, 255, 0.1)' } } },
|
||||
series: [{
|
||||
name: '日环比变化', type: 'bar',
|
||||
data: changes.map(val => ({ value: val, itemStyle: { color: val >= 0 ? '#67C23A' : '#F56C6C' } }))
|
||||
}]
|
||||
};
|
||||
chart.hideLoading();
|
||||
chart.setOption(option, true);
|
||||
options: {
|
||||
responsive: true,
|
||||
maintainAspectRatio: false,
|
||||
plugins: {
|
||||
title: {
|
||||
display: true,
|
||||
text: `${currentPrediction.value.data.product_name} - 销量预测趋势图`,
|
||||
color: '#ffffff',
|
||||
font: {
|
||||
size: 20,
|
||||
weight: 'bold',
|
||||
}
|
||||
},
|
||||
subtitle: {
|
||||
display: true,
|
||||
text: subtitleText,
|
||||
color: '#6c757d',
|
||||
font: {
|
||||
size: 14,
|
||||
},
|
||||
padding: {
|
||||
bottom: 20
|
||||
}
|
||||
}
|
||||
},
|
||||
scales: {
|
||||
x: {
|
||||
title: {
|
||||
display: true,
|
||||
text: '日期 (日)'
|
||||
},
|
||||
grid: {
|
||||
display: false
|
||||
}
|
||||
},
|
||||
y: {
|
||||
title: {
|
||||
display: true,
|
||||
text: '销量'
|
||||
},
|
||||
grid: {
|
||||
color: '#e9e9e9',
|
||||
drawBorder: false,
|
||||
},
|
||||
beginAtZero: true
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
const exportHistoryData = () => {
|
||||
@ -1102,16 +1099,24 @@ const exportHistoryData = () => {
|
||||
};
|
||||
|
||||
const resizeCharts = () => {
|
||||
if (fullscreenPredictionChart.value) fullscreenPredictionChart.value.resize();
|
||||
if (fullscreenHistoryChart.value) fullscreenHistoryChart.value.resize();
|
||||
if (predictionChart) {
|
||||
predictionChart.resize();
|
||||
}
|
||||
if (historyChart) {
|
||||
historyChart.resize();
|
||||
}
|
||||
};
|
||||
|
||||
window.addEventListener('resize', resizeCharts);
|
||||
|
||||
onUnmounted(() => {
|
||||
window.removeEventListener('resize', resizeCharts);
|
||||
if (fullscreenPredictionChart.value) fullscreenPredictionChart.value.dispose();
|
||||
if (fullscreenHistoryChart.value) fullscreenHistoryChart.value.dispose();
|
||||
if (predictionChart) {
|
||||
predictionChart.destroy();
|
||||
}
|
||||
if (historyChart) {
|
||||
historyChart.destroy();
|
||||
}
|
||||
});
|
||||
|
||||
onMounted(() => {
|
||||
|
@ -34,7 +34,7 @@
|
||||
</el-row>
|
||||
|
||||
<el-row :gutter="20" v-if="form.model_type">
|
||||
<el-col :span="6">
|
||||
<el-col :span="5">
|
||||
<el-form-item label="模型版本">
|
||||
<el-select
|
||||
v-model="form.version"
|
||||
@ -52,7 +52,7 @@
|
||||
</el-select>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
<el-col :span="6">
|
||||
<el-col :span="5">
|
||||
<el-form-item label="预测天数">
|
||||
<el-input-number
|
||||
v-model="form.future_days"
|
||||
@ -62,7 +62,17 @@
|
||||
/>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
<el-col :span="6">
|
||||
<el-col :span="5">
|
||||
<el-form-item label="历史天数">
|
||||
<el-input-number
|
||||
v-model="form.history_lookback_days"
|
||||
:min="7"
|
||||
:max="365"
|
||||
style="width: 100%"
|
||||
/>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
<el-col :span="5">
|
||||
<el-form-item label="起始日期">
|
||||
<el-date-picker
|
||||
v-model="form.start_date"
|
||||
@ -75,7 +85,7 @@
|
||||
/>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
<el-col :span="6">
|
||||
<el-col :span="4">
|
||||
<el-form-item label="预测分析">
|
||||
<el-switch
|
||||
v-model="form.analyze_result"
|
||||
@ -135,6 +145,7 @@ const form = reactive({
|
||||
model_type: '',
|
||||
version: '',
|
||||
future_days: 7,
|
||||
history_lookback_days: 30,
|
||||
start_date: '',
|
||||
analyze_result: true
|
||||
})
|
||||
@ -185,13 +196,15 @@ const startPrediction = async () => {
|
||||
try {
|
||||
predicting.value = true
|
||||
const payload = {
|
||||
training_mode: 'global', // 明确指定训练模式
|
||||
model_type: form.model_type,
|
||||
version: form.version,
|
||||
future_days: form.future_days,
|
||||
history_lookback_days: form.history_lookback_days,
|
||||
start_date: form.start_date,
|
||||
analyze_result: form.analyze_result
|
||||
}
|
||||
const response = await axios.post('/api/predict', payload)
|
||||
const response = await axios.post('/api/prediction', payload)
|
||||
if (response.data.status === 'success') {
|
||||
predictionResult.value = response.data.data
|
||||
ElMessage.success('预测完成!')
|
||||
@ -212,28 +225,113 @@ const renderChart = () => {
|
||||
if (chart) {
|
||||
chart.destroy()
|
||||
}
|
||||
const predictions = predictionResult.value.predictions
|
||||
const labels = predictions.map(p => p.date)
|
||||
const data = predictions.map(p => p.sales)
|
||||
|
||||
const formatDate = (date) => new Date(date).toISOString().split('T')[0];
|
||||
|
||||
const historyData = (predictionResult.value.history_data || []).map(p => ({ ...p, date: formatDate(p.date) }));
|
||||
const predictionData = (predictionResult.value.prediction_data || []).map(p => ({ ...p, date: formatDate(p.date) }));
|
||||
|
||||
if (historyData.length === 0 && predictionData.length === 0) {
|
||||
ElMessage.warning('没有可用于图表的数据。')
|
||||
return
|
||||
}
|
||||
|
||||
const allLabels = [...new Set([...historyData.map(p => p.date), ...predictionData.map(p => p.date)])].sort()
|
||||
const simplifiedLabels = allLabels.map(date => date.split('-')[2]);
|
||||
|
||||
const historyMap = new Map(historyData.map(p => [p.date, p.sales]))
|
||||
const predictionMap = new Map(predictionData.map(p => [p.date, p.predicted_sales]))
|
||||
|
||||
const alignedHistorySales = allLabels.map(label => historyMap.get(label) ?? null)
|
||||
const alignedPredictionSales = allLabels.map(label => predictionMap.get(label) ?? null)
|
||||
|
||||
if (historyData.length > 0 && predictionData.length > 0) {
|
||||
const lastHistoryDate = historyData[historyData.length - 1].date
|
||||
const lastHistoryValue = historyData[historyData.length - 1].sales
|
||||
if (!predictionMap.has(lastHistoryDate)) {
|
||||
alignedPredictionSales[allLabels.indexOf(lastHistoryDate)] = lastHistoryValue
|
||||
}
|
||||
}
|
||||
|
||||
let subtitleText = '';
|
||||
if (historyData.length > 0) {
|
||||
subtitleText += `历史数据: ${historyData[0].date} ~ ${historyData[historyData.length - 1].date}`;
|
||||
}
|
||||
if (predictionData.length > 0) {
|
||||
if (subtitleText) subtitleText += ' | ';
|
||||
subtitleText += `预测数据: ${predictionData[0].date} ~ ${predictionData[predictionData.length - 1].date}`;
|
||||
}
|
||||
|
||||
chart = new Chart(chartCanvas.value, {
|
||||
type: 'line',
|
||||
data: {
|
||||
labels,
|
||||
datasets: [{
|
||||
label: '预测销量',
|
||||
data,
|
||||
borderColor: '#409EFF',
|
||||
backgroundColor: 'rgba(64, 158, 255, 0.1)',
|
||||
tension: 0.4,
|
||||
fill: true
|
||||
}]
|
||||
labels: simplifiedLabels,
|
||||
datasets: [
|
||||
{
|
||||
label: '历史销量',
|
||||
data: alignedHistorySales,
|
||||
borderColor: '#67C23A',
|
||||
backgroundColor: 'rgba(103, 194, 58, 0.2)',
|
||||
tension: 0.4,
|
||||
fill: true,
|
||||
spanGaps: false,
|
||||
},
|
||||
{
|
||||
label: '预测销量',
|
||||
data: alignedPredictionSales,
|
||||
borderColor: '#409EFF',
|
||||
backgroundColor: 'rgba(64, 158, 255, 0.2)',
|
||||
tension: 0.4,
|
||||
fill: true,
|
||||
borderDash: [5, 5],
|
||||
}
|
||||
]
|
||||
},
|
||||
options: {
|
||||
responsive: true,
|
||||
maintainAspectRatio: false,
|
||||
plugins: {
|
||||
title: {
|
||||
display: true,
|
||||
text: '销量预测趋势图'
|
||||
text: '全局销量预测趋势图',
|
||||
color: '#ffffff',
|
||||
font: {
|
||||
size: 20,
|
||||
weight: 'bold',
|
||||
}
|
||||
},
|
||||
subtitle: {
|
||||
display: true,
|
||||
text: subtitleText,
|
||||
color: '#6c757d',
|
||||
font: {
|
||||
size: 14,
|
||||
},
|
||||
padding: {
|
||||
bottom: 20
|
||||
}
|
||||
}
|
||||
},
|
||||
scales: {
|
||||
x: {
|
||||
title: {
|
||||
display: true,
|
||||
text: '日期 (日)'
|
||||
},
|
||||
grid: {
|
||||
display: false
|
||||
}
|
||||
},
|
||||
y: {
|
||||
title: {
|
||||
display: true,
|
||||
text: '销量'
|
||||
},
|
||||
grid: {
|
||||
color: '#e9e9e9',
|
||||
drawBorder: false,
|
||||
},
|
||||
beginAtZero: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -4,157 +4,147 @@
|
||||
<template #header>
|
||||
<div class="card-header">
|
||||
<span>按药品预测</span>
|
||||
<el-tooltip content="使用针对特定药品训练的模型进行销售预测">
|
||||
<el-tooltip content="对系统中的所有药品模型进行批量或单个预测">
|
||||
<el-icon><QuestionFilled /></el-icon>
|
||||
</el-tooltip>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<div class="model-selection-section">
|
||||
<h4>🎯 选择预测模型</h4>
|
||||
<el-form :model="form" label-width="120px">
|
||||
<el-row :gutter="20">
|
||||
<el-col :span="8">
|
||||
<el-form-item label="目标药品">
|
||||
<ProductSelector
|
||||
v-model="form.product_id"
|
||||
@change="handleProductChange"
|
||||
:show-all-option="false"
|
||||
/>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
<el-col :span="8">
|
||||
<el-form-item label="算法类型">
|
||||
<el-select
|
||||
v-model="form.model_type"
|
||||
placeholder="选择算法"
|
||||
@change="handleModelTypeChange"
|
||||
style="width: 100%"
|
||||
:disabled="!form.product_id"
|
||||
>
|
||||
<el-option
|
||||
v-for="item in modelTypes"
|
||||
:key="item.id"
|
||||
:label="item.name"
|
||||
:value="item.id"
|
||||
/>
|
||||
</el-select>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
</el-row>
|
||||
|
||||
<el-row :gutter="20" v-if="form.model_type">
|
||||
<el-col :span="6">
|
||||
<el-form-item label="模型版本">
|
||||
<el-select
|
||||
v-model="form.version"
|
||||
placeholder="选择版本"
|
||||
style="width: 100%"
|
||||
:disabled="!availableVersions.length"
|
||||
:loading="versionsLoading"
|
||||
>
|
||||
<el-option
|
||||
v-for="version in availableVersions"
|
||||
:key="version"
|
||||
:label="version"
|
||||
:value="version"
|
||||
/>
|
||||
</el-select>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
<el-col :span="6">
|
||||
<el-form-item label="预测天数">
|
||||
<el-input-number
|
||||
v-model="form.future_days"
|
||||
:min="1"
|
||||
:max="365"
|
||||
style="width: 100%"
|
||||
/>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
<el-col :span="6">
|
||||
<el-form-item label="起始日期">
|
||||
<el-date-picker
|
||||
v-model="form.start_date"
|
||||
type="date"
|
||||
placeholder="选择日期"
|
||||
format="YYYY-MM-DD"
|
||||
value-format="YYYY-MM-DD"
|
||||
style="width: 100%"
|
||||
:clearable="false"
|
||||
/>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
<el-col :span="6">
|
||||
<el-form-item label="预测分析">
|
||||
<el-switch
|
||||
v-model="form.analyze_result"
|
||||
active-text="开启"
|
||||
inactive-text="关闭"
|
||||
/>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
</el-row>
|
||||
<div class="controls-section">
|
||||
<el-form :model="filters" label-width="80px" inline>
|
||||
<el-form-item label="目标药品">
|
||||
<ProductSelector
|
||||
v-model="filters.product_id"
|
||||
:show-all-option="true"
|
||||
all-option-label="所有药品"
|
||||
/>
|
||||
</el-form-item>
|
||||
<el-form-item label="算法类型">
|
||||
<el-select v-model="filters.model_type" placeholder="所有类型" clearable>
|
||||
<el-option
|
||||
v-for="item in modelTypes"
|
||||
:key="item.id"
|
||||
:label="item.name"
|
||||
:value="item.id"
|
||||
/>
|
||||
</el-select>
|
||||
</el-form-item>
|
||||
<el-form-item label="预测天数">
|
||||
<el-input-number v-model="form.future_days" :min="1" :max="365" />
|
||||
</el-form-item>
|
||||
<el-form-item label="历史天数">
|
||||
<el-input-number v-model="form.history_lookback_days" :min="7" :max="365" />
|
||||
</el-form-item>
|
||||
<el-form-item label="起始日期">
|
||||
<el-date-picker
|
||||
v-model="form.start_date"
|
||||
type="date"
|
||||
placeholder="选择日期"
|
||||
format="YYYY-MM-DD"
|
||||
value-format="YYYY-MM-DD"
|
||||
:clearable="false"
|
||||
/>
|
||||
</el-form-item>
|
||||
</el-form>
|
||||
</div>
|
||||
|
||||
<div class="prediction-actions">
|
||||
<el-button
|
||||
type="primary"
|
||||
size="large"
|
||||
@click="startPrediction"
|
||||
:loading="predicting"
|
||||
:disabled="!canPredict"
|
||||
>
|
||||
<el-icon><TrendCharts /></el-icon>
|
||||
开始预测
|
||||
</el-button>
|
||||
<!-- 模型列表 -->
|
||||
<div class="model-list-section">
|
||||
<h4>📦 可用药品模型列表</h4>
|
||||
<el-table :data="paginatedModelList" style="width: 100%" v-loading="modelsLoading">
|
||||
<el-table-column prop="product_name" label="药品名称" sortable />
|
||||
<el-table-column prop="model_type" label="模型类型" sortable />
|
||||
<el-table-column prop="version" label="版本" />
|
||||
<el-table-column prop="created_at" label="创建时间" />
|
||||
<el-table-column label="操作">
|
||||
<template #default="{ row }">
|
||||
<el-button
|
||||
type="primary"
|
||||
size="small"
|
||||
@click="startPrediction(row)"
|
||||
:loading="predicting[row.model_id]"
|
||||
>
|
||||
<el-icon><TrendCharts /></el-icon>
|
||||
开始预测
|
||||
</el-button>
|
||||
</template>
|
||||
</el-table-column>
|
||||
</el-table>
|
||||
<el-pagination
|
||||
background
|
||||
layout="prev, pager, next"
|
||||
:total="filteredModelList.length"
|
||||
:page-size="pagination.pageSize"
|
||||
@current-change="handlePageChange"
|
||||
style="margin-top: 20px; justify-content: center;"
|
||||
/>
|
||||
</div>
|
||||
</el-card>
|
||||
|
||||
<el-card v-if="predictionResult" style="margin-top: 20px">
|
||||
<template #header>
|
||||
<div class="card-header">
|
||||
<span>📈 预测结果</span>
|
||||
</div>
|
||||
</template>
|
||||
<!-- 预测结果弹窗 -->
|
||||
<el-dialog v-model="dialogVisible" title="📈 预测结果" width="70%">
|
||||
<div class="prediction-chart">
|
||||
<canvas ref="chartCanvas" width="800" height="400"></canvas>
|
||||
</div>
|
||||
</el-card>
|
||||
<template #footer>
|
||||
<el-button @click="dialogVisible = false">关闭</el-button>
|
||||
</template>
|
||||
</el-dialog>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup>
|
||||
import { ref, reactive, onMounted, computed, watch, nextTick } from 'vue'
|
||||
import { ref, reactive, onMounted, nextTick, computed } from 'vue'
|
||||
import axios from 'axios'
|
||||
import { ElMessage } from 'element-plus'
|
||||
import { ElMessage, ElDialog, ElTable, ElTableColumn, ElButton, ElIcon, ElCard, ElTooltip, ElForm, ElFormItem, ElInputNumber, ElDatePicker, ElSelect, ElOption, ElRow, ElCol, ElPagination } from 'element-plus'
|
||||
import { QuestionFilled, TrendCharts } from '@element-plus/icons-vue'
|
||||
import Chart from 'chart.js/auto'
|
||||
import ProductSelector from '../../components/ProductSelector.vue'
|
||||
|
||||
const modelList = ref([])
|
||||
const modelTypes = ref([])
|
||||
const availableVersions = ref([])
|
||||
const versionsLoading = ref(false)
|
||||
const predicting = ref(false)
|
||||
const modelsLoading = ref(false)
|
||||
const predicting = reactive({})
|
||||
const dialogVisible = ref(false)
|
||||
const predictionResult = ref(null)
|
||||
const chartCanvas = ref(null)
|
||||
let chart = null
|
||||
|
||||
const form = reactive({
|
||||
training_mode: 'product',
|
||||
product_id: '',
|
||||
model_type: '',
|
||||
version: '',
|
||||
future_days: 7,
|
||||
history_lookback_days: 30,
|
||||
start_date: '',
|
||||
analyze_result: true
|
||||
analyze_result: true // 保持分析功能开启,但UI上移除开关
|
||||
})
|
||||
|
||||
const canPredict = computed(() => {
|
||||
return form.product_id && form.model_type && form.version
|
||||
const filters = reactive({
|
||||
product_id: '',
|
||||
model_type: ''
|
||||
})
|
||||
|
||||
const pagination = reactive({
|
||||
currentPage: 1,
|
||||
pageSize: 8
|
||||
})
|
||||
|
||||
const filteredModelList = computed(() => {
|
||||
return modelList.value.filter(model => {
|
||||
const productMatch = !filters.product_id || model.product_id === filters.product_id
|
||||
const modelTypeMatch = !filters.model_type || model.model_type === filters.model_type
|
||||
return productMatch && modelTypeMatch
|
||||
})
|
||||
})
|
||||
|
||||
const paginatedModelList = computed(() => {
|
||||
const start = (pagination.currentPage - 1) * pagination.pageSize
|
||||
const end = start + pagination.pageSize
|
||||
return filteredModelList.value.slice(start, end)
|
||||
})
|
||||
|
||||
const handlePageChange = (page) => {
|
||||
pagination.currentPage = page
|
||||
}
|
||||
|
||||
const fetchModelTypes = async () => {
|
||||
try {
|
||||
const response = await axios.get('/api/model_types')
|
||||
@ -166,63 +156,48 @@ const fetchModelTypes = async () => {
|
||||
}
|
||||
}
|
||||
|
||||
const fetchAvailableVersions = async () => {
|
||||
if (!form.product_id || !form.model_type) {
|
||||
availableVersions.value = []
|
||||
return
|
||||
}
|
||||
const fetchModels = async () => {
|
||||
modelsLoading.value = true
|
||||
try {
|
||||
versionsLoading.value = true
|
||||
const url = `/api/models/${form.product_id}/${form.model_type}/versions`
|
||||
const response = await axios.get(url)
|
||||
const response = await axios.get('/api/models', { params: { training_mode: 'product' } })
|
||||
if (response.data.status === 'success') {
|
||||
availableVersions.value = response.data.data.versions || []
|
||||
if (response.data.data.latest_version) {
|
||||
form.version = response.data.data.latest_version
|
||||
}
|
||||
modelList.value = response.data.data
|
||||
} else {
|
||||
ElMessage.error('获取模型列表失败')
|
||||
}
|
||||
} catch (error) {
|
||||
availableVersions.value = []
|
||||
ElMessage.error('获取模型列表失败')
|
||||
} finally {
|
||||
versionsLoading.value = false
|
||||
modelsLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const handleProductChange = () => {
|
||||
form.model_type = ''
|
||||
form.version = ''
|
||||
availableVersions.value = []
|
||||
}
|
||||
|
||||
const handleModelTypeChange = () => {
|
||||
form.version = ''
|
||||
fetchAvailableVersions()
|
||||
}
|
||||
|
||||
const startPrediction = async () => {
|
||||
const startPrediction = async (model) => {
|
||||
predicting[model.model_id] = true
|
||||
try {
|
||||
predicting.value = true
|
||||
const payload = {
|
||||
model_type: form.model_type,
|
||||
version: form.version,
|
||||
product_id: model.product_id,
|
||||
model_type: model.model_type,
|
||||
version: model.version,
|
||||
future_days: form.future_days,
|
||||
history_lookback_days: form.history_lookback_days,
|
||||
start_date: form.start_date,
|
||||
analyze_result: form.analyze_result,
|
||||
product_id: form.product_id
|
||||
include_visualization: true, // 分析功能硬编码为开启
|
||||
}
|
||||
const response = await axios.post('/api/predict', payload)
|
||||
const response = await axios.post('/api/prediction', payload)
|
||||
if (response.data.status === 'success') {
|
||||
predictionResult.value = response.data.data
|
||||
ElMessage.success('预测完成!')
|
||||
dialogVisible.value = true
|
||||
await nextTick()
|
||||
renderChart()
|
||||
} else {
|
||||
ElMessage.error(response.data.message || '预测失败')
|
||||
ElMessage.error(response.data.error || '预测失败')
|
||||
}
|
||||
} catch (error) {
|
||||
ElMessage.error('预测请求失败')
|
||||
ElMessage.error(error.response?.data?.error || '预测请求失败')
|
||||
} finally {
|
||||
predicting.value = false
|
||||
predicting[model.model_id] = false
|
||||
}
|
||||
}
|
||||
|
||||
@ -231,28 +206,113 @@ const renderChart = () => {
|
||||
if (chart) {
|
||||
chart.destroy()
|
||||
}
|
||||
const predictions = predictionResult.value.predictions
|
||||
const labels = predictions.map(p => p.date)
|
||||
const data = predictions.map(p => p.sales)
|
||||
|
||||
const formatDate = (date) => new Date(date).toISOString().split('T')[0];
|
||||
|
||||
const historyData = (predictionResult.value.history_data || []).map(p => ({ ...p, date: formatDate(p.date) }));
|
||||
const predictionData = (predictionResult.value.prediction_data || []).map(p => ({ ...p, date: formatDate(p.date) }));
|
||||
|
||||
if (historyData.length === 0 && predictionData.length === 0) {
|
||||
ElMessage.warning('没有可用于图表的数据。')
|
||||
return
|
||||
}
|
||||
|
||||
const allLabels = [...new Set([...historyData.map(p => p.date), ...predictionData.map(p => p.date)])].sort()
|
||||
const simplifiedLabels = allLabels.map(date => date.split('-').slice(1).join('/'));
|
||||
|
||||
const historyMap = new Map(historyData.map(p => [p.date, p.sales]))
|
||||
const predictionMap = new Map(predictionData.map(p => [p.date, p.predicted_sales]))
|
||||
|
||||
const alignedHistorySales = allLabels.map(label => historyMap.get(label) ?? null)
|
||||
const alignedPredictionSales = allLabels.map(label => predictionMap.get(label) ?? null)
|
||||
|
||||
if (historyData.length > 0 && predictionData.length > 0) {
|
||||
const lastHistoryDate = historyData[historyData.length - 1].date
|
||||
const lastHistoryValue = historyData[historyData.length - 1].sales
|
||||
if (!predictionMap.has(lastHistoryDate)) {
|
||||
alignedPredictionSales[allLabels.indexOf(lastHistoryDate)] = lastHistoryValue
|
||||
}
|
||||
}
|
||||
|
||||
let subtitleText = '';
|
||||
if (historyData.length > 0) {
|
||||
subtitleText += `历史数据: ${historyData[0].date} ~ ${historyData[historyData.length - 1].date}`;
|
||||
}
|
||||
if (predictionData.length > 0) {
|
||||
if (subtitleText) subtitleText += ' | ';
|
||||
subtitleText += `预测数据: ${predictionData[0].date} ~ ${predictionData[predictionData.length - 1].date}`;
|
||||
}
|
||||
|
||||
chart = new Chart(chartCanvas.value, {
|
||||
type: 'line',
|
||||
data: {
|
||||
labels,
|
||||
datasets: [{
|
||||
label: '预测销量',
|
||||
data,
|
||||
borderColor: '#409EFF',
|
||||
backgroundColor: 'rgba(64, 158, 255, 0.1)',
|
||||
tension: 0.4,
|
||||
fill: true
|
||||
}]
|
||||
labels: simplifiedLabels,
|
||||
datasets: [
|
||||
{
|
||||
label: '历史销量',
|
||||
data: alignedHistorySales,
|
||||
borderColor: '#67C23A',
|
||||
backgroundColor: 'rgba(103, 194, 58, 0.2)',
|
||||
tension: 0.4,
|
||||
fill: true,
|
||||
spanGaps: false,
|
||||
},
|
||||
{
|
||||
label: '预测销量',
|
||||
data: alignedPredictionSales,
|
||||
borderColor: '#409EFF',
|
||||
backgroundColor: 'rgba(64, 158, 255, 0.2)',
|
||||
tension: 0.4,
|
||||
fill: true,
|
||||
borderDash: [5, 5],
|
||||
}
|
||||
]
|
||||
},
|
||||
options: {
|
||||
responsive: true,
|
||||
maintainAspectRatio: false,
|
||||
plugins: {
|
||||
title: {
|
||||
display: true,
|
||||
text: '销量预测趋势图'
|
||||
text: `${predictionResult.value.product_name} - 销量预测趋势图`,
|
||||
color: '#303133',
|
||||
font: {
|
||||
size: 20,
|
||||
weight: 'bold',
|
||||
}
|
||||
},
|
||||
subtitle: {
|
||||
display: true,
|
||||
text: subtitleText,
|
||||
color: '#606266',
|
||||
font: {
|
||||
size: 14,
|
||||
},
|
||||
padding: {
|
||||
bottom: 20
|
||||
}
|
||||
}
|
||||
},
|
||||
scales: {
|
||||
x: {
|
||||
title: {
|
||||
display: true,
|
||||
text: '日期'
|
||||
},
|
||||
grid: {
|
||||
display: false
|
||||
}
|
||||
},
|
||||
y: {
|
||||
title: {
|
||||
display: true,
|
||||
text: '销量'
|
||||
},
|
||||
grid: {
|
||||
color: '#e9e9e9',
|
||||
drawBorder: false,
|
||||
},
|
||||
beginAtZero: true
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -260,14 +320,11 @@ const renderChart = () => {
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
fetchModels()
|
||||
fetchModelTypes()
|
||||
const today = new Date()
|
||||
form.start_date = today.toISOString().split('T')[0]
|
||||
})
|
||||
|
||||
watch([() => form.product_id, () => form.model_type], () => {
|
||||
fetchAvailableVersions()
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
@ -279,15 +336,11 @@ watch([() => form.product_id, () => form.model_type], () => {
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
}
|
||||
.model-selection-section h4 {
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
.prediction-actions {
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
.filters-section, .global-settings-section, .model-list-section {
|
||||
margin-top: 20px;
|
||||
padding-top: 20px;
|
||||
border-top: 1px solid #ebeef5;
|
||||
}
|
||||
.filters-section h4, .global-settings-section h4, .model-list-section h4 {
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
.prediction-chart {
|
||||
margin-top: 20px;
|
||||
|
@ -44,7 +44,7 @@
|
||||
</el-row>
|
||||
|
||||
<el-row :gutter="20" v-if="form.model_type">
|
||||
<el-col :span="6">
|
||||
<el-col :span="5">
|
||||
<el-form-item label="模型版本">
|
||||
<el-select
|
||||
v-model="form.version"
|
||||
@ -62,7 +62,7 @@
|
||||
</el-select>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
<el-col :span="6">
|
||||
<el-col :span="5">
|
||||
<el-form-item label="预测天数">
|
||||
<el-input-number
|
||||
v-model="form.future_days"
|
||||
@ -72,7 +72,17 @@
|
||||
/>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
<el-col :span="6">
|
||||
<el-col :span="5">
|
||||
<el-form-item label="历史天数">
|
||||
<el-input-number
|
||||
v-model="form.history_lookback_days"
|
||||
:min="7"
|
||||
:max="365"
|
||||
style="width: 100%"
|
||||
/>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
<el-col :span="5">
|
||||
<el-form-item label="起始日期">
|
||||
<el-date-picker
|
||||
v-model="form.start_date"
|
||||
@ -85,7 +95,7 @@
|
||||
/>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
<el-col :span="6">
|
||||
<el-col :span="4">
|
||||
<el-form-item label="预测分析">
|
||||
<el-switch
|
||||
v-model="form.analyze_result"
|
||||
@ -147,6 +157,7 @@ const form = reactive({
|
||||
model_type: '',
|
||||
version: '',
|
||||
future_days: 7,
|
||||
history_lookback_days: 30,
|
||||
start_date: '',
|
||||
analyze_result: true
|
||||
})
|
||||
@ -203,14 +214,16 @@ const startPrediction = async () => {
|
||||
try {
|
||||
predicting.value = true
|
||||
const payload = {
|
||||
training_mode: form.training_mode,
|
||||
store_id: form.store_id,
|
||||
model_type: form.model_type,
|
||||
version: form.version,
|
||||
future_days: form.future_days,
|
||||
history_lookback_days: form.history_lookback_days,
|
||||
start_date: form.start_date,
|
||||
analyze_result: form.analyze_result,
|
||||
store_id: form.store_id
|
||||
}
|
||||
const response = await axios.post('/api/predict', payload)
|
||||
const response = await axios.post('/api/prediction', payload)
|
||||
if (response.data.status === 'success') {
|
||||
predictionResult.value = response.data.data
|
||||
ElMessage.success('预测完成!')
|
||||
@ -231,28 +244,113 @@ const renderChart = () => {
|
||||
if (chart) {
|
||||
chart.destroy()
|
||||
}
|
||||
const predictions = predictionResult.value.predictions
|
||||
const labels = predictions.map(p => p.date)
|
||||
const data = predictions.map(p => p.sales)
|
||||
|
||||
const formatDate = (date) => new Date(date).toISOString().split('T')[0];
|
||||
|
||||
const historyData = (predictionResult.value.history_data || []).map(p => ({ ...p, date: formatDate(p.date) }));
|
||||
const predictionData = (predictionResult.value.prediction_data || []).map(p => ({ ...p, date: formatDate(p.date) }));
|
||||
|
||||
if (historyData.length === 0 && predictionData.length === 0) {
|
||||
ElMessage.warning('没有可用于图表的数据。')
|
||||
return
|
||||
}
|
||||
|
||||
const allLabels = [...new Set([...historyData.map(p => p.date), ...predictionData.map(p => p.date)])].sort()
|
||||
const simplifiedLabels = allLabels.map(date => date.split('-')[2]);
|
||||
|
||||
const historyMap = new Map(historyData.map(p => [p.date, p.sales]))
|
||||
const predictionMap = new Map(predictionData.map(p => [p.date, p.predicted_sales]))
|
||||
|
||||
const alignedHistorySales = allLabels.map(label => historyMap.get(label) ?? null)
|
||||
const alignedPredictionSales = allLabels.map(label => predictionMap.get(label) ?? null)
|
||||
|
||||
if (historyData.length > 0 && predictionData.length > 0) {
|
||||
const lastHistoryDate = historyData[historyData.length - 1].date
|
||||
const lastHistoryValue = historyData[historyData.length - 1].sales
|
||||
if (!predictionMap.has(lastHistoryDate)) {
|
||||
alignedPredictionSales[allLabels.indexOf(lastHistoryDate)] = lastHistoryValue
|
||||
}
|
||||
}
|
||||
|
||||
let subtitleText = '';
|
||||
if (historyData.length > 0) {
|
||||
subtitleText += `历史数据: ${historyData[0].date} ~ ${historyData[historyData.length - 1].date}`;
|
||||
}
|
||||
if (predictionData.length > 0) {
|
||||
if (subtitleText) subtitleText += ' | ';
|
||||
subtitleText += `预测数据: ${predictionData[0].date} ~ ${predictionData[predictionData.length - 1].date}`;
|
||||
}
|
||||
|
||||
chart = new Chart(chartCanvas.value, {
|
||||
type: 'line',
|
||||
data: {
|
||||
labels,
|
||||
datasets: [{
|
||||
label: '预测销量',
|
||||
data,
|
||||
borderColor: '#409EFF',
|
||||
backgroundColor: 'rgba(64, 158, 255, 0.1)',
|
||||
tension: 0.4,
|
||||
fill: true
|
||||
}]
|
||||
labels: simplifiedLabels,
|
||||
datasets: [
|
||||
{
|
||||
label: '历史销量',
|
||||
data: alignedHistorySales,
|
||||
borderColor: '#67C23A',
|
||||
backgroundColor: 'rgba(103, 194, 58, 0.2)',
|
||||
tension: 0.4,
|
||||
fill: true,
|
||||
spanGaps: false,
|
||||
},
|
||||
{
|
||||
label: '预测销量',
|
||||
data: alignedPredictionSales,
|
||||
borderColor: '#409EFF',
|
||||
backgroundColor: 'rgba(64, 158, 255, 0.2)',
|
||||
tension: 0.4,
|
||||
fill: true,
|
||||
borderDash: [5, 5],
|
||||
}
|
||||
]
|
||||
},
|
||||
options: {
|
||||
responsive: true,
|
||||
maintainAspectRatio: false,
|
||||
plugins: {
|
||||
title: {
|
||||
display: true,
|
||||
text: '销量预测趋势图'
|
||||
text: `${predictionResult.value.product_name} - 销量预测趋势图`,
|
||||
color: '#ffffff',
|
||||
font: {
|
||||
size: 20,
|
||||
weight: 'bold',
|
||||
}
|
||||
},
|
||||
subtitle: {
|
||||
display: true,
|
||||
text: subtitleText,
|
||||
color: '#6c757d',
|
||||
font: {
|
||||
size: 14,
|
||||
},
|
||||
padding: {
|
||||
bottom: 20
|
||||
}
|
||||
}
|
||||
},
|
||||
scales: {
|
||||
x: {
|
||||
title: {
|
||||
display: true,
|
||||
text: '日期 (日)'
|
||||
},
|
||||
grid: {
|
||||
display: false
|
||||
}
|
||||
},
|
||||
y: {
|
||||
title: {
|
||||
display: true,
|
||||
text: '销量'
|
||||
},
|
||||
grid: {
|
||||
color: '#e9e9e9',
|
||||
drawBorder: false,
|
||||
},
|
||||
beginAtZero: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
101
feature_branch_workflow.md
Normal file
101
feature_branch_workflow.md
Normal file
@ -0,0 +1,101 @@
|
||||
# 功能分支开发与合并标准流程
|
||||
|
||||
本文档旨在说明一个标准、安全的功能开发流程,涵盖从创建分支到最终合并的完整步骤。
|
||||
|
||||
## 流程概述
|
||||
|
||||
1. **创建功能分支**:基于主开发分支(如 `lyf-dev`)在远程仓库创建一个新的功能分支(如 `lyf-dev-req0001`)。
|
||||
2. **同步到本地**:将远程的新分支同步到本地,并切换到该分支进行开发。
|
||||
3. **开发与提交**:在功能分支上进行代码开发,并频繁提交改动。
|
||||
4. **推送到远程**:定期将本地的提交推送到远程功能分支,用于备份和协作。
|
||||
5. **合并回主分支**:当功能开发和测试完成后,将功能分支合并回主开发分支。
|
||||
|
||||
---
|
||||
|
||||
## 详细操作步骤
|
||||
|
||||
### 第一步:同步并切换到功能分支
|
||||
|
||||
当远程仓库已经创建了新的功能分支后(例如 `lyf-dev-req0001`),本地需要执行以下命令来同步和切换。
|
||||
|
||||
1. **获取远程最新信息**:
|
||||
```bash
|
||||
git fetch
|
||||
```
|
||||
这个命令会拉取远程仓库的所有最新信息,包括新建的分支。
|
||||
|
||||
2. **创建并切换到本地分支**:
|
||||
```bash
|
||||
git checkout lyf-dev-req0001
|
||||
```
|
||||
Git 会自动检测到远程存在一个同名分支,并为您创建一个本地分支来跟踪它。
|
||||
|
||||
### 第二步:在功能分支上开发和提交
|
||||
|
||||
现在您可以在 `lyf-dev-req0001` 分支上安全地进行开发。
|
||||
|
||||
1. **进行代码修改**:添加、修改或删除文件以实现新功能。
|
||||
|
||||
2. **提交代码改动**:
|
||||
```bash
|
||||
# 添加所有修改过的文件到暂存区
|
||||
git add .
|
||||
|
||||
# 提交改动到本地仓库,并附上有意义的说明
|
||||
git commit -m "feat: 完成用户认证模块"
|
||||
```
|
||||
> **最佳实践**:保持提交的粒度小且描述清晰,方便代码审查和问题回溯。
|
||||
|
||||
### 第三步:推送功能分支到远程
|
||||
|
||||
为了备份代码和进行团队协作,需要将本地的提交推送到远程仓库。
|
||||
|
||||
```bash
|
||||
# 将当前分支 (lyf-dev-req0001) 的提交推送到远程同名分支
|
||||
git push origin lyf-dev-req0001
|
||||
```
|
||||
|
||||
### 第四步:合并功能到主开发分支 (`lyf-dev`)
|
||||
|
||||
当功能开发完毕并通过测试后,就可以准备将其合并回 `lyf-dev` 分支。
|
||||
|
||||
1. **切换到主开发分支**:
|
||||
```bash
|
||||
git checkout lyf-dev
|
||||
```
|
||||
|
||||
2. **确保主开发分支是最新版本**:
|
||||
在合并前,务必先拉取远程 `lyf-dev` 的最新代码,以减少冲突的可能性。
|
||||
```bash
|
||||
git pull origin lyf-dev
|
||||
```
|
||||
|
||||
3. **合并功能分支**:
|
||||
将 `lyf-dev-req0001` 的所有改动合并到当前的 `lyf-dev` 分支。
|
||||
```bash
|
||||
git merge lyf-dev-req0001
|
||||
```
|
||||
* **如果出现冲突 (Conflict)**:Git 会提示您哪些文件存在冲突。您需要手动打开这些文件,解决冲突部分,然后再次执行 `git add .` 和 `git commit` 来完成合并提交。
|
||||
* **如果没有冲突**:Git 会自动创建一个合并提交。
|
||||
|
||||
4. **将合并后的主分支推送到远程**:
|
||||
```bash
|
||||
git push origin lyf-dev
|
||||
```
|
||||
|
||||
### 第五步:清理(可选)
|
||||
|
||||
当功能分支确认不再需要后,可以删除它以保持仓库整洁。
|
||||
|
||||
1. **删除远程分支**:
|
||||
```bash
|
||||
git push origin --delete lyf-dev-req0001
|
||||
```
|
||||
|
||||
2. **删除本地分支**:
|
||||
```bash
|
||||
git branch -d lyf-dev-req0001
|
||||
```
|
||||
|
||||
---
|
||||
遵循以上流程可以确保您的开发工作流程清晰、安全且高效。
|
326
lyf开发日志记录文档.md
Normal file
326
lyf开发日志记录文档.md
Normal file
@ -0,0 +1,326 @@
|
||||
# 开发日志记录
|
||||
|
||||
本文档记录了项目开发过程中的主要修改、问题修复和重要决策。
|
||||
|
||||
---
|
||||
|
||||
## 2025-07-13:早期后端修复与重构
|
||||
**开发者**: lyf
|
||||
|
||||
### 13:30 - 修复数据加载路径问题
|
||||
- **任务目标**: 解决模型训练时因数据文件路径错误导致的数据加载失败问题。
|
||||
- **核心问题**: `server/core/predictor.py` 中的 `PharmacyPredictor` 类初始化时,硬编码了错误的默认数据文件路径。
|
||||
- **修复方案**: 将默认数据路径更正为 `'data/timeseries_training_data_sample_10s50p.parquet'`,并同步更新了所有训练器。
|
||||
|
||||
### 14:00 - 数据流重构
|
||||
- **任务目标**: 解决因数据处理流程中断导致关键特征丢失,从而引发模型训练失败的根本问题。
|
||||
- **核心问题**: `predictor.py` 未将预处理好的数据向下传递,导致各训练器重复加载并错误处理数据。
|
||||
- **修复方案**: 重构了核心数据流,确保数据在 `predictor.py` 中被统一加载和预处理,然后作为一个DataFrame显式传递给所有下游的训练器函数。
|
||||
|
||||
---
|
||||
|
||||
## 2025-07-14:模型训练与并发问题集中攻坚
|
||||
**开发者**: lyf
|
||||
|
||||
### 10:16 - 修复训练器层 `KeyError`
|
||||
- **问题**: 所有模型训练均因 `KeyError: "['sales', 'price'] not in index"` 失败。
|
||||
- **分析**: 训练器硬编码的特征列表中包含了数据源中不存在的 `'price'` 列。
|
||||
- **修复**: 从所有四个训练器 (`mlstm`, `transformer`, `tcn`, `kan`) 的 `features` 列表中移除了对不存在的 `'price'` 列的依赖。
|
||||
|
||||
### 10:38 - 修复数据标准化层 `KeyError`
|
||||
- **问题**: 修复后出现新错误 `KeyError: "['sales'] not in index"`。
|
||||
- **分析**: `server/utils/multi_store_data_utils.py` 中的 `standardize_column_names` 函数列名映射错误,且缺少最终列选择机制。
|
||||
- **修复**: 修正了列名映射,并增加了列选择机制,确保函数返回的 `DataFrame` 结构统一且包含 `sales` 列。
|
||||
|
||||
### 11:04 - 修复JSON序列化失败问题
|
||||
- **问题**: 训练完成后,因 `Object of type float32 is not JSON serializable` 导致前后端通信失败。
|
||||
- **分析**: 训练产生的评估指标是NumPy的 `float32` 类型,无法被标准 `json` 库序列化。
|
||||
- **修复**: 在 `server/utils/training_process_manager.py` 中增加了 `convert_numpy_types` 辅助函数,在通过WebSocket或API返回数据前,将所有NumPy数值类型转换为Python原生类型,从根源上解决了所有序列化问题。
|
||||
|
||||
### 11:15 - 修复MAPE计算错误
|
||||
- **问题**: 训练日志显示 `MAPE: nan%` 并伴有 `RuntimeWarning: Mean of empty slice.`。
|
||||
- **分析**: 当测试集中的所有真实值都为0时,计算MAPE会导致对空数组求平均值。
|
||||
- **修复**: 在 `server/analysis/metrics.py` 中增加条件判断,若不存在非零真实值,则直接将MAPE设为0。
|
||||
|
||||
### 11:41 - 修复“按店铺训练”页面列表加载失败
|
||||
- **问题**: “选择店铺”的下拉列表为空。
|
||||
- **分析**: `standardize_column_names` 函数错误地移除了包括店铺元数据在内的非训练必需列。
|
||||
- **修复**: 将列筛选的逻辑从通用的 `standardize_column_names` 函数中移出,精确地应用到仅为模型训练准备数据的函数中。
|
||||
|
||||
### 13:00 - 修复“按店铺训练-所有药品”模式
|
||||
- **问题**: 选择“所有药品”训练时,因 `product_id` 被错误地处理为字符串 `"unknown"` 而失败。
|
||||
- **修复**: 在 `server/core/predictor.py` 中拦截 `"unknown"` ID,并将其意图正确地转换为“聚合此店铺的所有产品数据”。同时扩展了 `aggregate_multi_store_data` 函数,使其支持按店铺ID进行聚合。
|
||||
|
||||
### 14:19 - 修复并发训练中的稳定性问题
|
||||
- **问题**: 并发训练时出现 `API列表排序错误` 和 `WebSocket连接错误`。
|
||||
- **修复**:
|
||||
1. **排序**: 在 `api.py` 中为 `None` 类型的 `start_time` 提供了默认值,解决了 `TypeError`。
|
||||
2. **连接**: 在 `socketio.run()` 调用时增加了 `allow_unsafe_werkzeug=True` 参数,解决了调试模式下Socket.IO与Werkzeug的冲突。
|
||||
|
||||
### 15:30 - 根治模型训练中的维度不匹配问题
|
||||
- **问题**: 所有模型训练完成后,评估指标 `R²` 始终为0.0。
|
||||
- **根本原因**: `server/utils/data_utils.py` 的 `create_dataset` 函数在创建目标数据集 `dataY` 时,错误地保留了一个多余的维度。同时,模型文件 (`mlstm_model.py`, `transformer_model.py`) 的输出也存在维度问题。
|
||||
- **最终修复**:
|
||||
1. **数据层**: 在 `create_dataset` 中使用 `.flatten()` 修正了 `y` 标签的维度。
|
||||
2. **模型层**: 在所有模型的 `forward` 方法最后增加了 `.squeeze(-1)`,确保模型输出维度正确。
|
||||
3. **训练器层**: 撤销了所有为解决此问题而做的临时性维度调整,恢复了最直接的损失计算。
|
||||
|
||||
### 16:10 - 修复“全局模型训练-所有药品”模式
|
||||
- **问题**: 与“按店铺训练”类似,全局训练的“所有药品”模式也因 `product_id="unknown"` 而失败。
|
||||
- **修复**: 采用了与店铺训练完全相同的修复模式。在 `predictor.py` 中拦截 `"unknown"` 并将其意图转换为真正的全局聚合(`product_id=None`),并扩展 `aggregate_multi_store_data` 函数以支持此功能。
|
||||
|
||||
---
|
||||
|
||||
## 2025-07-15:端到端修复“按药品预测”图表功能
|
||||
**开发者**: lyf
|
||||
|
||||
### 10:00 - 阶段一:修复数据库写入失败 (`sqlite3.IntegrityError`)
|
||||
- **问题**: 后端日志显示 `datatype mismatch`。
|
||||
- **分析**: `save_prediction_result` 函数试图将复杂Python对象直接存入数据库。
|
||||
- **修复**: 在 `server/api.py` 中,执行数据库插入前,使用 `json.dumps()` 将复杂对象序列化为JSON字符串。
|
||||
|
||||
### 10:30 - 阶段二:修复API响应结构与前端不匹配
|
||||
- **问题**: 图表依然无法渲染。
|
||||
- **分析**: 前端期望 `history_data` 在顶层,而后端将其封装在 `data` 子对象中。
|
||||
- **修复**: 修改 `server/api.py` 的 `predict` 函数,将关键数据提升到响应的根级别。
|
||||
|
||||
### 11:00 - 阶段三:修复历史数据与预测数据时间不连续
|
||||
- **问题**: 图表数据在时间上完全脱节。
|
||||
- **分析**: 获取历史数据的逻辑总是取整个数据集的最后30条,而非预测起始日期之前的30条。
|
||||
- **修复**: 在 `server/api.py` 中增加了正确的日期筛选逻辑。
|
||||
|
||||
### 14:00 - 阶段四:重构数据源,根治数据不一致问题
|
||||
- **问题**: 历史数据(绿线)与预测数据(蓝线)的口径完全不同。
|
||||
- **根本原因**: API层独立加载**原始数据**画图,而预测器使用**聚合后数据**预测。
|
||||
- **修复 (重构)**:
|
||||
1. 修改 `server/predictors/model_predictor.py`,使其返回预测结果的同时,也返回其所使用的、口径一致的历史数据。
|
||||
2. 彻底删除了 `server/api.py` 中所有独立加载历史数据的冗余代码,确保了数据源的唯一性。
|
||||
|
||||
### 15:00 - 阶段五:修复图表X轴日期格式问题
|
||||
- **问题**: X轴显示为混乱的GMT格式时间戳。
|
||||
- **分析**: `history_data` 中的 `Timestamp` 对象未被正确格式化。
|
||||
- **修复**: 在 `server/api.py` 中,为 `history_data` 增加了 `.strftime('%Y-%m-%d')` 的格式化处理。
|
||||
|
||||
### 16:00 - 阶段六:修复模型“学不会”的根本原因 (超参数传递中断)
|
||||
- **问题**: 即便流程正确,所有模型的预测结果依然是无法学习的直线。
|
||||
- **根本原因**: `server/core/predictor.py` 在调用训练器时,**没有将 `sequence_length` 等关键超参数传递下去**,导致所有模型都使用了错误的默认值。
|
||||
- **修复**:
|
||||
1. 修改 `server/core/predictor.py`,在调用中加入超参数的传递。
|
||||
2. 修改所有四个训练器文件,使其能接收并使用这些参数。
|
||||
|
||||
---
|
||||
|
||||
## 2025-07-16:最终验证与项目总结
|
||||
**开发者**: lyf
|
||||
|
||||
### 10:00 - 阶段七:最终验证与结论
|
||||
- **问题**: 在修复所有代码问题后,对特定日期的预测结果依然是平线。
|
||||
- **分析**: 通过编写临时数据分析脚本 (`temp_check_parquet.py`) 最终确认,这是**数据本身**的问题。我们选择的预测日期在样本数据集中恰好处于一个“零销量”的空白期。
|
||||
- **最终结论**: 系统代码已完全修复。图表上显示的平线,是模型对“零销量”历史做出的**正确且符合逻辑**的反应。
|
||||
|
||||
### 11:45 - 项目总结与文档归档
|
||||
- **任务**: 根据用户要求,回顾整个调试过程,将所有问题、解决方案、优化思路和最终结论,按照日期和时间顺序,整理并更新到本开发日志中,形成一份高质量的技术档案。
|
||||
- **结果**: 本文档已更新完成。
|
||||
|
||||
|
||||
### 13:15 - 最终修复:根治模型标识符不一致问题
|
||||
- **问题**: 经过再次测试和日志分析,发现即便是修正后,店铺模型的 `model_identifier` 在训练时依然被错误地构建为 `01010023_store_01010023`。
|
||||
- **根本原因**: `server/core/predictor.py` 的 `train_model` 方法中,在 `training_mode == 'store'` 的分支下,构建 `model_identifier` 的逻辑存在冗余和错误。
|
||||
- **最终解决方案**: 删除了错误的拼接逻辑 `model_identifier = f"{store_id}_{product_id}"`,直接使用在之前步骤中已经被正确赋值为 `f"store_{store_id}"` 的 `product_id` 变量作为 `model_identifier`。这确保了从训练、保存到最终API查询,店铺模型的唯一标识符始终保持一致。
|
||||
|
||||
|
||||
### 13:30 - 最终修复(第二轮):根治模型保存路径错误
|
||||
- **问题**: 即便修复了标识符,模型版本依然无法加载。
|
||||
- **根本原因**: 通过分析训练日志,发现所有训练器(`transformer_trainer.py`, `mlstm_trainer.py`, `tcn_trainer.py`)中的 `save_checkpoint` 函数,都会强制在 `saved_models` 目录下创建一个 `checkpoints` 子目录,并将所有模型文件保存在其中。而负责查找模型的 `get_model_versions` 函数只在根目录查找,导致模型永远无法被发现。
|
||||
- **最终解决方案**: 逐一修改了所有相关训练器文件中的 `save_checkpoint` 函数,移除了创建和使用 `checkpoints` 子目录的逻辑,确保所有模型都直接保存在 `saved_models` 根目录下。
|
||||
- **结论**: 至此,模型保存的路径与查找的路径完全统一,从根本上解决了模型版本无法加载的问题。
|
||||
|
||||
|
||||
### 13:40 - 最终修复(第三轮):统一所有训练器的模型保存逻辑
|
||||
- **问题**: 在修复了 `transformer_trainer.py` 后,发现 `mlstm_trainer.py` 和 `tcn_trainer.py` 存在完全相同的路径和命名错误,导致问题依旧。
|
||||
- **根本原因**: `save_checkpoint` 函数在所有训练器中都被错误地实现,它们都强制创建了 `checkpoints` 子目录,并使用了错误的逻辑来拼接文件名。
|
||||
- **最终解决方案**:
|
||||
1. **逐一修复**: 逐一修改了 `transformer_trainer.py`, `mlstm_trainer.py`, 和 `tcn_trainer.py` 中的 `save_checkpoint` 函数。
|
||||
2. **路径修复**: 移除了创建和使用 `checkpoints` 子目录的逻辑,确保模型直接保存在 `model_dir` (即 `saved_models`) 的根目录下。
|
||||
3. **文件名修复**: 简化并修正了文件名的生成逻辑,直接使用 `product_id` 参数作为唯一标识符(该参数已由上游逻辑正确赋值为 `药品ID` 或 `store_{店铺ID}`),不再进行任何额外的、错误的拼接。
|
||||
- **结论**: 至此,所有训练器的模型保存逻辑完全统一,模型保存的路径和文件名与API的查找逻辑完全匹配,从根本上解决了模型版本无法加载的问题。
|
||||
|
||||
|
||||
---
|
||||
|
||||
## 2025-07-16 (续):端到端修复“店铺预测”图表功能
|
||||
**开发者**: lyf
|
||||
|
||||
### 15:30 - 最终修复(第四轮):打通店铺预测的数据流
|
||||
- **问题**: 在解决了模型加载问题后,“店铺预测”功能虽然可以成功执行,但前端图表依然空白,不显示历史数据和预测数据。
|
||||
- **根本原因**: 参数传递在调用链中出现断裂。
|
||||
1. `server/api.py` 在调用 `run_prediction` 时,没有传递 `training_mode`。
|
||||
2. `server/core/predictor.py` 在调用 `load_model_and_predict` 时,没有传递 `store_id` 和 `training_mode`。
|
||||
3. `server/predictors/model_predictor.py` 内部的数据加载逻辑,在处理店铺预测时,错误地使用了模型标识符(`store_{id}`)作为产品ID来过滤数据,导致无法加载到任何历史数据。
|
||||
- **最终解决方案 (三步修复)**:
|
||||
1. **修复 `model_predictor.py`**: 修改 `load_model_and_predict` 函数,使其能够根据 `training_mode` 参数智能地加载数据。当模式为 `'store'` 时,它会正确地聚合该店铺的所有销售数据作为历史数据,这与训练时的数据准备方式完全一致。
|
||||
2. **修复 `predictor.py`**: 修改 `predict` 方法,将 `store_id` 和 `training_mode` 参数正确地传递给底层的 `load_model_and_predict` 函数。
|
||||
3. **修复 `api.py`**: 修改 `predict` 路由和 `run_prediction` 辅助函数,确保 `training_mode` 参数在整个调用链中被完整传递。
|
||||
- **结论**: 通过以上修复,我们确保了从API接口到最底层数据加载器的参数传递是完整和正确的。现在,无论是药品预测还是店铺预测,系统都能够加载正确的历史数据用于图表绘制,彻底解决了图表显示空白的问题。
|
||||
|
||||
### 16:16 - 项目状态更新
|
||||
- **状态**: **所有已知问题已修复**。
|
||||
- **确认**: 用户已确认“现在药品和店铺预测流程通了。
|
||||
- **后续**: 将本次修复过程归档至本文档。
|
||||
|
||||
|
||||
---
|
||||
|
||||
### 2025年7月16日 18:38 - 全模型预测功能通用性修复
|
||||
|
||||
**问题现象**:
|
||||
在解决了 `Transformer` 模型的预测问题后,发现一个更深层次的系统性问题:在所有预测模式(按药品、按店铺、全局)中,只有 `Transformer` 算法可以成功预测并显示图表,而其他四种模型(`mLSTM`, `KAN`, `优化版KAN`, `TCN`)虽然能成功训练,但在预测时均会失败,并提示“没有可用于图表的数据”。
|
||||
|
||||
**根本原因深度分析**:
|
||||
这个问题的核心在于**模型配置的持久化不完整且不统一**。
|
||||
|
||||
1. **Transformer 的“幸存”**: `Transformer` 模型的实现恰好不依赖于那些在保存时被遗漏的特定超参数,因此它能“幸存”下来。
|
||||
2. **其他模型的“共性缺陷”**: 其他所有模型 (`mLSTM`, `TCN`, `KAN`) 在它们的构造函数中,都依赖于一些在训练时定义、但在保存到检查点文件 (`.pth`) 时**被遗漏的**关键结构性参数。
|
||||
* **mLSTM**: 缺少 `mlstm_layers`, `embed_dim`, `dense_dim` 等参数。
|
||||
* **TCN**: 缺少 `num_channels`, `kernel_size` 等参数。
|
||||
* **KAN**: 缺少 `hidden_sizes` 列表。
|
||||
3. **连锁失败**:
|
||||
* 当 `server/predictors/model_predictor.py` 尝试加载这些模型的检查点文件时,它从 `checkpoint['config']` 中找不到实例化模型所必需的全部参数。
|
||||
* 模型实例化失败,抛出 `KeyError` 或 `TypeError`。
|
||||
* 这个异常导致 `load_model_and_predict` 函数提前返回 `None`,最终导致返回给前端的响应中缺少 `history_data`,前端因此无法渲染图表。
|
||||
|
||||
**系统性、可扩展的解决方案**:
|
||||
为了彻底解决这个问题,并为未来平稳地加入新算法,我们对所有非 Transformer 的训练器进行了标准化的、彻底的修复。
|
||||
|
||||
1. **修复 `mlstm_trainer.py`**: 在 `config` 字典中补全了 `mlstm_layers`, `embed_dim`, `dense_dim` 等所有缺失的参数。
|
||||
2. **修复 `tcn_trainer.py`**: 在 `config` 字典中补全了 `num_channels`, `kernel_size` 等所有缺失的参数。
|
||||
3. **修复 `kan_trainer.py`**: 在 `config` 字典中补全了 `hidden_sizes` 列表。
|
||||
|
||||
**结果**:
|
||||
通过这次系统性的修复,我们确保了所有训练器在保存模型时,都会将完整的、可用于重新实例化模型的配置信息写入检查点文件。这从根本上解决了所有模型算法的预测失败问题,使得整个系统在处理不同算法时具有了通用性和健壮性。
|
||||
|
||||
|
||||
---
|
||||
|
||||
## 2025-07-17:系统性链路疏通与规范化
|
||||
**开发者**: lyf
|
||||
|
||||
### 15:00 - 创建技术文档与上手指南
|
||||
- **任务**: 为了便于新成员理解和未来维护,创建了两份核心技术文档。
|
||||
- **产出**:
|
||||
1. **`系统调用逻辑与核心代码分析.md`**: 一份深入代码细节的端到端调用链路分析文档,详细描述了从前端交互到后端处理,再到模型训练和预测的完整流程。
|
||||
2. **`项目快速上手指南.md`**: 一份面向新成员(特别是Java背景)的高层次指南,通过技术栈类比、架构分层图和清晰的开发流程,帮助新成员快速建立对项目的宏观理解。
|
||||
|
||||
### 16:00 - 修复 `mLSTM` 模型加载链路
|
||||
- **问题**: `mLSTM` 模型在预测时因参数名不一致而加载失败。
|
||||
- **分析**:
|
||||
- 第一次失败: 加载器需要 `num_layers`,但训练器保存的是 `mlstm_layers`。
|
||||
- 第二次失败: 加载器需要 `dropout`,但训练器保存的是 `dropout_rate`。
|
||||
- **修复**: 遵循“保存方决定命名”的原则,修改了 `server/predictors/model_predictor.py`,将加载时使用的参数名统一为 `mlstm_layers` 和 `dropout_rate`,与训练器保持一致。
|
||||
|
||||
### 16:45 - 修复 `mLSTM` 模型算法缺陷
|
||||
- **问题**: `mLSTM` 模型修复加载问题后,预测结果为一条无效的直线。
|
||||
- **根本原因**: `server/models/mlstm_model.py` 中的模型架构存在设计缺陷。其解码器逻辑错误地将输入序列的最后一个时间步复制多份作为预测,导致模型无法学习时间序列的变化。
|
||||
- **修复**: 重构了 `MLSTMTransformer` 类的 `forward` 方法,移除了有问题的解码器逻辑,改为直接使用编码器最终的隐藏状态通过一个线性层进行预测,从根本上修正了算法的实现。
|
||||
|
||||
### 17:00 - 修复 `TCN` 模型加载链路
|
||||
- **问题**: `TCN` 模型在预测加载时存在硬编码参数,是一个潜在的崩溃点。
|
||||
- **分析**: `server/predictors/model_predictor.py` 在创建 `TCNForecaster` 实例时,硬编码了 `kernel_size=3`,而没有从模型配置中读取。
|
||||
- **修复**: 修改了 `model_predictor.py`,使其从 `config['kernel_size']` 中动态读取该参数,确保了配置的完整性和一致性。
|
||||
|
||||
### 17:15 - 修复 `KAN` 模型版本发现问题
|
||||
- **问题**: `KAN` 和 `优化版KAN` 训练成功后,在预测页面无法找到任何模型版本。
|
||||
- **根本原因**: **保存**和**搜索**逻辑不匹配。`kan_trainer.py` 使用 `model_manager.py` 以 `..._product_...` 格式保存模型,而 `server/core/config.py` 中的 `get_model_versions` 函数却只按 `..._epoch_...` 的格式进行搜索。
|
||||
- **修复**: 扩展了 `config.py` 中的 `get_model_versions` 函数,使其能够兼容并搜索多种命名格式,包括 `KAN` 模型使用的 `..._product_...` 格式。
|
||||
|
||||
### 17:25 - 修复 `KAN` 模型文件路径生成问题
|
||||
- **问题**: 修复版本发现问题后,点击预测依然失败,提示“未找到模型文件”。
|
||||
- **根本原因**: 只修复了**版本发现**逻辑,但未同步修复**文件路径生成**逻辑。`config.py` 中的 `get_model_file_path` 函数在为 `KAN` 模型生成路径时,依然错误地使用了 `_epoch_` 格式。
|
||||
- **修复**: 修改了 `get_model_file_path` 函数,为 `kan` 和 `optimized_kan` 模型增加了特殊处理,确保在生成其文件路径时使用正确的 `_product_` 命名格式。
|
||||
|
||||
### 17:40 - 升级 `KAN` 训练器的版本管理功能
|
||||
- **问题**: `KAN` 模型只有一个静态的 `'v1'` 版本,与其他模型(有 `best`, `final_epoch_...` 等版本)不一致。
|
||||
- **根本原因**: `kan_trainer.py` 的实现逻辑过于简单,缺少在训练过程中动态评估并保存多个版本的功能,仅在最后硬编码保存为 `'v1'`。
|
||||
- **修复 (功能升级)**: 重构了 `server/trainers/kan_trainer.py`,为其增加了与其他训练器完全一致的动态版本管理功能。现在它可以在训练时自动追踪并保存性能最佳的 `best` 版本,并在训练结束后保存 `final_epoch_...` 版本。
|
||||
|
||||
### 17:58 - 最终结论
|
||||
- **状态**: **所有已知问题已修复**。
|
||||
- **成果**:
|
||||
1. 所有模型的 **“数据 -> 训练 -> 保存 -> 加载 -> 预测 -> 可视化”** 执行链路已全面打通和验证。
|
||||
2. 统一并修复了所有模型在配置持久化和加载过程中的参数不一致问题。
|
||||
3. 将所有模型的版本管理逻辑和工程实现标准完全对齐。
|
||||
4. 创建并完善了核心技术文档,固化了开发规范。
|
||||
- **项目状态**: 系统现在处于一个健壮、一致且可扩展的稳定状态。
|
||||
|
||||
---
|
||||
|
||||
## 2025-07-18: 系统性重构模型版本管理机制
|
||||
**开发者**: lyf
|
||||
|
||||
### 14:00 - 根治版本混乱与模型加载失败问题
|
||||
- **问题现象**: `KAN` 及其他算法在训练后,预测时出现版本号混乱(如出现裸数字 `1`、`3` 或 `best` 等无效版本)、版本重复、以及因版本不匹配导致的“模型文件未找到”的 `404` 错误。
|
||||
- **根本原因深度分析**:
|
||||
1. **逻辑分散**: 版本生成的逻辑分散在各个训练器 (`trainer`) 中,而版本发现的逻辑在 `config.py` 中,两者标准不一,充满冲突的正则表达式和硬编码规则。
|
||||
2. **命名不统一**: `KAN` 训练器使用 `model_manager` 保存,而其他训练器使用本地的 `save_checkpoint` 函数,导致了 `..._product_..._v1.pth` 和 `..._epoch_best.pth` 等多种不兼容的命名格式并存。
|
||||
3. **提取错误**: `config.py` 中的 `get_model_versions` 函数因其过于宽泛和冲突的匹配规则,会从文件名中错误地提取出无效的版本号,是导致前端下拉框内容混乱的直接原因。
|
||||
- **系统性重构解决方案**:
|
||||
1. **确立单一权威**: 将 [`server/utils/model_manager.py`](server/utils/model_manager.py:1) 确立为系统中唯一负责版本管理、模型命名和文件IO的组件。
|
||||
2. **实现自动版本控制**: 在 `ModelManager` 中增加了 `_get_next_version` 内部方法,使其能够自动扫描现有文件,并安全地生成下一个递增的、带 `v` 前缀的版本号(如 `v3`)。
|
||||
3. **统一所有训练器**: 全面重构了 `kan_trainer.py`, `mlstm_trainer.py`, `tcn_trainer.py`, 和 `transformer_trainer.py`。现在,所有训练器在保存最终模型时,都调用 `model_manager.save_model` 并且**不再自行决定版本号**,完全由 `ModelManager` 自动生成。对于训练过程中的最佳模型,则统一显式保存为 `best` 版本。
|
||||
4. **清理与加固**: 废弃并删除了 `config.py` 中所有旧的、有问题的版本管理函数,并重写了 `get_model_versions`,使其只使用严格的正则表达式来查找和解析符合新命名规范的模型版本。
|
||||
5. **优化API**: 更新了 `api.py`,使其完全与新的 `ModelManager` 对接,并改进了预测失败时的错误信息反馈。
|
||||
- **结论**: 通过这次重构,系统的版本管理机制从一个分散、混乱、充满硬编码的状态,升级为了一个集中的、统一的、自动化的健壮系统。所有已知相关的bug已被从根本上解决。
|
||||
|
||||
---
|
||||
|
||||
## 2025-07-18 (续): 实现“按店铺”AI闭环及连锁Bug修复
|
||||
**开发者**: lyf
|
||||
|
||||
### 15:00 - 架构升级:实现“按店铺”训练与预测功能
|
||||
- **任务目标**: 在现有“按药品”模式基础上,增加并打通“按店铺”维度的完整AI闭环。
|
||||
- **核心挑战**: 需要对数据处理、模型标识、训练流程和API调用进行系统性改造,以支持新的训练模式。
|
||||
- **解决方案 (四步重构)**:
|
||||
1. **升级 `ModelManager`**: 重新设计了模型命名规则,为店铺和全局模型提供了清晰、无歧义的标识(如 `transformer_store_S001_v1.pth`),并同步更新了解析逻辑。
|
||||
2. **修正核心预测器**: 修复了 `predictor.py` 中的关键逻辑缺陷,确保在店铺模式下,系统能生成并使用正确的 `model_identifier`(如 `store_S001`),并强制调用数据聚合函数。
|
||||
3. **适配API层**: 调整了 `api.py` 中的训练和预测接口,使其能够兼容和正确处理新的店铺模式请求。
|
||||
4. **统一所有训练器**: 对全部四个训练器文件进行了统一修改,确保它们在保存模型时,都正确地使用了新的 `model_identifier`。
|
||||
|
||||
### 15:30 - 连锁Bug修复第一环:解决店铺模型版本加载失败
|
||||
- **问题现象**: “按店铺预测”页面的模型版本下拉框为空。
|
||||
- **根本原因**: `api.py` 中负责获取店铺模型版本的接口 `get_store_model_versions_api` 仍在使用旧的、不兼容新命名规范的函数来查找模型。
|
||||
- **修复**: 重写了该接口,使其放弃旧函数,转而使用 `ModelManager` 来进行统一、可靠的模型查找。
|
||||
|
||||
### 15:40 - 连锁Bug修复第二环:解决店铺预测 `404` 失败
|
||||
- **问题现象**: 版本列表加载正常后,点击“开始预测”返回 `404` 错误。
|
||||
- **根本原因**: 后端预测接口 `predict()` 内部的执行函数 `load_model_and_predict` 存在一段过时的、手动的模型文件查找逻辑,它完全绕过了 `ModelManager`,并错误地构建了文件路径。
|
||||
- **修复 (联合重构)**:
|
||||
1. **改造 `model_predictor.py`**: 彻底移除了 `load_model_and_predict` 函数内部所有过时的文件查找代码,并修改其函数签名,使其直接接收一个明确的 `model_path` 参数。
|
||||
2. **改造 `api.py`**: 修改了 `predict` 接口,将在API层通过 `ModelManager` 找到的正确模型路径,一路传递到最底层的 `load_model_and_predict` 函数中,确保了调用链的逻辑一致性。
|
||||
|
||||
### 15:50 - 连锁Bug修复第三环:解决服务启动 `NameError`
|
||||
- **问题现象**: 在修复预测逻辑后,API服务无法启动,报错 `NameError: name 'Optional' is not defined`。
|
||||
- **根本原因**: 在修改 `model_predictor.py` 时,使用了 `Optional` 类型提示,但忘记从 `typing` 模块导入。
|
||||
- **修复**: 在 `server/predictors/model_predictor.py` 文件顶部添加了 `from typing import Optional`。
|
||||
- **最终结论**: 至此,所有与“按店铺”功能相关的架构升级和连锁bug均已修复。系统现在能够稳定、正确地处理两种维度的训练和预测任务,并且代码逻辑更加统一和健壮。
|
||||
|
||||
|
||||
---
|
||||
|
||||
## 2025-07-21:前后端联合调试与UI修复
|
||||
**开发者**: lyf
|
||||
|
||||
### 15:45 - 修复后端 `DataFrame` 序列化错误
|
||||
- **问题现象**: 在清理了历史模型并重新进行预测后,前端出现 `Object of type DataFrame is not JSON serializable` 错误。
|
||||
- **根本原因**: `server/predictors/model_predictor.py` 中的 `load_model_and_predict` 函数在返回结果时,为了兼容旧版接口而保留的 `'predictions'` 字段,其值依然是未经处理的 Pandas DataFrame (`predictions_df`)。
|
||||
- **修复方案**: 修改了该函数的返回字典,将 `'predictions'` 字段的值也更新为已经过 `.to_dict('records')` 方法处理的 `prediction_data_json`,确保了返回对象的所有部分都是JSON兼容的。
|
||||
|
||||
### 16:00 - 统一修复所有预测视图的图表渲染问题
|
||||
- **问题现象**: 在解决了后端的序列化问题后,所有三个预测视图(按药品、按店铺、全局)的图表均为空白,并且图表下方的日期副标题显示为未经格式化的原始JavaScript日期字符串。
|
||||
- **根本原因深度分析**:
|
||||
1. **数据访问路径不精确**: 前端代码直接从API响应的根对象 (`response.data`) 中获取数据,而最可靠的数据源位于 `response.data.data` 中。
|
||||
2. **日期对象处理不当**: 前端代码未能将从后端接收到的日期(无论是字符串还是由axios自动转换的Date对象)标准化为统一的字符串格式。这导致在使用 `Set` 对日期进行去重时,因对象引用不同而失败,最终图表上没有数据点。
|
||||
- **统一修复方案**:
|
||||
1. **逐一修改**: 逐一修改了 `ProductPredictionView.vue`, `StorePredictionView.vue`, 和 `GlobalPredictionView.vue` 三个文件。
|
||||
2. **修正数据访问**: 在 `startPrediction` 方法中,将API响应的核心数据 `response.data.data` 赋值给 `predictionResult`。
|
||||
3. **标准化日期**: 在 `renderChart` 方法的开头,增加了一个 `formatDate` 辅助函数,并在处理数据时立即调用它,将所有日期都统一转换为 `'YYYY-MM-DD'` 格式的字符串,从而一举解决了数据点丢失和标题格式错误的双重问题。
|
||||
- **最终结论**: 至此,所有预测视图的前后端数据链路和UI展示功能均已修复,系统功能恢复正常。
|
Binary file not shown.
@ -56,3 +56,6 @@ tzdata==2025.2
|
||||
werkzeug==3.1.3
|
||||
win32-setctime==1.2.0
|
||||
wsproto==1.2.0
|
||||
python-dateutil
|
||||
xgboost
|
||||
scikit-learn
|
||||
|
BIN
sales_trends.png
BIN
sales_trends.png
Binary file not shown.
Before Width: | Height: | Size: 348 KiB |
830
server/api.py
830
server/api.py
File diff suppressed because it is too large
Load Diff
@ -58,7 +58,9 @@ HIDDEN_SIZE = 64 # 隐藏层大小
|
||||
NUM_LAYERS = 2 # 层数
|
||||
|
||||
# 支持的模型类型
|
||||
SUPPORTED_MODELS = ['mlstm', 'kan', 'transformer', 'tcn', 'optimized_kan']
|
||||
# 支持的模型类型 (v2 - 动态加载)
|
||||
from models.model_registry import TRAINER_REGISTRY
|
||||
SUPPORTED_MODELS = list(TRAINER_REGISTRY.keys())
|
||||
|
||||
# 版本管理配置
|
||||
MODEL_VERSION_PREFIX = 'v' # 版本前缀
|
||||
@ -71,76 +73,30 @@ TRAINING_UPDATE_INTERVAL = 1 # 训练进度更新间隔(秒)
|
||||
# 创建模型保存目录
|
||||
os.makedirs(DEFAULT_MODEL_DIR, exist_ok=True)
|
||||
|
||||
def get_next_model_version(product_id: str, model_type: str) -> str:
|
||||
"""
|
||||
获取指定产品和模型类型的下一个版本号
|
||||
|
||||
Args:
|
||||
product_id: 产品ID
|
||||
model_type: 模型类型
|
||||
|
||||
Returns:
|
||||
下一个版本号,格式如 'v2', 'v3' 等
|
||||
"""
|
||||
# 新格式:带版本号的文件
|
||||
pattern_new = f"{model_type}_model_product_{product_id}_v*.pth"
|
||||
existing_files_new = glob.glob(os.path.join(DEFAULT_MODEL_DIR, pattern_new))
|
||||
|
||||
# 旧格式:不带版本号的文件(兼容性支持)
|
||||
pattern_old = f"{model_type}_model_product_{product_id}.pth"
|
||||
old_file_path = os.path.join(DEFAULT_MODEL_DIR, pattern_old)
|
||||
has_old_format = os.path.exists(old_file_path)
|
||||
|
||||
# 如果没有任何格式的文件,返回默认版本
|
||||
if not existing_files_new and not has_old_format:
|
||||
return DEFAULT_VERSION
|
||||
|
||||
# 提取新格式文件的版本号
|
||||
versions = []
|
||||
for file_path in existing_files_new:
|
||||
filename = os.path.basename(file_path)
|
||||
version_match = re.search(rf"_v(\d+)\.pth$", filename)
|
||||
if version_match:
|
||||
versions.append(int(version_match.group(1)))
|
||||
|
||||
# 如果存在旧格式文件,将其视为v1
|
||||
if has_old_format:
|
||||
versions.append(1)
|
||||
print(f"检测到旧格式模型文件: {old_file_path},将其视为版本v1")
|
||||
|
||||
if versions:
|
||||
next_version_num = max(versions) + 1
|
||||
return f"v{next_version_num}"
|
||||
else:
|
||||
return DEFAULT_VERSION
|
||||
|
||||
def get_model_file_path(product_id: str, model_type: str, version: str = None) -> str:
|
||||
def get_model_file_path(product_id: str, model_type: str, version: str) -> str:
|
||||
"""
|
||||
生成模型文件路径
|
||||
根据产品ID、模型类型和版本号,生成模型文件的准确路径。
|
||||
|
||||
Args:
|
||||
product_id: 产品ID
|
||||
product_id: 产品ID (纯数字)
|
||||
model_type: 模型类型
|
||||
version: 版本号,如果为None则获取下一个版本
|
||||
version: 版本字符串 (例如 'best', 'final_epoch_50', 'v1_legacy')
|
||||
|
||||
Returns:
|
||||
模型文件的完整路径
|
||||
"""
|
||||
if version is None:
|
||||
version = get_next_model_version(product_id, model_type)
|
||||
|
||||
# 特殊处理v1版本:检查是否存在旧格式文件
|
||||
if version == "v1":
|
||||
# 检查旧格式文件是否存在
|
||||
old_format_filename = f"{model_type}_model_product_{product_id}.pth"
|
||||
old_format_path = os.path.join(DEFAULT_MODEL_DIR, old_format_filename)
|
||||
|
||||
if os.path.exists(old_format_path):
|
||||
print(f"找到旧格式模型文件: {old_format_path},将其作为v1版本")
|
||||
return old_format_path
|
||||
|
||||
# 使用新格式文件名
|
||||
filename = f"{model_type}_model_product_{product_id}_{version}.pth"
|
||||
# 处理历史遗留的 "v1" 格式
|
||||
if version == "v1_legacy":
|
||||
filename = f"{model_type}_model_product_{product_id}.pth"
|
||||
return os.path.join(DEFAULT_MODEL_DIR, filename)
|
||||
|
||||
# 修正:直接使用唯一的product_id(它可能包含store_前缀)来构建文件名
|
||||
# 文件名示例: transformer_17002608_epoch_best.pth 或 transformer_store_01010023_epoch_best.pth
|
||||
# 针对 KAN 和 optimized_kan,使用 model_manager 的命名约定
|
||||
# 统一所有模型的命名格式
|
||||
filename = f"{model_type}_product_{product_id}_{version}.pth"
|
||||
# 修正:直接在根模型目录查找,不再使用checkpoints子目录
|
||||
return os.path.join(DEFAULT_MODEL_DIR, filename)
|
||||
|
||||
def get_model_versions(product_id: str, model_type: str) -> list:
|
||||
@ -148,54 +104,38 @@ def get_model_versions(product_id: str, model_type: str) -> list:
|
||||
获取指定产品和模型类型的所有版本
|
||||
|
||||
Args:
|
||||
product_id: 产品ID
|
||||
product_id: 产品ID (现在应该是纯数字ID)
|
||||
model_type: 模型类型
|
||||
|
||||
Returns:
|
||||
版本列表,按版本号排序
|
||||
"""
|
||||
# 新格式:带版本号的文件
|
||||
pattern_new = f"{model_type}_model_product_{product_id}_v*.pth"
|
||||
existing_files_new = glob.glob(os.path.join(DEFAULT_MODEL_DIR, pattern_new))
|
||||
# 统一使用新的命名约定进行搜索
|
||||
pattern = os.path.join(DEFAULT_MODEL_DIR, f"{model_type}_product_{product_id}_*.pth")
|
||||
existing_files = glob.glob(pattern)
|
||||
|
||||
# 旧格式:不带版本号的文件(兼容性支持)
|
||||
pattern_old = f"{model_type}_model_product_{product_id}.pth"
|
||||
old_file_path = os.path.join(DEFAULT_MODEL_DIR, pattern_old)
|
||||
has_old_format = os.path.exists(old_file_path)
|
||||
versions = set()
|
||||
|
||||
versions = []
|
||||
|
||||
# 处理新格式文件
|
||||
for file_path in existing_files_new:
|
||||
for file_path in existing_files:
|
||||
filename = os.path.basename(file_path)
|
||||
version_match = re.search(rf"_v(\d+)\.pth$", filename)
|
||||
if version_match:
|
||||
version_num = int(version_match.group(1))
|
||||
versions.append(f"v{version_num}")
|
||||
|
||||
# 如果存在旧格式文件,将其视为v1
|
||||
if has_old_format:
|
||||
if "v1" not in versions: # 避免重复添加
|
||||
versions.append("v1")
|
||||
print(f"检测到旧格式模型文件: {old_file_path},将其视为版本v1")
|
||||
|
||||
# 按版本号排序
|
||||
versions.sort(key=lambda v: int(v[1:]))
|
||||
return versions
|
||||
|
||||
def get_latest_model_version(product_id: str, model_type: str) -> str:
|
||||
"""
|
||||
获取指定产品和模型类型的最新版本
|
||||
|
||||
Args:
|
||||
product_id: 产品ID
|
||||
model_type: 模型类型
|
||||
|
||||
Returns:
|
||||
最新版本号,如果没有则返回None
|
||||
"""
|
||||
versions = get_model_versions(product_id, model_type)
|
||||
return versions[-1] if versions else None
|
||||
# 严格匹配 _v<number> 或 'best'
|
||||
match = re.search(r'_(v\d+|best)\.pth$', filename)
|
||||
if match:
|
||||
versions.add(match.group(1))
|
||||
|
||||
# 按数字版本降序排序,'best'始终在最前
|
||||
def sort_key(v):
|
||||
if v == 'best':
|
||||
return -1 # 'best' is always first
|
||||
if v.startswith('v'):
|
||||
return int(v[1:])
|
||||
return float('inf') # Should not happen
|
||||
|
||||
sorted_versions = sorted(list(versions), key=sort_key, reverse=True)
|
||||
|
||||
return sorted_versions
|
||||
|
||||
|
||||
def save_model_version_info(product_id: str, model_type: str, version: str, file_path: str, metrics: dict = None):
|
||||
"""
|
||||
|
@ -11,12 +11,13 @@ import time
|
||||
import matplotlib.pyplot as plt
|
||||
from datetime import datetime
|
||||
|
||||
from trainers import (
|
||||
train_product_model_with_mlstm,
|
||||
train_product_model_with_kan,
|
||||
train_product_model_with_tcn,
|
||||
train_product_model_with_transformer
|
||||
)
|
||||
# from trainers import (
|
||||
# train_product_model_with_mlstm,
|
||||
# train_product_model_with_kan,
|
||||
# train_product_model_with_tcn,
|
||||
# train_product_model_with_transformer
|
||||
# )
|
||||
# 上述导入已不再需要,因为我们现在通过模型注册表动态获取训练器
|
||||
from predictors.model_predictor import load_model_and_predict
|
||||
from utils.data_utils import prepare_data, prepare_sequences
|
||||
from utils.multi_store_data_utils import (
|
||||
@ -132,8 +133,8 @@ class PharmacyPredictor:
|
||||
file_path=self.data_path
|
||||
)
|
||||
log_message(f"按店铺聚合训练: 店铺 {store_id}, 聚合方法 {aggregation_method}, 数据量: {len(product_data)}")
|
||||
# 将product_id设置为店铺ID,以便模型保存时使用有意义的标识
|
||||
product_id = store_id
|
||||
# 将product_id设置为'store_{store_id}',与API查找逻辑保持一致
|
||||
product_id = f"store_{store_id}"
|
||||
except Exception as e:
|
||||
log_message(f"聚合店铺 {store_id} 数据失败: {e}", 'error')
|
||||
return None
|
||||
@ -177,82 +178,59 @@ class PharmacyPredictor:
|
||||
log_message(f"不支持的训练模式: {training_mode}", 'error')
|
||||
return None
|
||||
|
||||
# 根据训练模式构建模型标识符
|
||||
# 根据训练模式构建模型标识符 (v2 修正)
|
||||
if training_mode == 'store':
|
||||
model_identifier = f"{store_id}_{product_id}"
|
||||
# 店铺模型的标识符只应基于店铺ID
|
||||
model_identifier = f"store_{store_id}"
|
||||
elif training_mode == 'global':
|
||||
model_identifier = f"global_{product_id}_{aggregation_method}"
|
||||
else:
|
||||
# 全局模型的标识符不应依赖于单个product_id
|
||||
model_identifier = f"global_{aggregation_method}"
|
||||
else: # product mode
|
||||
model_identifier = product_id
|
||||
|
||||
# 调用相应的训练函数
|
||||
# 调用相应的训练函数 (重构为使用注册表)
|
||||
try:
|
||||
log_message(f"🤖 开始调用 {model_type} 训练器")
|
||||
if model_type == 'transformer':
|
||||
model_result, metrics, actual_version = train_product_model_with_transformer(
|
||||
product_id=product_id,
|
||||
product_df=product_data,
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
epochs=epochs,
|
||||
model_dir=self.model_dir,
|
||||
version=version,
|
||||
socketio=socketio,
|
||||
task_id=task_id,
|
||||
continue_training=continue_training
|
||||
)
|
||||
log_message(f"✅ {model_type} 训练器返回: metrics={type(metrics)}, version={actual_version}", 'success')
|
||||
elif model_type == 'mlstm':
|
||||
_, metrics, _, _ = train_product_model_with_mlstm(
|
||||
product_id=product_id,
|
||||
product_df=product_data,
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
epochs=epochs,
|
||||
model_dir=self.model_dir,
|
||||
socketio=socketio,
|
||||
task_id=task_id,
|
||||
progress_callback=progress_callback
|
||||
)
|
||||
elif model_type == 'kan':
|
||||
_, metrics = train_product_model_with_kan(
|
||||
product_id=product_id,
|
||||
product_df=product_data,
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
epochs=epochs,
|
||||
use_optimized=use_optimized,
|
||||
model_dir=self.model_dir
|
||||
)
|
||||
elif model_type == 'optimized_kan':
|
||||
_, metrics = train_product_model_with_kan(
|
||||
product_id=product_id,
|
||||
product_df=product_data,
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
epochs=epochs,
|
||||
use_optimized=True,
|
||||
model_dir=self.model_dir
|
||||
)
|
||||
elif model_type == 'tcn':
|
||||
_, metrics, _, _ = train_product_model_with_tcn(
|
||||
product_id=product_id,
|
||||
product_df=product_data,
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
epochs=epochs,
|
||||
model_dir=self.model_dir,
|
||||
socketio=socketio,
|
||||
task_id=task_id
|
||||
)
|
||||
from models.model_registry import get_trainer
|
||||
log_message(f"🤖 正在从注册表获取 '{model_type}' 训练器...")
|
||||
trainer_function = get_trainer(model_type)
|
||||
log_message(f"✅ 成功获取训练器: {trainer_function.__name__}")
|
||||
|
||||
# 准备通用参数
|
||||
trainer_args = {
|
||||
'product_id': product_id,
|
||||
'model_identifier': model_identifier,
|
||||
'product_df': product_data,
|
||||
'store_id': store_id,
|
||||
'training_mode': training_mode,
|
||||
'aggregation_method': aggregation_method,
|
||||
'epochs': epochs,
|
||||
'sequence_length': sequence_length,
|
||||
'forecast_horizon': forecast_horizon,
|
||||
'model_dir': self.model_dir,
|
||||
'socketio': socketio,
|
||||
'task_id': task_id,
|
||||
'progress_callback': progress_callback,
|
||||
'version': version,
|
||||
'continue_training': continue_training,
|
||||
'use_optimized': use_optimized # KAN模型需要
|
||||
}
|
||||
|
||||
# 动态调用训练函数 (v2 - 智能参数过滤)
|
||||
import inspect
|
||||
sig = inspect.signature(trainer_function)
|
||||
valid_args = {k: v for k, v in trainer_args.items() if k in sig.parameters}
|
||||
|
||||
log_message(f"🔍 准备调用 {trainer_function.__name__},有效参数: {list(valid_args.keys())}")
|
||||
|
||||
result = trainer_function(**valid_args)
|
||||
|
||||
# 根据返回值的数量解析metrics
|
||||
if isinstance(result, tuple) and len(result) >= 2:
|
||||
metrics = result[1] # 通常第二个返回值是metrics
|
||||
else:
|
||||
log_message(f"不支持的模型类型: {model_type}", 'error')
|
||||
return None
|
||||
log_message(f"⚠️ 训练器返回格式未知,无法直接提取metrics: {type(result)}", 'warning')
|
||||
metrics = None
|
||||
|
||||
|
||||
# 检查和打印返回的metrics
|
||||
log_message(f"📊 训练完成,检查返回的metrics: {metrics}")
|
||||
@ -296,21 +274,24 @@ class PharmacyPredictor:
|
||||
返回:
|
||||
预测结果和分析(如果analyze_result为True)
|
||||
"""
|
||||
# 根据训练模式构建模型标识符
|
||||
# 根据训练模式构建模型标识符 (v2 修正)
|
||||
if training_mode == 'store' and store_id:
|
||||
model_identifier = f"{store_id}_{product_id}"
|
||||
model_identifier = f"store_{store_id}"
|
||||
elif training_mode == 'global':
|
||||
model_identifier = f"global_{product_id}_{aggregation_method}"
|
||||
else:
|
||||
# 全局模型的标识符不应依赖于单个product_id
|
||||
model_identifier = f"global_{aggregation_method}"
|
||||
else: # product mode
|
||||
model_identifier = product_id
|
||||
|
||||
return load_model_and_predict(
|
||||
model_identifier,
|
||||
model_type,
|
||||
future_days=future_days,
|
||||
start_date=start_date,
|
||||
model_identifier,
|
||||
model_type,
|
||||
store_id=store_id,
|
||||
future_days=future_days,
|
||||
start_date=start_date,
|
||||
analyze_result=analyze_result,
|
||||
version=version
|
||||
version=version,
|
||||
training_mode=training_mode
|
||||
)
|
||||
|
||||
def train_optimized_kan_model(self, product_id, epochs=100, batch_size=32,
|
||||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
102
server/models/cnn_bilstm_attention.py
Normal file
102
server/models/cnn_bilstm_attention.py
Normal file
@ -0,0 +1,102 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
CNN-BiLSTM-Attention 模型定义,适配药店销售预测系统。
|
||||
原始代码来源: python机器学习回归全家桶
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# 注意:由于原始代码使用了 TensorFlow/Keras 的层,我们将在这里创建一个 PyTorch 的等效实现。
|
||||
# 这是一个更健壮、更符合现有系统架构的做法。
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""
|
||||
PyTorch 实现的注意力机制。
|
||||
"""
|
||||
def __init__(self, feature_dim, step_dim, bias=True, **kwargs):
|
||||
super(Attention, self).__init__(**kwargs)
|
||||
|
||||
self.supports_masking = True
|
||||
self.bias = bias
|
||||
self.feature_dim = feature_dim
|
||||
self.step_dim = step_dim
|
||||
self.features_dim = 0
|
||||
|
||||
weight = torch.zeros(feature_dim, 1)
|
||||
nn.init.xavier_uniform_(weight)
|
||||
self.weight = nn.Parameter(weight)
|
||||
|
||||
if bias:
|
||||
self.b = nn.Parameter(torch.zeros(step_dim))
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
feature_dim = self.feature_dim
|
||||
step_dim = self.step_dim
|
||||
|
||||
eij = torch.mm(
|
||||
x.contiguous().view(-1, feature_dim),
|
||||
self.weight
|
||||
).view(-1, step_dim)
|
||||
|
||||
if self.bias:
|
||||
eij = eij + self.b
|
||||
|
||||
eij = torch.tanh(eij)
|
||||
a = torch.exp(eij)
|
||||
|
||||
if mask is not None:
|
||||
a = a * mask
|
||||
|
||||
a = a / (torch.sum(a, 1, keepdim=True) + 1e-10)
|
||||
|
||||
weighted_input = x * torch.unsqueeze(a, -1)
|
||||
return torch.sum(weighted_input, 1)
|
||||
|
||||
|
||||
class CnnBiLstmAttention(nn.Module):
|
||||
"""
|
||||
CNN-BiLSTM-Attention 模型的 PyTorch 实现。
|
||||
"""
|
||||
def __init__(self, input_dim, output_dim, sequence_length, cnn_filters=64, cnn_kernel_size=1, lstm_units=128):
|
||||
super(CnnBiLstmAttention, self).__init__()
|
||||
self.sequence_length = sequence_length
|
||||
self.cnn_filters = cnn_filters
|
||||
self.lstm_units = lstm_units
|
||||
|
||||
# CNN 层
|
||||
self.conv1d = nn.Conv1d(in_channels=input_dim, out_channels=cnn_filters, kernel_size=cnn_kernel_size)
|
||||
self.relu = nn.ReLU()
|
||||
self.maxpool = nn.MaxPool1d(kernel_size=1)
|
||||
|
||||
# BiLSTM 层
|
||||
self.bilstm = nn.LSTM(input_size=cnn_filters, hidden_size=lstm_units, num_layers=1, batch_first=True, bidirectional=True)
|
||||
|
||||
# Attention 层
|
||||
self.attention = Attention(feature_dim=lstm_units * 2, step_dim=sequence_length)
|
||||
|
||||
# 全连接输出层
|
||||
self.dense = nn.Linear(lstm_units * 2, output_dim)
|
||||
|
||||
def forward(self, x):
|
||||
# 输入 x 的形状: (batch_size, sequence_length, input_dim)
|
||||
|
||||
# CNN 处理
|
||||
x = x.permute(0, 2, 1) # 转换为 (batch_size, input_dim, sequence_length) 以适应 Conv1d
|
||||
x = self.conv1d(x)
|
||||
x = self.relu(x)
|
||||
x = x.permute(0, 2, 1) # 转换回 (batch_size, sequence_length, cnn_filters)
|
||||
|
||||
# BiLSTM 处理
|
||||
lstm_out, _ = self.bilstm(x) # lstm_out 形状: (batch_size, sequence_length, lstm_units * 2)
|
||||
|
||||
# Attention 处理
|
||||
# 注意:这里的 Attention 实现可能需要根据具体任务微调
|
||||
# 一个简化的方法是直接使用 LSTM 的最终隐藏状态或输出
|
||||
# 这里我们先用一个简化的逻辑:直接展平 LSTM 输出
|
||||
attention_out = self.attention(lstm_out)
|
||||
|
||||
# 全连接层输出
|
||||
output = self.dense(attention_out)
|
||||
|
||||
return output
|
64
server/models/model_registry.py
Normal file
64
server/models/model_registry.py
Normal file
@ -0,0 +1,64 @@
|
||||
"""
|
||||
模型注册表
|
||||
用于解耦模型的调用和实现,支持插件式扩展新模型。
|
||||
"""
|
||||
|
||||
# 训练器注册表
|
||||
TRAINER_REGISTRY = {}
|
||||
|
||||
def register_trainer(name, func):
|
||||
"""
|
||||
注册一个模型训练器。
|
||||
|
||||
参数:
|
||||
name (str): 模型类型名称 (e.g., 'xgboost')
|
||||
func (function): 对应的训练函数
|
||||
"""
|
||||
if name in TRAINER_REGISTRY:
|
||||
print(f"警告: 模型训练器 '{name}' 已被覆盖注册。")
|
||||
TRAINER_REGISTRY[name] = func
|
||||
print(f"✅ 已注册训练器: {name}")
|
||||
|
||||
def get_trainer(name):
|
||||
"""
|
||||
根据模型类型名称获取一个已注册的训练器。
|
||||
"""
|
||||
if name not in TRAINER_REGISTRY:
|
||||
# 在打印可用训练器之前,确保它们已经被加载
|
||||
from trainers import discover_trainers
|
||||
discover_trainers()
|
||||
if name not in TRAINER_REGISTRY:
|
||||
raise ValueError(f"未注册的模型训练器: '{name}'. 可用: {list(TRAINER_REGISTRY.keys())}")
|
||||
return TRAINER_REGISTRY[name]
|
||||
|
||||
# --- 预测器注册表 ---
|
||||
|
||||
# 预测器函数需要一个统一的接口,例如:
|
||||
# def predictor_function(model, checkpoint, **kwargs): -> predictions
|
||||
|
||||
PREDICTOR_REGISTRY = {}
|
||||
|
||||
def register_predictor(name, func):
|
||||
"""
|
||||
注册一个模型预测器。
|
||||
"""
|
||||
if name in PREDICTOR_REGISTRY:
|
||||
print(f"警告: 模型预测器 '{name}' 已被覆盖注册。")
|
||||
PREDICTOR_REGISTRY[name] = func
|
||||
|
||||
def get_predictor(name):
|
||||
"""
|
||||
根据模型类型名称获取一个已注册的预测器。
|
||||
如果找不到特定预测器,可以返回一个默认的。
|
||||
"""
|
||||
return PREDICTOR_REGISTRY.get(name, PREDICTOR_REGISTRY.get('default'))
|
||||
|
||||
# 默认的PyTorch预测逻辑可以被注册为 'default'
|
||||
def register_default_predictors():
|
||||
from predictors.model_predictor import default_pytorch_predictor
|
||||
register_predictor('default', default_pytorch_predictor)
|
||||
# 如果其他PyTorch模型有特殊预测逻辑,也可以在这里注册
|
||||
# register_predictor('kan', kan_predictor_func)
|
||||
|
||||
# 注意:这个函数的调用时机很重要,需要在应用启动时执行一次。
|
||||
# 我们可以暂时在 model_predictor.py 导入注册表后调用它。
|
Binary file not shown.
@ -10,6 +10,7 @@ from datetime import datetime, timedelta
|
||||
import matplotlib.pyplot as plt
|
||||
from sklearn.preprocessing import MinMaxScaler
|
||||
import sklearn.preprocessing._data # 添加这一行以支持MinMaxScaler的反序列化
|
||||
from typing import Optional
|
||||
|
||||
from models.transformer_model import TimeSeriesTransformer
|
||||
from models.slstm_model import sLSTM as ScalarLSTM
|
||||
@ -17,365 +18,174 @@ from models.mlstm_model import MLSTMTransformer as MatrixLSTM
|
||||
from models.kan_model import KANForecaster
|
||||
from models.tcn_model import TCNForecaster
|
||||
from models.optimized_kan_forecaster import OptimizedKANForecaster
|
||||
from models.cnn_bilstm_attention import CnnBiLstmAttention
|
||||
import xgboost as xgb
|
||||
|
||||
from analysis.trend_analysis import analyze_prediction_result
|
||||
from utils.visualization import plot_prediction_results
|
||||
from utils.multi_store_data_utils import get_store_product_sales_data, aggregate_multi_store_data
|
||||
from core.config import DEVICE, get_model_file_path
|
||||
from core.config import DEVICE, get_model_file_path, DEFAULT_DATA_PATH
|
||||
from models.model_registry import get_predictor, register_predictor
|
||||
|
||||
def load_model_and_predict(product_id, model_type, store_id=None, future_days=7, start_date=None, analyze_result=False, version=None):
|
||||
def default_pytorch_predictor(model, checkpoint, product_df, future_days, start_date, history_lookback_days):
|
||||
"""
|
||||
加载已训练的模型并进行预测
|
||||
|
||||
参数:
|
||||
product_id: 产品ID
|
||||
model_type: 模型类型 ('transformer', 'mlstm', 'kan', 'tcn', 'optimized_kan')
|
||||
store_id: 店铺ID,为None时使用全局模型
|
||||
future_days: 预测未来天数
|
||||
start_date: 预测起始日期,如果为None则使用最后一个已知日期
|
||||
analyze_result: 是否分析预测结果
|
||||
version: 模型版本,如果为None则使用最新版本
|
||||
|
||||
返回:
|
||||
预测结果和分析(如果analyze_result为True)
|
||||
默认的PyTorch模型预测逻辑,支持自动回归。
|
||||
"""
|
||||
try:
|
||||
# 确定模型文件路径(支持多店铺)
|
||||
model_path = None
|
||||
|
||||
if version:
|
||||
# 使用版本管理系统获取正确的文件路径
|
||||
model_path = get_model_file_path(product_id, model_type, version)
|
||||
config = checkpoint['config']
|
||||
scaler_X = checkpoint['scaler_X']
|
||||
scaler_y = checkpoint['scaler_y']
|
||||
features = config.get('features', ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'])
|
||||
sequence_length = config['sequence_length']
|
||||
|
||||
if start_date:
|
||||
start_date_dt = pd.to_datetime(start_date)
|
||||
prediction_input_df = product_df[product_df['date'] < start_date_dt].tail(sequence_length)
|
||||
else:
|
||||
prediction_input_df = product_df.tail(sequence_length)
|
||||
start_date_dt = product_df['date'].iloc[-1] + timedelta(days=1)
|
||||
|
||||
if len(prediction_input_df) < sequence_length:
|
||||
raise ValueError(f"预测所需的历史数据不足。需要 {sequence_length} 天, 但只有 {len(prediction_input_df)} 天。")
|
||||
|
||||
history_for_chart_df = product_df[product_df['date'] < start_date_dt].tail(history_lookback_days)
|
||||
|
||||
all_predictions = []
|
||||
current_sequence_df = prediction_input_df.copy()
|
||||
|
||||
for _ in range(future_days):
|
||||
X_current_scaled = scaler_X.transform(current_sequence_df[features].values)
|
||||
# **核心改进**: 智能判断模型类型并调用相应的预测方法
|
||||
if isinstance(model, xgb.Booster):
|
||||
# XGBoost 模型预测路径
|
||||
X_input_reshaped = X_current_scaled.reshape(1, -1)
|
||||
d_input = xgb.DMatrix(X_input_reshaped)
|
||||
# **关键修复**: 使用 best_iteration 进行预测,以匹配早停策略
|
||||
y_pred_scaled = model.predict(d_input, iteration_range=(0, model.best_iteration))
|
||||
next_step_pred_scaled = y_pred_scaled.reshape(1, -1)
|
||||
else:
|
||||
# 根据store_id确定搜索目录
|
||||
if store_id:
|
||||
# 查找特定店铺的模型
|
||||
possible_dirs = [
|
||||
os.path.join('saved_models', model_type, store_id),
|
||||
os.path.join('models', model_type, store_id)
|
||||
]
|
||||
else:
|
||||
# 查找全局模型
|
||||
possible_dirs = [
|
||||
os.path.join('saved_models', model_type, 'global'),
|
||||
os.path.join('models', model_type, 'global'),
|
||||
os.path.join('saved_models', model_type), # 后向兼容
|
||||
'saved_models' # 最基本的目录
|
||||
]
|
||||
|
||||
# 文件名模式
|
||||
model_suffix = '_optimized' if model_type == 'optimized_kan' else ''
|
||||
file_model_type = 'kan' if model_type == 'optimized_kan' else model_type
|
||||
|
||||
possible_names = [
|
||||
f"{product_id}_{model_type}_v1_model.pt", # 新多店铺格式
|
||||
f"{product_id}_{model_type}_v1_global_model.pt", # 全局模型格式
|
||||
f"{product_id}_{model_type}_v1.pth", # 旧版本格式
|
||||
f"{file_model_type}{model_suffix}_model_product_{product_id}.pth", # 原始格式
|
||||
f"{model_type}_model_product_{product_id}.pth" # 简化格式
|
||||
]
|
||||
|
||||
# 搜索模型文件
|
||||
for dir_path in possible_dirs:
|
||||
if not os.path.exists(dir_path):
|
||||
continue
|
||||
for name in possible_names:
|
||||
test_path = os.path.join(dir_path, name)
|
||||
if os.path.exists(test_path):
|
||||
model_path = test_path
|
||||
break
|
||||
if model_path:
|
||||
break
|
||||
|
||||
if not model_path:
|
||||
scope_msg = f"店铺 {store_id}" if store_id else "全局"
|
||||
print(f"找不到产品 {product_id} 的 {model_type} 模型文件 ({scope_msg})")
|
||||
print(f"搜索目录: {possible_dirs}")
|
||||
return None
|
||||
|
||||
print(f"尝试加载模型文件: {model_path}")
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
print(f"模型文件 {model_path} 不存在")
|
||||
return None
|
||||
|
||||
# 加载销售数据(支持多店铺)
|
||||
try:
|
||||
if store_id:
|
||||
# 加载特定店铺的数据
|
||||
product_df = get_store_product_sales_data(
|
||||
store_id,
|
||||
product_id,
|
||||
'pharmacy_sales_multi_store.csv'
|
||||
)
|
||||
store_name = product_df['store_name'].iloc[0] if 'store_name' in product_df.columns else f"店铺{store_id}"
|
||||
prediction_scope = f"店铺 '{store_name}' ({store_id})"
|
||||
else:
|
||||
# 聚合所有店铺的数据进行预测
|
||||
product_df = aggregate_multi_store_data(
|
||||
product_id,
|
||||
aggregation_method='sum',
|
||||
file_path='pharmacy_sales_multi_store.csv'
|
||||
)
|
||||
prediction_scope = "全部店铺(聚合数据)"
|
||||
except Exception as e:
|
||||
print(f"多店铺数据加载失败,尝试使用原始数据格式: {e}")
|
||||
# 后向兼容:尝试加载原始数据格式
|
||||
try:
|
||||
df = pd.read_excel('pharmacy_sales.xlsx')
|
||||
product_df = df[df['product_id'] == product_id].sort_values('date')
|
||||
if store_id:
|
||||
print(f"警告:原始数据不支持店铺过滤,将使用所有数据预测")
|
||||
prediction_scope = "默认数据"
|
||||
except Exception as e2:
|
||||
print(f"加载产品数据失败: {str(e2)}")
|
||||
return None
|
||||
|
||||
if product_df.empty:
|
||||
print(f"产品 {product_id} 没有销售数据")
|
||||
return None
|
||||
|
||||
product_name = product_df['product_name'].iloc[0]
|
||||
print(f"使用 {model_type} 模型预测产品 '{product_name}' (ID: {product_id}) 的未来 {future_days} 天销量")
|
||||
print(f"预测范围: {prediction_scope}")
|
||||
|
||||
# 添加安全的全局变量以支持MinMaxScaler的反序列化
|
||||
try:
|
||||
torch.serialization.add_safe_globals([sklearn.preprocessing._data.MinMaxScaler])
|
||||
except Exception as e:
|
||||
print(f"添加安全全局变量失败,但这可能不影响模型加载: {str(e)}")
|
||||
|
||||
# 加载模型和配置
|
||||
try:
|
||||
# 首先尝试使用weights_only=False加载
|
||||
try:
|
||||
print("尝试使用 weights_only=False 加载模型")
|
||||
checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False)
|
||||
except Exception as e:
|
||||
print(f"使用weights_only=False加载失败: {str(e)}")
|
||||
print("尝试使用默认参数加载模型")
|
||||
checkpoint = torch.load(model_path, map_location=DEVICE)
|
||||
|
||||
print(f"模型加载成功,检查checkpoint类型: {type(checkpoint)}")
|
||||
if isinstance(checkpoint, dict):
|
||||
print(f"checkpoint包含的键: {list(checkpoint.keys())}")
|
||||
else:
|
||||
print(f"checkpoint不是字典类型,而是: {type(checkpoint)}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"加载模型失败: {str(e)}")
|
||||
return None
|
||||
|
||||
# 检查并获取配置
|
||||
if 'config' not in checkpoint:
|
||||
print("模型文件中没有配置信息")
|
||||
return None
|
||||
|
||||
config = checkpoint['config']
|
||||
print(f"模型配置: {config}")
|
||||
|
||||
# 检查并获取缩放器
|
||||
if 'scaler_X' not in checkpoint or 'scaler_y' not in checkpoint:
|
||||
print("模型文件中没有缩放器信息")
|
||||
return None
|
||||
|
||||
scaler_X = checkpoint['scaler_X']
|
||||
scaler_y = checkpoint['scaler_y']
|
||||
|
||||
# 创建模型实例
|
||||
try:
|
||||
if model_type == 'transformer':
|
||||
model = TimeSeriesTransformer(
|
||||
num_features=config['input_dim'],
|
||||
d_model=config['hidden_size'],
|
||||
nhead=config['num_heads'],
|
||||
num_encoder_layers=config['num_layers'],
|
||||
dim_feedforward=config['hidden_size'] * 2,
|
||||
dropout=config['dropout'],
|
||||
output_sequence_length=config['output_dim'],
|
||||
seq_length=config['sequence_length'],
|
||||
batch_size=32
|
||||
).to(DEVICE)
|
||||
elif model_type == 'slstm':
|
||||
model = ScalarLSTM(
|
||||
input_dim=config['input_dim'],
|
||||
hidden_dim=config['hidden_size'],
|
||||
output_dim=config['output_dim'],
|
||||
num_layers=config['num_layers'],
|
||||
dropout=config['dropout']
|
||||
).to(DEVICE)
|
||||
elif model_type == 'mlstm':
|
||||
# 获取配置参数,如果不存在则使用默认值
|
||||
embed_dim = config.get('embed_dim', 32)
|
||||
dense_dim = config.get('dense_dim', 32)
|
||||
num_heads = config.get('num_heads', 4)
|
||||
num_blocks = config.get('num_blocks', 3)
|
||||
|
||||
model = MatrixLSTM(
|
||||
num_features=config['input_dim'],
|
||||
hidden_size=config['hidden_size'],
|
||||
mlstm_layers=config['num_layers'],
|
||||
embed_dim=embed_dim,
|
||||
dense_dim=dense_dim,
|
||||
num_heads=num_heads,
|
||||
dropout_rate=config['dropout'],
|
||||
num_blocks=num_blocks,
|
||||
output_sequence_length=config['output_dim']
|
||||
).to(DEVICE)
|
||||
elif model_type == 'kan':
|
||||
model = KANForecaster(
|
||||
input_features=config['input_dim'],
|
||||
hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']],
|
||||
output_sequence_length=config['output_dim']
|
||||
).to(DEVICE)
|
||||
elif model_type == 'optimized_kan':
|
||||
model = OptimizedKANForecaster(
|
||||
input_features=config['input_dim'],
|
||||
hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']],
|
||||
output_sequence_length=config['output_dim']
|
||||
).to(DEVICE)
|
||||
elif model_type == 'tcn':
|
||||
model = TCNForecaster(
|
||||
num_features=config['input_dim'],
|
||||
output_sequence_length=config['output_dim'],
|
||||
num_channels=[config['hidden_size']] * config['num_layers'],
|
||||
kernel_size=3,
|
||||
dropout=config['dropout']
|
||||
).to(DEVICE)
|
||||
else:
|
||||
print(f"不支持的模型类型: {model_type}")
|
||||
return None
|
||||
|
||||
print(f"模型实例创建成功: {type(model)}")
|
||||
except Exception as e:
|
||||
print(f"创建模型实例失败: {str(e)}")
|
||||
return None
|
||||
|
||||
# 加载模型参数
|
||||
try:
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
model.eval()
|
||||
print("模型参数加载成功")
|
||||
except Exception as e:
|
||||
print(f"加载模型参数失败: {str(e)}")
|
||||
return None
|
||||
|
||||
# 准备输入数据
|
||||
try:
|
||||
features = ['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||
sequence_length = config['sequence_length']
|
||||
|
||||
# 获取最近的sequence_length天数据作为输入
|
||||
recent_data = product_df.iloc[-sequence_length:].copy()
|
||||
|
||||
# 如果指定了起始日期,则使用该日期之后的数据
|
||||
if start_date:
|
||||
if isinstance(start_date, str):
|
||||
start_date = datetime.strptime(start_date, '%Y-%m-%d')
|
||||
recent_data = product_df[product_df['date'] >= start_date].iloc[:sequence_length].copy()
|
||||
if len(recent_data) < sequence_length:
|
||||
print(f"警告: 从指定日期 {start_date} 开始的数据少于所需的 {sequence_length} 天")
|
||||
# 补充数据
|
||||
missing_days = sequence_length - len(recent_data)
|
||||
additional_data = product_df[product_df['date'] < start_date].iloc[-missing_days:].copy()
|
||||
recent_data = pd.concat([additional_data, recent_data]).reset_index(drop=True)
|
||||
|
||||
print(f"输入数据准备完成,形状: {recent_data.shape}")
|
||||
except Exception as e:
|
||||
print(f"准备输入数据失败: {str(e)}")
|
||||
return None
|
||||
|
||||
# 归一化输入数据
|
||||
try:
|
||||
X = recent_data[features].values
|
||||
X_scaled = scaler_X.transform(X)
|
||||
|
||||
# 转换为模型输入格式
|
||||
X_input = torch.tensor(X_scaled.reshape(1, sequence_length, -1), dtype=torch.float32).to(DEVICE)
|
||||
print(f"输入张量准备完成,形状: {X_input.shape}")
|
||||
except Exception as e:
|
||||
print(f"归一化输入数据失败: {str(e)}")
|
||||
return None
|
||||
|
||||
# 预测
|
||||
try:
|
||||
# 默认 PyTorch 模型预测路径
|
||||
X_input = torch.tensor(X_current_scaled.reshape(1, sequence_length, -1), dtype=torch.float32).to(DEVICE)
|
||||
with torch.no_grad():
|
||||
y_pred_scaled = model(X_input).cpu().numpy()
|
||||
print(f"原始预测输出形状: {y_pred_scaled.shape}")
|
||||
|
||||
# 处理TCN、Transformer、mLSTM和KAN模型的输出,确保形状正确
|
||||
if model_type in ['tcn', 'transformer', 'mlstm', 'kan', 'optimized_kan'] and len(y_pred_scaled.shape) == 3:
|
||||
y_pred_scaled = y_pred_scaled.squeeze(-1)
|
||||
print(f"处理后的预测输出形状: {y_pred_scaled.shape}")
|
||||
|
||||
# 反归一化预测结果
|
||||
y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).flatten()
|
||||
print(f"反归一化后的预测结果: {y_pred}")
|
||||
|
||||
# 生成预测日期
|
||||
last_date = recent_data['date'].iloc[-1]
|
||||
pred_dates = [(last_date + timedelta(days=i+1)) for i in range(len(y_pred))]
|
||||
print(f"预测日期: {pred_dates}")
|
||||
except Exception as e:
|
||||
print(f"执行预测失败: {str(e)}")
|
||||
return None
|
||||
|
||||
# 创建预测结果DataFrame
|
||||
next_step_pred_scaled = y_pred_scaled[0, 0].reshape(1, -1)
|
||||
next_step_pred_unscaled = float(max(0, scaler_y.inverse_transform(next_step_pred_scaled)[0][0]))
|
||||
|
||||
next_date = current_sequence_df['date'].iloc[-1] + timedelta(days=1)
|
||||
all_predictions.append({'date': next_date, 'predicted_sales': next_step_pred_unscaled})
|
||||
|
||||
new_row = {'date': next_date, 'sales': next_step_pred_unscaled, 'weekday': next_date.weekday(), 'month': next_date.month, 'is_holiday': 0, 'is_weekend': 1 if next_date.weekday() >= 5 else 0, 'is_promotion': 0, 'temperature': current_sequence_df['temperature'].iloc[-1]}
|
||||
new_row_df = pd.DataFrame([new_row])
|
||||
current_sequence_df = pd.concat([current_sequence_df.iloc[1:], new_row_df], ignore_index=True)
|
||||
|
||||
predictions_df = pd.DataFrame(all_predictions)
|
||||
return predictions_df, history_for_chart_df, prediction_input_df
|
||||
|
||||
# 注册默认的PyTorch预测器
|
||||
register_predictor('default', default_pytorch_predictor)
|
||||
# 将增强后的默认预测器也注册给xgboost
|
||||
register_predictor('xgboost', default_pytorch_predictor)
|
||||
# 将新模型也注册给默认预测器
|
||||
register_predictor('cnn_bilstm_attention', default_pytorch_predictor)
|
||||
|
||||
|
||||
def load_model_and_predict(model_path: str, product_id: str, model_type: str, store_id: Optional[str] = None, future_days: int = 7, start_date: Optional[str] = None, analyze_result: bool = False, version: Optional[str] = None, training_mode: str = 'product', history_lookback_days: int = 30):
|
||||
"""
|
||||
加载已训练的模型并进行预测 (v4版 - 插件式架构)
|
||||
"""
|
||||
try:
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(f"模型文件 {model_path} 不存在")
|
||||
|
||||
# --- 数据加载部分保持不变 ---
|
||||
from utils.multi_store_data_utils import aggregate_multi_store_data
|
||||
if training_mode == 'store' and store_id:
|
||||
from utils.multi_store_data_utils import load_multi_store_data
|
||||
store_df_for_name = load_multi_store_data(store_id=store_id)
|
||||
product_name = store_df_for_name['store_name'].iloc[0] if not store_df_for_name.empty else f"店铺 {store_id}"
|
||||
product_df = aggregate_multi_store_data(store_id=store_id, aggregation_method='sum', file_path=DEFAULT_DATA_PATH)
|
||||
elif training_mode == 'global':
|
||||
product_df = aggregate_multi_store_data(aggregation_method='sum', file_path=DEFAULT_DATA_PATH)
|
||||
product_name = "全局销售数据"
|
||||
else:
|
||||
product_df = aggregate_multi_store_data(product_id=product_id, aggregation_method='sum', file_path=DEFAULT_DATA_PATH)
|
||||
product_name = product_df['product_name'].iloc[0] if not product_df.empty else product_id
|
||||
|
||||
if product_df.empty:
|
||||
raise ValueError(f"产品 {product_id} 或店铺 {store_id} 没有销售数据")
|
||||
|
||||
# --- 模型加载与实例化 (重构) ---
|
||||
try:
|
||||
predictions_df = pd.DataFrame({
|
||||
'date': pred_dates,
|
||||
'sales': y_pred # 使用sales字段名而不是predicted_sales,以便与历史数据兼容
|
||||
})
|
||||
print(f"预测结果DataFrame创建成功,形状: {predictions_df.shape}")
|
||||
except Exception as e:
|
||||
print(f"创建预测结果DataFrame失败: {str(e)}")
|
||||
return None
|
||||
torch.serialization.add_safe_globals([sklearn.preprocessing._data.MinMaxScaler])
|
||||
except Exception: pass
|
||||
|
||||
# 绘制预测结果
|
||||
try:
|
||||
plt.figure(figsize=(12, 6))
|
||||
plt.plot(product_df['date'], product_df['sales'], 'b-', label='历史销量')
|
||||
plt.plot(predictions_df['date'], predictions_df['sales'], 'r--', label='预测销量')
|
||||
plt.title(f'{product_name} - {model_type}模型销量预测')
|
||||
plt.xlabel('日期')
|
||||
plt.ylabel('销量')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
plt.xticks(rotation=45)
|
||||
plt.tight_layout()
|
||||
checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False)
|
||||
config = checkpoint.get('config', {})
|
||||
loaded_model_type = config.get('model_type', model_type) # 优先使用模型内保存的类型
|
||||
|
||||
# 根据模型类型决定如何获取模型实例
|
||||
if loaded_model_type == 'xgboost':
|
||||
# 对于XGBoost, 模型对象直接保存在'model_state_dict'键中
|
||||
model = checkpoint['model_state_dict']
|
||||
else:
|
||||
# 对于PyTorch模型, 需要重新构建实例并加载state_dict
|
||||
if loaded_model_type == 'transformer':
|
||||
model = TimeSeriesTransformer(num_features=config['input_dim'], d_model=config['hidden_size'], nhead=config['num_heads'], num_encoder_layers=config['num_layers'], dim_feedforward=config['hidden_size'] * 2, dropout=config['dropout'], output_sequence_length=config['output_dim'], seq_length=config['sequence_length'], batch_size=32).to(DEVICE)
|
||||
elif loaded_model_type == 'mlstm':
|
||||
model = MatrixLSTM(num_features=config['input_dim'], hidden_size=config['hidden_size'], mlstm_layers=config['mlstm_layers'], embed_dim=config.get('embed_dim', 32), dense_dim=config.get('dense_dim', 32), num_heads=config.get('num_heads', 4), dropout_rate=config['dropout_rate'], num_blocks=config.get('num_blocks', 3), output_sequence_length=config['output_dim']).to(DEVICE)
|
||||
elif loaded_model_type == 'kan':
|
||||
model = KANForecaster(input_features=config['input_dim'], hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']], output_sequence_length=config['output_dim']).to(DEVICE)
|
||||
elif loaded_model_type == 'optimized_kan':
|
||||
model = OptimizedKANForecaster(input_features=config['input_dim'], hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']], output_sequence_length=config['output_dim']).to(DEVICE)
|
||||
elif loaded_model_type == 'tcn':
|
||||
model = TCNForecaster(num_features=config['input_dim'], output_sequence_length=config['output_dim'], num_channels=[config['hidden_size']] * config['num_layers'], kernel_size=config['kernel_size'], dropout=config['dropout']).to(DEVICE)
|
||||
elif loaded_model_type == 'cnn_bilstm_attention':
|
||||
model = CnnBiLstmAttention(
|
||||
input_dim=config['input_dim'],
|
||||
output_dim=config['output_dim'],
|
||||
sequence_length=config['sequence_length']
|
||||
).to(DEVICE)
|
||||
else:
|
||||
raise ValueError(f"不支持的模型类型: {loaded_model_type}")
|
||||
|
||||
# 保存图像
|
||||
plt.savefig(f'{product_id}_{model_type}_prediction.png')
|
||||
plt.close()
|
||||
|
||||
print(f"预测结果已保存到 {product_id}_{model_type}_prediction.png")
|
||||
except Exception as e:
|
||||
print(f"绘制预测结果图表失败: {str(e)}")
|
||||
# 这个错误不影响主要功能,继续执行
|
||||
|
||||
# 分析预测结果
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
model.eval()
|
||||
|
||||
# --- 动态调用预测器 ---
|
||||
predictor_function = get_predictor(loaded_model_type)
|
||||
if not predictor_function:
|
||||
raise ValueError(f"找不到模型类型 '{loaded_model_type}' 的预测器实现")
|
||||
|
||||
predictions_df, history_for_chart_df, prediction_input_df = predictor_function(
|
||||
model=model,
|
||||
checkpoint=checkpoint,
|
||||
product_df=product_df,
|
||||
future_days=future_days,
|
||||
start_date=start_date,
|
||||
history_lookback_days=history_lookback_days
|
||||
)
|
||||
|
||||
# --- 分析与返回部分保持不变 ---
|
||||
analysis = None
|
||||
if analyze_result:
|
||||
try:
|
||||
analysis = analyze_prediction_result(product_id, model_type, y_pred, X)
|
||||
print("\n预测结果分析:")
|
||||
if analysis and 'explanation' in analysis:
|
||||
print(analysis['explanation'])
|
||||
else:
|
||||
print("分析结果不包含explanation字段")
|
||||
analysis = analyze_prediction_result(product_id, loaded_model_type, predictions_df['predicted_sales'].values, prediction_input_df[config.get('features')].values)
|
||||
except Exception as e:
|
||||
print(f"分析预测结果失败: {str(e)}")
|
||||
# 分析失败不影响主要功能,继续执行
|
||||
|
||||
|
||||
history_data_json = history_for_chart_df.to_dict('records') if not history_for_chart_df.empty else []
|
||||
prediction_data_json = predictions_df.to_dict('records') if not predictions_df.empty else []
|
||||
|
||||
return {
|
||||
'product_id': product_id,
|
||||
'product_name': product_name,
|
||||
'model_type': model_type,
|
||||
'predictions': predictions_df,
|
||||
'model_type': loaded_model_type,
|
||||
'predictions': prediction_data_json,
|
||||
'prediction_data': prediction_data_json,
|
||||
'history_data': history_data_json,
|
||||
'analysis': analysis
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"预测过程中出现未捕获的异常: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return None
|
||||
return None
|
@ -2,18 +2,44 @@
|
||||
药店销售预测系统 - 模型训练模块
|
||||
"""
|
||||
|
||||
from .mlstm_trainer import train_product_model_with_mlstm
|
||||
from .kan_trainer import train_product_model_with_kan
|
||||
from .tcn_trainer import train_product_model_with_tcn
|
||||
from .transformer_trainer import train_product_model_with_transformer
|
||||
import os
|
||||
import glob
|
||||
import importlib
|
||||
|
||||
# 默认训练函数
|
||||
from .mlstm_trainer import train_product_model_with_mlstm as train_product_model
|
||||
_TRAINERS_LOADED = False
|
||||
|
||||
def discover_trainers():
|
||||
"""
|
||||
自动发现并加载所有训练器插件。
|
||||
使用一个标志位确保这个过程只执行一次。
|
||||
"""
|
||||
global _TRAINERS_LOADED
|
||||
if _TRAINERS_LOADED:
|
||||
return
|
||||
|
||||
print("🚀 开始发现并加载训练器插件...")
|
||||
|
||||
package_dir = os.path.dirname(__file__)
|
||||
module_name = __name__
|
||||
|
||||
trainer_files = glob.glob(os.path.join(package_dir, "*_trainer.py"))
|
||||
|
||||
for f in trainer_files:
|
||||
base_name = os.path.basename(f)
|
||||
if base_name.startswith('__'):
|
||||
continue
|
||||
|
||||
module_stem = base_name.replace('.py', '')
|
||||
|
||||
try:
|
||||
# 动态导入模块以触发自注册
|
||||
importlib.import_module(f".{module_stem}", package=module_name)
|
||||
except ImportError as e:
|
||||
print(f"⚠️ 加载训练器 {module_stem} 失败: {e}")
|
||||
|
||||
_TRAINERS_LOADED = True
|
||||
print("✅ 所有训练器插件加载完成。")
|
||||
|
||||
# 在包被首次导入时,自动执行发现过程
|
||||
discover_trainers()
|
||||
|
||||
__all__ = [
|
||||
'train_product_model',
|
||||
'train_product_model_with_mlstm',
|
||||
'train_product_model_with_kan',
|
||||
'train_product_model_with_tcn',
|
||||
'train_product_model_with_transformer'
|
||||
]
|
||||
|
118
server/trainers/cnn_bilstm_attention_trainer.py
Normal file
118
server/trainers/cnn_bilstm_attention_trainer.py
Normal file
@ -0,0 +1,118 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
CNN-BiLSTM-Attention 模型训练器
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
import numpy as np
|
||||
|
||||
from models.model_registry import register_trainer
|
||||
from utils.model_manager import model_manager
|
||||
from analysis.metrics import evaluate_model
|
||||
from utils.data_utils import create_dataset
|
||||
from sklearn.preprocessing import MinMaxScaler
|
||||
|
||||
# 导入新创建的模型
|
||||
from models.cnn_bilstm_attention import CnnBiLstmAttention
|
||||
|
||||
def train_with_cnn_bilstm_attention(product_id, model_identifier, product_df, store_id, training_mode, aggregation_method, epochs, sequence_length, forecast_horizon, model_dir, **kwargs):
|
||||
"""
|
||||
使用 CNN-BiLSTM-Attention 模型进行训练。
|
||||
函数签名遵循系统标准。
|
||||
"""
|
||||
print(f"🚀 CNN-BiLSTM-Attention 训练器启动: model_identifier='{model_identifier}'")
|
||||
|
||||
# --- 1. 数据准备 ---
|
||||
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||
|
||||
X = product_df[features].values
|
||||
y = product_df[['sales']].values
|
||||
|
||||
scaler_X = MinMaxScaler(feature_range=(0, 1))
|
||||
scaler_y = MinMaxScaler(feature_range=(0, 1))
|
||||
|
||||
X_scaled = scaler_X.fit_transform(X)
|
||||
y_scaled = scaler_y.fit_transform(y)
|
||||
|
||||
train_size = int(len(X_scaled) * 0.8)
|
||||
X_train_raw, X_test_raw = X_scaled[:train_size], X_scaled[train_size:]
|
||||
y_train_raw, y_test_raw = y_scaled[:train_size], y_scaled[train_size:]
|
||||
|
||||
trainX, trainY = create_dataset(X_train_raw, y_train_raw, sequence_length, forecast_horizon)
|
||||
testX, testY = create_dataset(X_test_raw, y_test_raw, sequence_length, forecast_horizon)
|
||||
|
||||
# 转换为 PyTorch Tensors
|
||||
trainX = torch.from_numpy(trainX).float()
|
||||
trainY = torch.from_numpy(trainY).float()
|
||||
testX = torch.from_numpy(testX).float()
|
||||
testY = torch.from_numpy(testY).float()
|
||||
|
||||
# --- 2. 实例化模型和优化器 ---
|
||||
input_dim = trainX.shape[2]
|
||||
|
||||
model = CnnBiLstmAttention(
|
||||
input_dim=input_dim,
|
||||
output_dim=forecast_horizon,
|
||||
sequence_length=sequence_length
|
||||
)
|
||||
|
||||
optimizer = optim.Adam(model.parameters(), lr=kwargs.get('learning_rate', 0.001))
|
||||
criterion = torch.nn.MSELoss()
|
||||
|
||||
# --- 3. 训练循环 ---
|
||||
print("开始训练 CNN-BiLSTM-Attention 模型...")
|
||||
for epoch in range(epochs):
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
|
||||
outputs = model(trainX)
|
||||
loss = criterion(outputs, trainY.squeeze(-1)) # 确保目标维度匹配
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if (epoch + 1) % 10 == 0:
|
||||
print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')
|
||||
|
||||
# --- 4. 模型评估 ---
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
test_pred_scaled = model(testX)
|
||||
|
||||
test_pred_unscaled = scaler_y.inverse_transform(test_pred_scaled.numpy())
|
||||
test_true_unscaled = scaler_y.inverse_transform(testY.squeeze(-1).numpy())
|
||||
|
||||
metrics = evaluate_model(test_true_unscaled.flatten(), test_pred_unscaled.flatten())
|
||||
print(f"模型评估完成: RMSE={metrics['rmse']:.4f}")
|
||||
|
||||
# --- 5. 模型保存 ---
|
||||
model_data = {
|
||||
'model_state_dict': model.state_dict(),
|
||||
'scaler_X': scaler_X,
|
||||
'scaler_y': scaler_y,
|
||||
'config': {
|
||||
'model_type': 'cnn_bilstm_attention',
|
||||
'input_dim': input_dim,
|
||||
'output_dim': forecast_horizon,
|
||||
'sequence_length': sequence_length,
|
||||
'features': features
|
||||
},
|
||||
'metrics': metrics
|
||||
}
|
||||
|
||||
final_model_path, final_version = model_manager.save_model(
|
||||
model_data=model_data,
|
||||
product_id=product_id,
|
||||
model_type='cnn_bilstm_attention',
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
product_name=product_df['product_name'].iloc[0]
|
||||
)
|
||||
|
||||
print(f"✅ CNN-BiLSTM-Attention 模型已保存,版本: {final_version}")
|
||||
return model, metrics, final_version, final_model_path
|
||||
|
||||
# --- 关键步骤: 将训练器注册到系统中 ---
|
||||
register_trainer('cnn_bilstm_attention', train_with_cnn_bilstm_attention)
|
@ -21,7 +21,7 @@ from utils.visualization import plot_loss_curve
|
||||
from analysis.metrics import evaluate_model
|
||||
from core.config import DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON
|
||||
|
||||
def train_product_model_with_kan(product_id, product_df=None, store_id=None, training_mode='product', aggregation_method='sum', epochs=50, use_optimized=False, model_dir=DEFAULT_MODEL_DIR):
|
||||
def train_product_model_with_kan(product_id, model_identifier, product_df=None, store_id=None, training_mode='product', aggregation_method='sum', epochs=50, sequence_length=LOOK_BACK, forecast_horizon=FORECAST_HORIZON, use_optimized=False, model_dir=DEFAULT_MODEL_DIR):
|
||||
"""
|
||||
使用KAN模型训练产品销售预测模型
|
||||
|
||||
@ -79,11 +79,11 @@ def train_product_model_with_kan(product_id, product_df=None, store_id=None, tra
|
||||
raise ValueError(f"产品 {product_id} 没有可用的销售数据")
|
||||
|
||||
# 数据量检查
|
||||
min_required_samples = LOOK_BACK + FORECAST_HORIZON
|
||||
min_required_samples = sequence_length + forecast_horizon
|
||||
if len(product_df) < min_required_samples:
|
||||
error_msg = (
|
||||
f"❌ 训练数据不足错误\n"
|
||||
f"当前配置需要: {min_required_samples} 天数据 (LOOK_BACK={LOOK_BACK} + FORECAST_HORIZON={FORECAST_HORIZON})\n"
|
||||
f"当前配置需要: {min_required_samples} 天数据 (LOOK_BACK={sequence_length} + FORECAST_HORIZON={forecast_horizon})\n"
|
||||
f"实际数据量: {len(product_df)} 天\n"
|
||||
f"产品ID: {product_id}, 训练模式: {training_mode}\n"
|
||||
f"建议解决方案:\n"
|
||||
@ -123,8 +123,8 @@ def train_product_model_with_kan(product_id, product_df=None, store_id=None, tra
|
||||
y_train, y_test = y_scaled[:train_size], y_scaled[train_size:]
|
||||
|
||||
# 创建时间序列数据
|
||||
trainX, trainY = create_dataset(X_train, y_train, LOOK_BACK, FORECAST_HORIZON)
|
||||
testX, testY = create_dataset(X_test, y_test, LOOK_BACK, FORECAST_HORIZON)
|
||||
trainX, trainY = create_dataset(X_train, y_train, sequence_length, forecast_horizon)
|
||||
testX, testY = create_dataset(X_test, y_test, sequence_length, forecast_horizon)
|
||||
|
||||
# 转换为PyTorch的Tensor
|
||||
trainX_tensor = torch.Tensor(trainX)
|
||||
@ -142,7 +142,7 @@ def train_product_model_with_kan(product_id, product_df=None, store_id=None, tra
|
||||
|
||||
# 初始化KAN模型
|
||||
input_dim = X_train.shape[1]
|
||||
output_dim = FORECAST_HORIZON
|
||||
output_dim = forecast_horizon
|
||||
hidden_size = 64
|
||||
|
||||
if use_optimized:
|
||||
@ -168,6 +168,7 @@ def train_product_model_with_kan(product_id, product_df=None, store_id=None, tra
|
||||
train_losses = []
|
||||
test_losses = []
|
||||
start_time = time.time()
|
||||
best_loss = float('inf')
|
||||
|
||||
for epoch in range(epochs):
|
||||
model.train()
|
||||
@ -225,6 +226,43 @@ def train_product_model_with_kan(product_id, product_df=None, store_id=None, tra
|
||||
|
||||
test_loss = test_loss / len(test_loader)
|
||||
test_losses.append(test_loss)
|
||||
|
||||
# 检查是否为最佳模型
|
||||
model_type_name = 'optimized_kan' if use_optimized else 'kan'
|
||||
if test_loss < best_loss:
|
||||
best_loss = test_loss
|
||||
print(f"🎉 新的最佳模型发现在 epoch {epoch+1},测试损失: {test_loss:.4f}")
|
||||
|
||||
# 为保存最佳模型准备数据
|
||||
best_model_data = {
|
||||
'model_state_dict': model.state_dict(),
|
||||
'scaler_X': scaler_X,
|
||||
'scaler_y': scaler_y,
|
||||
'config': {
|
||||
'input_dim': input_dim,
|
||||
'output_dim': output_dim,
|
||||
'hidden_size': hidden_size,
|
||||
'hidden_sizes': [hidden_size, hidden_size * 2, hidden_size],
|
||||
'sequence_length': sequence_length,
|
||||
'forecast_horizon': forecast_horizon,
|
||||
'model_type': model_type_name,
|
||||
'use_optimized': use_optimized
|
||||
},
|
||||
'epoch': epoch + 1
|
||||
}
|
||||
|
||||
# 使用模型管理器保存 'best' 版本
|
||||
from utils.model_manager import model_manager
|
||||
model_manager.save_model(
|
||||
model_data=best_model_data,
|
||||
product_id=model_identifier, # 修正:使用唯一的标识符
|
||||
model_type=model_type_name,
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
product_name=product_name,
|
||||
version='best' # 显式覆盖版本为'best'
|
||||
)
|
||||
|
||||
if (epoch + 1) % 10 == 0:
|
||||
print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")
|
||||
@ -282,9 +320,9 @@ def train_product_model_with_kan(product_id, product_df=None, store_id=None, tra
|
||||
'input_dim': input_dim,
|
||||
'output_dim': output_dim,
|
||||
'hidden_size': hidden_size,
|
||||
'hidden_sizes': [hidden_size, hidden_size*2, hidden_size],
|
||||
'sequence_length': LOOK_BACK,
|
||||
'forecast_horizon': FORECAST_HORIZON,
|
||||
'hidden_sizes': [hidden_size, hidden_size * 2, hidden_size],
|
||||
'sequence_length': sequence_length,
|
||||
'forecast_horizon': forecast_horizon,
|
||||
'model_type': model_type_name,
|
||||
'use_optimized': use_optimized
|
||||
},
|
||||
@ -297,15 +335,23 @@ def train_product_model_with_kan(product_id, product_df=None, store_id=None, tra
|
||||
'loss_curve_path': loss_curve_path
|
||||
}
|
||||
|
||||
model_path = model_manager.save_model(
|
||||
# 保存最终模型,让 model_manager 自动处理版本号
|
||||
final_model_path, final_version = model_manager.save_model(
|
||||
model_data=model_data,
|
||||
product_id=product_id,
|
||||
product_id=model_identifier, # 修正:使用唯一的标识符
|
||||
model_type=model_type_name,
|
||||
version='v1', # KAN训练器默认使用v1
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
product_name=product_name
|
||||
# 注意:此处不传递version参数,由管理器自动生成
|
||||
)
|
||||
|
||||
return model, metrics
|
||||
print(f"最终模型已保存,版本: {final_version}, 路径: {final_model_path}")
|
||||
|
||||
return model, metrics
|
||||
|
||||
# --- 将此训练器注册到系统中 ---
|
||||
from models.model_registry import register_trainer
|
||||
register_trainer('kan', train_product_model_with_kan)
|
||||
register_trainer('optimized_kan', train_product_model_with_kan)
|
@ -20,97 +20,21 @@ from utils.multi_store_data_utils import get_store_product_sales_data, aggregate
|
||||
from utils.visualization import plot_loss_curve
|
||||
from analysis.metrics import evaluate_model
|
||||
from core.config import (
|
||||
DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON,
|
||||
get_next_model_version, get_model_file_path, get_latest_model_version
|
||||
DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON
|
||||
)
|
||||
from utils.training_progress import progress_manager
|
||||
|
||||
def save_checkpoint(checkpoint_data: dict, epoch_or_label, product_id: str,
|
||||
model_type: str, model_dir: str, store_id=None,
|
||||
training_mode: str = 'product', aggregation_method=None):
|
||||
"""
|
||||
保存训练检查点
|
||||
|
||||
Args:
|
||||
checkpoint_data: 检查点数据
|
||||
epoch_or_label: epoch编号或标签(如'best')
|
||||
product_id: 产品ID
|
||||
model_type: 模型类型
|
||||
model_dir: 模型保存目录
|
||||
store_id: 店铺ID
|
||||
training_mode: 训练模式
|
||||
aggregation_method: 聚合方法
|
||||
"""
|
||||
# 创建检查点目录
|
||||
checkpoint_dir = os.path.join(model_dir, 'checkpoints')
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
# 生成检查点文件名
|
||||
if training_mode == 'store' and store_id:
|
||||
filename = f"{model_type}_store_{store_id}_{product_id}_epoch_{epoch_or_label}.pth"
|
||||
elif training_mode == 'global' and aggregation_method:
|
||||
filename = f"{model_type}_global_{product_id}_{aggregation_method}_epoch_{epoch_or_label}.pth"
|
||||
else:
|
||||
filename = f"{model_type}_product_{product_id}_epoch_{epoch_or_label}.pth"
|
||||
|
||||
checkpoint_path = os.path.join(checkpoint_dir, filename)
|
||||
|
||||
# 保存检查点
|
||||
torch.save(checkpoint_data, checkpoint_path)
|
||||
print(f"[mLSTM] 检查点已保存: {checkpoint_path}", flush=True)
|
||||
|
||||
return checkpoint_path
|
||||
|
||||
|
||||
def load_checkpoint(product_id: str, model_type: str, epoch_or_label,
|
||||
model_dir: str, store_id=None, training_mode: str = 'product',
|
||||
aggregation_method=None):
|
||||
"""
|
||||
加载训练检查点
|
||||
|
||||
Args:
|
||||
product_id: 产品ID
|
||||
model_type: 模型类型
|
||||
epoch_or_label: epoch编号或标签
|
||||
model_dir: 模型保存目录
|
||||
store_id: 店铺ID
|
||||
training_mode: 训练模式
|
||||
aggregation_method: 聚合方法
|
||||
|
||||
Returns:
|
||||
checkpoint_data: 检查点数据,如果未找到返回None
|
||||
"""
|
||||
checkpoint_dir = os.path.join(model_dir, 'checkpoints')
|
||||
|
||||
# 生成检查点文件名
|
||||
if training_mode == 'store' and store_id:
|
||||
filename = f"{model_type}_store_{store_id}_{product_id}_epoch_{epoch_or_label}.pth"
|
||||
elif training_mode == 'global' and aggregation_method:
|
||||
filename = f"{model_type}_global_{product_id}_{aggregation_method}_epoch_{epoch_or_label}.pth"
|
||||
else:
|
||||
filename = f"{model_type}_product_{product_id}_epoch_{epoch_or_label}.pth"
|
||||
|
||||
checkpoint_path = os.path.join(checkpoint_dir, filename)
|
||||
|
||||
if os.path.exists(checkpoint_path):
|
||||
try:
|
||||
checkpoint_data = torch.load(checkpoint_path, map_location=DEVICE)
|
||||
print(f"[mLSTM] 检查点已加载: {checkpoint_path}", flush=True)
|
||||
return checkpoint_data
|
||||
except Exception as e:
|
||||
print(f"[mLSTM] 加载检查点失败: {e}", flush=True)
|
||||
return None
|
||||
else:
|
||||
print(f"[mLSTM] 检查点文件不存在: {checkpoint_path}", flush=True)
|
||||
return None
|
||||
from utils.model_manager import model_manager
|
||||
|
||||
def train_product_model_with_mlstm(
|
||||
product_id,
|
||||
model_identifier,
|
||||
product_df,
|
||||
store_id=None,
|
||||
training_mode='product',
|
||||
aggregation_method='sum',
|
||||
epochs=50,
|
||||
sequence_length=LOOK_BACK,
|
||||
forecast_horizon=FORECAST_HORIZON,
|
||||
model_dir=DEFAULT_MODEL_DIR,
|
||||
version=None,
|
||||
socketio=None,
|
||||
@ -174,15 +98,9 @@ def train_product_model_with_mlstm(
|
||||
emit_progress("开始mLSTM模型训练...")
|
||||
|
||||
# 确定版本号
|
||||
if version is None:
|
||||
if continue_training:
|
||||
version = get_latest_model_version(product_id, 'mlstm')
|
||||
if version is None:
|
||||
version = get_next_model_version(product_id, 'mlstm')
|
||||
else:
|
||||
version = get_next_model_version(product_id, 'mlstm')
|
||||
|
||||
emit_progress(f"开始训练 mLSTM 模型版本 {version}")
|
||||
emit_progress(f"开始训练 mLSTM 模型")
|
||||
if version:
|
||||
emit_progress(f"使用指定版本: {version}")
|
||||
|
||||
# 初始化训练进度管理器(如果还未初始化)
|
||||
if socketio and task_id:
|
||||
@ -215,11 +133,11 @@ def train_product_model_with_mlstm(
|
||||
training_scope = "所有店铺"
|
||||
|
||||
# 数据量检查
|
||||
min_required_samples = LOOK_BACK + FORECAST_HORIZON
|
||||
min_required_samples = sequence_length + forecast_horizon
|
||||
if len(product_df) < min_required_samples:
|
||||
error_msg = (
|
||||
f"❌ 训练数据不足错误\n"
|
||||
f"当前配置需要: {min_required_samples} 天数据 (LOOK_BACK={LOOK_BACK} + FORECAST_HORIZON={FORECAST_HORIZON})\n"
|
||||
f"当前配置需要: {min_required_samples} 天数据 (LOOK_BACK={sequence_length} + FORECAST_HORIZON={forecast_horizon})\n"
|
||||
f"实际数据量: {len(product_df)} 天\n"
|
||||
f"产品ID: {product_id}, 训练模式: {training_mode}\n"
|
||||
f"建议解决方案:\n"
|
||||
@ -235,7 +153,7 @@ def train_product_model_with_mlstm(
|
||||
|
||||
print(f"[mLSTM] 使用mLSTM模型训练产品 '{product_name}' (ID: {product_id}) 的销售预测模型", flush=True)
|
||||
print(f"[mLSTM] 训练范围: {training_scope}", flush=True)
|
||||
print(f"[mLSTM] 版本: {version}", flush=True)
|
||||
# print(f"[mLSTM] 版本: {version}", flush=True) # Version is now handled by model_manager
|
||||
print(f"[mLSTM] 使用设备: {DEVICE}", flush=True)
|
||||
print(f"[mLSTM] 模型将保存到目录: {model_dir}", flush=True)
|
||||
print(f"[mLSTM] 数据量: {len(product_df)} 条记录", flush=True)
|
||||
@ -269,8 +187,8 @@ def train_product_model_with_mlstm(
|
||||
y_train, y_test = y_scaled[:train_size], y_scaled[train_size:]
|
||||
|
||||
# 创建时间序列数据
|
||||
trainX, trainY = create_dataset(X_train, y_train, LOOK_BACK, FORECAST_HORIZON)
|
||||
testX, testY = create_dataset(X_test, y_test, LOOK_BACK, FORECAST_HORIZON)
|
||||
trainX, trainY = create_dataset(X_train, y_train, sequence_length, forecast_horizon)
|
||||
testX, testY = create_dataset(X_test, y_test, sequence_length, forecast_horizon)
|
||||
|
||||
# 转换为PyTorch的Tensor
|
||||
trainX_tensor = torch.Tensor(trainX)
|
||||
@ -295,7 +213,7 @@ def train_product_model_with_mlstm(
|
||||
|
||||
# 初始化mLSTM结合Transformer模型
|
||||
input_dim = X_train.shape[1]
|
||||
output_dim = FORECAST_HORIZON
|
||||
output_dim = forecast_horizon
|
||||
hidden_size = 128
|
||||
num_heads = 4
|
||||
dropout_rate = 0.1
|
||||
@ -324,23 +242,15 @@ def train_product_model_with_mlstm(
|
||||
|
||||
# 如果是继续训练,加载现有模型
|
||||
if continue_training and version != 'v1':
|
||||
try:
|
||||
existing_model_path = get_model_file_path(product_id, 'mlstm', version)
|
||||
if os.path.exists(existing_model_path):
|
||||
checkpoint = torch.load(existing_model_path, map_location=DEVICE)
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
print(f"加载现有模型: {existing_model_path}")
|
||||
emit_progress(f"加载现有模型版本 {version} 进行继续训练")
|
||||
except Exception as e:
|
||||
print(f"无法加载现有模型,将重新开始训练: {e}")
|
||||
emit_progress("无法加载现有模型,重新开始训练")
|
||||
# TODO: Implement continue_training logic with the new model_manager
|
||||
pass
|
||||
|
||||
# 将模型移动到设备上
|
||||
model = model.to(DEVICE)
|
||||
|
||||
criterion = nn.MSELoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=patience // 2, factor=0.5, verbose=True)
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=patience // 2, factor=0.5)
|
||||
|
||||
emit_progress("数据预处理完成,开始模型训练...", progress=10)
|
||||
|
||||
@ -432,12 +342,13 @@ def train_product_model_with_mlstm(
|
||||
'output_dim': output_dim,
|
||||
'hidden_size': hidden_size,
|
||||
'num_heads': num_heads,
|
||||
'dropout': dropout_rate,
|
||||
'dropout_rate': dropout_rate,
|
||||
'num_blocks': num_blocks,
|
||||
'embed_dim': embed_dim,
|
||||
'dense_dim': dense_dim,
|
||||
'sequence_length': LOOK_BACK,
|
||||
'forecast_horizon': FORECAST_HORIZON,
|
||||
'mlstm_layers': 2, # 确保这个参数被保存
|
||||
'sequence_length': sequence_length,
|
||||
'forecast_horizon': forecast_horizon,
|
||||
'model_type': 'mlstm'
|
||||
},
|
||||
'training_info': {
|
||||
@ -451,21 +362,23 @@ def train_product_model_with_mlstm(
|
||||
}
|
||||
}
|
||||
|
||||
# 保存检查点
|
||||
save_checkpoint(checkpoint_data, epoch + 1, product_id, 'mlstm',
|
||||
model_dir, store_id, training_mode, aggregation_method)
|
||||
|
||||
# 如果是最佳模型,额外保存一份
|
||||
if test_loss < best_loss:
|
||||
best_loss = test_loss
|
||||
save_checkpoint(checkpoint_data, 'best', product_id, 'mlstm',
|
||||
model_dir, store_id, training_mode, aggregation_method)
|
||||
model_manager.save_model(
|
||||
model_data=checkpoint_data,
|
||||
product_id=model_identifier, # 修正:使用唯一的标识符
|
||||
model_type='mlstm',
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
product_name=product_name,
|
||||
version='best'
|
||||
)
|
||||
emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})")
|
||||
epochs_no_improve = 0
|
||||
else:
|
||||
epochs_no_improve += 1
|
||||
|
||||
emit_progress(f"💾 保存训练检查点 epoch_{epoch+1}")
|
||||
|
||||
if (epoch + 1) % 10 == 0:
|
||||
print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}", flush=True)
|
||||
@ -524,7 +437,6 @@ def train_product_model_with_mlstm(
|
||||
# 计算评估指标
|
||||
metrics = evaluate_model(test_true_inv, test_pred_inv)
|
||||
metrics['training_time'] = training_time
|
||||
metrics['version'] = version
|
||||
|
||||
# 打印评估指标
|
||||
print("\n模型评估指标:")
|
||||
@ -553,12 +465,13 @@ def train_product_model_with_mlstm(
|
||||
'output_dim': output_dim,
|
||||
'hidden_size': hidden_size,
|
||||
'num_heads': num_heads,
|
||||
'dropout': dropout_rate,
|
||||
'dropout_rate': dropout_rate,
|
||||
'num_blocks': num_blocks,
|
||||
'embed_dim': embed_dim,
|
||||
'dense_dim': dense_dim,
|
||||
'sequence_length': LOOK_BACK,
|
||||
'forecast_horizon': FORECAST_HORIZON,
|
||||
'mlstm_layers': 2, # 确保这个参数被保存
|
||||
'sequence_length': sequence_length,
|
||||
'forecast_horizon': forecast_horizon,
|
||||
'model_type': 'mlstm'
|
||||
},
|
||||
'metrics': metrics,
|
||||
@ -575,10 +488,15 @@ def train_product_model_with_mlstm(
|
||||
}
|
||||
}
|
||||
|
||||
# 保存最终模型(使用epoch标识)
|
||||
final_model_path = save_checkpoint(
|
||||
final_model_data, f"final_epoch_{epochs}", product_id, 'mlstm',
|
||||
model_dir, store_id, training_mode, aggregation_method
|
||||
# 保存最终模型,让 model_manager 自动处理版本号
|
||||
final_model_path, final_version = model_manager.save_model(
|
||||
model_data=final_model_data,
|
||||
product_id=model_identifier, # 修正:使用唯一的标识符
|
||||
model_type='mlstm',
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
product_name=product_name
|
||||
)
|
||||
|
||||
# 发送训练完成消息
|
||||
@ -590,9 +508,14 @@ def train_product_model_with_mlstm(
|
||||
'mape': metrics['mape'],
|
||||
'training_time': training_time,
|
||||
'final_epoch': epochs,
|
||||
'model_path': final_model_path
|
||||
'model_path': final_model_path,
|
||||
'version': final_version
|
||||
}
|
||||
|
||||
emit_progress(f"✅ mLSTM模型训练完成!最终epoch: {epochs} 已保存", progress=100, metrics=final_metrics)
|
||||
emit_progress(f"✅ mLSTM模型训练完成!版本 {final_version} 已保存", progress=100, metrics=final_metrics)
|
||||
|
||||
return model, metrics, epochs, final_model_path
|
||||
return model, metrics, epochs, final_model_path
|
||||
|
||||
# --- 将此训练器注册到系统中 ---
|
||||
from models.model_registry import register_trainer
|
||||
register_trainer('mlstm', train_product_model_with_mlstm)
|
@ -20,50 +20,18 @@ from utils.visualization import plot_loss_curve
|
||||
from analysis.metrics import evaluate_model
|
||||
from core.config import DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON
|
||||
from utils.training_progress import progress_manager
|
||||
|
||||
def save_checkpoint(checkpoint_data: dict, epoch_or_label, product_id: str,
|
||||
model_type: str, model_dir: str, store_id=None,
|
||||
training_mode: str = 'product', aggregation_method=None):
|
||||
"""
|
||||
保存训练检查点
|
||||
|
||||
Args:
|
||||
checkpoint_data: 检查点数据
|
||||
epoch_or_label: epoch编号或标签(如'best')
|
||||
product_id: 产品ID
|
||||
model_type: 模型类型
|
||||
model_dir: 模型保存目录
|
||||
store_id: 店铺ID
|
||||
training_mode: 训练模式
|
||||
aggregation_method: 聚合方法
|
||||
"""
|
||||
# 创建检查点目录
|
||||
checkpoint_dir = os.path.join(model_dir, 'checkpoints')
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
# 生成检查点文件名
|
||||
if training_mode == 'store' and store_id:
|
||||
filename = f"{model_type}_store_{store_id}_{product_id}_epoch_{epoch_or_label}.pth"
|
||||
elif training_mode == 'global' and aggregation_method:
|
||||
filename = f"{model_type}_global_{product_id}_{aggregation_method}_epoch_{epoch_or_label}.pth"
|
||||
else:
|
||||
filename = f"{model_type}_product_{product_id}_epoch_{epoch_or_label}.pth"
|
||||
|
||||
checkpoint_path = os.path.join(checkpoint_dir, filename)
|
||||
|
||||
# 保存检查点
|
||||
torch.save(checkpoint_data, checkpoint_path)
|
||||
print(f"[TCN] 检查点已保存: {checkpoint_path}", flush=True)
|
||||
|
||||
return checkpoint_path
|
||||
from utils.model_manager import model_manager
|
||||
|
||||
def train_product_model_with_tcn(
|
||||
product_id,
|
||||
model_identifier,
|
||||
product_df=None,
|
||||
store_id=None,
|
||||
training_mode='product',
|
||||
aggregation_method='sum',
|
||||
epochs=50,
|
||||
sequence_length=LOOK_BACK,
|
||||
forecast_horizon=FORECAST_HORIZON,
|
||||
model_dir=DEFAULT_MODEL_DIR,
|
||||
version=None,
|
||||
socketio=None,
|
||||
@ -72,21 +40,6 @@ def train_product_model_with_tcn(
|
||||
):
|
||||
"""
|
||||
使用TCN模型训练产品销售预测模型
|
||||
|
||||
参数:
|
||||
product_id: 产品ID
|
||||
epochs: 训练轮次
|
||||
model_dir: 模型保存目录,默认使用配置中的DEFAULT_MODEL_DIR
|
||||
version: 指定版本号,如果为None则自动生成
|
||||
socketio: WebSocket对象,用于实时反馈
|
||||
task_id: 训练任务ID
|
||||
continue_training: 是否继续训练现有模型
|
||||
|
||||
返回:
|
||||
model: 训练好的模型
|
||||
metrics: 模型评估指标
|
||||
version: 实际使用的版本号
|
||||
model_path: 模型文件路径
|
||||
"""
|
||||
|
||||
def emit_progress(message, progress=None, metrics=None):
|
||||
@ -103,73 +56,28 @@ def train_product_model_with_tcn(
|
||||
data['metrics'] = metrics
|
||||
socketio.emit('training_progress', data, namespace='/training')
|
||||
|
||||
# 确定版本号
|
||||
if version is None:
|
||||
from core.config import get_latest_model_version, get_next_model_version
|
||||
if continue_training:
|
||||
version = get_latest_model_version(product_id, 'tcn')
|
||||
if version is None:
|
||||
version = get_next_model_version(product_id, 'tcn')
|
||||
else:
|
||||
version = get_next_model_version(product_id, 'tcn')
|
||||
emit_progress(f"开始训练 TCN 模型")
|
||||
|
||||
emit_progress(f"开始训练 TCN 模型版本 {version}")
|
||||
|
||||
# 如果没有传入product_df,则根据训练模式加载数据
|
||||
if product_df is None:
|
||||
from utils.multi_store_data_utils import load_multi_store_data, get_store_product_sales_data, aggregate_multi_store_data
|
||||
|
||||
try:
|
||||
if training_mode == 'store' and store_id:
|
||||
# 加载特定店铺的数据
|
||||
product_df = get_store_product_sales_data(
|
||||
store_id,
|
||||
product_id,
|
||||
'pharmacy_sales_multi_store.csv'
|
||||
)
|
||||
training_scope = f"店铺 {store_id}"
|
||||
elif training_mode == 'global':
|
||||
# 聚合所有店铺的数据
|
||||
product_df = aggregate_multi_store_data(
|
||||
product_id,
|
||||
aggregation_method=aggregation_method,
|
||||
file_path='pharmacy_sales_multi_store.csv'
|
||||
)
|
||||
training_scope = f"全局聚合({aggregation_method})"
|
||||
else:
|
||||
# 默认:加载所有店铺的产品数据
|
||||
product_df = load_multi_store_data('pharmacy_sales_multi_store.csv', product_id=product_id)
|
||||
training_scope = "所有店铺"
|
||||
except Exception as e:
|
||||
print(f"多店铺数据加载失败: {e}")
|
||||
# 后备方案:尝试原始数据
|
||||
df = pd.read_excel('pharmacy_sales.xlsx')
|
||||
product_df = df[df['product_id'] == product_id].sort_values('date')
|
||||
training_scope = "原始数据"
|
||||
from utils.multi_store_data_utils import aggregate_multi_store_data
|
||||
product_df = aggregate_multi_store_data(
|
||||
product_id=product_id,
|
||||
aggregation_method=aggregation_method
|
||||
)
|
||||
training_scope = f"全局聚合({aggregation_method})"
|
||||
else:
|
||||
# 如果传入了product_df,直接使用
|
||||
if training_mode == 'store' and store_id:
|
||||
training_scope = f"店铺 {store_id}"
|
||||
elif training_mode == 'global':
|
||||
training_scope = f"全局聚合({aggregation_method})"
|
||||
else:
|
||||
training_scope = "所有店铺"
|
||||
training_scope = "所有店铺"
|
||||
|
||||
if product_df.empty:
|
||||
raise ValueError(f"产品 {product_id} 没有可用的销售数据")
|
||||
|
||||
# 数据量检查
|
||||
min_required_samples = LOOK_BACK + FORECAST_HORIZON
|
||||
min_required_samples = sequence_length + forecast_horizon
|
||||
if len(product_df) < min_required_samples:
|
||||
error_msg = (
|
||||
f"❌ 训练数据不足错误\n"
|
||||
f"当前配置需要: {min_required_samples} 天数据 (LOOK_BACK={LOOK_BACK} + FORECAST_HORIZON={FORECAST_HORIZON})\n"
|
||||
f"当前配置需要: {min_required_samples} 天数据 (LOOK_BACK={sequence_length} + FORECAST_HORIZON={forecast_horizon})\n"
|
||||
f"实际数据量: {len(product_df)} 天\n"
|
||||
f"产品ID: {product_id}, 训练模式: {training_mode}\n"
|
||||
f"建议解决方案:\n"
|
||||
f"1. 生成更多数据: uv run generate_multi_store_data.py\n"
|
||||
f"2. 调整配置参数: 减小 LOOK_BACK 或 FORECAST_HORIZON\n"
|
||||
f"3. 使用全局训练模式聚合更多数据"
|
||||
)
|
||||
print(error_msg)
|
||||
emit_progress(f"训练失败:数据不足 ({len(product_df)}/{min_required_samples} 天)")
|
||||
@ -180,48 +88,39 @@ def train_product_model_with_tcn(
|
||||
|
||||
print(f"使用TCN模型训练产品 '{product_name}' (ID: {product_id}) 的销售预测模型")
|
||||
print(f"训练范围: {training_scope}")
|
||||
print(f"版本: {version}")
|
||||
print(f"使用设备: {DEVICE}")
|
||||
print(f"模型将保存到目录: {model_dir}")
|
||||
|
||||
emit_progress(f"训练产品: {product_name} (ID: {product_id})")
|
||||
|
||||
# 创建特征和目标变量
|
||||
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||
|
||||
# 预处理数据
|
||||
X = product_df[features].values
|
||||
y = product_df[['sales']].values # 保持为二维数组
|
||||
y = product_df[['sales']].values
|
||||
|
||||
# 设置数据预处理阶段
|
||||
progress_manager.set_stage("data_preprocessing", 0)
|
||||
emit_progress("数据预处理中...")
|
||||
|
||||
# 归一化数据
|
||||
scaler_X = MinMaxScaler(feature_range=(0, 1))
|
||||
scaler_y = MinMaxScaler(feature_range=(0, 1))
|
||||
|
||||
X_scaled = scaler_X.fit_transform(X)
|
||||
y_scaled = scaler_y.fit_transform(y)
|
||||
|
||||
# 划分训练集和测试集(80% 训练,20% 测试)
|
||||
train_size = int(len(X_scaled) * 0.8)
|
||||
X_train, X_test = X_scaled[:train_size], X_scaled[train_size:]
|
||||
y_train, y_test = y_scaled[:train_size], y_scaled[train_size:]
|
||||
|
||||
progress_manager.set_stage("data_preprocessing", 50)
|
||||
|
||||
# 创建时间序列数据
|
||||
trainX, trainY = create_dataset(X_train, y_train, LOOK_BACK, FORECAST_HORIZON)
|
||||
testX, testY = create_dataset(X_test, y_test, LOOK_BACK, FORECAST_HORIZON)
|
||||
trainX, trainY = create_dataset(X_train, y_train, sequence_length, forecast_horizon)
|
||||
testX, testY = create_dataset(X_test, y_test, sequence_length, forecast_horizon)
|
||||
|
||||
# 转换为PyTorch的Tensor
|
||||
trainX_tensor = torch.Tensor(trainX)
|
||||
trainY_tensor = torch.Tensor(trainY)
|
||||
testX_tensor = torch.Tensor(testX)
|
||||
testY_tensor = torch.Tensor(testY)
|
||||
|
||||
# 创建数据加载器
|
||||
train_dataset = PharmacyDataset(trainX_tensor, trainY_tensor)
|
||||
test_dataset = PharmacyDataset(testX_tensor, testY_tensor)
|
||||
|
||||
@ -229,7 +128,6 @@ def train_product_model_with_tcn(
|
||||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
||||
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
|
||||
|
||||
# 更新进度管理器的批次信息
|
||||
total_batches = len(train_loader)
|
||||
total_samples = len(train_dataset)
|
||||
progress_manager.total_batches_per_epoch = total_batches
|
||||
@ -238,9 +136,8 @@ def train_product_model_with_tcn(
|
||||
|
||||
progress_manager.set_stage("data_preprocessing", 100)
|
||||
|
||||
# 初始化TCN模型
|
||||
input_dim = X_train.shape[1]
|
||||
output_dim = FORECAST_HORIZON
|
||||
output_dim = forecast_horizon
|
||||
hidden_size = 64
|
||||
num_layers = 3
|
||||
kernel_size = 3
|
||||
@ -254,21 +151,8 @@ def train_product_model_with_tcn(
|
||||
dropout=dropout_rate
|
||||
)
|
||||
|
||||
# 如果是继续训练,加载现有模型
|
||||
if continue_training and version != 'v1':
|
||||
try:
|
||||
from core.config import get_model_file_path
|
||||
existing_model_path = get_model_file_path(product_id, 'tcn', version)
|
||||
if os.path.exists(existing_model_path):
|
||||
checkpoint = torch.load(existing_model_path, map_location=DEVICE)
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
print(f"加载现有模型: {existing_model_path}")
|
||||
emit_progress(f"加载现有模型版本 {version} 进行继续训练")
|
||||
except Exception as e:
|
||||
print(f"无法加载现有模型,将重新开始训练: {e}")
|
||||
emit_progress("无法加载现有模型,重新开始训练")
|
||||
# TODO: Implement continue_training logic with the new model_manager
|
||||
|
||||
# 将模型移动到设备上
|
||||
model = model.to(DEVICE)
|
||||
|
||||
criterion = nn.MSELoss()
|
||||
@ -276,20 +160,17 @@ def train_product_model_with_tcn(
|
||||
|
||||
emit_progress("开始模型训练...")
|
||||
|
||||
# 训练模型
|
||||
train_losses = []
|
||||
test_losses = []
|
||||
start_time = time.time()
|
||||
|
||||
# 配置检查点保存
|
||||
checkpoint_interval = max(1, epochs // 10) # 每10%进度保存一次,最少每1个epoch
|
||||
checkpoint_interval = max(1, epochs // 10)
|
||||
best_loss = float('inf')
|
||||
|
||||
progress_manager.set_stage("model_training", 0)
|
||||
emit_progress(f"开始训练 - 总epoch: {epochs}, 检查点间隔: {checkpoint_interval}")
|
||||
|
||||
for epoch in range(epochs):
|
||||
# 开始新的轮次
|
||||
progress_manager.start_epoch(epoch)
|
||||
|
||||
model.train()
|
||||
@ -298,43 +179,34 @@ def train_product_model_with_tcn(
|
||||
for batch_idx, (X_batch, y_batch) in enumerate(train_loader):
|
||||
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
|
||||
|
||||
# 确保目标张量有正确的形状 (batch_size, forecast_horizon, 1)
|
||||
if y_batch.dim() == 2:
|
||||
y_batch = y_batch.unsqueeze(-1)
|
||||
|
||||
# 前向传播
|
||||
outputs = model(X_batch)
|
||||
|
||||
# 确保输出和目标形状匹配
|
||||
loss = criterion(outputs, y_batch)
|
||||
|
||||
# 反向传播和优化
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
epoch_loss += loss.item()
|
||||
|
||||
# 更新批次进度(每10个批次更新一次)
|
||||
if batch_idx % 10 == 0 or batch_idx == len(train_loader) - 1:
|
||||
current_lr = optimizer.param_groups[0]['lr']
|
||||
progress_manager.update_batch(batch_idx, loss.item(), current_lr)
|
||||
|
||||
# 计算训练损失
|
||||
train_loss = epoch_loss / len(train_loader)
|
||||
train_losses.append(train_loss)
|
||||
|
||||
# 设置验证阶段
|
||||
progress_manager.set_stage("validation", 0)
|
||||
|
||||
# 在测试集上评估
|
||||
model.eval()
|
||||
test_loss = 0
|
||||
with torch.no_grad():
|
||||
for batch_idx, (X_batch, y_batch) in enumerate(test_loader):
|
||||
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
|
||||
|
||||
# 确保目标张量有正确的形状
|
||||
if y_batch.dim() == 2:
|
||||
y_batch = y_batch.unsqueeze(-1)
|
||||
|
||||
@ -342,7 +214,6 @@ def train_product_model_with_tcn(
|
||||
loss = criterion(outputs, y_batch)
|
||||
test_loss += loss.item()
|
||||
|
||||
# 更新验证进度
|
||||
if batch_idx % 5 == 0 or batch_idx == len(test_loader) - 1:
|
||||
val_progress = (batch_idx / len(test_loader)) * 100
|
||||
progress_manager.set_stage("validation", val_progress)
|
||||
@ -350,10 +221,8 @@ def train_product_model_with_tcn(
|
||||
test_loss = test_loss / len(test_loader)
|
||||
test_losses.append(test_loss)
|
||||
|
||||
# 完成当前轮次
|
||||
progress_manager.finish_epoch(train_loss, test_loss)
|
||||
|
||||
# 发送训练进度(保持与旧系统的兼容性)
|
||||
if (epoch + 1) % 5 == 0 or epoch == epochs - 1:
|
||||
progress = ((epoch + 1) / epochs) * 100
|
||||
current_metrics = {
|
||||
@ -365,7 +234,6 @@ def train_product_model_with_tcn(
|
||||
emit_progress(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}",
|
||||
progress=progress, metrics=current_metrics)
|
||||
|
||||
# 定期保存检查点
|
||||
if (epoch + 1) % checkpoint_interval == 0 or epoch == epochs - 1:
|
||||
checkpoint_data = {
|
||||
'epoch': epoch + 1,
|
||||
@ -382,10 +250,11 @@ def train_product_model_with_tcn(
|
||||
'output_dim': output_dim,
|
||||
'hidden_size': hidden_size,
|
||||
'num_layers': num_layers,
|
||||
'num_channels': [hidden_size] * num_layers,
|
||||
'dropout': dropout_rate,
|
||||
'kernel_size': kernel_size,
|
||||
'sequence_length': LOOK_BACK,
|
||||
'forecast_horizon': FORECAST_HORIZON,
|
||||
'sequence_length': sequence_length,
|
||||
'forecast_horizon': forecast_horizon,
|
||||
'model_type': 'tcn'
|
||||
},
|
||||
'training_info': {
|
||||
@ -398,30 +267,28 @@ def train_product_model_with_tcn(
|
||||
}
|
||||
}
|
||||
|
||||
# 保存检查点
|
||||
save_checkpoint(checkpoint_data, epoch + 1, product_id, 'tcn',
|
||||
model_dir, store_id, training_mode, aggregation_method)
|
||||
|
||||
# 如果是最佳模型,额外保存一份
|
||||
if test_loss < best_loss:
|
||||
best_loss = test_loss
|
||||
save_checkpoint(checkpoint_data, 'best', product_id, 'tcn',
|
||||
model_dir, store_id, training_mode, aggregation_method)
|
||||
model_manager.save_model(
|
||||
model_data=checkpoint_data,
|
||||
product_id=model_identifier, # 修正:使用唯一的标识符
|
||||
model_type='tcn',
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
product_name=product_name,
|
||||
version='best'
|
||||
)
|
||||
emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})")
|
||||
|
||||
emit_progress(f"💾 保存训练检查点 epoch_{epoch+1}")
|
||||
|
||||
if (epoch + 1) % 10 == 0:
|
||||
print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")
|
||||
|
||||
# 计算训练时间
|
||||
training_time = time.time() - start_time
|
||||
|
||||
# 设置模型保存阶段
|
||||
progress_manager.set_stage("model_saving", 0)
|
||||
emit_progress("训练完成,正在保存模型...")
|
||||
|
||||
# 绘制损失曲线并保存到模型目录
|
||||
loss_curve_path = plot_loss_curve(
|
||||
train_losses,
|
||||
test_losses,
|
||||
@ -431,23 +298,17 @@ def train_product_model_with_tcn(
|
||||
)
|
||||
print(f"损失曲线已保存到: {loss_curve_path}")
|
||||
|
||||
# 评估模型
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
# 确保测试数据的形状正确
|
||||
test_pred = model(testX_tensor.to(DEVICE))
|
||||
# 将输出转换为二维数组 [samples, forecast_horizon]
|
||||
test_pred = test_pred.squeeze(-1).cpu().numpy()
|
||||
|
||||
# 反归一化预测结果和真实值
|
||||
test_pred_inv = scaler_y.inverse_transform(test_pred.reshape(-1, 1)).flatten()
|
||||
test_true_inv = scaler_y.inverse_transform(testY.reshape(-1, 1)).flatten()
|
||||
|
||||
# 计算评估指标
|
||||
metrics = evaluate_model(test_true_inv, test_pred_inv)
|
||||
metrics['training_time'] = training_time
|
||||
|
||||
# 打印评估指标
|
||||
print("\n模型评估指标:")
|
||||
print(f"MSE: {metrics['mse']:.4f}")
|
||||
print(f"RMSE: {metrics['rmse']:.4f}")
|
||||
@ -456,9 +317,8 @@ def train_product_model_with_tcn(
|
||||
print(f"MAPE: {metrics['mape']:.2f}%")
|
||||
print(f"训练时间: {training_time:.2f}秒")
|
||||
|
||||
# 保存最终训练完成的模型(基于最终epoch)
|
||||
final_model_data = {
|
||||
'epoch': epochs, # 最终epoch
|
||||
'epoch': epochs,
|
||||
'model_state_dict': model.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'train_loss': train_losses[-1],
|
||||
@ -472,10 +332,11 @@ def train_product_model_with_tcn(
|
||||
'output_dim': output_dim,
|
||||
'hidden_size': hidden_size,
|
||||
'num_layers': num_layers,
|
||||
'num_channels': [hidden_size] * num_layers,
|
||||
'dropout': dropout_rate,
|
||||
'kernel_size': kernel_size,
|
||||
'sequence_length': LOOK_BACK,
|
||||
'forecast_horizon': FORECAST_HORIZON,
|
||||
'sequence_length': sequence_length,
|
||||
'forecast_horizon': forecast_horizon,
|
||||
'model_type': 'tcn'
|
||||
},
|
||||
'metrics': metrics,
|
||||
@ -493,10 +354,14 @@ def train_product_model_with_tcn(
|
||||
|
||||
progress_manager.set_stage("model_saving", 50)
|
||||
|
||||
# 保存最终模型(使用epoch标识)
|
||||
final_model_path = save_checkpoint(
|
||||
final_model_data, f"final_epoch_{epochs}", product_id, 'tcn',
|
||||
model_dir, store_id, training_mode, aggregation_method
|
||||
final_model_path, final_version = model_manager.save_model(
|
||||
model_data=final_model_data,
|
||||
product_id=model_identifier, # 修正:使用唯一的标识符
|
||||
model_type='tcn',
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
product_name=product_name
|
||||
)
|
||||
|
||||
progress_manager.set_stage("model_saving", 100)
|
||||
@ -508,9 +373,14 @@ def train_product_model_with_tcn(
|
||||
'r2': metrics['r2'],
|
||||
'mape': metrics['mape'],
|
||||
'training_time': training_time,
|
||||
'final_epoch': epochs
|
||||
'final_epoch': epochs,
|
||||
'version': final_version
|
||||
}
|
||||
|
||||
emit_progress(f"模型训练完成!最终epoch: {epochs}", progress=100, metrics=final_metrics)
|
||||
emit_progress(f"模型训练完成!版本 {final_version} 已保存", progress=100, metrics=final_metrics)
|
||||
|
||||
return model, metrics, epochs, final_model_path
|
||||
return model, metrics, epochs, final_model_path
|
||||
|
||||
# --- 将此训练器注册到系统中 ---
|
||||
from models.model_registry import register_trainer
|
||||
register_trainer('tcn', train_product_model_with_tcn)
|
@ -21,55 +21,21 @@ from utils.multi_store_data_utils import get_store_product_sales_data, aggregate
|
||||
from utils.visualization import plot_loss_curve
|
||||
from analysis.metrics import evaluate_model
|
||||
from core.config import (
|
||||
DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON,
|
||||
get_next_model_version, get_model_file_path, get_latest_model_version
|
||||
DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON
|
||||
)
|
||||
from utils.training_progress import progress_manager
|
||||
from utils.model_manager import model_manager
|
||||
|
||||
def save_checkpoint(checkpoint_data: dict, epoch_or_label, product_id: str,
|
||||
model_type: str, model_dir: str, store_id=None,
|
||||
training_mode: str = 'product', aggregation_method=None):
|
||||
"""
|
||||
保存训练检查点
|
||||
|
||||
Args:
|
||||
checkpoint_data: 检查点数据
|
||||
epoch_or_label: epoch编号或标签(如'best')
|
||||
product_id: 产品ID
|
||||
model_type: 模型类型
|
||||
model_dir: 模型保存目录
|
||||
store_id: 店铺ID
|
||||
training_mode: 训练模式
|
||||
aggregation_method: 聚合方法
|
||||
"""
|
||||
# 创建检查点目录
|
||||
checkpoint_dir = os.path.join(model_dir, 'checkpoints')
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
# 生成检查点文件名
|
||||
if training_mode == 'store' and store_id:
|
||||
filename = f"{model_type}_store_{store_id}_{product_id}_epoch_{epoch_or_label}.pth"
|
||||
elif training_mode == 'global' and aggregation_method:
|
||||
filename = f"{model_type}_global_{product_id}_{aggregation_method}_epoch_{epoch_or_label}.pth"
|
||||
else:
|
||||
filename = f"{model_type}_product_{product_id}_epoch_{epoch_or_label}.pth"
|
||||
|
||||
checkpoint_path = os.path.join(checkpoint_dir, filename)
|
||||
|
||||
# 保存检查点
|
||||
torch.save(checkpoint_data, checkpoint_path)
|
||||
print(f"[Transformer] 检查点已保存: {checkpoint_path}", flush=True)
|
||||
|
||||
return checkpoint_path
|
||||
|
||||
def train_product_model_with_transformer(
|
||||
product_id,
|
||||
model_identifier,
|
||||
product_df=None,
|
||||
store_id=None,
|
||||
training_mode='product',
|
||||
aggregation_method='sum',
|
||||
epochs=50,
|
||||
sequence_length=LOOK_BACK,
|
||||
forecast_horizon=FORECAST_HORIZON,
|
||||
model_dir=DEFAULT_MODEL_DIR,
|
||||
version=None,
|
||||
socketio=None,
|
||||
@ -81,23 +47,8 @@ def train_product_model_with_transformer(
|
||||
):
|
||||
"""
|
||||
使用Transformer模型训练产品销售预测模型
|
||||
|
||||
参数:
|
||||
product_id: 产品ID
|
||||
epochs: 训练轮次
|
||||
model_dir: 模型保存目录,默认使用配置中的DEFAULT_MODEL_DIR
|
||||
version: 指定版本号,如果为None则自动生成
|
||||
socketio: WebSocket对象,用于实时反馈
|
||||
task_id: 训练任务ID
|
||||
continue_training: 是否继续训练现有模型
|
||||
|
||||
返回:
|
||||
model: 训练好的模型
|
||||
metrics: 模型评估指标
|
||||
version: 实际使用的版本号
|
||||
"""
|
||||
|
||||
# WebSocket进度反馈函数
|
||||
def emit_progress(message, progress=None, metrics=None):
|
||||
"""发送训练进度到前端"""
|
||||
if socketio and task_id:
|
||||
@ -112,18 +63,15 @@ def train_product_model_with_transformer(
|
||||
data['metrics'] = metrics
|
||||
socketio.emit('training_progress', data, namespace='/training')
|
||||
print(f"[{time.strftime('%H:%M:%S')}] {message}", flush=True)
|
||||
# 强制刷新输出缓冲区
|
||||
import sys
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
|
||||
emit_progress("开始Transformer模型训练...")
|
||||
|
||||
# 获取训练进度管理器实例
|
||||
try:
|
||||
from utils.training_progress import progress_manager
|
||||
except ImportError:
|
||||
# 如果无法导入,创建一个空的管理器以避免错误
|
||||
class DummyProgressManager:
|
||||
def set_stage(self, *args, **kwargs): pass
|
||||
def start_training(self, *args, **kwargs): pass
|
||||
@ -133,61 +81,26 @@ def train_product_model_with_transformer(
|
||||
def finish_training(self, *args, **kwargs): pass
|
||||
progress_manager = DummyProgressManager()
|
||||
|
||||
# 如果没有传入product_df,则根据训练模式加载数据
|
||||
if product_df is None:
|
||||
from utils.multi_store_data_utils import load_multi_store_data, get_store_product_sales_data, aggregate_multi_store_data
|
||||
|
||||
try:
|
||||
if training_mode == 'store' and store_id:
|
||||
# 加载特定店铺的数据
|
||||
product_df = get_store_product_sales_data(
|
||||
store_id,
|
||||
product_id,
|
||||
'pharmacy_sales_multi_store.csv'
|
||||
)
|
||||
training_scope = f"店铺 {store_id}"
|
||||
elif training_mode == 'global':
|
||||
# 聚合所有店铺的数据
|
||||
product_df = aggregate_multi_store_data(
|
||||
product_id,
|
||||
aggregation_method=aggregation_method,
|
||||
file_path='pharmacy_sales_multi_store.csv'
|
||||
)
|
||||
training_scope = f"全局聚合({aggregation_method})"
|
||||
else:
|
||||
# 默认:加载所有店铺的产品数据
|
||||
product_df = load_multi_store_data('pharmacy_sales_multi_store.csv', product_id=product_id)
|
||||
training_scope = "所有店铺"
|
||||
except Exception as e:
|
||||
print(f"多店铺数据加载失败: {e}")
|
||||
# 后备方案:尝试原始数据
|
||||
df = pd.read_excel('pharmacy_sales.xlsx')
|
||||
product_df = df[df['product_id'] == product_id].sort_values('date')
|
||||
training_scope = "原始数据"
|
||||
from utils.multi_store_data_utils import aggregate_multi_store_data
|
||||
product_df = aggregate_multi_store_data(
|
||||
product_id=product_id,
|
||||
aggregation_method=aggregation_method
|
||||
)
|
||||
training_scope = f"全局聚合({aggregation_method})"
|
||||
else:
|
||||
# 如果传入了product_df,直接使用
|
||||
if training_mode == 'store' and store_id:
|
||||
training_scope = f"店铺 {store_id}"
|
||||
elif training_mode == 'global':
|
||||
training_scope = f"全局聚合({aggregation_method})"
|
||||
else:
|
||||
training_scope = "所有店铺"
|
||||
training_scope = "所有店铺"
|
||||
|
||||
if product_df.empty:
|
||||
raise ValueError(f"产品 {product_id} 没有可用的销售数据")
|
||||
|
||||
# 数据量检查
|
||||
min_required_samples = LOOK_BACK + FORECAST_HORIZON
|
||||
min_required_samples = sequence_length + forecast_horizon
|
||||
if len(product_df) < min_required_samples:
|
||||
error_msg = (
|
||||
f"❌ 训练数据不足错误\n"
|
||||
f"当前配置需要: {min_required_samples} 天数据 (LOOK_BACK={LOOK_BACK} + FORECAST_HORIZON={FORECAST_HORIZON})\n"
|
||||
f"当前配置需要: {min_required_samples} 天数据 (LOOK_BACK={sequence_length} + FORECAST_HORIZON={forecast_horizon})\n"
|
||||
f"实际数据量: {len(product_df)} 天\n"
|
||||
f"产品ID: {product_id}, 训练模式: {training_mode}\n"
|
||||
f"建议解决方案:\n"
|
||||
f"1. 生成更多数据: uv run generate_multi_store_data.py\n"
|
||||
f"2. 调整配置参数: 减小 LOOK_BACK 或 FORECAST_HORIZON\n"
|
||||
f"3. 使用全局训练模式聚合更多数据"
|
||||
)
|
||||
print(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
@ -199,18 +112,14 @@ def train_product_model_with_transformer(
|
||||
print(f"[Device] 使用设备: {DEVICE}", flush=True)
|
||||
print(f"[Model] 模型将保存到目录: {model_dir}", flush=True)
|
||||
|
||||
# 创建特征和目标变量
|
||||
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||
|
||||
# 设置数据预处理阶段
|
||||
progress_manager.set_stage("data_preprocessing", 0)
|
||||
emit_progress("数据预处理中...")
|
||||
|
||||
# 预处理数据
|
||||
X = product_df[features].values
|
||||
y = product_df[['sales']].values # 保持为二维数组
|
||||
y = product_df[['sales']].values
|
||||
|
||||
# 归一化数据
|
||||
scaler_X = MinMaxScaler(feature_range=(0, 1))
|
||||
scaler_y = MinMaxScaler(feature_range=(0, 1))
|
||||
|
||||
@ -219,24 +128,20 @@ def train_product_model_with_transformer(
|
||||
|
||||
progress_manager.set_stage("data_preprocessing", 40)
|
||||
|
||||
# 划分训练集和测试集(80% 训练,20% 测试)
|
||||
train_size = int(len(X_scaled) * 0.8)
|
||||
X_train, X_test = X_scaled[:train_size], X_scaled[train_size:]
|
||||
y_train, y_test = y_scaled[:train_size], y_scaled[train_size:]
|
||||
|
||||
# 创建时间序列数据
|
||||
trainX, trainY = create_dataset(X_train, y_train, LOOK_BACK, FORECAST_HORIZON)
|
||||
testX, testY = create_dataset(X_test, y_test, LOOK_BACK, FORECAST_HORIZON)
|
||||
trainX, trainY = create_dataset(X_train, y_train, sequence_length, forecast_horizon)
|
||||
testX, testY = create_dataset(X_test, y_test, sequence_length, forecast_horizon)
|
||||
|
||||
progress_manager.set_stage("data_preprocessing", 70)
|
||||
|
||||
# 转换为PyTorch的Tensor
|
||||
trainX_tensor = torch.Tensor(trainX)
|
||||
trainY_tensor = torch.Tensor(trainY)
|
||||
testX_tensor = torch.Tensor(testX)
|
||||
testY_tensor = torch.Tensor(testY)
|
||||
|
||||
# 创建数据加载器
|
||||
train_dataset = PharmacyDataset(trainX_tensor, trainY_tensor)
|
||||
test_dataset = PharmacyDataset(testX_tensor, testY_tensor)
|
||||
|
||||
@ -244,7 +149,6 @@ def train_product_model_with_transformer(
|
||||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
||||
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
|
||||
|
||||
# 更新进度管理器的批次信息
|
||||
total_batches = len(train_loader)
|
||||
total_samples = len(train_dataset)
|
||||
progress_manager.total_batches_per_epoch = total_batches
|
||||
@ -254,9 +158,8 @@ def train_product_model_with_transformer(
|
||||
progress_manager.set_stage("data_preprocessing", 100)
|
||||
emit_progress("数据预处理完成,开始模型训练...")
|
||||
|
||||
# 初始化Transformer模型
|
||||
input_dim = X_train.shape[1]
|
||||
output_dim = FORECAST_HORIZON
|
||||
output_dim = forecast_horizon
|
||||
hidden_size = 64
|
||||
num_heads = 4
|
||||
dropout_rate = 0.1
|
||||
@ -270,24 +173,21 @@ def train_product_model_with_transformer(
|
||||
dim_feedforward=hidden_size * 2,
|
||||
dropout=dropout_rate,
|
||||
output_sequence_length=output_dim,
|
||||
seq_length=LOOK_BACK,
|
||||
seq_length=sequence_length,
|
||||
batch_size=batch_size
|
||||
)
|
||||
|
||||
# 将模型移动到设备上
|
||||
model = model.to(DEVICE)
|
||||
|
||||
criterion = nn.MSELoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=patience // 2, factor=0.5, verbose=True)
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=patience // 2, factor=0.5)
|
||||
|
||||
# 训练模型
|
||||
train_losses = []
|
||||
test_losses = []
|
||||
start_time = time.time()
|
||||
|
||||
# 配置检查点保存
|
||||
checkpoint_interval = max(1, epochs // 10) # 每10%进度保存一次,最少每1个epoch
|
||||
checkpoint_interval = max(1, epochs // 10)
|
||||
best_loss = float('inf')
|
||||
epochs_no_improve = 0
|
||||
|
||||
@ -295,7 +195,6 @@ def train_product_model_with_transformer(
|
||||
emit_progress(f"开始训练 - 总epoch: {epochs}, 检查点间隔: {checkpoint_interval}, 耐心值: {patience}")
|
||||
|
||||
for epoch in range(epochs):
|
||||
# 开始新的轮次
|
||||
progress_manager.start_epoch(epoch)
|
||||
|
||||
model.train()
|
||||
@ -304,12 +203,9 @@ def train_product_model_with_transformer(
|
||||
for batch_idx, (X_batch, y_batch) in enumerate(train_loader):
|
||||
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
|
||||
|
||||
# 确保目标张量有正确的形状
|
||||
# 前向传播
|
||||
outputs = model(X_batch)
|
||||
loss = criterion(outputs, y_batch)
|
||||
|
||||
# 反向传播和优化
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
if clip_norm:
|
||||
@ -318,31 +214,25 @@ def train_product_model_with_transformer(
|
||||
|
||||
epoch_loss += loss.item()
|
||||
|
||||
# 更新批次进度
|
||||
if batch_idx % 5 == 0 or batch_idx == len(train_loader) - 1:
|
||||
current_lr = optimizer.param_groups[0]['lr']
|
||||
progress_manager.update_batch(batch_idx, loss.item(), current_lr)
|
||||
|
||||
# 计算训练损失
|
||||
train_loss = epoch_loss / len(train_loader)
|
||||
train_losses.append(train_loss)
|
||||
|
||||
# 设置验证阶段
|
||||
progress_manager.set_stage("validation", 0)
|
||||
|
||||
# 在测试集上评估
|
||||
model.eval()
|
||||
test_loss = 0
|
||||
with torch.no_grad():
|
||||
for batch_idx, (X_batch, y_batch) in enumerate(test_loader):
|
||||
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
|
||||
|
||||
# 确保目标张量有正确的形状
|
||||
outputs = model(X_batch)
|
||||
loss = criterion(outputs, y_batch)
|
||||
test_loss += loss.item()
|
||||
|
||||
# 更新验证进度
|
||||
if batch_idx % 3 == 0 or batch_idx == len(test_loader) - 1:
|
||||
val_progress = (batch_idx / len(test_loader)) * 100
|
||||
progress_manager.set_stage("validation", val_progress)
|
||||
@ -350,13 +240,10 @@ def train_product_model_with_transformer(
|
||||
test_loss = test_loss / len(test_loader)
|
||||
test_losses.append(test_loss)
|
||||
|
||||
# 更新学习率
|
||||
scheduler.step(test_loss)
|
||||
|
||||
# 完成当前轮次
|
||||
progress_manager.finish_epoch(train_loss, test_loss)
|
||||
|
||||
# 发送训练进度
|
||||
if (epoch + 1) % 5 == 0 or epoch == epochs - 1:
|
||||
progress = ((epoch + 1) / epochs) * 100
|
||||
current_metrics = {
|
||||
@ -368,7 +255,6 @@ def train_product_model_with_transformer(
|
||||
emit_progress(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}",
|
||||
progress=progress, metrics=current_metrics)
|
||||
|
||||
# 定期保存检查点
|
||||
if (epoch + 1) % checkpoint_interval == 0 or epoch == epochs - 1:
|
||||
checkpoint_data = {
|
||||
'epoch': epoch + 1,
|
||||
@ -387,8 +273,8 @@ def train_product_model_with_transformer(
|
||||
'num_heads': num_heads,
|
||||
'dropout': dropout_rate,
|
||||
'num_layers': num_layers,
|
||||
'sequence_length': LOOK_BACK,
|
||||
'forecast_horizon': FORECAST_HORIZON,
|
||||
'sequence_length': sequence_length,
|
||||
'forecast_horizon': forecast_horizon,
|
||||
'model_type': 'transformer'
|
||||
},
|
||||
'training_info': {
|
||||
@ -401,38 +287,35 @@ def train_product_model_with_transformer(
|
||||
}
|
||||
}
|
||||
|
||||
# 保存检查点
|
||||
save_checkpoint(checkpoint_data, epoch + 1, product_id, 'transformer',
|
||||
model_dir, store_id, training_mode, aggregation_method)
|
||||
|
||||
# 如果是最佳模型,额外保存一份
|
||||
if test_loss < best_loss:
|
||||
best_loss = test_loss
|
||||
save_checkpoint(checkpoint_data, 'best', product_id, 'transformer',
|
||||
model_dir, store_id, training_mode, aggregation_method)
|
||||
model_manager.save_model(
|
||||
model_data=checkpoint_data,
|
||||
product_id=model_identifier, # 修正:使用唯一的标识符
|
||||
model_type='transformer',
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
product_name=product_name,
|
||||
version='best'
|
||||
)
|
||||
emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})")
|
||||
epochs_no_improve = 0
|
||||
else:
|
||||
epochs_no_improve += 1
|
||||
|
||||
emit_progress(f"💾 保存训练检查点 epoch_{epoch+1}")
|
||||
|
||||
if (epoch + 1) % 10 == 0:
|
||||
print(f"📊 Epoch {epoch+1}/{epochs}, 训练损失: {train_loss:.4f}, 测试损失: {test_loss:.4f}", flush=True)
|
||||
|
||||
# 提前停止逻辑
|
||||
if epochs_no_improve >= patience:
|
||||
emit_progress(f"连续 {patience} 个epoch测试损失未改善,提前停止训练。")
|
||||
break
|
||||
|
||||
# 计算训练时间
|
||||
training_time = time.time() - start_time
|
||||
|
||||
# 设置模型保存阶段
|
||||
progress_manager.set_stage("model_saving", 0)
|
||||
emit_progress("训练完成,正在保存模型...")
|
||||
|
||||
# 绘制损失曲线并保存到模型目录
|
||||
loss_curve_path = plot_loss_curve(
|
||||
train_losses,
|
||||
test_losses,
|
||||
@ -442,21 +325,17 @@ def train_product_model_with_transformer(
|
||||
)
|
||||
print(f"📈 损失曲线已保存到: {loss_curve_path}", flush=True)
|
||||
|
||||
# 评估模型
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
test_pred = model(testX_tensor.to(DEVICE)).cpu().numpy()
|
||||
test_true = testY
|
||||
|
||||
# 反归一化预测结果和真实值
|
||||
test_pred_inv = scaler_y.inverse_transform(test_pred)
|
||||
test_true_inv = scaler_y.inverse_transform(test_true)
|
||||
|
||||
# 计算评估指标
|
||||
metrics = evaluate_model(test_true_inv, test_pred_inv)
|
||||
metrics['training_time'] = training_time
|
||||
|
||||
# 打印评估指标
|
||||
print(f"\n📊 模型评估指标:", flush=True)
|
||||
print(f" MSE: {metrics['mse']:.4f}", flush=True)
|
||||
print(f" RMSE: {metrics['rmse']:.4f}", flush=True)
|
||||
@ -465,9 +344,8 @@ def train_product_model_with_transformer(
|
||||
print(f" MAPE: {metrics['mape']:.2f}%", flush=True)
|
||||
print(f" ⏱️ 训练时间: {training_time:.2f}秒", flush=True)
|
||||
|
||||
# 保存最终训练完成的模型(基于最终epoch)
|
||||
final_model_data = {
|
||||
'epoch': epochs, # 最终epoch
|
||||
'epoch': epochs,
|
||||
'model_state_dict': model.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'train_loss': train_losses[-1],
|
||||
@ -483,8 +361,8 @@ def train_product_model_with_transformer(
|
||||
'num_heads': num_heads,
|
||||
'dropout': dropout_rate,
|
||||
'num_layers': num_layers,
|
||||
'sequence_length': LOOK_BACK,
|
||||
'forecast_horizon': FORECAST_HORIZON,
|
||||
'sequence_length': sequence_length,
|
||||
'forecast_horizon': forecast_horizon,
|
||||
'model_type': 'transformer'
|
||||
},
|
||||
'metrics': metrics,
|
||||
@ -502,10 +380,14 @@ def train_product_model_with_transformer(
|
||||
|
||||
progress_manager.set_stage("model_saving", 50)
|
||||
|
||||
# 保存最终模型(使用epoch标识)
|
||||
final_model_path = save_checkpoint(
|
||||
final_model_data, f"final_epoch_{epochs}", product_id, 'transformer',
|
||||
model_dir, store_id, training_mode, aggregation_method
|
||||
final_model_path, final_version = model_manager.save_model(
|
||||
model_data=final_model_data,
|
||||
product_id=model_identifier, # 修正:使用唯一的标识符
|
||||
model_type='transformer',
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
product_name=product_name
|
||||
)
|
||||
|
||||
progress_manager.set_stage("model_saving", 100)
|
||||
@ -513,7 +395,6 @@ def train_product_model_with_transformer(
|
||||
|
||||
print(f"💾 模型已保存到 {final_model_path}", flush=True)
|
||||
|
||||
# 准备最终返回的指标
|
||||
final_metrics = {
|
||||
'mse': metrics['mse'],
|
||||
'rmse': metrics['rmse'],
|
||||
@ -521,7 +402,12 @@ def train_product_model_with_transformer(
|
||||
'r2': metrics['r2'],
|
||||
'mape': metrics['mape'],
|
||||
'training_time': training_time,
|
||||
'final_epoch': epochs
|
||||
'final_epoch': epochs,
|
||||
'version': final_version
|
||||
}
|
||||
|
||||
return model, final_metrics, epochs
|
||||
return model, final_metrics, epochs
|
||||
|
||||
# --- 将此训练器注册到系统中 ---
|
||||
from models.model_registry import register_trainer
|
||||
register_trainer('transformer', train_product_model_with_transformer)
|
142
server/trainers/xgboost_trainer.py
Normal file
142
server/trainers/xgboost_trainer.py
Normal file
@ -0,0 +1,142 @@
|
||||
"""
|
||||
药店销售预测系统 - XGBoost 模型训练器 (插件式)
|
||||
"""
|
||||
|
||||
import time
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import xgboost as xgb
|
||||
from sklearn.preprocessing import MinMaxScaler
|
||||
from xgboost.callback import EarlyStopping
|
||||
|
||||
# 导入核心工具
|
||||
from utils.data_utils import create_dataset
|
||||
from analysis.metrics import evaluate_model
|
||||
from utils.model_manager import model_manager
|
||||
from models.model_registry import register_trainer
|
||||
|
||||
def train_product_model_with_xgboost(product_id, model_identifier, product_df, store_id, training_mode, aggregation_method, epochs, sequence_length, forecast_horizon, model_dir, **kwargs):
|
||||
"""
|
||||
使用 XGBoost 模型训练产品销售预测模型。
|
||||
此函数签名与其他训练器保持一致,以兼容注册表调用。
|
||||
"""
|
||||
print(f"🚀 XGBoost训练器启动: model_identifier='{model_identifier}'")
|
||||
|
||||
# --- 1. 数据准备和验证 ---
|
||||
if product_df.empty:
|
||||
raise ValueError(f"产品 {product_id} 没有可用的销售数据")
|
||||
|
||||
min_required_samples = sequence_length + forecast_horizon
|
||||
if len(product_df) < min_required_samples:
|
||||
error_msg = (f"数据不足: 需要 {min_required_samples} 条, 实际 {len(product_df)} 条。")
|
||||
raise ValueError(error_msg)
|
||||
|
||||
product_df = product_df.sort_values('date')
|
||||
product_name = product_df['product_name'].iloc[0] if 'product_name' in product_df.columns else model_identifier
|
||||
|
||||
# --- 2. 数据预处理和适配 ---
|
||||
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||
|
||||
X = product_df[features].values
|
||||
y = product_df[['sales']].values
|
||||
|
||||
scaler_X = MinMaxScaler(feature_range=(0, 1))
|
||||
scaler_y = MinMaxScaler(feature_range=(0, 1))
|
||||
|
||||
X_scaled = scaler_X.fit_transform(X)
|
||||
y_scaled = scaler_y.fit_transform(y)
|
||||
|
||||
train_size = int(len(X_scaled) * 0.8)
|
||||
X_train_raw, X_test_raw = X_scaled[:train_size], X_scaled[train_size:]
|
||||
y_train_raw, y_test_raw = y_scaled[:train_size], y_scaled[train_size:]
|
||||
|
||||
trainX, trainY = create_dataset(X_train_raw, y_train_raw, sequence_length, forecast_horizon)
|
||||
testX, testY = create_dataset(X_test_raw, y_test_raw, sequence_length, forecast_horizon)
|
||||
|
||||
# **关键适配步骤**: XGBoost 需要二维输入
|
||||
trainX = trainX.reshape(trainX.shape[0], -1)
|
||||
testX = testX.reshape(testX.shape[0], -1)
|
||||
|
||||
# **关键适配**: 转换为 XGBoost 核心 DMatrix 格式,以使用稳定的 xgb.train API
|
||||
dtrain = xgb.DMatrix(trainX, label=trainY)
|
||||
dtest = xgb.DMatrix(testX, label=testY)
|
||||
|
||||
# --- 3. 模型训练 (使用核心 xgb.train API) ---
|
||||
xgb_params = {
|
||||
'learning_rate': kwargs.get('learning_rate', 0.08),
|
||||
'subsample': kwargs.get('subsample', 0.75),
|
||||
'colsample_bytree': kwargs.get('colsample_bytree', 1),
|
||||
'max_depth': kwargs.get('max_depth', 7),
|
||||
'gamma': kwargs.get('gamma', 0),
|
||||
'objective': 'reg:squarederror',
|
||||
'eval_metric': 'rmse', # eval_metric 在这里是原生支持的
|
||||
'n_jobs': -1
|
||||
}
|
||||
n_estimators = kwargs.get('n_estimators', 500)
|
||||
|
||||
print("开始训练XGBoost模型 (使用核心xgb.train API)...")
|
||||
start_time = time.time()
|
||||
|
||||
evals_result = {}
|
||||
model = xgb.train(
|
||||
params=xgb_params,
|
||||
dtrain=dtrain,
|
||||
num_boost_round=n_estimators,
|
||||
evals=[(dtrain, 'train'), (dtest, 'test')],
|
||||
early_stopping_rounds=50, # early_stopping_rounds 在这里是原生支持的
|
||||
evals_result=evals_result,
|
||||
verbose_eval=False
|
||||
)
|
||||
|
||||
training_time = time.time() - start_time
|
||||
print(f"XGBoost模型训练完成,耗时: {training_time:.2f}秒")
|
||||
|
||||
# --- 4. 模型评估 ---
|
||||
# 使用 model.best_iteration 获取最佳轮次的预测结果
|
||||
test_pred = model.predict(dtest, iteration_range=(0, model.best_iteration))
|
||||
|
||||
test_pred_inv = scaler_y.inverse_transform(test_pred.reshape(-1, forecast_horizon))
|
||||
test_true_inv = scaler_y.inverse_transform(testY.reshape(-1, forecast_horizon))
|
||||
|
||||
metrics = evaluate_model(test_true_inv.flatten(), test_pred_inv.flatten())
|
||||
metrics['training_time'] = training_time
|
||||
|
||||
print("\n模型评估指标:")
|
||||
print(f"MSE: {metrics['mse']:.4f}, RMSE: {metrics['rmse']:.4f}, MAE: {metrics['mae']:.4f}, R²: {metrics['r2']:.4f}, MAPE: {metrics['mape']:.2f}%")
|
||||
|
||||
# --- 5. 模型保存 (借道 utils.model_manager) ---
|
||||
# **关键适配点**: 我们将完整的XGBoost模型对象存入字典
|
||||
# torch.save 可以序列化多种Python对象,包括sklearn模型
|
||||
model_data = {
|
||||
'model_state_dict': model, # 直接保存模型对象
|
||||
'scaler_X': scaler_X,
|
||||
'scaler_y': scaler_y,
|
||||
'config': {
|
||||
'sequence_length': sequence_length,
|
||||
'forecast_horizon': forecast_horizon,
|
||||
'model_type': 'xgboost',
|
||||
'features': features,
|
||||
'xgb_params': xgb_params
|
||||
},
|
||||
'metrics': metrics,
|
||||
'loss_history': evals_result
|
||||
}
|
||||
|
||||
# 调用全局管理器进行保存,复用其命名和版本逻辑
|
||||
final_model_path, final_version = model_manager.save_model(
|
||||
model_data=model_data,
|
||||
product_id=product_id,
|
||||
model_type='xgboost',
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
product_name=product_name
|
||||
)
|
||||
|
||||
print(f"XGBoost模型已通过统一管理器保存,版本: {final_version}, 路径: {final_model_path}")
|
||||
|
||||
# 返回值遵循统一格式
|
||||
return model, metrics, final_version, final_model_path
|
||||
|
||||
# --- 将此训练器注册到系统中 ---
|
||||
register_trainer('xgboost', train_product_model_with_xgboost)
|
@ -8,6 +8,7 @@ import json
|
||||
import torch
|
||||
import glob
|
||||
from datetime import datetime
|
||||
import re
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
from core.config import DEFAULT_MODEL_DIR
|
||||
|
||||
@ -24,56 +25,91 @@ class ModelManager:
|
||||
if not os.path.exists(self.model_dir):
|
||||
os.makedirs(self.model_dir)
|
||||
|
||||
def generate_model_filename(self,
|
||||
product_id: str,
|
||||
model_type: str,
|
||||
def _get_next_version(self, model_type: str, product_id: Optional[str] = None, store_id: Optional[str] = None, training_mode: str = 'product', aggregation_method: Optional[str] = None) -> int:
|
||||
"""获取下一个模型版本号 (纯数字)"""
|
||||
search_pattern = self.generate_model_filename(
|
||||
model_type=model_type,
|
||||
version='v*',
|
||||
product_id=product_id,
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method
|
||||
)
|
||||
|
||||
full_search_path = os.path.join(self.model_dir, search_pattern)
|
||||
existing_files = glob.glob(full_search_path)
|
||||
|
||||
max_version = 0
|
||||
for f in existing_files:
|
||||
match = re.search(r'_v(\d+)\.pth$', os.path.basename(f))
|
||||
if match:
|
||||
max_version = max(max_version, int(match.group(1)))
|
||||
|
||||
return max_version + 1
|
||||
|
||||
def generate_model_filename(self,
|
||||
model_type: str,
|
||||
version: str,
|
||||
store_id: Optional[str] = None,
|
||||
training_mode: str = 'product',
|
||||
product_id: Optional[str] = None,
|
||||
store_id: Optional[str] = None,
|
||||
aggregation_method: Optional[str] = None) -> str:
|
||||
"""
|
||||
生成统一的模型文件名
|
||||
|
||||
格式规范:
|
||||
|
||||
格式规范 (v2):
|
||||
- 产品模式: {model_type}_product_{product_id}_{version}.pth
|
||||
- 店铺模式: {model_type}_store_{store_id}_{product_id}_{version}.pth
|
||||
- 全局模式: {model_type}_global_{product_id}_{aggregation_method}_{version}.pth
|
||||
- 店铺模式: {model_type}_store_{store_id}_{version}.pth
|
||||
- 全局模式: {model_type}_global_{aggregation_method}_{version}.pth
|
||||
"""
|
||||
if training_mode == 'store' and store_id:
|
||||
return f"{model_type}_store_{store_id}_{product_id}_{version}.pth"
|
||||
return f"{model_type}_store_{store_id}_{version}.pth"
|
||||
elif training_mode == 'global' and aggregation_method:
|
||||
return f"{model_type}_global_{product_id}_{aggregation_method}_{version}.pth"
|
||||
else:
|
||||
# 默认产品模式
|
||||
return f"{model_type}_global_{aggregation_method}_{version}.pth"
|
||||
elif training_mode == 'product' and product_id:
|
||||
return f"{model_type}_product_{product_id}_{version}.pth"
|
||||
else:
|
||||
# 提供一个后备或抛出错误,以避免生成无效文件名
|
||||
raise ValueError(f"无法为训练模式 '{training_mode}' 生成有效的文件名,缺少必需的ID。")
|
||||
|
||||
def save_model(self,
|
||||
def save_model(self,
|
||||
model_data: dict,
|
||||
product_id: str,
|
||||
model_type: str,
|
||||
version: str,
|
||||
model_type: str,
|
||||
store_id: Optional[str] = None,
|
||||
training_mode: str = 'product',
|
||||
aggregation_method: Optional[str] = None,
|
||||
product_name: Optional[str] = None) -> str:
|
||||
product_name: Optional[str] = None,
|
||||
version: Optional[str] = None) -> Tuple[str, str]:
|
||||
"""
|
||||
保存模型到统一位置
|
||||
保存模型到统一位置,并自动管理版本。
|
||||
|
||||
参数:
|
||||
model_data: 包含模型状态和配置的字典
|
||||
product_id: 产品ID
|
||||
model_type: 模型类型
|
||||
version: 版本号
|
||||
store_id: 店铺ID (可选)
|
||||
training_mode: 训练模式
|
||||
aggregation_method: 聚合方法 (可选)
|
||||
product_name: 产品名称 (可选)
|
||||
...
|
||||
version: (可选) 如果提供,则覆盖自动版本控制 (如 'best')。
|
||||
|
||||
返回:
|
||||
模型文件路径
|
||||
(模型文件路径, 使用的版本号)
|
||||
"""
|
||||
if version is None:
|
||||
next_version_num = self._get_next_version(
|
||||
model_type=model_type,
|
||||
product_id=product_id,
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method
|
||||
)
|
||||
version_str = f"v{next_version_num}"
|
||||
else:
|
||||
version_str = version
|
||||
|
||||
filename = self.generate_model_filename(
|
||||
product_id, model_type, version, store_id, training_mode, aggregation_method
|
||||
model_type=model_type,
|
||||
version=version_str,
|
||||
training_mode=training_mode,
|
||||
product_id=product_id,
|
||||
store_id=store_id,
|
||||
aggregation_method=aggregation_method
|
||||
)
|
||||
|
||||
# 统一保存到根目录,避免复杂的子目录结构
|
||||
@ -86,7 +122,7 @@ class ModelManager:
|
||||
'product_id': product_id,
|
||||
'product_name': product_name or product_id,
|
||||
'model_type': model_type,
|
||||
'version': version,
|
||||
'version': version_str,
|
||||
'store_id': store_id,
|
||||
'training_mode': training_mode,
|
||||
'aggregation_method': aggregation_method,
|
||||
@ -99,7 +135,7 @@ class ModelManager:
|
||||
torch.save(enhanced_model_data, model_path)
|
||||
|
||||
print(f"模型已保存: {model_path}")
|
||||
return model_path
|
||||
return model_path, version_str
|
||||
|
||||
def list_models(self,
|
||||
product_id: Optional[str] = None,
|
||||
@ -228,127 +264,58 @@ class ModelManager:
|
||||
|
||||
def parse_model_filename(self, filename: str) -> Optional[Dict]:
|
||||
"""
|
||||
解析模型文件名,提取模型信息
|
||||
|
||||
解析模型文件名,提取模型信息 (v2版)
|
||||
|
||||
支持的格式:
|
||||
- {model_type}_product_{product_id}_{version}.pth
|
||||
- {model_type}_store_{store_id}_{product_id}_{version}.pth
|
||||
- {model_type}_global_{product_id}_{aggregation_method}_{version}.pth
|
||||
- 旧格式兼容
|
||||
- 产品: {model_type}_product_{product_id}_{version}.pth
|
||||
- 店铺: {model_type}_store_{store_id}_{version}.pth
|
||||
- 全局: {model_type}_global_{aggregation_method}_{version}.pth
|
||||
"""
|
||||
if not filename.endswith('.pth'):
|
||||
return None
|
||||
|
||||
|
||||
base_name = filename.replace('.pth', '')
|
||||
|
||||
parts = base_name.split('_')
|
||||
|
||||
if len(parts) < 3:
|
||||
return None # 格式不符合基本要求
|
||||
|
||||
# **核心修复**: 采用更健壮的、从后往前的解析逻辑,以支持带下划线的模型名称
|
||||
try:
|
||||
# 新格式解析
|
||||
if '_product_' in base_name:
|
||||
# 产品模式: model_type_product_product_id_version
|
||||
parts = base_name.split('_product_')
|
||||
model_type = parts[0]
|
||||
rest = parts[1]
|
||||
|
||||
# 分离产品ID和版本
|
||||
if '_v' in rest:
|
||||
last_v_index = rest.rfind('_v')
|
||||
product_id = rest[:last_v_index]
|
||||
version = rest[last_v_index+1:]
|
||||
else:
|
||||
product_id = rest
|
||||
version = 'v1'
|
||||
|
||||
version = parts[-1]
|
||||
identifier = parts[-2]
|
||||
mode_candidate = parts[-3]
|
||||
|
||||
if mode_candidate == 'product':
|
||||
model_type = '_'.join(parts[:-3])
|
||||
return {
|
||||
'model_type': model_type,
|
||||
'product_id': product_id,
|
||||
'version': version,
|
||||
'training_mode': 'product',
|
||||
'store_id': None,
|
||||
'aggregation_method': None
|
||||
'product_id': identifier,
|
||||
'version': version,
|
||||
}
|
||||
|
||||
elif '_store_' in base_name:
|
||||
# 店铺模式: model_type_store_store_id_product_id_version
|
||||
parts = base_name.split('_store_')
|
||||
model_type = parts[0]
|
||||
rest = parts[1]
|
||||
|
||||
# 分离店铺ID、产品ID和版本
|
||||
rest_parts = rest.split('_')
|
||||
if len(rest_parts) >= 3:
|
||||
store_id = rest_parts[0]
|
||||
if rest_parts[-1].startswith('v'):
|
||||
# 最后一部分是版本号
|
||||
version = rest_parts[-1]
|
||||
product_id = '_'.join(rest_parts[1:-1])
|
||||
else:
|
||||
version = 'v1'
|
||||
product_id = '_'.join(rest_parts[1:])
|
||||
|
||||
return {
|
||||
'model_type': model_type,
|
||||
'product_id': product_id,
|
||||
'version': version,
|
||||
'training_mode': 'store',
|
||||
'store_id': store_id,
|
||||
'aggregation_method': None
|
||||
}
|
||||
|
||||
elif '_global_' in base_name:
|
||||
# 全局模式: model_type_global_product_id_aggregation_method_version
|
||||
parts = base_name.split('_global_')
|
||||
model_type = parts[0]
|
||||
rest = parts[1]
|
||||
|
||||
rest_parts = rest.split('_')
|
||||
if len(rest_parts) >= 3:
|
||||
if rest_parts[-1].startswith('v'):
|
||||
# 最后一部分是版本号
|
||||
version = rest_parts[-1]
|
||||
aggregation_method = rest_parts[-2]
|
||||
product_id = '_'.join(rest_parts[:-2])
|
||||
else:
|
||||
version = 'v1'
|
||||
aggregation_method = rest_parts[-1]
|
||||
product_id = '_'.join(rest_parts[:-1])
|
||||
|
||||
return {
|
||||
'model_type': model_type,
|
||||
'product_id': product_id,
|
||||
'version': version,
|
||||
'training_mode': 'global',
|
||||
'store_id': None,
|
||||
'aggregation_method': aggregation_method
|
||||
}
|
||||
|
||||
# 兼容旧格式
|
||||
else:
|
||||
# 尝试解析其他格式
|
||||
if 'model_product_' in base_name:
|
||||
parts = base_name.split('_model_product_')
|
||||
model_type = parts[0]
|
||||
product_part = parts[1]
|
||||
|
||||
if '_v' in product_part:
|
||||
last_v_index = product_part.rfind('_v')
|
||||
product_id = product_part[:last_v_index]
|
||||
version = product_part[last_v_index+1:]
|
||||
else:
|
||||
product_id = product_part
|
||||
version = 'v1'
|
||||
|
||||
return {
|
||||
'model_type': model_type,
|
||||
'product_id': product_id,
|
||||
'version': version,
|
||||
'training_mode': 'product',
|
||||
'store_id': None,
|
||||
'aggregation_method': None
|
||||
}
|
||||
|
||||
elif mode_candidate == 'store':
|
||||
model_type = '_'.join(parts[:-3])
|
||||
return {
|
||||
'model_type': model_type,
|
||||
'training_mode': 'store',
|
||||
'store_id': identifier,
|
||||
'version': version,
|
||||
}
|
||||
elif mode_candidate == 'global':
|
||||
model_type = '_'.join(parts[:-3])
|
||||
return {
|
||||
'model_type': model_type,
|
||||
'training_mode': 'global',
|
||||
'aggregation_method': identifier,
|
||||
'version': version,
|
||||
}
|
||||
except IndexError:
|
||||
# 如果文件名部分不够,则解析失败
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f"解析文件名失败 {filename}: {e}")
|
||||
|
||||
|
||||
return None
|
||||
|
||||
def delete_model(self, model_file: str) -> bool:
|
||||
|
@ -1,5 +1,5 @@
|
||||
### 根目录启动
|
||||
`uv pip install loguru numpy pandas torch matplotlib flask flask_cors flask_socketio flasgger scikit-learn tqdm pytorch_tcn`
|
||||
`uv pip install loguru numpy pandas torch matplotlib flask flask_cors flask_socketio flasgger scikit-learn tqdm pytorch_tcn pyarrow xgboost`
|
||||
|
||||
### UI
|
||||
`npm install` `npm run dev`
|
||||
@ -755,4 +755,34 @@
|
||||
# ... 后续处理逻辑保持不变 ...
|
||||
```
|
||||
|
||||
通过以上步骤,您就可以在不改动项目其他任何部分的情况下,轻松地将数据源从本地文件切换到服务器数据库。
|
||||
通过以上步骤,您就可以在不改动项目其他任何部分的情况下,轻松地将数据源从本地文件切换到服务器数据库。
|
||||
|
||||
---
|
||||
**日期**: 2025-07-15
|
||||
**主题**: 修复“按药品预测”功能并增强图表展示
|
||||
**开发者**: lyf
|
||||
|
||||
### 问题描述
|
||||
“预测分析” -> “按药品预测”页面无法正常使用。前端API调用地址错误,且图表渲染逻辑与后端返回的数据结构不匹配。
|
||||
|
||||
### 解决方案
|
||||
对 `UI/src/views/prediction/ProductPredictionView.vue` 文件进行了以下修复和增强:
|
||||
|
||||
1. **API端点修复**:
|
||||
* **位置**: `startPrediction` 函数。
|
||||
* **操作**: 将API请求地址从错误的 `/api/predict` 修正为正确的 `/api/prediction`。
|
||||
|
||||
2. **数据处理对齐**:
|
||||
* **位置**: `startPrediction` 和 `renderChart` 函数。
|
||||
* **操作**: 修改了数据接收逻辑,使其能够正确处理后端返回的 `history_data` 和 `prediction_data` 字段。
|
||||
|
||||
3. **图表功能增强**:
|
||||
* **位置**: `renderChart` 函数。
|
||||
* **操作**: 重构了图表渲染逻辑,现在可以同时展示历史销量(绿色实线)和预测销量(蓝色虚线),为用户提供更直观的对比分析。
|
||||
|
||||
4. **错误提示优化**:
|
||||
* **位置**: `startPrediction` 函数的 `catch` 块。
|
||||
* **操作**: 改进了错误处理,现在可以从响应中提取并显示来自后端的更具体的错误信息。
|
||||
|
||||
### 最终结果
|
||||
“按药品预测”功能已与后端成功对接,可以正常使用,并且提供了更丰富、更健壮的可视化体验。
|
222
xz新模型添加流程.md
Normal file
222
xz新模型添加流程.md
Normal file
@ -0,0 +1,222 @@
|
||||
# 如何向系统添加新模型
|
||||
|
||||
本指南详细说明了如何向本预测系统添加一个全新的预测模型。系统采用灵活的插件式架构,集成新模型的过程非常模块化,主要围绕 **模型(Model)**、**训练器(Trainer)** 和 **预测器(Predictor)** 这三个核心组件进行。
|
||||
|
||||
## 核心理念
|
||||
|
||||
系统的核心是 `models/model_registry.py`,它维护了两个独立的注册表:一个用于训练函数,另一个用于预测函数。添加新模型的本质就是:
|
||||
|
||||
1. **定义模型**:创建模型的架构。
|
||||
2. **创建训练器**:编写一个函数来训练这个模型,并将其注册到训练器注册表。
|
||||
3. **集成预测器**:确保系统知道如何加载模型并用它来预测,然后将预测逻辑注册到预测器注册表。
|
||||
|
||||
---
|
||||
|
||||
## 第 1 步:定义模型架构
|
||||
|
||||
首先,您需要在 `ShopTRAINING/server/models/` 目录下创建一个新的 Python 文件来定义您的模型。
|
||||
|
||||
**示例:创建 `ShopTRAINING/server/models/my_new_model.py`**
|
||||
|
||||
如果您的新模型是基于 PyTorch 的,它应该是一个继承自 `torch.nn.Module` 的类。
|
||||
|
||||
```python
|
||||
# file: ShopTRAINING/server/models/my_new_model.py
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class MyNewModel(nn.Module):
|
||||
def __init__(self, input_features, hidden_size, output_sequence_length):
|
||||
"""
|
||||
定义模型的层和结构。
|
||||
"""
|
||||
super(MyNewModel, self).__init__()
|
||||
self.layer1 = nn.Linear(input_features, hidden_size)
|
||||
self.relu = nn.ReLU()
|
||||
self.layer2 = nn.Linear(hidden_size, output_sequence_length)
|
||||
# ... 可添加更复杂的结构
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
定义数据通过模型的前向传播路径。
|
||||
x 的形状通常是 (batch_size, sequence_length, num_features)
|
||||
"""
|
||||
# 确保输入是正确的形状
|
||||
# 例如,对于简单的线性层,可能需要展平
|
||||
batch_size, seq_len, features = x.shape
|
||||
x = x.view(batch_size * seq_len, features) # 展平
|
||||
|
||||
out = self.layer1(x)
|
||||
out = self.relu(out)
|
||||
out = self.layer2(out)
|
||||
|
||||
# 恢复形状以匹配输出
|
||||
out = out.view(batch_size, seq_len, -1)
|
||||
# 通常我们只关心序列的最后一个预测
|
||||
return out[:, -1, :]
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 第 2 步:创建模型训练器
|
||||
|
||||
接下来,在 `ShopTRAINING/server/trainers/` 目录下创建一个新的训练器文件。这个文件负责模型的整个训练、评估和保存流程。
|
||||
|
||||
**示例:创建 `ShopTRAINING/server/trainers/my_new_model_trainer.py`**
|
||||
|
||||
这个训练函数需要遵循系统中其他训练器(如 `xgboost_trainer.py`)的统一函数签名,并使用 `@register_trainer` 装饰器或在文件末尾调用 `register_trainer` 函数。
|
||||
|
||||
```python
|
||||
# file: ShopTRAINING/server/trainers/my_new_model_trainer.py
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
from models.model_registry import register_trainer
|
||||
from utils.model_manager import model_manager
|
||||
from analysis.metrics import evaluate_model
|
||||
from models.my_new_model import MyNewModel # 导入您的新模型
|
||||
|
||||
# 遵循系统的标准函数签名
|
||||
def train_with_mynewmodel(product_id, model_identifier, product_df, store_id, training_mode, aggregation_method, epochs, sequence_length, forecast_horizon, model_dir, **kwargs):
|
||||
print(f"🚀 MyNewModel 训练器启动: model_identifier='{model_identifier}'")
|
||||
|
||||
# --- 1. 数据准备 ---
|
||||
# (此处省略了数据加载、标准化和创建数据集的详细代码,
|
||||
# 您可以参考 xgboost_trainer.py 或其他训练器中的实现)
|
||||
# ...
|
||||
# 假设您已准备好 trainX, trainY, testX, testY, scaler_y 等变量
|
||||
# trainX = ...
|
||||
# trainY = ...
|
||||
# testX = ...
|
||||
# testY = ...
|
||||
# scaler_y = ...
|
||||
# features = [...]
|
||||
|
||||
# --- 2. 实例化模型和优化器 ---
|
||||
input_dim = trainX.shape[2] # 获取特征数量
|
||||
hidden_size = 64 # 示例超参数
|
||||
|
||||
model = MyNewModel(
|
||||
input_features=input_dim,
|
||||
hidden_size=hidden_size,
|
||||
output_sequence_length=forecast_horizon
|
||||
)
|
||||
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||
criterion = torch.nn.MSELoss()
|
||||
|
||||
# --- 3. 训练循环 ---
|
||||
print("开始训练 MyNewModel...")
|
||||
for epoch in range(epochs):
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
outputs = model(trainX)
|
||||
loss = criterion(outputs, trainY)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
if (epoch + 1) % 10 == 0:
|
||||
print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')
|
||||
|
||||
# --- 4. 模型评估 ---
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
test_pred_scaled = model(testX)
|
||||
|
||||
# 反标准化并计算指标
|
||||
# ... (参考其他训练器)
|
||||
metrics = {'rmse': 0.0, 'mae': 0.0, 'r2': 0.0, 'mape': 0.0} # 示例
|
||||
|
||||
# --- 5. 模型保存 ---
|
||||
model_data = {
|
||||
'model_state_dict': model.state_dict(),
|
||||
'scaler_X': None, # 替换为您的 scaler_X
|
||||
'scaler_y': scaler_y,
|
||||
'config': {
|
||||
'model_type': 'mynewmodel', # **关键**: 使用唯一的模型名称
|
||||
'input_dim': input_dim,
|
||||
'hidden_size': hidden_size,
|
||||
'sequence_length': sequence_length,
|
||||
'forecast_horizon': forecast_horizon,
|
||||
'features': features
|
||||
},
|
||||
'metrics': metrics
|
||||
}
|
||||
|
||||
model_manager.save_model(
|
||||
model_data=model_data,
|
||||
product_id=product_id,
|
||||
model_type='mynewmodel', # **关键**: 再次确认模型名称
|
||||
# ... 其他参数
|
||||
)
|
||||
|
||||
print("✅ MyNewModel 模型训练并保存完成!")
|
||||
return model, metrics, "v1", "path/to/model" # 返回值遵循统一格式
|
||||
|
||||
# --- 关键步骤: 将训练器注册到系统中 ---
|
||||
register_trainer('mynewmodel', train_with_mynewmodel)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 第 3 步:集成模型预测器
|
||||
|
||||
最后,您需要让系统知道如何加载和使用您的新模型进行预测。这需要在 `ShopTRAINING/server/predictors/model_predictor.py` 中进行两处修改。
|
||||
|
||||
**文件: `ShopTRAINING/server/predictors/model_predictor.py`**
|
||||
|
||||
1. **让系统知道如何构建您的模型实例**
|
||||
|
||||
在 `load_model_and_predict` 函数中,有一个 `if/elif` 结构用于根据模型类型实例化不同的模型。您需要为 `MyNewModel` 添加一个新的分支。
|
||||
|
||||
```python
|
||||
# 在 model_predictor.py 中
|
||||
|
||||
# 首先,导入您的新模型类
|
||||
from models.my_new_model import MyNewModel
|
||||
|
||||
# ... 在 load_model_and_predict 函数内部 ...
|
||||
|
||||
# ... 其他模型的 elif 分支 ...
|
||||
elif loaded_model_type == 'tcn':
|
||||
model = TCNForecaster(...)
|
||||
|
||||
# vvv 添加这个新的分支 vvv
|
||||
elif loaded_model_type == 'mynewmodel':
|
||||
model = MyNewModel(
|
||||
input_features=config['input_dim'],
|
||||
hidden_size=config['hidden_size'],
|
||||
output_sequence_length=config['forecast_horizon']
|
||||
).to(DEVICE)
|
||||
# ^^^ 添加结束 ^^^
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的模型类型: {loaded_model_type}")
|
||||
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
model.eval()
|
||||
```
|
||||
|
||||
2. **注册预测逻辑**
|
||||
|
||||
如果您的模型是一个标准的 PyTorch 模型,并且其预测逻辑与现有的模型(如 Transformer, KAN)相同,您可以直接复用 `default_pytorch_predictor`。只需在文件末尾添加一行注册代码即可。
|
||||
|
||||
```python
|
||||
# 在 model_predictor.py 文件末尾
|
||||
|
||||
# ...
|
||||
# 将增强后的默认预测器也注册给xgboost
|
||||
register_predictor('xgboost', default_pytorch_predictor)
|
||||
|
||||
# vvv 添加这行代码 vvv
|
||||
# 让 'mynewmodel' 也使用通用的 PyTorch 预测器
|
||||
register_predictor('mynewmodel', default_pytorch_predictor)
|
||||
# ^^^ 添加结束 ^^^
|
||||
```
|
||||
|
||||
如果您的模型需要特殊的预测逻辑(例如,像 XGBoost 那样有不同的输入格式或调用方式),您可以复制 `default_pytorch_predictor` 创建一个新函数,修改其内部逻辑,然后将新函数注册给 `'mynewmodel'`。
|
||||
|
||||
---
|
||||
|
||||
## 总结
|
||||
|
||||
完成以上三个步骤后,您的新模型 `MyNewModel` 就已完全集成到系统中了。系统会自动在 `trainers` 目录中发现您的新训练器。当您通过 API 或界面选择 `mynewmodel` 进行训练和预测时,系统将自动调用您刚刚编写和注册的所有相应逻辑。
|
150
新需求开发流程.md
Normal file
150
新需求开发流程.md
Normal file
@ -0,0 +1,150 @@
|
||||
# 新需求开发标准流程
|
||||
|
||||
本文档旨在提供一个标准、安全、高效的新功能开发工作流,涵盖从创建功能分支到最终合并回主开发分支的完整步骤,并融入日常开发的最佳实践。
|
||||
|
||||
## 核心开发理念
|
||||
|
||||
- **主分支保护**: `lyf-dev` 是团队的主开发分支,应始终保持稳定和可部署状态。所有新功能开发都必须在独立的功能分支中进行。
|
||||
- **功能分支**: 每个新需求(如 `req0001`)都对应一个功能分支(如 `lyf-dev-req0001`)。分支命名应清晰、有意义。
|
||||
- **小步快跑**: 频繁提交(Commit)、频繁推送(Push)、频繁与主线同步(`rebase` 或 `merge`)。这能有效减少后期合并的难度和风险。
|
||||
- **清晰的历史**: 保持 Git 提交历史的可读性,方便代码审查(Code Review)和问题追溯。
|
||||
|
||||
---
|
||||
|
||||
## 每日工作第一步:同步最新代码
|
||||
|
||||
**无论你昨天工作到哪里,每天开始新一天的工作时,请务必执行以下步骤。这是保证团队高效协作、避免合并冲突的基石。**
|
||||
|
||||
1. **更新主开发分支 `lyf-dev`**
|
||||
```bash
|
||||
# 切换到主开发分支
|
||||
git checkout lyf-dev
|
||||
|
||||
# 从远程拉取最新代码,--prune 会清理远程已删除的分支引用
|
||||
git pull origin lyf-dev --prune
|
||||
```
|
||||
|
||||
2. **同步你的功能分支 (团队选择一种方案)**
|
||||
将主分支的最新代码同步到你的功能分支,有两种主流方案,请团队根据偏好选择其一。
|
||||
|
||||
---
|
||||
### 方案一 (推荐): 使用 `rebase` 保持历史清爽
|
||||
|
||||
此方案会让你的分支提交历史保持为一条直线,非常清晰。
|
||||
|
||||
```bash
|
||||
# 切换回你正在开发的功能分支(例如 lyf-dev-req0001)
|
||||
git checkout lyf-dev-req0001
|
||||
|
||||
# 使用 rebase 将 lyf-dev 的最新更新同步到你的分支
|
||||
git rebase lyf-dev
|
||||
```
|
||||
- **优点**: 最终的提交历史非常干净、线性,便于代码审查和问题追溯。
|
||||
- **缺点**: 重写了提交历史,需要使用 `git push --force-with-lease` 强制推送。
|
||||
- **冲突解决**:
|
||||
1. 手动修改冲突文件。
|
||||
2. 执行 `git add <冲突文件>`。
|
||||
3. 执行 `git rebase --continue`。
|
||||
4. 若想中止,执行 `git rebase --abort`。
|
||||
|
||||
---
|
||||
### 方案二: 使用 `merge` 保留完整历史
|
||||
|
||||
此方案会忠实记录每一次合并操作,不修改历史提交。
|
||||
|
||||
```bash
|
||||
# 切换回你正在开发的功能分支(例如 lyf-dev-req0001)
|
||||
git checkout lyf-dev-req0001
|
||||
|
||||
# 将最新的 lyf-dev 合并到你当前的分支
|
||||
git merge lyf-dev
|
||||
```
|
||||
- **优点**: 操作安全,不修改历史,推送时无需强制。
|
||||
- **缺点**: 会在功能分支中产生额外的合并提交记录 (e.g., "Merge branch 'lyf-dev' into ..."),使历史记录变得复杂。
|
||||
- **冲突解决**:
|
||||
1. 手动修改冲突文件。
|
||||
2. 执行 `git add <冲突文件>`。
|
||||
3. 执行 `git commit` 完成合并。
|
||||
|
||||
---
|
||||
|
||||
## 完整开发流程
|
||||
|
||||
### 1. 开始新需求:创建功能分支
|
||||
|
||||
**当你需要开启一个全新的功能开发时:**
|
||||
|
||||
1. **确保 `lyf-dev` 已是最新**
|
||||
(此步骤已在“每日工作第一步”中完成,此处作为提醒)
|
||||
|
||||
2. **从 `lyf-dev` 创建并切换到新分支**
|
||||
假设新需求编号是 `req0002`:
|
||||
```bash
|
||||
# 这会从最新的 lyf-dev 创建 lyf-dev-req0002 分支并切换过去
|
||||
git checkout -b lyf-dev-req0002
|
||||
```
|
||||
|
||||
### 2. 日常开发:提交与推送
|
||||
|
||||
**在你的功能分支上(如 `lyf-dev-req0002`)进行开发:**
|
||||
|
||||
1. **编码与本地提交**
|
||||
完成一个小的、完整的功能点后,就进行一次提交。
|
||||
```bash
|
||||
# 查看修改状态
|
||||
git status
|
||||
# 添加所有相关文件到暂存区
|
||||
git add .
|
||||
# 提交并撰写清晰的说明(feat: 功能, fix: 修复, docs: 文档等)
|
||||
git commit -m "feat: 实现用户认证模块"
|
||||
```
|
||||
|
||||
2. **推送改动到远程备份**
|
||||
为了代码安全和方便团队协作,应频繁将本地提交推送到远程。
|
||||
```bash
|
||||
# -u 参数会设置本地分支跟踪远程分支,后续只需 git push 即可
|
||||
git push -u origin lyf-dev-req0002
|
||||
```
|
||||
|
||||
### 3. 功能完成:合并回主线
|
||||
|
||||
**当功能开发完成并通过测试后,将其合并回 `lyf-dev`:**
|
||||
|
||||
1. **最后一次同步**
|
||||
在正式合并前,做最后一次同步,确保分支包含了 `lyf-dev` 的所有最新内容。
|
||||
(重复“每日工作第一步”中的同步流程)
|
||||
|
||||
2. **切换到主分支并拉取最新代码**
|
||||
```bash
|
||||
git checkout lyf-dev
|
||||
git pull origin lyf-dev
|
||||
```
|
||||
|
||||
3. **合并功能分支**
|
||||
我们使用 `--no-ff` (No Fast-forward) 参数来创建合并提交,这样可以清晰地记录“合并了一个功能”这个行为。
|
||||
```bash
|
||||
# --no-ff 会创建一个新的合并提交,保留分支历史
|
||||
git merge --no-ff lyf-dev-req0002
|
||||
```
|
||||
如果同步工作做得好,这一步通常不会有冲突。
|
||||
|
||||
4. **推送合并后的主分支**
|
||||
```bash
|
||||
git push origin lyf-dev
|
||||
```
|
||||
|
||||
### 4. 清理工作
|
||||
|
||||
**合并完成后,功能分支的历史使命就完成了:**
|
||||
|
||||
1. **删除远程分支**
|
||||
```bash
|
||||
git push origin --delete lyf-dev-req0002
|
||||
```
|
||||
|
||||
2. **删除本地分支**
|
||||
```bash
|
||||
git branch -d lyf-dev-req0002
|
||||
```
|
||||
|
||||
遵循以上流程,可以确保团队的开发工作流清晰、安全且高效。
|
466
系统调用逻辑与核心代码分析.md
Normal file
466
系统调用逻辑与核心代码分析.md
Normal file
@ -0,0 +1,466 @@
|
||||
# 系统调用逻辑与核心代码分析
|
||||
|
||||
本文档旨在详细阐述本销售预测系统的端到端调用链路,从系统启动、前端交互、后端处理,到最终的模型训练、预测和图表展示。
|
||||
|
||||
## 1. 系统启动
|
||||
|
||||
系统由两部分组成:Vue.js前端和Flask后端。
|
||||
|
||||
### 1.1. 启动后端API服务
|
||||
|
||||
在项目根目录下,通过以下命令启动后端服务:
|
||||
|
||||
```bash
|
||||
python server/api.py
|
||||
```
|
||||
|
||||
该命令会启动一个Flask应用,监听在 `http://localhost:5000`,并提供所有API和WebSocket服务。
|
||||
|
||||
### 1.2. 启动前端开发服务器
|
||||
|
||||
进入 `UI` 目录,执行以下命令:
|
||||
|
||||
```bash
|
||||
cd UI
|
||||
npm install
|
||||
npm run dev
|
||||
```
|
||||
|
||||
这将启动Vite开发服务器,通常在 `http://localhost:5173`,并自动打开浏览器访问前端页面。
|
||||
|
||||
## 2. 核心调用链路概览
|
||||
|
||||
以最核心的 **“按药品训练 -> 按药品预测”** 流程为例,其高层调用链路如下:
|
||||
|
||||
**训练流程:**
|
||||
`前端UI` -> `POST /api/training` -> `api.py: start_training()` -> `TrainingManager` -> `后台进程` -> `predictor.py: train_model()` -> `[model]_trainer.py: train_product_model_with_*()` -> `保存模型.pth`
|
||||
|
||||
**预测流程:**
|
||||
`前端UI` -> `POST /api/prediction` -> `api.py: predict()` -> `predictor.py: predict()` -> `model_predictor.py: load_model_and_predict()` -> `加载模型.pth` -> `返回预测JSON` -> `前端图表渲染`
|
||||
|
||||
## 3. 详细流程:按药品训练
|
||||
|
||||
此流程的目标是为特定药品训练一个专用的预测模型。
|
||||
|
||||
### 3.1. 前端交互与API请求
|
||||
|
||||
1. **用户操作**: 用户在 **“按药品训练”** 页面 ([`UI/src/views/training/ProductTrainingView.vue`](UI/src/views/training/ProductTrainingView.vue:1)) 选择一个药品、一个模型类型(如Transformer)、设置训练轮次(Epochs),然后点击 **“启动药品训练”** 按钮。
|
||||
|
||||
2. **触发函数**: 点击事件调用 [`startTraining`](UI/src/views/training/ProductTrainingView.vue:521) 方法。
|
||||
|
||||
3. **构建Payload**: `startTraining` 方法构建一个包含训练参数的 `payload` 对象。关键字段是 `training_mode: 'product'`,用于告知后端这是针对特定产品的训练。
|
||||
|
||||
*核心代码 ([`UI/src/views/training/ProductTrainingView.vue`](UI/src/views/training/ProductTrainingView.vue:521))*
|
||||
```javascript
|
||||
const startTraining = async () => {
|
||||
// ... 表单验证 ...
|
||||
trainingLoading.value = true;
|
||||
try {
|
||||
const endpoint = "/api/training";
|
||||
|
||||
const payload = {
|
||||
product_id: form.product_id,
|
||||
store_id: form.data_scope === 'global' ? null : form.store_id,
|
||||
model_type: form.model_type,
|
||||
epochs: form.epochs,
|
||||
training_mode: 'product' // 标识这是药品训练模式
|
||||
};
|
||||
|
||||
const response = await axios.post(endpoint, payload);
|
||||
// ... 处理响应,启动WebSocket监听 ...
|
||||
}
|
||||
// ... 错误处理 ...
|
||||
};
|
||||
```
|
||||
|
||||
4. **API请求**: 使用 `axios` 向后端 `POST /api/training` 发送请求。
|
||||
|
||||
### 3.2. 后端API接收与任务分发
|
||||
|
||||
1. **路由处理**: 后端 [`server/api.py`](server/api.py:1) 中的 [`@app.route('/api/training', methods=['POST'])`](server/api.py:933) 装饰器捕获该请求,并由 [`start_training()`](server/api.py:971) 函数处理。
|
||||
|
||||
2. **任务提交**: `start_training()` 函数解析请求中的JSON数据,然后调用 `training_manager.submit_task()` 将训练任务提交到一个后台进程池中执行,以避免阻塞API主线程。这使得API可以立即返回一个任务ID,而训练在后台异步进行。
|
||||
|
||||
*核心代码 ([`server/api.py`](server/api.py:971))*
|
||||
```python
|
||||
@app.route('/api/training', methods=['POST'])
|
||||
def start_training():
|
||||
data = request.get_json()
|
||||
|
||||
training_mode = data.get('training_mode', 'product')
|
||||
model_type = data.get('model_type')
|
||||
epochs = data.get('epochs', 50)
|
||||
product_id = data.get('product_id')
|
||||
store_id = data.get('store_id')
|
||||
|
||||
if not model_type or (training_mode == 'product' and not product_id):
|
||||
return jsonify({'error': '缺少必要参数'}), 400
|
||||
|
||||
try:
|
||||
# 使用训练进程管理器提交任务
|
||||
task_id = training_manager.submit_task(
|
||||
product_id=product_id or "unknown",
|
||||
model_type=model_type,
|
||||
training_mode=training_mode,
|
||||
store_id=store_id,
|
||||
epochs=epochs
|
||||
)
|
||||
|
||||
logger.info(f"🚀 训练任务已提交到进程管理器: {task_id[:8]}")
|
||||
|
||||
return jsonify({
|
||||
'message': '模型训练已开始(使用独立进程)',
|
||||
'task_id': task_id,
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 提交训练任务失败: {str(e)}")
|
||||
return jsonify({'error': f'启动训练任务失败: {str(e)}'}), 500
|
||||
```
|
||||
|
||||
### 3.3. 核心训练逻辑
|
||||
|
||||
1. **调用核心预测器**: 后台进程最终会调用 [`server/core/predictor.py`](server/core/predictor.py:1) 中的 [`PharmacyPredictor.train_model()`](server/core/predictor.py:63) 方法。
|
||||
|
||||
2. **数据准备**: `train_model` 方法首先根据 `training_mode` (`'product'`) 和 `product_id` 从数据源加载并聚合所有店铺关于该药品的销售数据。
|
||||
|
||||
3. **分发到具体训练器**: 接着,它根据 `model_type` 调用相应的训练函数。例如,如果 `model_type` 是 `transformer`,它会调用 `train_product_model_with_transformer`。
|
||||
|
||||
*核心代码 ([`server/core/predictor.py`](server/core/predictor.py:63))*
|
||||
```python
|
||||
class PharmacyPredictor:
|
||||
def train_model(self, product_id, model_type='transformer', ..., training_mode='product', ...):
|
||||
# ...
|
||||
if training_mode == 'product':
|
||||
product_data = self.data[self.data['product_id'] == product_id].copy()
|
||||
# ...
|
||||
|
||||
# 根据训练模式构建模型标识符
|
||||
model_identifier = product_id
|
||||
|
||||
try:
|
||||
if model_type == 'transformer':
|
||||
model_result, metrics, actual_version = train_product_model_with_transformer(
|
||||
product_id=product_id,
|
||||
model_identifier=model_identifier,
|
||||
product_df=product_data,
|
||||
# ... 其他参数 ...
|
||||
)
|
||||
# ... 其他模型的elif分支 ...
|
||||
|
||||
return metrics
|
||||
except Exception as e:
|
||||
# ... 错误处理 ...
|
||||
return None
|
||||
```
|
||||
|
||||
### 3.4. 模型训练与保存
|
||||
|
||||
1. **具体训练器**: 以 [`server/trainers/transformer_trainer.py`](server/trainers/transformer_trainer.py:1) 为例,`train_product_model_with_transformer` 函数执行以下步骤:
|
||||
* **数据预处理**: 调用 `prepare_data` 和 `prepare_sequences` 将原始销售数据转换为模型可以理解的、带有时间序列特征的监督学习格式(输入序列和目标序列)。
|
||||
* **模型实例化**: 创建 `TimeSeriesTransformer` 模型实例。
|
||||
* **训练循环**: 执行指定的 `epochs` 次训练,计算损失并使用优化器更新模型权重。
|
||||
* **进度更新**: 在训练过程中,通过 `socketio.emit` 向前端发送 `training_progress` 事件,实时更新进度条和日志。
|
||||
* **模型保存**: 训练完成后,将模型权重 (`model.state_dict()`)、完整的模型配置 (`config`) 以及数据缩放器 (`scaler_X`, `scaler_y`) 打包成一个字典(checkpoint),并使用 `torch.save()` 保存到 `.pth` 文件中。文件名由 `get_model_file_path` 根据 `model_identifier`、`model_type` 和 `version` 统一生成。
|
||||
|
||||
*核心代码 ([`server/trainers/transformer_trainer.py`](server/trainers/transformer_trainer.py:33))*
|
||||
```python
|
||||
def train_product_model_with_transformer(...):
|
||||
# ... 数据准备 ...
|
||||
|
||||
# 定义模型配置
|
||||
config = {
|
||||
'input_dim': input_dim,
|
||||
'output_dim': forecast_horizon,
|
||||
'hidden_size': hidden_size,
|
||||
# ... 所有必要的超参数 ...
|
||||
'model_type': 'transformer'
|
||||
}
|
||||
|
||||
model = TimeSeriesTransformer(...)
|
||||
|
||||
# ... 训练循环 ...
|
||||
|
||||
# 保存模型
|
||||
checkpoint = {
|
||||
'model_state_dict': model.state_dict(),
|
||||
'config': config,
|
||||
'scaler_X': scaler_X,
|
||||
'scaler_y': scaler_y,
|
||||
'metrics': test_metrics
|
||||
}
|
||||
|
||||
model_path = get_model_file_path(model_identifier, 'transformer', version)
|
||||
torch.save(checkpoint, model_path)
|
||||
|
||||
return model, test_metrics, version
|
||||
```
|
||||
|
||||
## 4. 详细流程:按药品预测
|
||||
|
||||
训练完成后,用户可以使用已保存的模型进行预测。
|
||||
|
||||
### 4.1. 前端交互与API请求
|
||||
|
||||
1. **用户操作**: 用户在 **“按药品预测”** 页面 ([`UI/src/views/prediction/ProductPredictionView.vue`](UI/src/views/prediction/ProductPredictionView.vue:1)) 选择同一个药品、对应的模型和版本,然后点击 **“开始预测”**。
|
||||
|
||||
2. **触发函数**: 点击事件调用 [`startPrediction`](UI/src/views/prediction/ProductPredictionView.vue:202) 方法。
|
||||
|
||||
3. **构建Payload**: 该方法构建一个包含预测参数的 `payload`。
|
||||
|
||||
*核心代码 ([`UI/src/views/prediction/ProductPredictionView.vue`](UI/src/views/prediction/ProductPredictionView.vue:202))*
|
||||
```javascript
|
||||
const startPrediction = async () => {
|
||||
try {
|
||||
predicting.value = true
|
||||
const payload = {
|
||||
product_id: form.product_id,
|
||||
model_type: form.model_type,
|
||||
version: form.version,
|
||||
future_days: form.future_days,
|
||||
// training_mode is implicitly 'product' here
|
||||
}
|
||||
const response = await axios.post('/api/prediction', payload)
|
||||
if (response.data.status === 'success') {
|
||||
predictionResult.value = response.data
|
||||
await nextTick()
|
||||
renderChart()
|
||||
}
|
||||
// ... 错误处理 ...
|
||||
}
|
||||
// ...
|
||||
}
|
||||
```
|
||||
|
||||
4. **API请求**: 使用 `axios` 向后端 `POST /api/prediction` 发送请求。
|
||||
|
||||
### 4.2. 后端API接收与预测执行
|
||||
|
||||
1. **路由处理**: [`server/api.py`](server/api.py:1) 中的 [`@app.route('/api/prediction', methods=['POST'])`](server/api.py:1413) 捕获请求,由 [`predict()`](server/api.py:1469) 函数处理。
|
||||
|
||||
2. **调用核心预测器**: `predict()` 函数解析参数,然后调用 `run_prediction` 辅助函数,该函数内部再调用 [`server/core/predictor.py`](server/core/predictor.py:1) 中的 [`PharmacyPredictor.predict()`](server/core/predictor.py:295) 方法。
|
||||
|
||||
*核心代码 ([`server/api.py`](server/api.py:1469))*
|
||||
```python
|
||||
@app.route('/api/prediction', methods=['POST'])
|
||||
def predict():
|
||||
try:
|
||||
data = request.json
|
||||
# ... 解析参数 ...
|
||||
training_mode = data.get('training_mode', 'product')
|
||||
product_id = data.get('product_id')
|
||||
# ...
|
||||
|
||||
# 根据模式确定模型标识符
|
||||
if training_mode == 'product':
|
||||
model_identifier = product_id
|
||||
# ...
|
||||
|
||||
# 执行预测
|
||||
prediction_result = run_prediction(model_type, product_id, model_id, ...)
|
||||
|
||||
# ... 格式化响应 ...
|
||||
return jsonify(response_data)
|
||||
except Exception as e:
|
||||
# ... 错误处理 ...
|
||||
```
|
||||
|
||||
3. **分发到模型加载器**: [`PharmacyPredictor.predict()`](server/core/predictor.py:295) 方法的主要作用是再次根据 `training_mode` 和 `product_id` 确定 `model_identifier`,然后将所有参数传递给 [`server/predictors/model_predictor.py`](server/predictors/model_predictor.py:1) 中的 [`load_model_and_predict()`](server/predictors/model_predictor.py:26) 函数。
|
||||
|
||||
*核心代码 ([`server/core/predictor.py`](server/core/predictor.py:295))*
|
||||
```python
|
||||
class PharmacyPredictor:
|
||||
def predict(self, product_id, model_type, ..., training_mode='product', ...):
|
||||
if training_mode == 'product':
|
||||
model_identifier = product_id
|
||||
# ...
|
||||
|
||||
return load_model_and_predict(
|
||||
model_identifier,
|
||||
model_type,
|
||||
# ... 其他参数 ...
|
||||
)
|
||||
```
|
||||
|
||||
### 4.3. 模型加载与执行预测
|
||||
|
||||
[`load_model_and_predict()`](server/predictors/model_predictor.py:26) 是预测流程的核心,它执行以下步骤:
|
||||
|
||||
1. **定位模型文件**: 使用 `get_model_file_path` 根据 `product_id` (即 `model_identifier`), `model_type`, 和 `version` 找到之前保存的 `.pth` 模型文件。
|
||||
|
||||
2. **加载Checkpoint**: 使用 `torch.load()` 加载模型文件,得到包含 `model_state_dict`, `config`, 和 `scalers` 的字典。
|
||||
|
||||
3. **重建模型**: 根据加载的 `config` 中的超参数(如 `hidden_size`, `num_layers` 等),重新创建一个与训练时结构完全相同的模型实例。**这是我们之前修复的关键点,确保所有必要参数都被保存和加载。**
|
||||
|
||||
4. **加载权重**: 将加载的 `model_state_dict` 应用到新创建的模型实例上。
|
||||
|
||||
5. **准备输入数据**: 从数据源获取最新的 `sequence_length` 天的历史数据作为预测的输入。
|
||||
|
||||
6. **数据归一化**: 使用加载的 `scaler_X` 对输入数据进行归一化。
|
||||
|
||||
7. **执行预测**: 将归一化的数据输入模型 (`model(X_input)`),得到预测结果。
|
||||
|
||||
8. **反归一化**: 使用加载的 `scaler_y` 将模型的输出(预测值)反归一化,转换回原始的销售量尺度。
|
||||
|
||||
9. **构建结果**: 将预测值和对应的未来日期组合成一个DataFrame,并连同历史数据一起返回。
|
||||
|
||||
*核心代码 ([`server/predictors/model_predictor.py`](server/predictors/model_predictor.py:26))*
|
||||
```python
|
||||
def load_model_and_predict(...):
|
||||
# ... 找到模型文件路径 model_path ...
|
||||
|
||||
# 加载模型和配置
|
||||
checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False)
|
||||
config = checkpoint['config']
|
||||
scaler_X = checkpoint['scaler_X']
|
||||
scaler_y = checkpoint['scaler_y']
|
||||
|
||||
# 创建模型实例 (以Transformer为例)
|
||||
model = TimeSeriesTransformer(
|
||||
num_features=config['input_dim'],
|
||||
d_model=config['hidden_size'],
|
||||
# ... 使用config中的所有参数 ...
|
||||
).to(DEVICE)
|
||||
|
||||
# 加载模型参数
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
model.eval()
|
||||
|
||||
# ... 准备输入数据 ...
|
||||
|
||||
# 归一化输入数据
|
||||
X_scaled = scaler_X.transform(X)
|
||||
X_input = torch.tensor(X_scaled.reshape(1, sequence_length, -1), ...).to(DEVICE)
|
||||
|
||||
# 预测
|
||||
with torch.no_grad():
|
||||
y_pred_scaled = model(X_input).cpu().numpy()
|
||||
|
||||
# 反归一化
|
||||
y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).flatten()
|
||||
|
||||
# ... 构建返回结果 ...
|
||||
return {
|
||||
'predictions': predictions_df,
|
||||
'history_data': recent_history,
|
||||
# ...
|
||||
}
|
||||
```
|
||||
|
||||
### 4.4. 响应格式化与前端图表渲染
|
||||
|
||||
1. **API层格式化**: 在 [`server/api.py`](server/api.py:1) 的 [`predict()`](server/api.py:1469) 函数中,从 `load_model_and_predict` 返回的结果被精心格式化成前端期望的JSON结构,该结构在顶层同时包含 `history_data` 和 `prediction_data` 两个数组。
|
||||
|
||||
2. **前端接收数据**: 前端 [`ProductPredictionView.vue`](UI/src/views/prediction/ProductPredictionView.vue:1) 在 `startPrediction` 方法中接收到这个JSON响应,并将其存入 `predictionResult` ref。
|
||||
|
||||
3. **图表渲染**: [`renderChart()`](UI/src/views/prediction/ProductPredictionView.vue:232) 方法被调用。它从 `predictionResult.value` 中提取 `history_data` 和 `prediction_data`,然后使用Chart.js库将这两部分数据绘制在同一个 `<canvas>` 上,历史数据为实线,预测数据为虚线,从而形成一个连续的趋势图。
|
||||
|
||||
*核心代码 ([`UI/src/views/prediction/ProductPredictionView.vue`](UI/src/views/prediction/ProductPredictionView.vue:232))*
|
||||
```javascript
|
||||
const renderChart = () => {
|
||||
if (!chartCanvas.value || !predictionResult.value) return
|
||||
// ...
|
||||
|
||||
// 后端直接提供 history_data 和 prediction_data
|
||||
const historyData = predictionResult.value.history_data || []
|
||||
const predictionData = predictionResult.value.prediction_data || []
|
||||
|
||||
const historyLabels = historyData.map(p => p.date)
|
||||
const historySales = historyData.map(p => p.sales)
|
||||
|
||||
const predictionLabels = predictionData.map(p => p.date)
|
||||
const predictionSales = predictionData.map(p => p.predicted_sales)
|
||||
|
||||
// ... 组合标签和数据,对齐数据点 ...
|
||||
|
||||
chart = new Chart(chartCanvas.value, {
|
||||
type: 'line',
|
||||
data: {
|
||||
labels: allLabels,
|
||||
datasets: [
|
||||
{
|
||||
label: '历史销量',
|
||||
data: alignedHistorySales,
|
||||
// ... 样式 ...
|
||||
},
|
||||
{
|
||||
label: '预测销量',
|
||||
data: alignedPredictionSales,
|
||||
// ... 样式 ...
|
||||
}
|
||||
]
|
||||
},
|
||||
// ... Chart.js 配置 ...
|
||||
})
|
||||
}
|
||||
```
|
||||
|
||||
至此,一个完整的“训练->预测->展示”的调用链路就完成了。
|
||||
|
||||
## 5. 模型保存与版本管理核心逻辑 (重构后)
|
||||
|
||||
为了根治版本混乱和模型加载失败的问题,系统进行了一项重要的重构。现在,所有与模型保存、命名和版本管理相关的逻辑都已**统一集中**到 [`server/utils/model_manager.py`](server/utils/model_manager.py:1) 的 `ModelManager` 类中。
|
||||
|
||||
### 5.1. 统一管理者:`ModelManager`
|
||||
|
||||
- **单一职责**: `ModelManager` 是系统中唯一负责处理模型文件IO的组件。所有训练器 (`trainer`) 在需要保存模型时,都必须通过它来进行。
|
||||
- **核心功能**:
|
||||
1. **自动版本控制**: 自动生成和递增符合规范的版本号。
|
||||
2. **统一命名**: 根据模型的元数据(算法类型、训练模式、ID等)生成标准化的文件名。
|
||||
3. **安全保存**: 将模型数据和元数据一起打包保存到 `.pth` 文件中。
|
||||
4. **可靠检索**: 提供统一的接口来列出和查找模型。
|
||||
|
||||
### 5.2. 统一版本规范
|
||||
|
||||
所有模型版本现在都遵循一个严格的、可预测的格式:
|
||||
|
||||
- **数字版本**: `v{数字}`,例如 `v1`, `v2`, `v3`...
|
||||
- **生成**: 当一次训练**正常完成**时,`ModelManager` 会自动计算出当前模型的下一个可用版本号(例如,如果已存在 `v1` 和 `v2`,则新版本为 `v3`),并以此命名最终的模型文件。
|
||||
- **用途**: 代表一次完整的、稳定的训练产出。
|
||||
- **特殊版本**: `best`
|
||||
- **生成**: 在训练过程中,如果某个 `epoch` 产生的模型在验证集上的性能超过了之前所有 `epoch`,训练器会调用 `ModelManager` 将这个模型保存为 `best` 版本,覆盖掉旧的 `best` 模型。
|
||||
- **用途**: 始终指向该模型迄今为止性能最佳的一个版本,便于快速进行高质量的预测。
|
||||
|
||||
### 5.3. 统一命名约定 (v2版)
|
||||
|
||||
随着系统增加了“按店铺”和“全局”训练模式,`ModelManager` 的 `generate_model_filename` 方法也已升级,以支持更丰富的、无歧义的命名格式:
|
||||
|
||||
- **药品模型**: `{model_type}_product_{product_id}_{version}.pth`
|
||||
- *示例*: `transformer_product_17002608_best.pth`
|
||||
- **店铺模型**: `{model_type}_store_{store_id}_{version}.pth`
|
||||
- *示例*: `mlstm_store_01010023_v2.pth`
|
||||
- **全局模型**: `{model_type}_global_{aggregation_method}_{version}.pth`
|
||||
- *示例*: `tcn_global_sum_v1.pth`
|
||||
|
||||
这个新的命名系统确保了不同训练模式产出的模型可以清晰地被识别和管理。
|
||||
|
||||
### 5.4. Checkpoint文件内容 (结构不变)
|
||||
|
||||
每个 `.pth` 文件依然是一个包含模型权重、完整配置和数据缩放器的PyTorch Checkpoint。重构加强了**所有训练器都必须将完整的配置信息存入 `config` 字典**这一规则,确保了模型的完全可复现性。
|
||||
|
||||
### 5.5. 核心优势 (重构后)
|
||||
|
||||
- **逻辑集中**: 所有版本管理的复杂性都被封装在 `ModelManager` 内部,训练器只需调用 `save_model` 即可,无需关心版本号如何生成。
|
||||
- **数据一致性**: 由于版本的生成、保存和检索都由同一个组件以同一种逻辑处理,从根本上杜绝了因命名或版本格式不匹配导致“模型未找到”的问题。
|
||||
- **易于维护**: 未来如果需要修改版本策略或命名规则,只需修改 `ModelManager` 一个文件即可,无需改动所有训练器。
|
||||
|
||||
## 6. 核心流程的演进:支持店铺与全局模式
|
||||
|
||||
在最初的“按药品”流程基础上,系统已重构以支持“按店铺”和“全局”的完整AI闭环。这引入了一些关键的逻辑变化:
|
||||
|
||||
### 6.1. 训练流程的变化
|
||||
|
||||
- **统一入口**: 所有训练请求(药品、店铺、全局)都通过 `POST /api/training` 接口,由 `training_mode` 参数区分。
|
||||
- **数据聚合**: 在 [`predictor.py`](server/core/predictor.py:1) 的 `train_model` 方法中,会根据 `training_mode` 调用 `aggregate_multi_store_data` 函数,为店铺或全局模式准备正确的聚合时间序列数据。
|
||||
- **模型标识符**: `train_model` 方法现在会生成一个唯一的 `model_identifier`(例如 `product_17002608`, `store_01010023`, `global_sum`),并将其传递给所有下游训练器。这是确保模型被正确命名的关键。
|
||||
|
||||
### 6.2. 预测流程的重大修复
|
||||
|
||||
预测流程经过了重大修复,以解决之前因逻辑不统一导致的 `404` 错误。
|
||||
|
||||
- **废弃旧函数**: `core/config.py` 中的 `get_model_file_path` 和 `get_model_versions` 等旧的、有缺陷的辅助函数已被**完全废弃**。
|
||||
- **统一查找逻辑**: 现在,[`api.py`](server/api.py:1) 的 `predict` 函数**必须**使用 `model_manager.list_models()` 方法来查找模型。
|
||||
- **可靠的路径传递**: `predict` 函数找到正确的模型文件路径后,会将其作为一个参数,一路传递给 `run_prediction` 和最终的 `load_model_and_predict` 函数。
|
||||
- **根除缺陷**: `load_model_and_predict` 函数内部所有手动的、过时的文件查找逻辑已被**完全移除**。它现在只负责接收一个明确的路径并加载模型。
|
||||
|
||||
这个修复确保了整个预测链路都依赖于 `ModelManager` 这一个“单一事实来源”,从根本上解决了因路径不匹配导致的预测失败问题。
|
127
项目快速上手指南.md
Normal file
127
项目快速上手指南.md
Normal file
@ -0,0 +1,127 @@
|
||||
# 项目快速上手指南 (面向新开发者)
|
||||
|
||||
欢迎加入项目!本指南旨在帮助你快速理解项目的核心功能、技术架构和开发流程,特别是为你(一位Java背景的开发者)提供清晰的切入点。
|
||||
|
||||
## 1. 项目是做什么的?(实现了什么功能)
|
||||
|
||||
这是一个基于历史销售数据的 **智能销售预测系统**。
|
||||
|
||||
核心功能有三个,全部通过Web界面操作:
|
||||
1. **模型训练**: 用户可以选择某个**药品**、某个**店铺**或**全局**数据,然后选择一种机器学习算法(如Transformer、mLSTM等)进行训练,最终生成一个预测模型。
|
||||
2. **销售预测**: 使用已经训练好的模型,对未来的销量进行预测。
|
||||
3. **结果可视化**: 将历史销量和预测销量在同一个图表中展示出来,方便用户直观地看到趋势。
|
||||
|
||||
简单来说,它就是一个 **"数据 -> 训练 -> 模型 -> 预测 -> 可视化"** 的完整闭环应用。
|
||||
|
||||
## 2. 用了什么技术?(技术栈)
|
||||
|
||||
你可以将这个项目的技术栈与Java世界进行类比:
|
||||
|
||||
| 层面 | 本项目技术 | Java世界类比 | 说明 |
|
||||
| :--- | :--- | :--- | :--- |
|
||||
| **后端框架** | **Flask** | Spring Boot | 一个轻量级的Web框架,用于提供API接口。 |
|
||||
| **前端框架** | **Vue.js** | React / Angular | 用于构建用户交互界面的现代化JavaScript框架。 |
|
||||
| **核心算法库** | **PyTorch** | (无直接对应) | 类似于Java的Deeplearning4j,是实现深度学习算法的核心。 |
|
||||
| **数据处理** | **Pandas** | (无直接对应) | Python中用于数据分析和处理的“瑞士军刀”,可以看作是内存中的强大数据表格。 |
|
||||
| **构建/打包** | **Vite** (前端) | Maven / Gradle | 前端项目的构建和依赖管理工具。 |
|
||||
| **数据库** | **SQLite** | H2 / MySQL | 一个轻量级的本地文件数据库,用于记录预测历史等。 |
|
||||
| **实时通信** | **Socket.IO** | WebSocket / STOMP | 用于后端在训练时向前端实时推送进度。 |
|
||||
|
||||
## 3. 系统架构是怎样的?(架构层级和设计)
|
||||
|
||||
本项目是经典的前后端分离架构,可以分为四个主要层次:
|
||||
|
||||
```
|
||||
+------------------------------------------------------+
|
||||
| 用户 (Browser) |
|
||||
+------------------------------------------------------+
|
||||
|
|
||||
+------------------------------------------------------+
|
||||
| 1. 前端层 (Frontend - Vue.js) |
|
||||
| - Views (页面组件, e.g., ProductPredictionView.vue) |
|
||||
| - API Calls (使用axios与后端通信) |
|
||||
| - Charting (使用Chart.js进行图表渲染) |
|
||||
+------------------------------------------------------+
|
||||
| (HTTP/S, WebSocket)
|
||||
+------------------------------------------------------+
|
||||
| 2. 后端API层 (Backend API - Flask) |
|
||||
| - api.py (类似Controller, 定义RESTful接口) |
|
||||
| - 接收请求, 验证参数, 调用业务逻辑层 |
|
||||
+------------------------------------------------------+
|
||||
|
|
||||
+------------------------------------------------------+
|
||||
| 3. 业务逻辑层 (Business Logic - Python) |
|
||||
| - core/predictor.py (类似Service层) |
|
||||
| - 封装核心业务, 如“根据参数选择合适的训练器” |
|
||||
+------------------------------------------------------+
|
||||
|
|
||||
+------------------------------------------------------+
|
||||
| 4. 数据与模型层 (Data & Model - PyTorch/Pandas) |
|
||||
| - trainers/*.py (具体的算法实现和训练逻辑) |
|
||||
| - predictors/model_predictor.py (模型加载与预测逻辑) |
|
||||
| - saved_models/ (存放训练好的.pth模型文件) |
|
||||
| - data/ (存放原始数据.parquet文件) |
|
||||
+------------------------------------------------------+
|
||||
```
|
||||
|
||||
## 4. 关键执行流程
|
||||
|
||||
以最常见的“按药品预测”为例:
|
||||
|
||||
1. **前端**: 用户在页面上选择药品和模型,点击“预测”按钮。Vue组件通过`axios`向后端发送一个POST请求到 `/api/prediction`。
|
||||
2. **API层**: `api.py` 接收到请求,像一个Controller一样,解析出药品ID、模型类型等参数。
|
||||
3. **业务逻辑层**: `api.py` 调用 `core/predictor.py` 中的 `predict` 方法,将参数传递下去。这一层是业务的“调度中心”。
|
||||
4. **模型层**: `core/predictor.py` 最终调用 `predictors/model_predictor.py` 中的 `load_model_and_predict` 函数。
|
||||
5. **模型加载与执行**:
|
||||
* 根据参数在 `saved_models/` 目录下找到对应的模型文件(例如 `transformer_store_01010023_best.pth` 或 `mlstm_product_17002608_v3.pth`)。
|
||||
* 加载文件,从中恢复出 **模型结构**、**模型权重** 和 **数据缩放器**。
|
||||
* 准备最新的历史数据作为输入,执行预测。
|
||||
* 将预测结果返回。
|
||||
6. **返回与渲染**: 结果逐层返回到`api.py`,在这里被格式化为JSON,然后发送给前端。前端接收到JSON后,使用`Chart.js`将历史和预测数据画在图表上。
|
||||
|
||||
## 5. 如何添加一个新的算法?(开发者指南)
|
||||
|
||||
这是你最可能接触到的新功能开发。假设你要添加一个名为 `NewNet` 的新算法,你需要按以下步骤操作:
|
||||
|
||||
**目标**: 让 `NewNet` 出现在前端的“模型类型”下拉框中,并能成功训练和预测。
|
||||
|
||||
1. **创建训练器文件**:
|
||||
* 在 `server/trainers/` 目录下,复制一份现有的训练器文件(例如 `tcn_trainer.py`)并重命名为 `newnet_trainer.py`。
|
||||
* 在 `newnet_trainer.py` 中:
|
||||
* 定义你的 `NewNet` 模型类(继承自 `torch.nn.Module`)。
|
||||
* 修改 `train_..._with_tcn` 函数,将其重命名为 `train_..._with_newnet`。
|
||||
* 在这个新函数里,确保实例化的是你的 `NewNet` 模型。
|
||||
* **最关键的一步**: 在保存checkpoint时,确保 `config` 字典里包含了重建 `NewNet` 所需的所有超参数(比如层数、节点数等)。
|
||||
|
||||
* **重要开发规范:参数命名规则**
|
||||
为了防止在模型加载时出现参数不匹配的错误(例如 `KeyError: 'num_layers'`),我们制定了以下命名规范:
|
||||
> **规则:** 对于特定于某个算法的超参数,其在 `config` 字典中的键名(key)必须以该算法的名称作为前缀或唯一标识。
|
||||
|
||||
**示例:**
|
||||
* 对于 `mLSTM` 模型的层数,键名应为 `mlstm_layers`。
|
||||
* 对于 `TCN` 模型的通道数,键名可以是 `tcn_channels`。
|
||||
* 对于 `Transformer` 模型的编码器层数,键名可以是 `num_encoder_layers` (因为这在Transformer语境下是明确的)。
|
||||
|
||||
在 **加载模型时** ([`server/predictors/model_predictor.py`](server/predictors/model_predictor.py:1)),必须使用与保存时完全一致的键名来读取这些参数。遵循此规则可以从根本上杜绝因参数名不一致导致的模型加载失败问题。
|
||||
|
||||
2. **注册新模型**:
|
||||
* 打开 `server/core/config.py` 文件。
|
||||
* 找到 `SUPPORTED_MODELS` 列表。
|
||||
* 在列表中添加你的新模型标识符 `'newnet'`。
|
||||
|
||||
3. **接入业务逻辑层 (训练)**:
|
||||
* 打开 `server/core/predictor.py` 文件。
|
||||
* 在 `train_model` 方法中,找到 `if/elif` 模型选择逻辑。
|
||||
* 添加一个新的 `elif model_type == 'newnet':` 分支,让它调用你在第一步中创建的 `train_..._with_newnet` 函数。
|
||||
|
||||
4. **接入模型层 (预测)**:
|
||||
* 打开 `server/predictors/model_predictor.py` 文件。
|
||||
* 在 `load_model_and_predict` 函数中,找到 `if/elif` 模型实例化逻辑。
|
||||
* 添加一个新的 `elif model_type == 'newnet':` 分支,确保它能根据 `config` 正确地创建 `NewNet` 模型实例。
|
||||
|
||||
5. **更新前端界面**:
|
||||
* 打开 `UI/src/views/training/` 和 `UI/src/views/prediction/` 目录下的相关Vue文件(如 `ProductTrainingView.vue`)。
|
||||
* 找到定义模型选项的地方(通常是一个数组或对象)。
|
||||
* 添加 `{ label: '新网络模型 (NewNet)', value: 'newnet' }` 这样的新选项。
|
||||
|
||||
完成以上步骤后,重启服务,你就可以在界面上选择并使用你的新算法了。
|
Loading…
x
Reference in New Issue
Block a user