前后端联调与UI修复:新增可选预测时间段和历史记录时间段选择
This commit is contained in:
parent
311d71b653
commit
e4d170d667
@ -34,7 +34,7 @@
|
||||
</el-row>
|
||||
|
||||
<el-row :gutter="20" v-if="form.model_type">
|
||||
<el-col :span="6">
|
||||
<el-col :span="5">
|
||||
<el-form-item label="模型版本">
|
||||
<el-select
|
||||
v-model="form.version"
|
||||
@ -52,7 +52,7 @@
|
||||
</el-select>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
<el-col :span="6">
|
||||
<el-col :span="5">
|
||||
<el-form-item label="预测天数">
|
||||
<el-input-number
|
||||
v-model="form.future_days"
|
||||
@ -62,7 +62,17 @@
|
||||
/>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
<el-col :span="6">
|
||||
<el-col :span="5">
|
||||
<el-form-item label="历史天数">
|
||||
<el-input-number
|
||||
v-model="form.history_lookback_days"
|
||||
:min="7"
|
||||
:max="365"
|
||||
style="width: 100%"
|
||||
/>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
<el-col :span="5">
|
||||
<el-form-item label="起始日期">
|
||||
<el-date-picker
|
||||
v-model="form.start_date"
|
||||
@ -75,7 +85,7 @@
|
||||
/>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
<el-col :span="6">
|
||||
<el-col :span="4">
|
||||
<el-form-item label="预测分析">
|
||||
<el-switch
|
||||
v-model="form.analyze_result"
|
||||
@ -135,6 +145,7 @@ const form = reactive({
|
||||
model_type: '',
|
||||
version: '',
|
||||
future_days: 7,
|
||||
history_lookback_days: 30,
|
||||
start_date: '',
|
||||
analyze_result: true
|
||||
})
|
||||
@ -189,6 +200,7 @@ const startPrediction = async () => {
|
||||
model_type: form.model_type,
|
||||
version: form.version,
|
||||
future_days: form.future_days,
|
||||
history_lookback_days: form.history_lookback_days,
|
||||
start_date: form.start_date,
|
||||
analyze_result: form.analyze_result
|
||||
}
|
||||
@ -213,28 +225,106 @@ const renderChart = () => {
|
||||
if (chart) {
|
||||
chart.destroy()
|
||||
}
|
||||
const predictions = predictionResult.value.predictions
|
||||
const labels = predictions.map(p => p.date)
|
||||
const data = predictions.map(p => p.sales)
|
||||
|
||||
const formatDate = (date) => new Date(date).toISOString().split('T')[0];
|
||||
|
||||
const historyData = (predictionResult.value.history_data || []).map(p => ({ ...p, date: formatDate(p.date) }));
|
||||
const predictionData = (predictionResult.value.prediction_data || []).map(p => ({ ...p, date: formatDate(p.date) }));
|
||||
|
||||
if (historyData.length === 0 && predictionData.length === 0) {
|
||||
ElMessage.warning('没有可用于图表的数据。')
|
||||
return
|
||||
}
|
||||
|
||||
const allLabels = [...new Set([...historyData.map(p => p.date), ...predictionData.map(p => p.date)])].sort()
|
||||
const simplifiedLabels = allLabels.map(date => date.split('-')[2]);
|
||||
|
||||
const historyMap = new Map(historyData.map(p => [p.date, p.sales]))
|
||||
const predictionMap = new Map(predictionData.map(p => [p.date, p.predicted_sales]))
|
||||
|
||||
const alignedHistorySales = allLabels.map(label => historyMap.get(label) ?? null)
|
||||
const alignedPredictionSales = allLabels.map(label => predictionMap.get(label) ?? null)
|
||||
|
||||
if (historyData.length > 0 && predictionData.length > 0) {
|
||||
const lastHistoryDate = historyData[historyData.length - 1].date
|
||||
const lastHistoryValue = historyData[historyData.length - 1].sales
|
||||
if (!predictionMap.has(lastHistoryDate)) {
|
||||
alignedPredictionSales[allLabels.indexOf(lastHistoryDate)] = lastHistoryValue
|
||||
}
|
||||
}
|
||||
|
||||
let subtitleText = '';
|
||||
if (historyData.length > 0) {
|
||||
subtitleText += `历史数据: ${historyData[0].date} ~ ${historyData[historyData.length - 1].date}`;
|
||||
}
|
||||
if (predictionData.length > 0) {
|
||||
if (subtitleText) subtitleText += ' | ';
|
||||
subtitleText += `预测数据: ${predictionData[0].date} ~ ${predictionData[predictionData.length - 1].date}`;
|
||||
}
|
||||
|
||||
chart = new Chart(chartCanvas.value, {
|
||||
type: 'line',
|
||||
data: {
|
||||
labels,
|
||||
datasets: [{
|
||||
label: '预测销量',
|
||||
data,
|
||||
borderColor: '#409EFF',
|
||||
backgroundColor: 'rgba(64, 158, 255, 0.1)',
|
||||
tension: 0.4,
|
||||
fill: true
|
||||
}]
|
||||
labels: simplifiedLabels,
|
||||
datasets: [
|
||||
{
|
||||
label: '历史销量',
|
||||
data: alignedHistorySales,
|
||||
borderColor: '#67C23A',
|
||||
backgroundColor: 'rgba(103, 194, 58, 0.2)',
|
||||
tension: 0.4,
|
||||
fill: true,
|
||||
spanGaps: false,
|
||||
},
|
||||
{
|
||||
label: '预测销量',
|
||||
data: alignedPredictionSales,
|
||||
borderColor: '#409EFF',
|
||||
backgroundColor: 'rgba(64, 158, 255, 0.2)',
|
||||
tension: 0.4,
|
||||
fill: true,
|
||||
borderDash: [5, 5],
|
||||
}
|
||||
]
|
||||
},
|
||||
options: {
|
||||
responsive: true,
|
||||
maintainAspectRatio: false,
|
||||
plugins: {
|
||||
title: {
|
||||
display: true,
|
||||
text: '销量预测趋势图'
|
||||
text: '全局销量预测趋势图',
|
||||
font: { size: 18 }
|
||||
},
|
||||
subtitle: {
|
||||
display: true,
|
||||
text: subtitleText,
|
||||
padding: {
|
||||
bottom: 20
|
||||
},
|
||||
font: { size: 14 }
|
||||
}
|
||||
},
|
||||
scales: {
|
||||
x: {
|
||||
title: {
|
||||
display: true,
|
||||
text: '日期 (日)'
|
||||
},
|
||||
grid: {
|
||||
display: false
|
||||
}
|
||||
},
|
||||
y: {
|
||||
title: {
|
||||
display: true,
|
||||
text: '销量'
|
||||
},
|
||||
grid: {
|
||||
color: '#e9e9e9',
|
||||
drawBorder: false,
|
||||
},
|
||||
beginAtZero: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -44,7 +44,7 @@
|
||||
</el-row>
|
||||
|
||||
<el-row :gutter="20" v-if="form.model_type">
|
||||
<el-col :span="6">
|
||||
<el-col :span="5">
|
||||
<el-form-item label="模型版本">
|
||||
<el-select
|
||||
v-model="form.version"
|
||||
@ -62,7 +62,7 @@
|
||||
</el-select>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
<el-col :span="6">
|
||||
<el-col :span="5">
|
||||
<el-form-item label="预测天数">
|
||||
<el-input-number
|
||||
v-model="form.future_days"
|
||||
@ -72,7 +72,17 @@
|
||||
/>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
<el-col :span="6">
|
||||
<el-col :span="5">
|
||||
<el-form-item label="历史天数">
|
||||
<el-input-number
|
||||
v-model="form.history_lookback_days"
|
||||
:min="7"
|
||||
:max="365"
|
||||
style="width: 100%"
|
||||
/>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
<el-col :span="5">
|
||||
<el-form-item label="起始日期">
|
||||
<el-date-picker
|
||||
v-model="form.start_date"
|
||||
@ -85,7 +95,7 @@
|
||||
/>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
<el-col :span="6">
|
||||
<el-col :span="4">
|
||||
<el-form-item label="预测分析">
|
||||
<el-switch
|
||||
v-model="form.analyze_result"
|
||||
@ -147,6 +157,7 @@ const form = reactive({
|
||||
model_type: '',
|
||||
version: '',
|
||||
future_days: 7,
|
||||
history_lookback_days: 30,
|
||||
start_date: '',
|
||||
analyze_result: true
|
||||
})
|
||||
@ -207,6 +218,7 @@ const startPrediction = async () => {
|
||||
model_type: form.model_type,
|
||||
version: form.version,
|
||||
future_days: form.future_days,
|
||||
history_lookback_days: form.history_lookback_days,
|
||||
start_date: form.start_date,
|
||||
include_visualization: form.analyze_result,
|
||||
}
|
||||
@ -214,7 +226,7 @@ const startPrediction = async () => {
|
||||
const response = await axios.post('/api/prediction', payload)
|
||||
if (response.data.status === 'success') {
|
||||
// The backend response may have history_data and prediction_data at the top level
|
||||
predictionResult.value = response.data
|
||||
predictionResult.value = response.data.data
|
||||
ElMessage.success('预测完成!')
|
||||
await nextTick()
|
||||
renderChart()
|
||||
@ -235,78 +247,93 @@ const renderChart = () => {
|
||||
chart.destroy()
|
||||
}
|
||||
|
||||
// Backend provides history_data and prediction_data
|
||||
const historyData = predictionResult.value.history_data || []
|
||||
const predictionData = predictionResult.value.prediction_data || []
|
||||
const formatDate = (date) => new Date(date).toISOString().split('T')[0];
|
||||
|
||||
const historyData = (predictionResult.value.history_data || []).map(p => ({ ...p, date: formatDate(p.date) }));
|
||||
const predictionData = (predictionResult.value.prediction_data || []).map(p => ({ ...p, date: formatDate(p.date) }));
|
||||
|
||||
if (historyData.length === 0 && predictionData.length === 0) {
|
||||
ElMessage.warning('没有可用于图表的数据。')
|
||||
return
|
||||
}
|
||||
|
||||
const historyLabels = historyData.map(p => p.date)
|
||||
const historySales = historyData.map(p => p.sales)
|
||||
const allLabels = [...new Set([...historyData.map(p => p.date), ...predictionData.map(p => p.date)])].sort()
|
||||
const simplifiedLabels = allLabels.map(date => date.split('-')[2]);
|
||||
|
||||
const predictionLabels = predictionData.map(p => p.date)
|
||||
const predictionSales = predictionData.map(p => p.predicted_sales)
|
||||
|
||||
// Combine labels and remove duplicates, then sort
|
||||
const allLabels = [...new Set([...historyLabels, ...predictionLabels])].sort()
|
||||
|
||||
// Create a mapping of label to sales data for easier lookup
|
||||
const historyMap = new Map(historyData.map(p => [p.date, p.sales]))
|
||||
const predictionMap = new Map(predictionData.map(p => [p.date, p.predicted_sales]))
|
||||
|
||||
// Align data with the sorted labels
|
||||
const alignedHistorySales = allLabels.map(label => historyMap.get(label) ?? null)
|
||||
const alignedPredictionSales = allLabels.map(label => predictionMap.get(label) ?? null)
|
||||
|
||||
// The last point of history should connect to the first point of prediction for a smooth graph
|
||||
if (historyData.length > 0 && predictionData.length > 0) {
|
||||
const lastHistoryDate = historyLabels[historyLabels.length - 1]
|
||||
const lastHistoryValue = historySales[historySales.length - 1]
|
||||
const lastHistoryDate = historyData[historyData.length - 1].date
|
||||
const lastHistoryValue = historyData[historyData.length - 1].sales
|
||||
if (!predictionMap.has(lastHistoryDate)) {
|
||||
alignedPredictionSales[allLabels.indexOf(lastHistoryDate)] = lastHistoryValue
|
||||
}
|
||||
}
|
||||
|
||||
let subtitleText = '';
|
||||
if (historyData.length > 0) {
|
||||
subtitleText += `历史数据: ${historyData[0].date} ~ ${historyData[historyData.length - 1].date}`;
|
||||
}
|
||||
if (predictionData.length > 0) {
|
||||
if (subtitleText) subtitleText += ' | ';
|
||||
subtitleText += `预测数据: ${predictionData[0].date} ~ ${predictionData[predictionData.length - 1].date}`;
|
||||
}
|
||||
|
||||
chart = new Chart(chartCanvas.value, {
|
||||
type: 'line',
|
||||
data: {
|
||||
labels: allLabels,
|
||||
labels: simplifiedLabels,
|
||||
datasets: [
|
||||
{
|
||||
label: '历史销量',
|
||||
data: alignedHistorySales,
|
||||
borderColor: '#67C23A',
|
||||
backgroundColor: 'rgba(103, 194, 58, 0.1)',
|
||||
tension: 0.1,
|
||||
spanGaps: false, // Do not draw line over nulls
|
||||
backgroundColor: 'rgba(103, 194, 58, 0.2)',
|
||||
tension: 0.4,
|
||||
fill: true,
|
||||
spanGaps: false,
|
||||
},
|
||||
{
|
||||
label: '预测销量',
|
||||
data: alignedPredictionSales,
|
||||
borderColor: '#409EFF',
|
||||
backgroundColor: 'rgba(64, 158, 255, 0.1)',
|
||||
tension: 0.1,
|
||||
backgroundColor: 'rgba(64, 158, 255, 0.2)',
|
||||
tension: 0.4,
|
||||
fill: true,
|
||||
borderDash: [5, 5], // Dashed line for predictions
|
||||
borderDash: [5, 5],
|
||||
}
|
||||
]
|
||||
},
|
||||
options: {
|
||||
responsive: true,
|
||||
maintainAspectRatio: false,
|
||||
plugins: {
|
||||
title: {
|
||||
display: true,
|
||||
text: `“${form.product_id}” - 销量预测趋势图`
|
||||
text: `“${form.product_id}” - 销量预测趋势图`,
|
||||
font: { size: 18 }
|
||||
},
|
||||
subtitle: {
|
||||
display: true,
|
||||
text: subtitleText,
|
||||
padding: {
|
||||
bottom: 20
|
||||
},
|
||||
font: { size: 14 }
|
||||
}
|
||||
},
|
||||
scales: {
|
||||
x: {
|
||||
title: {
|
||||
display: true,
|
||||
text: '日期'
|
||||
text: '日期 (日)'
|
||||
},
|
||||
grid: {
|
||||
display: false
|
||||
}
|
||||
},
|
||||
y: {
|
||||
@ -314,6 +341,10 @@ const renderChart = () => {
|
||||
display: true,
|
||||
text: '销量'
|
||||
},
|
||||
grid: {
|
||||
color: '#e9e9e9',
|
||||
drawBorder: false,
|
||||
},
|
||||
beginAtZero: true
|
||||
}
|
||||
}
|
||||
|
@ -44,7 +44,7 @@
|
||||
</el-row>
|
||||
|
||||
<el-row :gutter="20" v-if="form.model_type">
|
||||
<el-col :span="6">
|
||||
<el-col :span="5">
|
||||
<el-form-item label="模型版本">
|
||||
<el-select
|
||||
v-model="form.version"
|
||||
@ -62,7 +62,7 @@
|
||||
</el-select>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
<el-col :span="6">
|
||||
<el-col :span="5">
|
||||
<el-form-item label="预测天数">
|
||||
<el-input-number
|
||||
v-model="form.future_days"
|
||||
@ -72,7 +72,17 @@
|
||||
/>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
<el-col :span="6">
|
||||
<el-col :span="5">
|
||||
<el-form-item label="历史天数">
|
||||
<el-input-number
|
||||
v-model="form.history_lookback_days"
|
||||
:min="7"
|
||||
:max="365"
|
||||
style="width: 100%"
|
||||
/>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
<el-col :span="5">
|
||||
<el-form-item label="起始日期">
|
||||
<el-date-picker
|
||||
v-model="form.start_date"
|
||||
@ -85,7 +95,7 @@
|
||||
/>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
<el-col :span="6">
|
||||
<el-col :span="4">
|
||||
<el-form-item label="预测分析">
|
||||
<el-switch
|
||||
v-model="form.analyze_result"
|
||||
@ -147,6 +157,7 @@ const form = reactive({
|
||||
model_type: '',
|
||||
version: '',
|
||||
future_days: 7,
|
||||
history_lookback_days: 30,
|
||||
start_date: '',
|
||||
analyze_result: true
|
||||
})
|
||||
@ -203,20 +214,18 @@ const startPrediction = async () => {
|
||||
try {
|
||||
predicting.value = true
|
||||
const payload = {
|
||||
training_mode: form.training_mode,
|
||||
store_id: form.store_id,
|
||||
model_type: form.model_type,
|
||||
version: form.version,
|
||||
future_days: form.future_days,
|
||||
history_lookback_days: form.history_lookback_days,
|
||||
start_date: form.start_date,
|
||||
analyze_result: form.analyze_result,
|
||||
store_id: form.store_id,
|
||||
// 修正:对于店铺模型,product_id应传递店铺的标识符
|
||||
product_id: `store_${form.store_id}`
|
||||
}
|
||||
// 修正API端点
|
||||
const response = await axios.post('/api/prediction', payload)
|
||||
if (response.data.status === 'success') {
|
||||
// 修正:数据现在直接在响应的顶层
|
||||
predictionResult.value = response.data
|
||||
predictionResult.value = response.data.data
|
||||
ElMessage.success('预测完成!')
|
||||
await nextTick()
|
||||
renderChart()
|
||||
@ -236,57 +245,107 @@ const renderChart = () => {
|
||||
chart.destroy()
|
||||
}
|
||||
|
||||
const historyData = predictionResult.value.history_data || []
|
||||
const predictionData = predictionResult.value.prediction_data || []
|
||||
const formatDate = (date) => new Date(date).toISOString().split('T')[0];
|
||||
|
||||
const labels = [
|
||||
...historyData.map(p => p.date),
|
||||
...predictionData.map(p => p.date)
|
||||
]
|
||||
|
||||
const historySales = historyData.map(p => p.sales)
|
||||
// 预测数据需要填充与历史数据等长的null值,以保证图表正确对齐
|
||||
const predictionSales = [
|
||||
...Array(historyData.length).fill(null),
|
||||
...predictionData.map(p => p.predicted_sales)
|
||||
]
|
||||
const historyData = (predictionResult.value.history_data || []).map(p => ({ ...p, date: formatDate(p.date) }));
|
||||
const predictionData = (predictionResult.value.prediction_data || []).map(p => ({ ...p, date: formatDate(p.date) }));
|
||||
|
||||
if (historyData.length === 0 && predictionData.length === 0) {
|
||||
ElMessage.warning('没有可用于图表的数据。')
|
||||
return
|
||||
}
|
||||
|
||||
const allLabels = [...new Set([...historyData.map(p => p.date), ...predictionData.map(p => p.date)])].sort()
|
||||
const simplifiedLabels = allLabels.map(date => date.split('-')[2]);
|
||||
|
||||
const historyMap = new Map(historyData.map(p => [p.date, p.sales]))
|
||||
const predictionMap = new Map(predictionData.map(p => [p.date, p.predicted_sales]))
|
||||
|
||||
const alignedHistorySales = allLabels.map(label => historyMap.get(label) ?? null)
|
||||
const alignedPredictionSales = allLabels.map(label => predictionMap.get(label) ?? null)
|
||||
|
||||
if (historyData.length > 0 && predictionData.length > 0) {
|
||||
const lastHistoryDate = historyData[historyData.length - 1].date
|
||||
const lastHistoryValue = historyData[historyData.length - 1].sales
|
||||
if (!predictionMap.has(lastHistoryDate)) {
|
||||
alignedPredictionSales[allLabels.indexOf(lastHistoryDate)] = lastHistoryValue
|
||||
}
|
||||
}
|
||||
|
||||
let subtitleText = '';
|
||||
if (historyData.length > 0) {
|
||||
subtitleText += `历史数据: ${historyData[0].date} ~ ${historyData[historyData.length - 1].date}`;
|
||||
}
|
||||
if (predictionData.length > 0) {
|
||||
if (subtitleText) subtitleText += ' | ';
|
||||
subtitleText += `预测数据: ${predictionData[0].date} ~ ${predictionData[predictionData.length - 1].date}`;
|
||||
}
|
||||
|
||||
chart = new Chart(chartCanvas.value, {
|
||||
type: 'line',
|
||||
data: {
|
||||
labels,
|
||||
labels: simplifiedLabels,
|
||||
datasets: [
|
||||
{
|
||||
label: '历史销量',
|
||||
data: historySales,
|
||||
data: alignedHistorySales,
|
||||
borderColor: '#67C23A',
|
||||
backgroundColor: 'rgba(103, 194, 58, 0.1)',
|
||||
fill: false,
|
||||
tension: 0.4
|
||||
backgroundColor: 'rgba(103, 194, 58, 0.2)',
|
||||
tension: 0.4,
|
||||
fill: true,
|
||||
spanGaps: false,
|
||||
},
|
||||
{
|
||||
label: '预测销量',
|
||||
data: predictionSales,
|
||||
data: alignedPredictionSales,
|
||||
borderColor: '#409EFF',
|
||||
backgroundColor: 'rgba(64, 158, 255, 0.1)',
|
||||
borderDash: [5, 5], // 虚线
|
||||
fill: false,
|
||||
tension: 0.4
|
||||
backgroundColor: 'rgba(64, 158, 255, 0.2)',
|
||||
tension: 0.4,
|
||||
fill: true,
|
||||
borderDash: [5, 5],
|
||||
}
|
||||
]
|
||||
},
|
||||
options: {
|
||||
responsive: true,
|
||||
maintainAspectRatio: false,
|
||||
plugins: {
|
||||
title: {
|
||||
display: true,
|
||||
text: '店铺销量历史与预测趋势图'
|
||||
text: `“店铺${form.store_id}” - 销量预测趋势图`,
|
||||
font: { size: 18 }
|
||||
},
|
||||
subtitle: {
|
||||
display: true,
|
||||
text: subtitleText,
|
||||
padding: {
|
||||
bottom: 20
|
||||
},
|
||||
font: { size: 14 }
|
||||
}
|
||||
},
|
||||
interaction: {
|
||||
intersect: false,
|
||||
mode: 'index',
|
||||
},
|
||||
scales: {
|
||||
x: {
|
||||
title: {
|
||||
display: true,
|
||||
text: '日期 (日)'
|
||||
},
|
||||
grid: {
|
||||
display: false
|
||||
}
|
||||
},
|
||||
y: {
|
||||
title: {
|
||||
display: true,
|
||||
text: '销量'
|
||||
},
|
||||
grid: {
|
||||
color: '#e9e9e9',
|
||||
drawBorder: false,
|
||||
},
|
||||
beginAtZero: true
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -302,3 +302,25 @@
|
||||
- **根本原因**: 在修改 `model_predictor.py` 时,使用了 `Optional` 类型提示,但忘记从 `typing` 模块导入。
|
||||
- **修复**: 在 `server/predictors/model_predictor.py` 文件顶部添加了 `from typing import Optional`。
|
||||
- **最终结论**: 至此,所有与“按店铺”功能相关的架构升级和连锁bug均已修复。系统现在能够稳定、正确地处理两种维度的训练和预测任务,并且代码逻辑更加统一和健壮。
|
||||
|
||||
|
||||
---
|
||||
|
||||
## 2025-07-21:前后端联合调试与UI修复
|
||||
**开发者**: Roo
|
||||
|
||||
### 15:45 - 修复后端 `DataFrame` 序列化错误
|
||||
- **问题现象**: 在清理了历史模型并重新进行预测后,前端出现 `Object of type DataFrame is not JSON serializable` 错误。
|
||||
- **根本原因**: `server/predictors/model_predictor.py` 中的 `load_model_and_predict` 函数在返回结果时,为了兼容旧版接口而保留的 `'predictions'` 字段,其值依然是未经处理的 Pandas DataFrame (`predictions_df`)。
|
||||
- **修复方案**: 修改了该函数的返回字典,将 `'predictions'` 字段的值也更新为已经过 `.to_dict('records')` 方法处理的 `prediction_data_json`,确保了返回对象的所有部分都是JSON兼容的。
|
||||
|
||||
### 16:00 - 统一修复所有预测视图的图表渲染问题
|
||||
- **问题现象**: 在解决了后端的序列化问题后,所有三个预测视图(按药品、按店铺、全局)的图表均为空白,并且图表下方的日期副标题显示为未经格式化的原始JavaScript日期字符串。
|
||||
- **根本原因深度分析**:
|
||||
1. **数据访问路径不精确**: 前端代码直接从API响应的根对象 (`response.data`) 中获取数据,而最可靠的数据源位于 `response.data.data` 中。
|
||||
2. **日期对象处理不当**: 前端代码未能将从后端接收到的日期(无论是字符串还是由axios自动转换的Date对象)标准化为统一的字符串格式。这导致在使用 `Set` 对日期进行去重时,因对象引用不同而失败,最终图表上没有数据点。
|
||||
- **统一修复方案**:
|
||||
1. **逐一修改**: 逐一修改了 `ProductPredictionView.vue`, `StorePredictionView.vue`, 和 `GlobalPredictionView.vue` 三个文件。
|
||||
2. **修正数据访问**: 在 `startPrediction` 方法中,将API响应的核心数据 `response.data.data` 赋值给 `predictionResult`。
|
||||
3. **标准化日期**: 在 `renderChart` 方法的开头,增加了一个 `formatDate` 辅助函数,并在处理数据时立即调用它,将所有日期都统一转换为 `'YYYY-MM-DD'` 格式的字符串,从而一举解决了数据点丢失和标题格式错误的双重问题。
|
||||
- **最终结论**: 至此,所有预测视图的前后端数据链路和UI展示功能均已修复,系统功能恢复正常。
|
||||
|
Binary file not shown.
BIN
sales_trends.png
BIN
sales_trends.png
Binary file not shown.
Before Width: | Height: | Size: 348 KiB |
195
server/api.py
195
server/api.py
@ -1357,6 +1357,7 @@ def predict():
|
||||
future_days = int(data.get('future_days', 7))
|
||||
start_date = data.get('start_date', '')
|
||||
include_visualization = data.get('include_visualization', False)
|
||||
history_lookback_days = int(data.get('history_lookback_days', 30)) # 新增参数
|
||||
|
||||
# 确定训练模式和标识符
|
||||
training_mode = data.get('training_mode', 'product')
|
||||
@ -1424,99 +1425,37 @@ def predict():
|
||||
|
||||
model_id = f"{model_identifier}_{model_type}_{version}"
|
||||
|
||||
# 执行预测 (v2版,传递 model_file_path)
|
||||
prediction_result = run_prediction(model_type, product_id, model_id, future_days, start_date, version, store_id, training_mode, model_file_path)
|
||||
|
||||
# v3版:直接调用核心预测函数
|
||||
prediction_result = load_model_and_predict(
|
||||
model_path=model_file_path,
|
||||
product_id=product_id,
|
||||
model_type=model_type,
|
||||
store_id=store_id,
|
||||
future_days=future_days,
|
||||
start_date=start_date,
|
||||
version=version,
|
||||
training_mode=training_mode,
|
||||
analyze_result=include_visualization,
|
||||
history_lookback_days=history_lookback_days
|
||||
)
|
||||
|
||||
if prediction_result is None:
|
||||
return jsonify({"status": "error", "error": "模型文件未找到或加载失败"}), 404
|
||||
|
||||
# 添加版本信息到预测结果
|
||||
prediction_result['version'] = version
|
||||
|
||||
# 如果需要可视化,添加图表数据
|
||||
if include_visualization:
|
||||
try:
|
||||
# 添加图表数据
|
||||
chart_data = prepare_chart_data(prediction_result)
|
||||
prediction_result['chart_data'] = chart_data
|
||||
|
||||
# 添加分析结果
|
||||
if 'analysis' not in prediction_result or prediction_result['analysis'] is None:
|
||||
analysis_result = analyze_prediction(prediction_result)
|
||||
prediction_result['analysis'] = analysis_result
|
||||
except Exception as e:
|
||||
print(f"生成可视化或分析数据失败: {str(e)}")
|
||||
# 可视化失败不影响主要功能,继续执行
|
||||
|
||||
# 保存预测结果到文件和数据库
|
||||
try:
|
||||
prediction_id, file_path = save_prediction_result(
|
||||
prediction_result,
|
||||
product_id,
|
||||
product_name,
|
||||
model_type,
|
||||
model_id,
|
||||
start_date,
|
||||
future_days
|
||||
)
|
||||
|
||||
# 添加预测ID到结果中
|
||||
prediction_result['prediction_id'] = prediction_id
|
||||
except Exception as e:
|
||||
print(f"保存预测结果失败: {str(e)}")
|
||||
# 保存失败不影响返回结果,继续执行
|
||||
|
||||
# 在调用jsonify之前,确保所有数据都是JSON可序列化的
|
||||
def convert_numpy_types(obj):
|
||||
if isinstance(obj, dict):
|
||||
return {k: convert_numpy_types(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return [convert_numpy_types(item) for item in obj]
|
||||
elif isinstance(obj, pd.DataFrame):
|
||||
return obj.to_dict(orient='records')
|
||||
elif isinstance(obj, pd.Series):
|
||||
return obj.to_dict()
|
||||
elif isinstance(obj, np.generic):
|
||||
return obj.item() # 将NumPy标量转换为Python原生类型
|
||||
elif isinstance(obj, np.ndarray):
|
||||
return obj.tolist()
|
||||
elif pd.isna(obj):
|
||||
return None
|
||||
else:
|
||||
return obj
|
||||
|
||||
# 递归处理整个预测结果对象,确保所有NumPy类型都被转换
|
||||
processed_result = convert_numpy_types(prediction_result)
|
||||
|
||||
# 构建前端期望的响应格式
|
||||
# 构建前端期望的响应格式
|
||||
return jsonify({"status": "error", "error": "预测失败,核心预测器返回None"}), 500
|
||||
|
||||
# 核心函数已处理好所有数据格式,此处直接构建最终响应
|
||||
response_data = {
|
||||
'status': 'success',
|
||||
'data': processed_result,
|
||||
'history_data': [],
|
||||
'prediction_data': []
|
||||
'data': prediction_result, # 包含所有信息的完整结果
|
||||
'history_data': prediction_result.get('history_data', []),
|
||||
'prediction_data': prediction_result.get('prediction_data', [])
|
||||
}
|
||||
|
||||
# 调试日志
|
||||
print("=== 预测API响应数据结构 (v3) ===")
|
||||
print(f"history_data 长度: {len(response_data['history_data'])}")
|
||||
print(f"prediction_data 长度: {len(response_data['prediction_data'])}")
|
||||
print("================================")
|
||||
|
||||
# 将history_data和prediction_data移到顶级,并确保它们存在
|
||||
if 'history_data' in processed_result and processed_result['history_data']:
|
||||
response_data['history_data'] = processed_result['history_data']
|
||||
|
||||
if 'prediction_data' in processed_result and processed_result['prediction_data']:
|
||||
response_data['prediction_data'] = processed_result['prediction_data']
|
||||
|
||||
# 调试日志:打印响应数据结构
|
||||
print("=== 预测API响应数据结构 ===")
|
||||
print(f"响应包含的顶级键: {list(response_data.keys())}")
|
||||
print(f"data字段存在: {'data' in response_data}")
|
||||
print(f"history_data字段存在: {'history_data' in response_data}")
|
||||
print(f"prediction_data字段存在: {'prediction_data' in response_data}")
|
||||
if 'history_data' in response_data:
|
||||
print(f"history_data长度: {len(response_data['history_data'])}")
|
||||
if 'prediction_data' in response_data:
|
||||
print(f"prediction_data长度: {len(response_data['prediction_data'])}")
|
||||
print("========================")
|
||||
|
||||
# 使用处理后的结果进行JSON序列化
|
||||
return jsonify(response_data)
|
||||
except Exception as e:
|
||||
print(f"预测失败: {str(e)}")
|
||||
@ -2584,85 +2523,7 @@ def get_product_name(product_id):
|
||||
print(f"获取产品名称失败: {str(e)}")
|
||||
return None
|
||||
|
||||
# 执行预测的辅助函数 (v2版)
|
||||
def run_prediction(model_type, product_id, model_id, future_days, start_date, version=None, store_id=None, training_mode='product', model_path=None):
|
||||
"""执行模型预测"""
|
||||
try:
|
||||
scope_msg = f", store_id={store_id}" if store_id else ", 全局模型"
|
||||
print(f"开始运行预测: model_type={model_type}, product_id={product_id}, model_id={model_id}, version={version}{scope_msg}")
|
||||
|
||||
if not model_path:
|
||||
raise ValueError("run_prediction v2版需要一个明确的 model_path。")
|
||||
|
||||
# 创建预测器实例
|
||||
predictor = PharmacyPredictor()
|
||||
|
||||
# 解析模型类型映射
|
||||
predictor_model_type = model_type
|
||||
if model_type == 'optimized_kan':
|
||||
predictor_model_type = 'optimized_kan'
|
||||
|
||||
# 生成预测 (v2版,直接调用 load_model_and_predict)
|
||||
prediction_result = load_model_and_predict(
|
||||
model_path=model_path,
|
||||
product_id=product_id,
|
||||
model_type=predictor_model_type,
|
||||
store_id=store_id,
|
||||
future_days=future_days,
|
||||
start_date=start_date,
|
||||
version=version,
|
||||
training_mode=training_mode,
|
||||
analyze_result=True # 默认进行分析
|
||||
)
|
||||
|
||||
if prediction_result is None:
|
||||
return {"status": "error", "error": "预测失败,预测器返回None"}
|
||||
|
||||
# 添加版本信息到预测结果
|
||||
prediction_result['version'] = version
|
||||
prediction_result['model_id'] = model_id
|
||||
|
||||
# 转换数据结构为前端期望的格式
|
||||
if 'predictions' in prediction_result and isinstance(prediction_result['predictions'], pd.DataFrame):
|
||||
predictions_df = prediction_result['predictions']
|
||||
|
||||
# 将DataFrame转换为prediction_data格式
|
||||
prediction_data = []
|
||||
for _, row in predictions_df.iterrows():
|
||||
# 纠正:预测器返回的DataFrame中使用'sales'作为预测值列名。
|
||||
# 我们从'sales'列读取,然后放入前端期望的'predicted_sales'键中。
|
||||
sales_value = float(row['sales']) if pd.notna(row['sales']) else 0.0
|
||||
item = {
|
||||
'date': row['date'].strftime('%Y-%m-%d') if hasattr(row['date'], 'strftime') else str(row['date']),
|
||||
'predicted_sales': sales_value,
|
||||
'sales': sales_value # 兼容字段
|
||||
}
|
||||
prediction_data.append(item)
|
||||
|
||||
prediction_result['prediction_data'] = prediction_data
|
||||
|
||||
# 统一数据格式:确保历史数据和预测数据在发送给前端前具有相同的结构和日期格式。
|
||||
if 'history_data' in prediction_result and isinstance(prediction_result['history_data'], pd.DataFrame):
|
||||
history_df = prediction_result['history_data']
|
||||
history_data_list = []
|
||||
for _, row in history_df.iterrows():
|
||||
item = {
|
||||
'date': row['date'].strftime('%Y-%m-%d'),
|
||||
'sales': float(row['sales']) if pd.notna(row['sales']) else 0.0
|
||||
}
|
||||
history_data_list.append(item)
|
||||
prediction_result['history_data'] = history_data_list
|
||||
else:
|
||||
# 确保即使没有数据,也返回一个空列表
|
||||
prediction_result['history_data'] = []
|
||||
|
||||
return prediction_result
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
print(f"预测过程中发生错误: {str(e)}")
|
||||
return {"status": "error", "error": str(e)}
|
||||
# run_prediction 函数已被移除,因为其逻辑已完全整合到 /api/prediction 路由处理函数中
|
||||
|
||||
# 添加新的API路由,支持/api/models/{model_type}/{product_id}/details格式
|
||||
@app.route('/api/models/<model_type>/<product_id>/details', methods=['GET'])
|
||||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -24,323 +24,161 @@ from utils.visualization import plot_prediction_results
|
||||
from utils.multi_store_data_utils import get_store_product_sales_data, aggregate_multi_store_data
|
||||
from core.config import DEVICE, get_model_file_path, DEFAULT_DATA_PATH
|
||||
|
||||
def load_model_and_predict(model_path: str, product_id: str, model_type: str, store_id: Optional[str] = None, future_days: int = 7, start_date: Optional[str] = None, analyze_result: bool = False, version: Optional[str] = None, training_mode: str = 'product'):
|
||||
def load_model_and_predict(model_path: str, product_id: str, model_type: str, store_id: Optional[str] = None, future_days: int = 7, start_date: Optional[str] = None, analyze_result: bool = False, version: Optional[str] = None, training_mode: str = 'product', history_lookback_days: int = 30):
|
||||
"""
|
||||
加载已训练的模型并进行预测 (v2版)
|
||||
加载已训练的模型并进行预测 (v3版 - 支持自动回归)
|
||||
|
||||
参数:
|
||||
model_path: 模型的准确文件路径
|
||||
product_id: 产品ID
|
||||
model_type: 模型类型
|
||||
store_id: 店铺ID
|
||||
future_days: 预测未来天数
|
||||
start_date: 预测起始日期
|
||||
analyze_result: 是否分析预测结果
|
||||
version: 模型版本
|
||||
training_mode: 训练模式
|
||||
... (同上, 新增 history_lookback_days)
|
||||
history_lookback_days: 用于图表展示的历史数据天数
|
||||
|
||||
返回:
|
||||
预测结果和分析
|
||||
"""
|
||||
try:
|
||||
print(f"v2版预测函数启动,直接使用模型路径: {model_path}")
|
||||
print(f"v3版预测函数启动,模型路径: {model_path}, 预测天数: {future_days}, 历史回看: {history_lookback_days}")
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
print(f"模型文件 {model_path} 不存在")
|
||||
return None
|
||||
|
||||
# 加载销售数据(支持多店铺)
|
||||
try:
|
||||
from utils.multi_store_data_utils import aggregate_multi_store_data
|
||||
|
||||
# 根据训练模式加载相应的数据
|
||||
if training_mode == 'store' and store_id:
|
||||
# 店铺模型:聚合该店铺的所有产品数据
|
||||
product_df = aggregate_multi_store_data(
|
||||
store_id=store_id,
|
||||
aggregation_method='sum',
|
||||
file_path=DEFAULT_DATA_PATH
|
||||
)
|
||||
store_name = product_df['store_name'].iloc[0] if 'store_name' in product_df.columns and not product_df.empty else f"店铺{store_id}"
|
||||
prediction_scope = f"店铺 '{store_name}' ({store_id})"
|
||||
product_name = store_name
|
||||
elif training_mode == 'global':
|
||||
# 全局模型:聚合所有数据
|
||||
product_df = aggregate_multi_store_data(
|
||||
aggregation_method='sum',
|
||||
file_path=DEFAULT_DATA_PATH
|
||||
)
|
||||
prediction_scope = "全局聚合数据"
|
||||
product_name = "全局销售数据"
|
||||
else:
|
||||
# 产品模型(默认):聚合该产品在所有店铺的数据
|
||||
product_df = aggregate_multi_store_data(
|
||||
product_id=product_id,
|
||||
aggregation_method='sum',
|
||||
file_path=DEFAULT_DATA_PATH
|
||||
)
|
||||
prediction_scope = "全部店铺(聚合数据)"
|
||||
product_name = product_df['product_name'].iloc[0] if not product_df.empty else product_id
|
||||
except Exception as e:
|
||||
print(f"加载数据失败: {e}")
|
||||
return None
|
||||
# 加载销售数据
|
||||
from utils.multi_store_data_utils import aggregate_multi_store_data
|
||||
if training_mode == 'store' and store_id:
|
||||
product_df = aggregate_multi_store_data(store_id=store_id, aggregation_method='sum', file_path=DEFAULT_DATA_PATH)
|
||||
product_name = product_df['store_name'].iloc[0] if not product_df.empty else f"店铺{store_id}"
|
||||
elif training_mode == 'global':
|
||||
product_df = aggregate_multi_store_data(aggregation_method='sum', file_path=DEFAULT_DATA_PATH)
|
||||
product_name = "全局销售数据"
|
||||
else:
|
||||
product_df = aggregate_multi_store_data(product_id=product_id, aggregation_method='sum', file_path=DEFAULT_DATA_PATH)
|
||||
product_name = product_df['product_name'].iloc[0] if not product_df.empty else product_id
|
||||
|
||||
if product_df.empty:
|
||||
print(f"产品 {product_id} 或店铺 {store_id} 没有销售数据")
|
||||
return None
|
||||
print(f"使用 {model_type} 模型预测产品 '{product_name}' (ID: {product_id}) 的未来 {future_days} 天销量")
|
||||
print(f"预测范围: {prediction_scope}")
|
||||
|
||||
# 添加安全的全局变量以支持MinMaxScaler的反序列化
|
||||
try:
|
||||
torch.serialization.add_safe_globals([sklearn.preprocessing._data.MinMaxScaler])
|
||||
except Exception as e:
|
||||
print(f"添加安全全局变量失败,但这可能不影响模型加载: {str(e)}")
|
||||
|
||||
|
||||
# 加载模型和配置
|
||||
try:
|
||||
# 首先尝试使用weights_only=False加载
|
||||
try:
|
||||
print("尝试使用 weights_only=False 加载模型")
|
||||
checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False)
|
||||
except Exception as e:
|
||||
print(f"使用weights_only=False加载失败: {str(e)}")
|
||||
print("尝试使用默认参数加载模型")
|
||||
checkpoint = torch.load(model_path, map_location=DEVICE)
|
||||
|
||||
print(f"模型加载成功,检查checkpoint类型: {type(checkpoint)}")
|
||||
if isinstance(checkpoint, dict):
|
||||
print(f"checkpoint包含的键: {list(checkpoint.keys())}")
|
||||
else:
|
||||
print(f"checkpoint不是字典类型,而是: {type(checkpoint)}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"加载模型失败: {str(e)}")
|
||||
return None
|
||||
torch.serialization.add_safe_globals([sklearn.preprocessing._data.MinMaxScaler])
|
||||
except Exception: pass
|
||||
|
||||
# 检查并获取配置
|
||||
if 'config' not in checkpoint:
|
||||
print("模型文件中没有配置信息")
|
||||
checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False)
|
||||
if 'config' not in checkpoint or 'scaler_X' not in checkpoint or 'scaler_y' not in checkpoint:
|
||||
print("模型文件不完整,缺少config或scaler")
|
||||
return None
|
||||
|
||||
config = checkpoint['config']
|
||||
print(f"模型配置: {config}")
|
||||
|
||||
# 检查并获取缩放器
|
||||
if 'scaler_X' not in checkpoint or 'scaler_y' not in checkpoint:
|
||||
print("模型文件中没有缩放器信息")
|
||||
return None
|
||||
|
||||
scaler_X = checkpoint['scaler_X']
|
||||
scaler_y = checkpoint['scaler_y']
|
||||
|
||||
# 创建模型实例
|
||||
try:
|
||||
if model_type == 'transformer':
|
||||
model = TimeSeriesTransformer(
|
||||
num_features=config['input_dim'],
|
||||
d_model=config['hidden_size'],
|
||||
nhead=config['num_heads'],
|
||||
num_encoder_layers=config['num_layers'],
|
||||
dim_feedforward=config['hidden_size'] * 2,
|
||||
dropout=config['dropout'],
|
||||
output_sequence_length=config['output_dim'],
|
||||
seq_length=config['sequence_length'],
|
||||
batch_size=32
|
||||
).to(DEVICE)
|
||||
elif model_type == 'slstm':
|
||||
model = ScalarLSTM(
|
||||
input_dim=config['input_dim'],
|
||||
hidden_dim=config['hidden_size'],
|
||||
output_dim=config['output_dim'],
|
||||
num_layers=config['num_layers'],
|
||||
dropout=config['dropout']
|
||||
).to(DEVICE)
|
||||
elif model_type == 'mlstm':
|
||||
# 获取配置参数,如果不存在则使用默认值
|
||||
embed_dim = config.get('embed_dim', 32)
|
||||
dense_dim = config.get('dense_dim', 32)
|
||||
num_heads = config.get('num_heads', 4)
|
||||
num_blocks = config.get('num_blocks', 3)
|
||||
|
||||
model = MatrixLSTM(
|
||||
num_features=config['input_dim'],
|
||||
hidden_size=config['hidden_size'],
|
||||
mlstm_layers=config['mlstm_layers'],
|
||||
embed_dim=embed_dim,
|
||||
dense_dim=dense_dim,
|
||||
num_heads=num_heads,
|
||||
dropout_rate=config['dropout_rate'],
|
||||
num_blocks=num_blocks,
|
||||
output_sequence_length=config['output_dim']
|
||||
).to(DEVICE)
|
||||
elif model_type == 'kan':
|
||||
model = KANForecaster(
|
||||
input_features=config['input_dim'],
|
||||
hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']],
|
||||
output_sequence_length=config['output_dim']
|
||||
).to(DEVICE)
|
||||
elif model_type == 'optimized_kan':
|
||||
model = OptimizedKANForecaster(
|
||||
input_features=config['input_dim'],
|
||||
hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']],
|
||||
output_sequence_length=config['output_dim']
|
||||
).to(DEVICE)
|
||||
elif model_type == 'tcn':
|
||||
model = TCNForecaster(
|
||||
num_features=config['input_dim'],
|
||||
output_sequence_length=config['output_dim'],
|
||||
num_channels=[config['hidden_size']] * config['num_layers'],
|
||||
kernel_size=config['kernel_size'],
|
||||
dropout=config['dropout']
|
||||
).to(DEVICE)
|
||||
else:
|
||||
print(f"不支持的模型类型: {model_type}")
|
||||
return None
|
||||
|
||||
print(f"模型实例创建成功: {type(model)}")
|
||||
except Exception as e:
|
||||
print(f"创建模型实例失败: {str(e)}")
|
||||
return None
|
||||
# (此处省略了与原版本相同的模型创建代码,以保持简洁)
|
||||
if model_type == 'transformer':
|
||||
model = TimeSeriesTransformer(num_features=config['input_dim'], d_model=config['hidden_size'], nhead=config['num_heads'], num_encoder_layers=config['num_layers'], dim_feedforward=config['hidden_size'] * 2, dropout=config['dropout'], output_sequence_length=config['output_dim'], seq_length=config['sequence_length'], batch_size=32).to(DEVICE)
|
||||
elif model_type == 'mlstm':
|
||||
model = MatrixLSTM(num_features=config['input_dim'], hidden_size=config['hidden_size'], mlstm_layers=config['mlstm_layers'], embed_dim=config.get('embed_dim', 32), dense_dim=config.get('dense_dim', 32), num_heads=config.get('num_heads', 4), dropout_rate=config['dropout_rate'], num_blocks=config.get('num_blocks', 3), output_sequence_length=config['output_dim']).to(DEVICE)
|
||||
elif model_type == 'kan':
|
||||
model = KANForecaster(input_features=config['input_dim'], hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']], output_sequence_length=config['output_dim']).to(DEVICE)
|
||||
elif model_type == 'optimized_kan':
|
||||
model = OptimizedKANForecaster(input_features=config['input_dim'], hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']], output_sequence_length=config['output_dim']).to(DEVICE)
|
||||
elif model_type == 'tcn':
|
||||
model = TCNForecaster(num_features=config['input_dim'], output_sequence_length=config['output_dim'], num_channels=[config['hidden_size']] * config['num_layers'], kernel_size=config['kernel_size'], dropout=config['dropout']).to(DEVICE)
|
||||
else:
|
||||
print(f"不支持的模型类型: {model_type}"); return None
|
||||
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
model.eval()
|
||||
|
||||
# 加载模型参数
|
||||
try:
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
model.eval()
|
||||
print("模型参数加载成功")
|
||||
except Exception as e:
|
||||
print(f"加载模型参数失败: {str(e)}")
|
||||
return None
|
||||
# --- 核心逻辑修改:自动回归预测 ---
|
||||
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||
sequence_length = config['sequence_length']
|
||||
|
||||
# 准备输入数据
|
||||
try:
|
||||
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||
sequence_length = config['sequence_length']
|
||||
|
||||
# 获取最近的sequence_length天数据作为输入
|
||||
recent_data = product_df.iloc[-sequence_length:].copy()
|
||||
|
||||
# 如果指定了起始日期,则使用该日期之后的数据
|
||||
if start_date:
|
||||
if isinstance(start_date, str):
|
||||
start_date = datetime.strptime(start_date, '%Y-%m-%d')
|
||||
recent_data = product_df[product_df['date'] >= start_date].iloc[:sequence_length].copy()
|
||||
if len(recent_data) < sequence_length:
|
||||
print(f"警告: 从指定日期 {start_date} 开始的数据少于所需的 {sequence_length} 天")
|
||||
# 补充数据
|
||||
missing_days = sequence_length - len(recent_data)
|
||||
additional_data = product_df[product_df['date'] < start_date].iloc[-missing_days:].copy()
|
||||
recent_data = pd.concat([additional_data, recent_data]).reset_index(drop=True)
|
||||
|
||||
print(f"输入数据准备完成,形状: {recent_data.shape}")
|
||||
except Exception as e:
|
||||
print(f"准备输入数据失败: {str(e)}")
|
||||
# 确定预测的起始点
|
||||
if start_date:
|
||||
start_date_dt = pd.to_datetime(start_date)
|
||||
# 获取预测开始日期前的 `sequence_length` 天数据作为初始输入
|
||||
prediction_input_df = product_df[product_df['date'] < start_date_dt].tail(sequence_length)
|
||||
else:
|
||||
# 如果未指定开始日期,则从数据的最后一天开始预测
|
||||
prediction_input_df = product_df.tail(sequence_length)
|
||||
start_date_dt = product_df['date'].iloc[-1] + timedelta(days=1)
|
||||
|
||||
if len(prediction_input_df) < sequence_length:
|
||||
print(f"错误: 预测所需的历史数据不足。需要 {sequence_length} 天, 但只有 {len(prediction_input_df)} 天。")
|
||||
return None
|
||||
|
||||
# 归一化输入数据
|
||||
try:
|
||||
X = recent_data[features].values
|
||||
X_scaled = scaler_X.transform(X)
|
||||
|
||||
# 转换为模型输入格式
|
||||
X_input = torch.tensor(X_scaled.reshape(1, sequence_length, -1), dtype=torch.float32).to(DEVICE)
|
||||
print(f"输入张量准备完成,形状: {X_input.shape}")
|
||||
except Exception as e:
|
||||
print(f"归一化输入数据失败: {str(e)}")
|
||||
return None
|
||||
|
||||
# 预测
|
||||
try:
|
||||
|
||||
# 准备用于图表展示的历史数据
|
||||
history_for_chart_df = product_df[product_df['date'] < start_date_dt].tail(history_lookback_days)
|
||||
|
||||
# 自动回归预测循环
|
||||
all_predictions = []
|
||||
current_sequence_df = prediction_input_df.copy()
|
||||
|
||||
print(f"开始自动回归预测,共 {future_days} 天...")
|
||||
for i in range(future_days):
|
||||
# 准备当前序列的输入张量
|
||||
X_current_scaled = scaler_X.transform(current_sequence_df[features].values)
|
||||
X_input = torch.tensor(X_current_scaled.reshape(1, sequence_length, -1), dtype=torch.float32).to(DEVICE)
|
||||
|
||||
# 模型进行一次预测(可能预测出多个点,但我们只用第一个)
|
||||
with torch.no_grad():
|
||||
y_pred_scaled = model(X_input).cpu().numpy()
|
||||
print(f"原始预测输出形状: {y_pred_scaled.shape}")
|
||||
|
||||
# 处理TCN、Transformer、mLSTM和KAN模型的输出,确保形状正确
|
||||
if model_type in ['tcn', 'transformer', 'mlstm', 'kan', 'optimized_kan'] and len(y_pred_scaled.shape) == 3:
|
||||
y_pred_scaled = y_pred_scaled.squeeze(-1)
|
||||
print(f"处理后的预测输出形状: {y_pred_scaled.shape}")
|
||||
# 提取下一个时间点的预测值
|
||||
next_step_pred_scaled = y_pred_scaled[0, 0].reshape(1, -1)
|
||||
next_step_pred_unscaled = scaler_y.inverse_transform(next_step_pred_scaled)[0][0]
|
||||
next_step_pred_unscaled = float(max(0, next_step_pred_unscaled)) # 确保销量不为负,并转换为标准float
|
||||
|
||||
# 获取新预测的日期
|
||||
next_date = current_sequence_df['date'].iloc[-1] + timedelta(days=1)
|
||||
all_predictions.append({'date': next_date, 'predicted_sales': next_step_pred_unscaled})
|
||||
|
||||
# 构建新的一行数据,用于更新输入序列
|
||||
new_row = {
|
||||
'date': next_date,
|
||||
'sales': next_step_pred_unscaled,
|
||||
'weekday': next_date.weekday(),
|
||||
'month': next_date.month,
|
||||
'is_holiday': 0,
|
||||
'is_weekend': 1 if next_date.weekday() >= 5 else 0,
|
||||
'is_promotion': 0,
|
||||
'temperature': current_sequence_df['temperature'].iloc[-1] # 沿用最后一天的温度
|
||||
}
|
||||
|
||||
# 反归一化预测结果
|
||||
y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).flatten()
|
||||
print(f"反归一化后的预测结果: {y_pred}")
|
||||
|
||||
# 生成预测日期
|
||||
last_date = recent_data['date'].iloc[-1]
|
||||
pred_dates = [(last_date + timedelta(days=i+1)) for i in range(len(y_pred))]
|
||||
print(f"预测日期: {pred_dates}")
|
||||
except Exception as e:
|
||||
print(f"执行预测失败: {str(e)}")
|
||||
return None
|
||||
|
||||
# 创建预测结果DataFrame
|
||||
try:
|
||||
predictions_df = pd.DataFrame({
|
||||
'date': pred_dates,
|
||||
'sales': y_pred # 使用sales字段名而不是predicted_sales,以便与历史数据兼容
|
||||
})
|
||||
print(f"预测结果DataFrame创建成功,形状: {predictions_df.shape}")
|
||||
except Exception as e:
|
||||
print(f"创建预测结果DataFrame失败: {str(e)}")
|
||||
return None
|
||||
|
||||
# 绘制预测结果
|
||||
try:
|
||||
plt.figure(figsize=(12, 6))
|
||||
plt.plot(product_df['date'], product_df['sales'], 'b-', label='历史销量')
|
||||
plt.plot(predictions_df['date'], predictions_df['sales'], 'r--', label='预测销量')
|
||||
plt.title(f'{product_name} - {model_type}模型销量预测')
|
||||
plt.xlabel('日期')
|
||||
plt.ylabel('销量')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
plt.xticks(rotation=45)
|
||||
plt.tight_layout()
|
||||
|
||||
# 保存图像
|
||||
plt.savefig(f'{product_id}_{model_type}_prediction.png')
|
||||
plt.close()
|
||||
|
||||
print(f"预测结果已保存到 {product_id}_{model_type}_prediction.png")
|
||||
except Exception as e:
|
||||
print(f"绘制预测结果图表失败: {str(e)}")
|
||||
# 这个错误不影响主要功能,继续执行
|
||||
|
||||
# 分析预测结果
|
||||
# 更新序列:移除最旧的一行,添加最新预测的一行
|
||||
new_row_df = pd.DataFrame([new_row])
|
||||
current_sequence_df = pd.concat([current_sequence_df.iloc[1:], new_row_df], ignore_index=True)
|
||||
|
||||
predictions_df = pd.DataFrame(all_predictions)
|
||||
print(f"自动回归预测完成,生成 {len(predictions_df)} 条预测数据。")
|
||||
|
||||
# 分析与可视化
|
||||
analysis = None
|
||||
if analyze_result:
|
||||
try:
|
||||
analysis = analyze_prediction_result(product_id, model_type, y_pred, X)
|
||||
print("\n预测结果分析:")
|
||||
if analysis and 'explanation' in analysis:
|
||||
print(analysis['explanation'])
|
||||
else:
|
||||
print("分析结果不包含explanation字段")
|
||||
y_pred_for_analysis = predictions_df['predicted_sales'].values
|
||||
# 使用初始输入序列的特征进行分析
|
||||
initial_features_for_analysis = prediction_input_df[features].values
|
||||
analysis = analyze_prediction_result(product_id, model_type, y_pred_for_analysis, initial_features_for_analysis)
|
||||
except Exception as e:
|
||||
print(f"分析预测结果失败: {str(e)}")
|
||||
# 分析失败不影响主要功能,继续执行
|
||||
|
||||
# 准备用于图表展示的历史数据
|
||||
history_df = product_df
|
||||
if start_date:
|
||||
try:
|
||||
# 筛选出所有早于预测起始日期的数据
|
||||
history_df = product_df[product_df['date'] < pd.to_datetime(start_date)]
|
||||
except Exception as e:
|
||||
print(f"筛选历史数据时日期格式错误: {e}")
|
||||
|
||||
# 从正确的历史记录中取最后30天
|
||||
recent_history = history_df.tail(30)
|
||||
|
||||
# 在返回前,将DataFrame转换为前端期望的JSON数组格式
|
||||
history_data_json = history_for_chart_df.to_dict('records') if not history_for_chart_df.empty else []
|
||||
prediction_data_json = predictions_df.to_dict('records') if not predictions_df.empty else []
|
||||
|
||||
return {
|
||||
'product_id': product_id,
|
||||
'product_name': product_name,
|
||||
'model_type': model_type,
|
||||
'predictions': predictions_df,
|
||||
'history_data': recent_history, # 将历史数据添加到返回结果中
|
||||
'predictions': prediction_data_json, # 兼容旧字段,使用已转换的json
|
||||
'prediction_data': prediction_data_json,
|
||||
'history_data': history_data_json,
|
||||
'analysis': analysis
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"预测过程中出现未捕获的异常: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return None
|
||||
return None
|
Loading…
x
Reference in New Issue
Block a user