Compare commits
33 Commits
Author | SHA1 | Date | |
---|---|---|---|
782bd40dd4 | |||
08b26b5fa0 | |||
3aaddcd658 | |||
87cc7b4d03 | |||
54f3fc6f61 | |||
9d7dcae1c8 | |||
af3d174ac6 | |||
e1980b3755 | |||
751de9b548 | |||
038289ae32 | |||
0d3b89abf6 | |||
ec636896da | |||
244393670d | |||
e4d170d667 | |||
311d71b653 | |||
ca7dc432c6 | |||
ada4e8e108 | |||
120caba3cd | |||
c64343fe95 | |||
9d439c36ba | |||
54428c80ca | |||
6f3240c723 | |||
e437658b9d | |||
ee9ba299fa | |||
a1d9c60e61 | |||
a18c8dddf9 | |||
398e949935 | |||
cc30295f1d | |||
066a0429e5 | |||
6c11aff234 | |||
b1b697117b | |||
cfb50d0573 | |||
484f39e12f |
@ -17,8 +17,9 @@
|
||||
|
||||
<el-scrollbar>
|
||||
<el-menu
|
||||
:default-openeds="['1']"
|
||||
router
|
||||
:default-active="activeMenu"
|
||||
:default-openeds="['1']"
|
||||
router
|
||||
class="futuristic-menu"
|
||||
background-color="transparent"
|
||||
text-color="#e0e6ff"
|
||||
@ -31,8 +32,8 @@
|
||||
<el-menu-item index="/">
|
||||
<el-icon><House /></el-icon>首页概览
|
||||
</el-menu-item>
|
||||
<el-menu-item index="/data">
|
||||
<el-icon><FolderOpened /></el-icon>数据管理
|
||||
<el-menu-item index="/store-management">
|
||||
<el-icon><Shop /></el-icon>店铺管理
|
||||
</el-menu-item>
|
||||
<el-sub-menu index="training-submenu">
|
||||
<template #title>
|
||||
@ -49,18 +50,27 @@
|
||||
<el-icon><Operation /></el-icon>全局模型训练
|
||||
</el-menu-item>
|
||||
</el-sub-menu>
|
||||
<el-menu-item index="/prediction">
|
||||
<el-icon><MagicStick /></el-icon>预测分析
|
||||
</el-menu-item>
|
||||
<el-sub-menu index="prediction-submenu">
|
||||
<template #title>
|
||||
<el-icon><MagicStick /></el-icon>
|
||||
<span>预测分析</span>
|
||||
</template>
|
||||
<el-menu-item index="/prediction/product">
|
||||
<el-icon><Coin /></el-icon>按药品预测
|
||||
</el-menu-item>
|
||||
<el-menu-item index="/prediction/store">
|
||||
<el-icon><Shop /></el-icon>按店铺预测
|
||||
</el-menu-item>
|
||||
<el-menu-item index="/prediction/global">
|
||||
<el-icon><Operation /></el-icon>全局模型预测
|
||||
</el-menu-item>
|
||||
</el-sub-menu>
|
||||
<el-menu-item index="/history">
|
||||
<el-icon><Histogram /></el-icon>历史预测
|
||||
</el-menu-item>
|
||||
<el-menu-item index="/management">
|
||||
<el-icon><Files /></el-icon>模型管理
|
||||
</el-menu-item>
|
||||
<el-menu-item index="/store-management">
|
||||
<el-icon><Shop /></el-icon>店铺管理
|
||||
</el-menu-item>
|
||||
</el-sub-menu>
|
||||
</el-menu>
|
||||
</el-scrollbar>
|
||||
@ -100,7 +110,12 @@
|
||||
</template>
|
||||
|
||||
<script setup>
|
||||
import { computed } from 'vue'
|
||||
import { useRoute } from 'vue-router'
|
||||
import { DataAnalysis, Refresh, DataLine, House, FolderOpened, Cpu, MagicStick, Files, Histogram, Coin, Shop, Operation } from '@element-plus/icons-vue'
|
||||
|
||||
const route = useRoute()
|
||||
const activeMenu = computed(() => route.path)
|
||||
</script>
|
||||
|
||||
<style>
|
||||
|
@ -225,8 +225,8 @@ body > .el-popper,
|
||||
|
||||
/* 遮罩层 */
|
||||
.el-overlay {
|
||||
background-color: rgba(6, 15, 28, 0.7) !important;
|
||||
backdrop-filter: blur(2px) !important;
|
||||
background-color: rgba(0, 0, 0, 0) !important; /* 完全透明 */
|
||||
pointer-events: auto !important; /* 确保能拦截点击事件 */
|
||||
}
|
||||
|
||||
/* 分页控件 */
|
||||
@ -595,3 +595,4 @@ textarea::placeholder {
|
||||
.el-drawer__close-btn:hover {
|
||||
color: #5d9cff !important;
|
||||
}
|
||||
|
||||
|
@ -9,11 +9,6 @@ const router = createRouter({
|
||||
name: 'dashboard',
|
||||
component: DashboardView
|
||||
},
|
||||
{
|
||||
path: '/data',
|
||||
name: 'data',
|
||||
component: () => import('../views/DataView.vue')
|
||||
},
|
||||
{
|
||||
path: '/training',
|
||||
name: 'training',
|
||||
@ -37,7 +32,22 @@ const router = createRouter({
|
||||
{
|
||||
path: '/prediction',
|
||||
name: 'prediction',
|
||||
component: () => import('../views/NewPredictionView.vue')
|
||||
redirect: '/prediction/product'
|
||||
},
|
||||
{
|
||||
path: '/prediction/product',
|
||||
name: 'product-prediction',
|
||||
component: () => import('../views/prediction/ProductPredictionView.vue')
|
||||
},
|
||||
{
|
||||
path: '/prediction/store',
|
||||
name: 'store-prediction',
|
||||
component: () => import('../views/prediction/StorePredictionView.vue')
|
||||
},
|
||||
{
|
||||
path: '/prediction/global',
|
||||
name: 'global-prediction',
|
||||
component: () => import('../views/prediction/GlobalPredictionView.vue')
|
||||
},
|
||||
{
|
||||
path: '/history',
|
||||
|
@ -102,10 +102,10 @@ const data = ref({
|
||||
// 功能卡片数据
|
||||
const featureCards = [
|
||||
{
|
||||
title: '数据管理',
|
||||
description: '管理产品和销售数据',
|
||||
title: '店铺管理',
|
||||
description: '查看店铺信息',
|
||||
icon: 'FolderOpened',
|
||||
path: '/data',
|
||||
path: '/store-management',
|
||||
type: 'data'
|
||||
},
|
||||
{
|
||||
|
@ -1,461 +0,0 @@
|
||||
<template>
|
||||
<el-card>
|
||||
<template #header>
|
||||
<div class="card-header">
|
||||
<span>销售数据管理</span>
|
||||
<el-upload
|
||||
:show-file-list="false"
|
||||
:http-request="handleUpload"
|
||||
>
|
||||
<el-button type="primary">上传销售数据</el-button>
|
||||
</el-upload>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<!-- 查询过滤条件 -->
|
||||
<div class="filter-section">
|
||||
<el-row :gutter="20">
|
||||
<el-col :span="6">
|
||||
<el-select v-model="filters.store_id" placeholder="选择店铺" clearable @change="handleFilterChange">
|
||||
<el-option label="全部店铺" value=""></el-option>
|
||||
<el-option
|
||||
v-for="store in stores"
|
||||
:key="store.store_id"
|
||||
:label="store.store_name"
|
||||
:value="store.store_id">
|
||||
</el-option>
|
||||
</el-select>
|
||||
</el-col>
|
||||
<el-col :span="6">
|
||||
<el-select v-model="filters.product_id" placeholder="选择产品" clearable @change="handleFilterChange">
|
||||
<el-option label="全部产品" value=""></el-option>
|
||||
<el-option
|
||||
v-for="product in allProducts"
|
||||
:key="product.product_id"
|
||||
:label="product.product_name"
|
||||
:value="product.product_id">
|
||||
</el-option>
|
||||
</el-select>
|
||||
</el-col>
|
||||
<el-col :span="8">
|
||||
<el-date-picker
|
||||
v-model="filters.dateRange"
|
||||
type="daterange"
|
||||
range-separator="至"
|
||||
start-placeholder="开始日期"
|
||||
end-placeholder="结束日期"
|
||||
format="YYYY-MM-DD"
|
||||
value-format="YYYY-MM-DD"
|
||||
@change="handleFilterChange"
|
||||
/>
|
||||
</el-col>
|
||||
<el-col :span="4">
|
||||
<el-button type="primary" @click="handleFilterChange">查询</el-button>
|
||||
</el-col>
|
||||
</el-row>
|
||||
</div>
|
||||
|
||||
<!-- 销售数据表格 -->
|
||||
<el-table :data="salesData" stripe v-loading="loading" class="mt-4">
|
||||
<el-table-column prop="date" label="日期" width="120"></el-table-column>
|
||||
<el-table-column prop="store_name" label="店铺名称" width="150"></el-table-column>
|
||||
<el-table-column prop="store_id" label="店铺ID" width="100"></el-table-column>
|
||||
<el-table-column prop="product_name" label="产品名称" width="150"></el-table-column>
|
||||
<el-table-column prop="product_id" label="产品ID" width="100"></el-table-column>
|
||||
<el-table-column prop="quantity_sold" label="销量" width="80" align="right"></el-table-column>
|
||||
<el-table-column prop="unit_price" label="单价" width="80" align="right">
|
||||
<template #default="{ row }">
|
||||
¥{{ row.unit_price?.toFixed(2) }}
|
||||
</template>
|
||||
</el-table-column>
|
||||
<el-table-column prop="sales_amount" label="销售额" width="100" align="right">
|
||||
<template #default="{ row }">
|
||||
¥{{ row.sales_amount?.toFixed(2) }}
|
||||
</template>
|
||||
</el-table-column>
|
||||
<el-table-column prop="store_type" label="店铺类型" width="100"></el-table-column>
|
||||
<el-table-column label="操作" width="120">
|
||||
<template #default="{ row }">
|
||||
<el-button link @click="viewStoreDetails(row.store_id)">店铺详情</el-button>
|
||||
</template>
|
||||
</el-table-column>
|
||||
</el-table>
|
||||
|
||||
<!-- 分页 -->
|
||||
<el-pagination
|
||||
v-if="total > 0"
|
||||
layout="total, sizes, prev, pager, next, jumper"
|
||||
:total="total"
|
||||
:page-size="pageSize"
|
||||
:page-sizes="[10, 20, 50, 100]"
|
||||
@current-change="handlePageChange"
|
||||
@size-change="handleSizeChange"
|
||||
class="mt-4"
|
||||
/>
|
||||
|
||||
<!-- 统计信息 -->
|
||||
<div class="statistics-section mt-4" v-if="statistics">
|
||||
<el-row :gutter="20">
|
||||
<el-col :span="6">
|
||||
<el-statistic title="总记录数" :value="statistics.total_records" />
|
||||
</el-col>
|
||||
<el-col :span="6">
|
||||
<el-statistic title="总销售额" :value="statistics.total_sales_amount" :precision="2" prefix="¥" />
|
||||
</el-col>
|
||||
<el-col :span="6">
|
||||
<el-statistic title="总销量" :value="statistics.total_quantity" />
|
||||
</el-col>
|
||||
<el-col :span="6">
|
||||
<el-statistic title="店铺数量" :value="statistics.stores" />
|
||||
</el-col>
|
||||
</el-row>
|
||||
</div>
|
||||
|
||||
<!-- 产品详情对话框 -->
|
||||
<el-dialog
|
||||
v-model="dialogVisible"
|
||||
:title="`${selectedProduct?.product_name} - 销售详情`"
|
||||
width="60%"
|
||||
>
|
||||
<div v-loading="detailLoading">
|
||||
<div v-if="salesData.length > 0">
|
||||
<div class="chart-container">
|
||||
<canvas ref="salesChartCanvas"></canvas>
|
||||
</div>
|
||||
<el-table :data="paginatedSalesData" stripe>
|
||||
<el-table-column prop="date" label="日期"></el-table-column>
|
||||
<el-table-column prop="sales" label="销量"></el-table-column>
|
||||
<el-table-column prop="price" label="价格"></el-table-column>
|
||||
</el-table>
|
||||
<el-pagination
|
||||
layout="prev, pager, next"
|
||||
:total="salesData.length"
|
||||
:page-size="pageSize"
|
||||
@current-change="handlePageChange"
|
||||
class="mt-4"
|
||||
/>
|
||||
</div>
|
||||
<el-empty v-else description="暂无销售数据"></el-empty>
|
||||
</div>
|
||||
</el-dialog>
|
||||
</el-card>
|
||||
</template>
|
||||
|
||||
<script setup>
|
||||
import { ref, onMounted, nextTick } from 'vue'
|
||||
import axios from 'axios'
|
||||
import { ElMessage } from 'element-plus'
|
||||
import Chart from 'chart.js/auto';
|
||||
import zoomPlugin from 'chartjs-plugin-zoom';
|
||||
|
||||
Chart.register(zoomPlugin);
|
||||
|
||||
// 数据相关
|
||||
const stores = ref([])
|
||||
const allProducts = ref([])
|
||||
const salesData = ref([])
|
||||
const statistics = ref(null)
|
||||
const loading = ref(true)
|
||||
|
||||
// 分页相关
|
||||
const pageSize = ref(20)
|
||||
const currentPage = ref(1)
|
||||
const total = ref(0)
|
||||
|
||||
// 过滤条件
|
||||
const filters = ref({
|
||||
store_id: '',
|
||||
product_id: '',
|
||||
dateRange: null
|
||||
})
|
||||
|
||||
// 对话框相关
|
||||
const dialogVisible = ref(false)
|
||||
const detailLoading = ref(false)
|
||||
const selectedProduct = ref(null)
|
||||
const paginatedSalesData = ref([])
|
||||
const salesChartCanvas = ref(null)
|
||||
let salesChart = null;
|
||||
|
||||
// 获取店铺列表
|
||||
const fetchStores = async () => {
|
||||
try {
|
||||
const response = await axios.get('/api/stores')
|
||||
if (response.data.status === 'success') {
|
||||
stores.value = response.data.data
|
||||
} else {
|
||||
ElMessage.error('获取店铺列表失败')
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('获取店铺列表失败:', error)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取产品列表
|
||||
const fetchProducts = async () => {
|
||||
try {
|
||||
const response = await axios.get('/api/products')
|
||||
if (response.data.status === 'success') {
|
||||
allProducts.value = response.data.data
|
||||
} else {
|
||||
ElMessage.error('获取产品列表失败')
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('获取产品列表失败:', error)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取销售数据
|
||||
const fetchSalesData = async () => {
|
||||
try {
|
||||
loading.value = true
|
||||
|
||||
// 构建查询参数
|
||||
const params = {
|
||||
page: currentPage.value,
|
||||
page_size: pageSize.value
|
||||
}
|
||||
|
||||
if (filters.value.store_id) {
|
||||
params.store_id = filters.value.store_id
|
||||
}
|
||||
|
||||
if (filters.value.product_id) {
|
||||
params.product_id = filters.value.product_id
|
||||
}
|
||||
|
||||
if (filters.value.dateRange && filters.value.dateRange.length === 2) {
|
||||
params.start_date = filters.value.dateRange[0]
|
||||
params.end_date = filters.value.dateRange[1]
|
||||
}
|
||||
|
||||
const response = await axios.get('/api/sales/data', { params })
|
||||
|
||||
if (response.data.status === 'success') {
|
||||
salesData.value = response.data.data
|
||||
total.value = response.data.total || 0
|
||||
statistics.value = response.data.statistics
|
||||
} else {
|
||||
ElMessage.error('获取销售数据失败')
|
||||
salesData.value = []
|
||||
total.value = 0
|
||||
statistics.value = null
|
||||
}
|
||||
} catch (error) {
|
||||
ElMessage.error('请求销售数据时出错')
|
||||
console.error(error)
|
||||
salesData.value = []
|
||||
total.value = 0
|
||||
statistics.value = null
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// 处理过滤条件变化
|
||||
const handleFilterChange = () => {
|
||||
currentPage.value = 1
|
||||
fetchSalesData()
|
||||
}
|
||||
|
||||
// 处理分页变化
|
||||
const handlePageChange = (page) => {
|
||||
currentPage.value = page
|
||||
fetchSalesData()
|
||||
}
|
||||
|
||||
// 处理页面大小变化
|
||||
const handleSizeChange = (size) => {
|
||||
pageSize.value = size
|
||||
currentPage.value = 1
|
||||
fetchSalesData()
|
||||
}
|
||||
|
||||
// 查看店铺详情
|
||||
const viewStoreDetails = async (storeId) => {
|
||||
try {
|
||||
const response = await axios.get(`/api/stores/${storeId}`)
|
||||
if (response.data.status === 'success') {
|
||||
const store = response.data.data
|
||||
ElMessage.info(`店铺:${store.store_name},位置:${store.location},类型:${store.type}`)
|
||||
}
|
||||
} catch (error) {
|
||||
ElMessage.error('获取店铺详情失败')
|
||||
}
|
||||
}
|
||||
|
||||
// 文件上传
|
||||
const handleUpload = async (options) => {
|
||||
const formData = new FormData()
|
||||
formData.append('file', options.file)
|
||||
try {
|
||||
const response = await axios.post('/api/data/upload', formData, {
|
||||
headers: {
|
||||
'Content-Type': 'multipart/form-data'
|
||||
}
|
||||
})
|
||||
if (response.data.status === 'success') {
|
||||
ElMessage.success('数据上传成功')
|
||||
await fetchStores()
|
||||
await fetchProducts()
|
||||
await fetchSalesData()
|
||||
} else {
|
||||
ElMessage.error(response.data.message || '数据上传失败')
|
||||
}
|
||||
} catch (error) {
|
||||
ElMessage.error('数据上传请求失败')
|
||||
console.error(error)
|
||||
}
|
||||
}
|
||||
|
||||
const viewDetails = async (product) => {
|
||||
selectedProduct.value = product;
|
||||
dialogVisible.value = true;
|
||||
detailLoading.value = true;
|
||||
try {
|
||||
const response = await axios.get(`/api/products/${product.product_id}/sales`);
|
||||
if (response.data.status === 'success') {
|
||||
salesData.value = response.data.data;
|
||||
handlePageChange(1); // Show first page
|
||||
await nextTick();
|
||||
renderChart();
|
||||
} else {
|
||||
ElMessage.error('获取销售详情失败');
|
||||
salesData.value = [];
|
||||
}
|
||||
} catch (error) {
|
||||
ElMessage.error('请求销售详情时出错');
|
||||
salesData.value = [];
|
||||
} finally {
|
||||
detailLoading.value = false;
|
||||
}
|
||||
}
|
||||
|
||||
// handlePageChange函数已在上面定义
|
||||
|
||||
const renderChart = () => {
|
||||
if (salesChart) {
|
||||
salesChart.destroy();
|
||||
}
|
||||
if (!salesChartCanvas.value || salesData.value.length === 0) return;
|
||||
|
||||
const labels = salesData.value.map(d => d.date);
|
||||
const data = salesData.value.map(d => d.sales);
|
||||
|
||||
salesChart = new Chart(salesChartCanvas.value, {
|
||||
type: 'line',
|
||||
data: {
|
||||
labels,
|
||||
datasets: [{
|
||||
label: '每日销量',
|
||||
data,
|
||||
borderColor: '#409EFF',
|
||||
tension: 0.1
|
||||
}]
|
||||
},
|
||||
options: {
|
||||
responsive: true,
|
||||
maintainAspectRatio: false,
|
||||
plugins: {
|
||||
tooltip: {
|
||||
callbacks: {
|
||||
title: function(context) {
|
||||
return `日期: ${context[0].label}`;
|
||||
},
|
||||
label: (context) => {
|
||||
const label = context.dataset.label || '';
|
||||
const value = context.parsed.y;
|
||||
const fullData = salesData.value[context.dataIndex];
|
||||
let tooltipText = `${label}: ${value}`;
|
||||
if (fullData) {
|
||||
tooltipText += ` | 温度: ${fullData.temperature}°C`;
|
||||
}
|
||||
return tooltipText;
|
||||
}
|
||||
}
|
||||
},
|
||||
zoom: {
|
||||
pan: {
|
||||
enabled: true,
|
||||
mode: 'x',
|
||||
},
|
||||
zoom: {
|
||||
wheel: {
|
||||
enabled: true,
|
||||
},
|
||||
pinch: {
|
||||
enabled: true
|
||||
},
|
||||
mode: 'x',
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// 组件挂载时初始化数据
|
||||
onMounted(async () => {
|
||||
await fetchStores()
|
||||
await fetchProducts()
|
||||
await fetchSalesData()
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.card-header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.filter-section {
|
||||
padding: 20px;
|
||||
background-color: #f8f9fa;
|
||||
border-radius: 8px;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
|
||||
.statistics-section {
|
||||
padding: 20px;
|
||||
background-color: #f0f9ff;
|
||||
border-radius: 8px;
|
||||
border: 1px solid #e0f2fe;
|
||||
}
|
||||
|
||||
.mt-4 {
|
||||
margin-top: 24px;
|
||||
}
|
||||
|
||||
.chart-container {
|
||||
width: 100%;
|
||||
height: 400px;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
|
||||
.el-statistic {
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.el-table .el-table__cell {
|
||||
padding: 8px 0;
|
||||
}
|
||||
|
||||
.filter-section .el-row {
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.filter-section .el-col {
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
|
||||
@media (max-width: 768px) {
|
||||
.filter-section .el-col {
|
||||
margin-bottom: 15px;
|
||||
}
|
||||
|
||||
.statistics-section .el-col {
|
||||
margin-bottom: 15px;
|
||||
}
|
||||
}
|
||||
</style>
|
@ -248,41 +248,9 @@ import { ref, onMounted, reactive, watch, nextTick } from 'vue';
|
||||
import axios from 'axios';
|
||||
import { ElMessage, ElMessageBox } from 'element-plus';
|
||||
import { QuestionFilled, Search, View, Delete, ArrowUp, ArrowDown, Minus, Download } from '@element-plus/icons-vue';
|
||||
import * as echarts from 'echarts/core';
|
||||
import { LineChart, BarChart } from 'echarts/charts';
|
||||
import {
|
||||
TitleComponent,
|
||||
TooltipComponent,
|
||||
GridComponent,
|
||||
DatasetComponent,
|
||||
TransformComponent,
|
||||
LegendComponent,
|
||||
ToolboxComponent,
|
||||
MarkLineComponent,
|
||||
MarkPointComponent
|
||||
} from 'echarts/components';
|
||||
import { LabelLayout, UniversalTransition } from 'echarts/features';
|
||||
import { CanvasRenderer } from 'echarts/renderers';
|
||||
import Chart from 'chart.js/auto'; // << 关键改动:导入Chart.js
|
||||
import { computed, onUnmounted } from 'vue';
|
||||
|
||||
// 注册必须的组件
|
||||
echarts.use([
|
||||
TitleComponent,
|
||||
TooltipComponent,
|
||||
GridComponent,
|
||||
DatasetComponent,
|
||||
TransformComponent,
|
||||
LegendComponent,
|
||||
ToolboxComponent,
|
||||
MarkLineComponent,
|
||||
MarkPointComponent,
|
||||
LineChart,
|
||||
BarChart,
|
||||
LabelLayout,
|
||||
UniversalTransition,
|
||||
CanvasRenderer
|
||||
]);
|
||||
|
||||
const loading = ref(false);
|
||||
const history = ref([]);
|
||||
const products = ref([]);
|
||||
@ -292,8 +260,8 @@ const currentPrediction = ref(null);
|
||||
const rawResponseData = ref(null);
|
||||
const showRawDataFlag = ref(false);
|
||||
|
||||
const fullscreenPredictionChart = ref(null);
|
||||
const fullscreenHistoryChart = ref(null);
|
||||
let predictionChart = null; // << 关键改动:使用单个chart实例
|
||||
let historyChart = null;
|
||||
|
||||
const filters = reactive({
|
||||
product_id: '',
|
||||
@ -982,104 +950,133 @@ const getFactorsArray = computed(() => {
|
||||
watch(detailsVisible, (newVal) => {
|
||||
if (newVal && currentPrediction.value) {
|
||||
nextTick(() => {
|
||||
// Init Prediction Chart
|
||||
if (fullscreenPredictionChart.value) fullscreenPredictionChart.value.dispose();
|
||||
const predChartDom = document.getElementById('fullscreen-prediction-chart-history');
|
||||
if (predChartDom) {
|
||||
fullscreenPredictionChart.value = echarts.init(predChartDom);
|
||||
if (currentPrediction.value.chart_data) {
|
||||
updatePredictionChart(currentPrediction.value.chart_data, fullscreenPredictionChart.value, true);
|
||||
}
|
||||
}
|
||||
|
||||
// Init History Chart
|
||||
if (currentPrediction.value.analysis) {
|
||||
if (fullscreenHistoryChart.value) fullscreenHistoryChart.value.dispose();
|
||||
const histChartDom = document.getElementById('fullscreen-history-chart-history');
|
||||
if (histChartDom) {
|
||||
fullscreenHistoryChart.value = echarts.init(histChartDom);
|
||||
updateHistoryChart(currentPrediction.value.analysis, fullscreenHistoryChart.value, true);
|
||||
}
|
||||
}
|
||||
renderChart();
|
||||
// 可以在这里添加渲染第二个图表的逻辑
|
||||
// renderHistoryAnalysisChart();
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
const updatePredictionChart = (chartData, chart, isFullscreen = false) => {
|
||||
if (!chart || !chartData) return;
|
||||
chart.showLoading();
|
||||
const dates = chartData.dates || [];
|
||||
const sales = chartData.sales || [];
|
||||
const types = chartData.types || [];
|
||||
// << 关键改动:从ProductPredictionView.vue复制并适应的renderChart函数
|
||||
const renderChart = () => {
|
||||
const chartCanvas = document.getElementById('fullscreen-prediction-chart-history');
|
||||
if (!chartCanvas || !currentPrediction.value || !currentPrediction.value.data) return;
|
||||
|
||||
const combinedData = [];
|
||||
for (let i = 0; i < dates.length; i++) {
|
||||
combinedData.push({ date: dates[i], sales: sales[i], type: types[i] });
|
||||
if (predictionChart) {
|
||||
predictionChart.destroy();
|
||||
}
|
||||
combinedData.sort((a, b) => new Date(a.date) - new Date(b.date));
|
||||
|
||||
const allDates = combinedData.map(item => item.date);
|
||||
const historyDates = combinedData.filter(d => d.type === '历史销量').map(d => d.date);
|
||||
const historySales = combinedData.filter(d => d.type === '历史销量').map(d => d.sales);
|
||||
const predictionDates = combinedData.filter(d => d.type === '预测销量').map(d => d.date);
|
||||
const predictionSales = combinedData.filter(d => d.type === '预测销量').map(d => d.sales);
|
||||
|
||||
const allSales = [...historySales, ...predictionSales].filter(val => !isNaN(val));
|
||||
const minSale = Math.max(0, Math.floor(Math.min(...allSales) * 0.9));
|
||||
const maxSale = Math.ceil(Math.max(...allSales) * 1.1);
|
||||
|
||||
const option = {
|
||||
title: { text: '销量预测趋势图', left: 'center', textStyle: { fontSize: isFullscreen ? 18 : 16, fontWeight: 'bold', color: '#e0e6ff' } },
|
||||
tooltip: { trigger: 'axis', axisPointer: { type: 'cross' },
|
||||
formatter: function(params) {
|
||||
if (!params || params.length === 0) return '';
|
||||
const date = params[0].axisValue;
|
||||
let html = `<div style="font-weight:bold">${date}</div>`;
|
||||
params.forEach(item => {
|
||||
if (item.value !== '-') {
|
||||
html += `<div style="display:flex;justify-content:space-between;align-items:center;margin:5px 0;">
|
||||
<span style="display:inline-block;margin-right:5px;width:10px;height:10px;border-radius:50%;background-color:${item.color};"></span>
|
||||
<span>${item.seriesName}:</span>
|
||||
<span style="font-weight:bold;margin-left:5px;">${item.value.toFixed(2)}</span>
|
||||
</div>`;
|
||||
}
|
||||
});
|
||||
return html;
|
||||
}
|
||||
const formatDate = (date) => new Date(date).toISOString().split('T')[0];
|
||||
|
||||
const historyData = (currentPrediction.value.data.history_data || []).map(p => ({ ...p, date: formatDate(p.date) }));
|
||||
const predictionData = (currentPrediction.value.data.prediction_data || []).map(p => ({ ...p, date: formatDate(p.date) }));
|
||||
|
||||
if (historyData.length === 0 && predictionData.length === 0) {
|
||||
ElMessage.warning('没有可用于图表的数据。');
|
||||
return;
|
||||
}
|
||||
|
||||
const allLabels = [...new Set([...historyData.map(p => p.date), ...predictionData.map(p => p.date)])].sort();
|
||||
const simplifiedLabels = allLabels.map(date => date.split('-')[2]);
|
||||
|
||||
const historyMap = new Map(historyData.map(p => [p.date, p.sales]));
|
||||
// 注意:这里使用 'sales' 字段,因为后端已经统一了
|
||||
const predictionMap = new Map(predictionData.map(p => [p.date, p.sales]));
|
||||
|
||||
const alignedHistorySales = allLabels.map(label => historyMap.get(label) ?? null);
|
||||
const alignedPredictionSales = allLabels.map(label => predictionMap.get(label) ?? null);
|
||||
|
||||
if (historyData.length > 0 && predictionData.length > 0) {
|
||||
const lastHistoryDate = historyData[historyData.length - 1].date;
|
||||
const lastHistoryValue = historyData[historyData.length - 1].sales;
|
||||
if (!predictionMap.has(lastHistoryDate)) {
|
||||
alignedPredictionSales[allLabels.indexOf(lastHistoryDate)] = lastHistoryValue;
|
||||
}
|
||||
}
|
||||
|
||||
let subtitleText = '';
|
||||
if (historyData.length > 0) {
|
||||
subtitleText += `历史数据: ${historyData[0].date} ~ ${historyData[historyData.length - 1].date}`;
|
||||
}
|
||||
if (predictionData.length > 0) {
|
||||
if (subtitleText) subtitleText += ' | ';
|
||||
subtitleText += `预测数据: ${predictionData[0].date} ~ ${predictionData[predictionData.length - 1].date}`;
|
||||
}
|
||||
|
||||
predictionChart = new Chart(chartCanvas, {
|
||||
type: 'line',
|
||||
data: {
|
||||
labels: simplifiedLabels,
|
||||
datasets: [
|
||||
{
|
||||
label: '历史销量',
|
||||
data: alignedHistorySales,
|
||||
borderColor: '#67C23A',
|
||||
backgroundColor: 'rgba(103, 194, 58, 0.2)',
|
||||
tension: 0.4,
|
||||
fill: true,
|
||||
spanGaps: false,
|
||||
},
|
||||
{
|
||||
label: '预测销量',
|
||||
data: alignedPredictionSales,
|
||||
borderColor: '#409EFF',
|
||||
backgroundColor: 'rgba(64, 158, 255, 0.2)',
|
||||
tension: 0.4,
|
||||
fill: true,
|
||||
borderDash: [5, 5],
|
||||
}
|
||||
]
|
||||
},
|
||||
legend: { data: ['历史销量', '预测销量'], top: isFullscreen ? 40 : 30, textStyle: { color: '#e0e6ff' } },
|
||||
grid: { left: '3%', right: '4%', bottom: '3%', containLabel: true },
|
||||
toolbox: { feature: { saveAsImage: { title: '保存图片' } }, iconStyle: { borderColor: '#e0e6ff' } },
|
||||
xAxis: { type: 'category', boundaryGap: false, data: allDates, axisLabel: { color: '#e0e6ff' }, axisLine: { lineStyle: { color: 'rgba(224, 230, 255, 0.5)' } } },
|
||||
yAxis: { type: 'value', name: '销量', min: minSale, max: maxSale, axisLabel: { color: '#e0e6ff' }, nameTextStyle: { color: '#e0e6ff' }, axisLine: { lineStyle: { color: 'rgba(224, 230, 255, 0.5)' } }, splitLine: { lineStyle: { color: 'rgba(224, 230, 255, 0.1)' } } },
|
||||
series: [
|
||||
{ name: '历史销量', type: 'line', smooth: true, connectNulls: true, data: allDates.map(date => historyDates.includes(date) ? historySales[historyDates.indexOf(date)] : null), areaStyle: { color: new echarts.graphic.LinearGradient(0, 0, 0, 1, [{ offset: 0, color: 'rgba(64, 158, 255, 0.3)' }, { offset: 1, color: 'rgba(64, 158, 255, 0.1)' }]) }, lineStyle: { color: '#409EFF' } },
|
||||
{ name: '预测销量', type: 'line', smooth: true, connectNulls: true, data: allDates.map(date => predictionDates.includes(date) ? predictionSales[predictionDates.indexOf(date)] : null), lineStyle: { color: '#F56C6C' } }
|
||||
]
|
||||
};
|
||||
chart.hideLoading();
|
||||
chart.setOption(option, true);
|
||||
};
|
||||
|
||||
const updateHistoryChart = (analysisData, chart, isFullscreen = false) => {
|
||||
if (!chart || !analysisData || !analysisData.history_chart_data) return;
|
||||
chart.showLoading();
|
||||
const { dates, changes } = analysisData.history_chart_data;
|
||||
|
||||
const option = {
|
||||
title: { text: '销量日环比变化', left: 'center', textStyle: { fontSize: isFullscreen ? 18 : 16, fontWeight: 'bold', color: '#e0e6ff' } },
|
||||
tooltip: { trigger: 'axis', axisPointer: { type: 'shadow' }, formatter: p => `${p[0].axisValue}<br/>环比: ${p[0].value.toFixed(2)}%` },
|
||||
grid: { left: '3%', right: '4%', bottom: '3%', containLabel: true },
|
||||
toolbox: { feature: { saveAsImage: { title: '保存图片' } }, iconStyle: { borderColor: '#e0e6ff' } },
|
||||
xAxis: { type: 'category', data: dates.map(d => formatDate(d)), axisLabel: { color: '#e0e6ff' }, axisLine: { lineStyle: { color: 'rgba(224, 230, 255, 0.5)' } } },
|
||||
yAxis: { type: 'value', name: '环比变化(%)', axisLabel: { formatter: '{value}%', color: '#e0e6ff' }, nameTextStyle: { color: '#e0e6ff' }, axisLine: { lineStyle: { color: 'rgba(224, 230, 255, 0.5)' } }, splitLine: { lineStyle: { color: 'rgba(224, 230, 255, 0.1)' } } },
|
||||
series: [{
|
||||
name: '日环比变化', type: 'bar',
|
||||
data: changes.map(val => ({ value: val, itemStyle: { color: val >= 0 ? '#67C23A' : '#F56C6C' } }))
|
||||
}]
|
||||
};
|
||||
chart.hideLoading();
|
||||
chart.setOption(option, true);
|
||||
options: {
|
||||
responsive: true,
|
||||
maintainAspectRatio: false,
|
||||
plugins: {
|
||||
title: {
|
||||
display: true,
|
||||
text: `${currentPrediction.value.data.product_name} - 销量预测趋势图`,
|
||||
color: '#ffffff',
|
||||
font: {
|
||||
size: 20,
|
||||
weight: 'bold',
|
||||
}
|
||||
},
|
||||
subtitle: {
|
||||
display: true,
|
||||
text: subtitleText,
|
||||
color: '#6c757d',
|
||||
font: {
|
||||
size: 14,
|
||||
},
|
||||
padding: {
|
||||
bottom: 20
|
||||
}
|
||||
}
|
||||
},
|
||||
scales: {
|
||||
x: {
|
||||
title: {
|
||||
display: true,
|
||||
text: '日期 (日)'
|
||||
},
|
||||
grid: {
|
||||
display: false
|
||||
}
|
||||
},
|
||||
y: {
|
||||
title: {
|
||||
display: true,
|
||||
text: '销量'
|
||||
},
|
||||
grid: {
|
||||
color: '#e9e9e9',
|
||||
drawBorder: false,
|
||||
},
|
||||
beginAtZero: true
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
const exportHistoryData = () => {
|
||||
@ -1102,16 +1099,24 @@ const exportHistoryData = () => {
|
||||
};
|
||||
|
||||
const resizeCharts = () => {
|
||||
if (fullscreenPredictionChart.value) fullscreenPredictionChart.value.resize();
|
||||
if (fullscreenHistoryChart.value) fullscreenHistoryChart.value.resize();
|
||||
if (predictionChart) {
|
||||
predictionChart.resize();
|
||||
}
|
||||
if (historyChart) {
|
||||
historyChart.resize();
|
||||
}
|
||||
};
|
||||
|
||||
window.addEventListener('resize', resizeCharts);
|
||||
|
||||
onUnmounted(() => {
|
||||
window.removeEventListener('resize', resizeCharts);
|
||||
if (fullscreenPredictionChart.value) fullscreenPredictionChart.value.dispose();
|
||||
if (fullscreenHistoryChart.value) fullscreenHistoryChart.value.dispose();
|
||||
if (predictionChart) {
|
||||
predictionChart.destroy();
|
||||
}
|
||||
if (historyChart) {
|
||||
historyChart.destroy();
|
||||
}
|
||||
});
|
||||
|
||||
onMounted(() => {
|
||||
|
372
UI/src/views/prediction/GlobalPredictionView.vue
Normal file
372
UI/src/views/prediction/GlobalPredictionView.vue
Normal file
@ -0,0 +1,372 @@
|
||||
<template>
|
||||
<div class="prediction-view">
|
||||
<el-card>
|
||||
<template #header>
|
||||
<div class="card-header">
|
||||
<span>全局模型预测</span>
|
||||
<el-tooltip content="对系统中的所有全局模型进行批量或单个预测">
|
||||
<el-icon><QuestionFilled /></el-icon>
|
||||
</el-tooltip>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<div class="controls-section">
|
||||
<el-form :model="filters" label-width="80px" inline>
|
||||
<el-form-item label="算法类型">
|
||||
<el-select v-model="filters.model_type" placeholder="所有类型" clearable value-key="id" style="width: 200px;">
|
||||
<el-option
|
||||
v-for="item in modelTypes"
|
||||
:key="item.id"
|
||||
:label="item.name"
|
||||
:value="item"
|
||||
/>
|
||||
</el-select>
|
||||
</el-form-item>
|
||||
<el-form-item label="预测天数">
|
||||
<el-input-number v-model="form.future_days" :min="1" :max="365" />
|
||||
</el-form-item>
|
||||
<el-form-item label="历史天数">
|
||||
<el-input-number v-model="form.history_lookback_days" :min="7" :max="365" />
|
||||
</el-form-item>
|
||||
<el-form-item label="起始日期">
|
||||
<el-date-picker
|
||||
v-model="form.start_date"
|
||||
type="date"
|
||||
placeholder="选择日期"
|
||||
format="YYYY-MM-DD"
|
||||
value-format="YYYY-MM-DD"
|
||||
:clearable="false"
|
||||
/>
|
||||
</el-form-item>
|
||||
</el-form>
|
||||
</div>
|
||||
|
||||
<!-- 模型列表 -->
|
||||
<div class="model-list-section">
|
||||
<h4>🌍 可用全局模型列表</h4>
|
||||
<el-table :data="paginatedModelList" style="width: 100%" v-loading="modelsLoading">
|
||||
<el-table-column prop="model_type" label="模型类型" sortable />
|
||||
<el-table-column prop="version" label="版本" />
|
||||
<el-table-column prop="created_at" label="创建时间" :formatter="formatDateTime" />
|
||||
<el-table-column label="操作">
|
||||
<template #default="{ row }">
|
||||
<el-button
|
||||
type="primary"
|
||||
size="small"
|
||||
@click="startPrediction(row)"
|
||||
:loading="predicting[row.model_id]"
|
||||
>
|
||||
<el-icon><TrendCharts /></el-icon>
|
||||
开始预测
|
||||
</el-button>
|
||||
</template>
|
||||
</el-table-column>
|
||||
</el-table>
|
||||
<el-pagination
|
||||
background
|
||||
layout="prev, pager, next"
|
||||
:total="filteredModelList.length"
|
||||
:page-size="pagination.pageSize"
|
||||
@current-change="handlePageChange"
|
||||
style="margin-top: 20px; justify-content: center;"
|
||||
/>
|
||||
</div>
|
||||
</el-card>
|
||||
|
||||
<!-- 预测结果弹窗 -->
|
||||
<el-dialog v-model="dialogVisible" title="📈 预测结果" width="70%" :modal="false" modal-class="no-overlay">
|
||||
<div class="prediction-chart">
|
||||
<canvas ref="chartCanvas" width="800" height="400"></canvas>
|
||||
</div>
|
||||
<template #footer>
|
||||
<el-button @click="dialogVisible = false">关闭</el-button>
|
||||
</template>
|
||||
</el-dialog>
|
||||
|
||||
</div>
|
||||
|
||||
</template>
|
||||
|
||||
<script setup>
|
||||
import { ref, reactive, onMounted, nextTick, computed } from 'vue'
|
||||
import axios from 'axios'
|
||||
import { ElMessage, ElDialog, ElTable, ElTableColumn, ElButton, ElIcon, ElCard, ElTooltip, ElForm, ElFormItem, ElInputNumber, ElDatePicker, ElSelect, ElOption, ElPagination } from 'element-plus'
|
||||
import { QuestionFilled, TrendCharts } from '@element-plus/icons-vue'
|
||||
import Chart from 'chart.js/auto'
|
||||
|
||||
const modelList = ref([])
|
||||
const modelTypes = ref([])
|
||||
const modelsLoading = ref(false)
|
||||
const predicting = reactive({})
|
||||
const dialogVisible = ref(false)
|
||||
const predictionResult = ref(null)
|
||||
const chartCanvas = ref(null)
|
||||
let chart = null
|
||||
|
||||
const form = reactive({
|
||||
future_days: 7,
|
||||
history_lookback_days: 30,
|
||||
start_date: '',
|
||||
analyze_result: true // 保持分析功能开启,但UI上移除开关
|
||||
})
|
||||
|
||||
const filters = reactive({
|
||||
model_type: null
|
||||
})
|
||||
|
||||
const pagination = reactive({
|
||||
currentPage: 1,
|
||||
pageSize: 8
|
||||
})
|
||||
|
||||
const filteredModelList = computed(() => {
|
||||
return modelList.value.filter(model => {
|
||||
const modelTypeMatch = !filters.model_type || model.model_type === filters.model_type.id
|
||||
return modelTypeMatch
|
||||
})
|
||||
})
|
||||
|
||||
const paginatedModelList = computed(() => {
|
||||
const start = (pagination.currentPage - 1) * pagination.pageSize
|
||||
const end = start + pagination.pageSize
|
||||
return filteredModelList.value.slice(start, end)
|
||||
})
|
||||
|
||||
const handlePageChange = (page) => {
|
||||
pagination.currentPage = page
|
||||
}
|
||||
|
||||
const formatDateTime = (row, column, cellValue) => {
|
||||
if (!cellValue) return ''
|
||||
const date = new Date(cellValue)
|
||||
return date.toLocaleString('zh-CN', {
|
||||
year: 'numeric',
|
||||
month: '2-digit',
|
||||
day: '2-digit',
|
||||
hour: '2-digit',
|
||||
minute: '2-digit',
|
||||
second: '2-digit',
|
||||
hour12: false
|
||||
}).replace(/\//g, '-')
|
||||
}
|
||||
|
||||
const fetchModelTypes = async () => {
|
||||
try {
|
||||
const response = await axios.get('/api/model_types')
|
||||
if (response.data.status === 'success') {
|
||||
modelTypes.value = response.data.data
|
||||
}
|
||||
} catch (error) {
|
||||
ElMessage.error('获取模型类型失败')
|
||||
}
|
||||
}
|
||||
|
||||
const fetchModels = async () => {
|
||||
modelsLoading.value = true
|
||||
try {
|
||||
const response = await axios.get('/api/models', { params: { training_mode: 'global' } })
|
||||
if (response.data.status === 'success') {
|
||||
modelList.value = response.data.data
|
||||
} else {
|
||||
ElMessage.error('获取模型列表失败')
|
||||
}
|
||||
} catch (error) {
|
||||
ElMessage.error('获取模型列表失败')
|
||||
} finally {
|
||||
modelsLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const startPrediction = async (model) => {
|
||||
predicting[model.model_id] = true
|
||||
try {
|
||||
const payload = {
|
||||
training_mode: 'global',
|
||||
model_type: model.model_type,
|
||||
version: model.version,
|
||||
future_days: form.future_days,
|
||||
history_lookback_days: form.history_lookback_days,
|
||||
start_date: form.start_date,
|
||||
analyze_result: form.analyze_result,
|
||||
}
|
||||
const response = await axios.post('/api/prediction', payload)
|
||||
if (response.data.status === 'success') {
|
||||
predictionResult.value = response.data.data
|
||||
ElMessage.success('预测完成!')
|
||||
dialogVisible.value = true
|
||||
await nextTick()
|
||||
renderChart()
|
||||
} else {
|
||||
ElMessage.error(response.data.error || '预测失败')
|
||||
}
|
||||
} catch (error) {
|
||||
ElMessage.error(error.response?.data?.error || '预测请求失败')
|
||||
} finally {
|
||||
predicting[model.model_id] = false
|
||||
}
|
||||
}
|
||||
|
||||
const renderChart = () => {
|
||||
if (!chartCanvas.value || !predictionResult.value) return
|
||||
if (chart) {
|
||||
chart.destroy()
|
||||
}
|
||||
|
||||
const formatDate = (date) => new Date(date).toISOString().split('T')[0];
|
||||
|
||||
const historyData = (predictionResult.value.history_data || []).map(p => ({ ...p, date: formatDate(p.date) }));
|
||||
const predictionData = (predictionResult.value.prediction_data || []).map(p => ({ ...p, date: formatDate(p.date) }));
|
||||
|
||||
if (historyData.length === 0 && predictionData.length === 0) {
|
||||
ElMessage.warning('没有可用于图表的数据。')
|
||||
return
|
||||
}
|
||||
|
||||
const allLabels = [...new Set([...historyData.map(p => p.date), ...predictionData.map(p => p.date)])].sort()
|
||||
const simplifiedLabels = allLabels.map(date => date.split('-').slice(1).join('/'));
|
||||
|
||||
const historyMap = new Map(historyData.map(p => [p.date, p.sales]))
|
||||
const predictionMap = new Map(predictionData.map(p => [p.date, p.predicted_sales]))
|
||||
|
||||
const alignedHistorySales = allLabels.map(label => historyMap.get(label) ?? null)
|
||||
const alignedPredictionSales = allLabels.map(label => predictionMap.get(label) ?? null)
|
||||
|
||||
if (historyData.length > 0 && predictionData.length > 0) {
|
||||
const lastHistoryDate = historyData[historyData.length - 1].date
|
||||
const lastHistoryValue = historyData[historyData.length - 1].sales
|
||||
if (!predictionMap.has(lastHistoryDate)) {
|
||||
alignedPredictionSales[allLabels.indexOf(lastHistoryDate)] = lastHistoryValue
|
||||
}
|
||||
}
|
||||
|
||||
let subtitleText = '';
|
||||
if (historyData.length > 0) {
|
||||
subtitleText += `历史数据: ${historyData[0].date} ~ ${historyData[historyData.length - 1].date}`;
|
||||
}
|
||||
if (predictionData.length > 0) {
|
||||
if (subtitleText) subtitleText += ' | ';
|
||||
subtitleText += `预测数据: ${predictionData[0].date} ~ ${predictionData[predictionData.length - 1].date}`;
|
||||
}
|
||||
|
||||
chart = new Chart(chartCanvas.value, {
|
||||
type: 'line',
|
||||
data: {
|
||||
labels: simplifiedLabels,
|
||||
datasets: [
|
||||
{
|
||||
label: '历史销量',
|
||||
data: alignedHistorySales,
|
||||
borderColor: '#67C23A',
|
||||
backgroundColor: 'rgba(103, 194, 58, 0.2)',
|
||||
tension: 0.4,
|
||||
fill: true,
|
||||
spanGaps: false,
|
||||
},
|
||||
{
|
||||
label: '预测销量',
|
||||
data: alignedPredictionSales,
|
||||
borderColor: '#409EFF',
|
||||
backgroundColor: 'rgba(64, 158, 255, 0.2)',
|
||||
tension: 0.4,
|
||||
fill: true,
|
||||
borderDash: [5, 5],
|
||||
}
|
||||
]
|
||||
},
|
||||
options: {
|
||||
responsive: true,
|
||||
maintainAspectRatio: false,
|
||||
plugins: {
|
||||
legend: {
|
||||
labels: {
|
||||
color: 'white' // 图例文字颜色
|
||||
}
|
||||
},
|
||||
title: {
|
||||
display: true,
|
||||
text: `全局模型 (${predictionResult.value.model_type}) - 销量预测趋势图`,
|
||||
color: 'white',
|
||||
font: {
|
||||
size: 20,
|
||||
weight: 'bold',
|
||||
}
|
||||
},
|
||||
subtitle: {
|
||||
display: true,
|
||||
text: subtitleText,
|
||||
color: 'white',
|
||||
font: {
|
||||
size: 14,
|
||||
},
|
||||
padding: {
|
||||
bottom: 20
|
||||
}
|
||||
}
|
||||
},
|
||||
scales: {
|
||||
x: {
|
||||
title: {
|
||||
display: true,
|
||||
text: '日期',
|
||||
color: 'white'
|
||||
},
|
||||
ticks: {
|
||||
color: 'white' // X轴刻度文字颜色
|
||||
},
|
||||
grid: {
|
||||
display: false
|
||||
}
|
||||
},
|
||||
y: {
|
||||
title: {
|
||||
display: true,
|
||||
text: '销量',
|
||||
color: 'white'
|
||||
},
|
||||
ticks: {
|
||||
color: 'white' // Y轴刻度文字颜色
|
||||
},
|
||||
grid: {
|
||||
color: 'rgba(255, 255, 255, 0.2)', // 网格线颜色
|
||||
drawBorder: false,
|
||||
},
|
||||
beginAtZero: true
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
fetchModels()
|
||||
fetchModelTypes()
|
||||
const today = new Date()
|
||||
form.start_date = today.toISOString().split('T')[0]
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.prediction-view {
|
||||
padding: 20px;
|
||||
}
|
||||
.card-header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
}
|
||||
.controls-section, .model-list-section {
|
||||
margin-top: 20px;
|
||||
}
|
||||
.model-list-section h4 {
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
.prediction-chart {
|
||||
margin-top: 20px;
|
||||
}
|
||||
</style>
|
||||
|
||||
<style>
|
||||
.no-overlay {
|
||||
background-color: transparent !important;
|
||||
}
|
||||
</style>
|
376
UI/src/views/prediction/ProductPredictionView.vue
Normal file
376
UI/src/views/prediction/ProductPredictionView.vue
Normal file
@ -0,0 +1,376 @@
|
||||
<template>
|
||||
<div class="prediction-view">
|
||||
<el-card>
|
||||
<template #header>
|
||||
<div class="card-header">
|
||||
<span>按药品预测</span>
|
||||
<el-tooltip content="对系统中的所有药品模型进行批量或单个预测">
|
||||
<el-icon><QuestionFilled /></el-icon>
|
||||
</el-tooltip>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<div class="controls-section">
|
||||
<el-form :model="filters" label-width="80px" inline>
|
||||
<el-form-item label="目标药品">
|
||||
<ProductSelector
|
||||
v-model="filters.product_id"
|
||||
:show-all-option="true"
|
||||
all-option-label="所有药品"
|
||||
clearable
|
||||
/>
|
||||
</el-form-item>
|
||||
<el-form-item label="算法类型">
|
||||
<el-select v-model="filters.model_type" placeholder="所有类型" clearable value-key="id" style="width: 200px;">
|
||||
<el-option
|
||||
v-for="item in modelTypes"
|
||||
:key="item.id"
|
||||
:label="item.name"
|
||||
:value="item"
|
||||
/>
|
||||
</el-select>
|
||||
</el-form-item>
|
||||
<el-form-item label="预测天数">
|
||||
<el-input-number v-model="form.future_days" :min="1" :max="365" />
|
||||
</el-form-item>
|
||||
<el-form-item label="历史天数">
|
||||
<el-input-number v-model="form.history_lookback_days" :min="7" :max="365" />
|
||||
</el-form-item>
|
||||
<el-form-item label="起始日期">
|
||||
<el-date-picker
|
||||
v-model="form.start_date"
|
||||
type="date"
|
||||
placeholder="选择日期"
|
||||
format="YYYY-MM-DD"
|
||||
value-format="YYYY-MM-DD"
|
||||
:clearable="false"
|
||||
/>
|
||||
</el-form-item>
|
||||
</el-form>
|
||||
</div>
|
||||
|
||||
<!-- 模型列表 -->
|
||||
<div class="model-list-section">
|
||||
<h4>📦 可用药品模型列表</h4>
|
||||
<el-table :data="paginatedModelList" style="width: 100%" v-loading="modelsLoading">
|
||||
<el-table-column prop="product_name" label="药品名称" sortable />
|
||||
<el-table-column prop="model_type" label="模型类型" sortable />
|
||||
<el-table-column prop="version" label="版本" />
|
||||
<el-table-column prop="created_at" label="创建时间" :formatter="formatDateTime" />
|
||||
<el-table-column label="操作">
|
||||
<template #default="{ row }">
|
||||
<el-button
|
||||
type="primary"
|
||||
size="small"
|
||||
@click="startPrediction(row)"
|
||||
:loading="predicting[row.model_id]"
|
||||
>
|
||||
<el-icon><TrendCharts /></el-icon>
|
||||
开始预测
|
||||
</el-button>
|
||||
</template>
|
||||
</el-table-column>
|
||||
</el-table>
|
||||
<el-pagination
|
||||
background
|
||||
layout="prev, pager, next"
|
||||
:total="filteredModelList.length"
|
||||
:page-size="pagination.pageSize"
|
||||
@current-change="handlePageChange"
|
||||
style="margin-top: 20px; justify-content: center;"
|
||||
/>
|
||||
</div>
|
||||
</el-card>
|
||||
|
||||
<!-- 预测结果弹窗 -->
|
||||
<el-dialog v-model="dialogVisible" title="📈 预测结果" width="70%">
|
||||
<div class="prediction-chart">
|
||||
<canvas ref="chartCanvas" width="800" height="400"></canvas>
|
||||
</div>
|
||||
<template #footer>
|
||||
<el-button @click="dialogVisible = false">关闭</el-button>
|
||||
</template>
|
||||
</el-dialog>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup>
|
||||
import { ref, reactive, onMounted, nextTick, computed } from 'vue'
|
||||
import axios from 'axios'
|
||||
import { ElMessage, ElDialog, ElTable, ElTableColumn, ElButton, ElIcon, ElCard, ElTooltip, ElForm, ElFormItem, ElInputNumber, ElDatePicker, ElSelect, ElOption, ElRow, ElCol, ElPagination } from 'element-plus'
|
||||
import { QuestionFilled, TrendCharts } from '@element-plus/icons-vue'
|
||||
import Chart from 'chart.js/auto'
|
||||
import ProductSelector from '../../components/ProductSelector.vue'
|
||||
|
||||
const modelList = ref([])
|
||||
const modelTypes = ref([])
|
||||
const modelsLoading = ref(false)
|
||||
const predicting = reactive({})
|
||||
const dialogVisible = ref(false)
|
||||
const predictionResult = ref(null)
|
||||
const chartCanvas = ref(null)
|
||||
let chart = null
|
||||
|
||||
const form = reactive({
|
||||
future_days: 7,
|
||||
history_lookback_days: 30,
|
||||
start_date: '',
|
||||
analyze_result: true // 保持分析功能开启,但UI上移除开关
|
||||
})
|
||||
|
||||
const filters = reactive({
|
||||
product_id: '',
|
||||
model_type: null
|
||||
})
|
||||
|
||||
const pagination = reactive({
|
||||
currentPage: 1,
|
||||
pageSize: 8
|
||||
})
|
||||
|
||||
const filteredModelList = computed(() => {
|
||||
return modelList.value.filter(model => {
|
||||
const productMatch = !filters.product_id || model.product_id === filters.product_id
|
||||
const modelTypeMatch = !filters.model_type || model.model_type === filters.model_type.id
|
||||
return productMatch && modelTypeMatch
|
||||
})
|
||||
})
|
||||
|
||||
const paginatedModelList = computed(() => {
|
||||
const start = (pagination.currentPage - 1) * pagination.pageSize
|
||||
const end = start + pagination.pageSize
|
||||
return filteredModelList.value.slice(start, end)
|
||||
})
|
||||
|
||||
const handlePageChange = (page) => {
|
||||
pagination.currentPage = page
|
||||
}
|
||||
|
||||
const formatDateTime = (row, column, cellValue) => {
|
||||
if (!cellValue) return ''
|
||||
const date = new Date(cellValue)
|
||||
return date.toLocaleString('zh-CN', {
|
||||
year: 'numeric',
|
||||
month: '2-digit',
|
||||
day: '2-digit',
|
||||
hour: '2-digit',
|
||||
minute: '2-digit',
|
||||
second: '2-digit',
|
||||
hour12: false
|
||||
}).replace(/\//g, '-')
|
||||
}
|
||||
|
||||
const fetchModelTypes = async () => {
|
||||
try {
|
||||
const response = await axios.get('/api/model_types')
|
||||
if (response.data.status === 'success') {
|
||||
modelTypes.value = response.data.data
|
||||
}
|
||||
} catch (error) {
|
||||
ElMessage.error('获取模型类型失败')
|
||||
}
|
||||
}
|
||||
|
||||
const fetchModels = async () => {
|
||||
modelsLoading.value = true
|
||||
try {
|
||||
const response = await axios.get('/api/models', { params: { training_mode: 'product' } })
|
||||
if (response.data.status === 'success') {
|
||||
modelList.value = response.data.data
|
||||
} else {
|
||||
ElMessage.error('获取模型列表失败')
|
||||
}
|
||||
} catch (error) {
|
||||
ElMessage.error('获取模型列表失败')
|
||||
} finally {
|
||||
modelsLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const startPrediction = async (model) => {
|
||||
predicting[model.model_id] = true
|
||||
try {
|
||||
const payload = {
|
||||
product_id: model.product_id,
|
||||
model_type: model.model_type,
|
||||
version: model.version,
|
||||
future_days: form.future_days,
|
||||
history_lookback_days: form.history_lookback_days,
|
||||
start_date: form.start_date,
|
||||
include_visualization: true, // 分析功能硬编码为开启
|
||||
}
|
||||
const response = await axios.post('/api/prediction', payload)
|
||||
if (response.data.status === 'success') {
|
||||
predictionResult.value = response.data.data
|
||||
ElMessage.success('预测完成!')
|
||||
dialogVisible.value = true
|
||||
await nextTick()
|
||||
renderChart()
|
||||
} else {
|
||||
ElMessage.error(response.data.error || '预测失败')
|
||||
}
|
||||
} catch (error) {
|
||||
ElMessage.error(error.response?.data?.error || '预测请求失败')
|
||||
} finally {
|
||||
predicting[model.model_id] = false
|
||||
}
|
||||
}
|
||||
|
||||
const renderChart = () => {
|
||||
if (!chartCanvas.value || !predictionResult.value) return
|
||||
if (chart) {
|
||||
chart.destroy()
|
||||
}
|
||||
|
||||
const formatDate = (date) => new Date(date).toISOString().split('T')[0];
|
||||
|
||||
const historyData = (predictionResult.value.history_data || []).map(p => ({ ...p, date: formatDate(p.date) }));
|
||||
const predictionData = (predictionResult.value.prediction_data || []).map(p => ({ ...p, date: formatDate(p.date) }));
|
||||
|
||||
if (historyData.length === 0 && predictionData.length === 0) {
|
||||
ElMessage.warning('没有可用于图表的数据。')
|
||||
return
|
||||
}
|
||||
|
||||
const allLabels = [...new Set([...historyData.map(p => p.date), ...predictionData.map(p => p.date)])].sort()
|
||||
const simplifiedLabels = allLabels.map(date => date.split('-').slice(1).join('/'));
|
||||
|
||||
const historyMap = new Map(historyData.map(p => [p.date, p.sales]))
|
||||
const predictionMap = new Map(predictionData.map(p => [p.date, p.predicted_sales]))
|
||||
|
||||
const alignedHistorySales = allLabels.map(label => historyMap.get(label) ?? null)
|
||||
const alignedPredictionSales = allLabels.map(label => predictionMap.get(label) ?? null)
|
||||
|
||||
if (historyData.length > 0 && predictionData.length > 0) {
|
||||
const lastHistoryDate = historyData[historyData.length - 1].date
|
||||
const lastHistoryValue = historyData[historyData.length - 1].sales
|
||||
if (!predictionMap.has(lastHistoryDate)) {
|
||||
alignedPredictionSales[allLabels.indexOf(lastHistoryDate)] = lastHistoryValue
|
||||
}
|
||||
}
|
||||
|
||||
let subtitleText = '';
|
||||
if (historyData.length > 0) {
|
||||
subtitleText += `历史数据: ${historyData[0].date} ~ ${historyData[historyData.length - 1].date}`;
|
||||
}
|
||||
if (predictionData.length > 0) {
|
||||
if (subtitleText) subtitleText += ' | ';
|
||||
subtitleText += `预测数据: ${predictionData[0].date} ~ ${predictionData[predictionData.length - 1].date}`;
|
||||
}
|
||||
|
||||
chart = new Chart(chartCanvas.value, {
|
||||
type: 'line',
|
||||
data: {
|
||||
labels: simplifiedLabels,
|
||||
datasets: [
|
||||
{
|
||||
label: '历史销量',
|
||||
data: alignedHistorySales,
|
||||
borderColor: '#67C23A',
|
||||
backgroundColor: 'rgba(103, 194, 58, 0.2)',
|
||||
tension: 0.4,
|
||||
fill: true,
|
||||
spanGaps: false,
|
||||
},
|
||||
{
|
||||
label: '预测销量',
|
||||
data: alignedPredictionSales,
|
||||
borderColor: '#409EFF',
|
||||
backgroundColor: 'rgba(64, 158, 255, 0.2)',
|
||||
tension: 0.4,
|
||||
fill: true,
|
||||
borderDash: [5, 5],
|
||||
}
|
||||
]
|
||||
},
|
||||
options: {
|
||||
responsive: true,
|
||||
maintainAspectRatio: false,
|
||||
plugins: {
|
||||
legend: {
|
||||
labels: {
|
||||
color: 'white' // 图例文字颜色
|
||||
}
|
||||
},
|
||||
title: {
|
||||
display: true,
|
||||
text: `${predictionResult.value.product_name} - 销量预测趋势图`,
|
||||
color: 'white',
|
||||
font: {
|
||||
size: 20,
|
||||
weight: 'bold',
|
||||
}
|
||||
},
|
||||
subtitle: {
|
||||
display: true,
|
||||
text: subtitleText,
|
||||
color: 'white',
|
||||
font: {
|
||||
size: 14,
|
||||
},
|
||||
padding: {
|
||||
bottom: 20
|
||||
}
|
||||
}
|
||||
},
|
||||
scales: {
|
||||
x: {
|
||||
title: {
|
||||
display: true,
|
||||
text: '日期',
|
||||
color: 'white'
|
||||
},
|
||||
ticks: {
|
||||
color: 'white' // X轴刻度文字颜色
|
||||
},
|
||||
grid: {
|
||||
display: false
|
||||
}
|
||||
},
|
||||
y: {
|
||||
title: {
|
||||
display: true,
|
||||
text: '销量',
|
||||
color: 'white'
|
||||
},
|
||||
ticks: {
|
||||
color: 'white' // Y轴刻度文字颜色
|
||||
},
|
||||
grid: {
|
||||
color: 'rgba(255, 255, 255, 0.2)', // 网格线颜色
|
||||
drawBorder: false,
|
||||
},
|
||||
beginAtZero: true
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
fetchModels()
|
||||
fetchModelTypes()
|
||||
const today = new Date()
|
||||
form.start_date = today.toISOString().split('T')[0]
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.prediction-view {
|
||||
padding: 20px;
|
||||
}
|
||||
.card-header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
}
|
||||
.filters-section, .global-settings-section, .model-list-section {
|
||||
margin-top: 20px;
|
||||
}
|
||||
.filters-section h4, .global-settings-section h4, .model-list-section h4 {
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
.prediction-chart {
|
||||
margin-top: 20px;
|
||||
}
|
||||
</style>
|
404
UI/src/views/prediction/StorePredictionView.vue
Normal file
404
UI/src/views/prediction/StorePredictionView.vue
Normal file
@ -0,0 +1,404 @@
|
||||
<template>
|
||||
<div class="prediction-view">
|
||||
<el-card>
|
||||
<template #header>
|
||||
<div class="card-header">
|
||||
<span>按店铺预测</span>
|
||||
<el-tooltip content="对系统中的所有店铺模型进行批量或单个预测">
|
||||
<el-icon><QuestionFilled /></el-icon>
|
||||
</el-tooltip>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<div class="controls-section">
|
||||
<el-form :model="filters" label-width="80px" inline>
|
||||
<el-form-item label="目标店铺">
|
||||
<StoreSelector
|
||||
v-model="filters.store_id"
|
||||
:show-all-option="true"
|
||||
all-option-label="所有店铺"
|
||||
clearable
|
||||
/>
|
||||
</el-form-item>
|
||||
<el-form-item label="算法类型">
|
||||
<el-select v-model="filters.model_type" placeholder="所有类型" clearable value-key="id" style="width: 200px;">
|
||||
<el-option
|
||||
v-for="item in modelTypes"
|
||||
:key="item.id"
|
||||
:label="item.name"
|
||||
:value="item"
|
||||
/>
|
||||
</el-select>
|
||||
</el-form-item>
|
||||
<el-form-item label="预测天数">
|
||||
<el-input-number v-model="form.future_days" :min="1" :max="365" />
|
||||
</el-form-item>
|
||||
<el-form-item label="历史天数">
|
||||
<el-input-number v-model="form.history_lookback_days" :min="7" :max="365" />
|
||||
</el-form-item>
|
||||
<el-form-item label="起始日期">
|
||||
<el-date-picker
|
||||
v-model="form.start_date"
|
||||
type="date"
|
||||
placeholder="选择日期"
|
||||
format="YYYY-MM-DD"
|
||||
value-format="YYYY-MM-DD"
|
||||
:clearable="false"
|
||||
/>
|
||||
</el-form-item>
|
||||
</el-form>
|
||||
</div>
|
||||
|
||||
<!-- 模型列表 -->
|
||||
<div class="model-list-section">
|
||||
<h4>📦 可用店铺模型列表</h4>
|
||||
<el-table :data="paginatedModelList" style="width: 100%" v-loading="modelsLoading">
|
||||
<el-table-column prop="store_name" label="店铺名称" sortable />
|
||||
<el-table-column prop="model_type" label="模型类型" sortable />
|
||||
<el-table-column prop="version" label="版本" />
|
||||
<el-table-column prop="created_at" label="创建时间" :formatter="formatDateTime" />
|
||||
<el-table-column label="操作">
|
||||
<template #default="{ row }">
|
||||
<el-button
|
||||
type="primary"
|
||||
size="small"
|
||||
@click="startPrediction(row)"
|
||||
:loading="predicting[row.model_id]"
|
||||
>
|
||||
<el-icon><TrendCharts /></el-icon>
|
||||
开始预测
|
||||
</el-button>
|
||||
</template>
|
||||
</el-table-column>
|
||||
</el-table>
|
||||
<el-pagination
|
||||
background
|
||||
layout="prev, pager, next"
|
||||
:total="filteredModelList.length"
|
||||
:page-size="pagination.pageSize"
|
||||
@current-change="handlePageChange"
|
||||
style="margin-top: 20px; justify-content: center;"
|
||||
/>
|
||||
</div>
|
||||
</el-card>
|
||||
|
||||
<!-- 预测结果弹窗 -->
|
||||
<el-dialog v-model="dialogVisible" title="📈 预测结果" width="70%">
|
||||
<div class="prediction-chart">
|
||||
<canvas ref="chartCanvas" width="800" height="400"></canvas>
|
||||
</div>
|
||||
<template #footer>
|
||||
<el-button @click="dialogVisible = false">关闭</el-button>
|
||||
</template>
|
||||
</el-dialog>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup>
|
||||
import { ref, reactive, onMounted, nextTick, computed } from 'vue'
|
||||
import axios from 'axios'
|
||||
import { ElMessage, ElDialog, ElTable, ElTableColumn, ElButton, ElIcon, ElCard, ElTooltip, ElForm, ElFormItem, ElInputNumber, ElDatePicker, ElSelect, ElOption, ElPagination } from 'element-plus'
|
||||
import { QuestionFilled, TrendCharts } from '@element-plus/icons-vue'
|
||||
import Chart from 'chart.js/auto'
|
||||
import StoreSelector from '../../components/StoreSelector.vue'
|
||||
|
||||
const modelList = ref([])
|
||||
const modelTypes = ref([])
|
||||
const modelsLoading = ref(false)
|
||||
const predicting = reactive({})
|
||||
const stores = ref([])
|
||||
const dialogVisible = ref(false)
|
||||
const predictionResult = ref(null)
|
||||
const chartCanvas = ref(null)
|
||||
let chart = null
|
||||
|
||||
const form = reactive({
|
||||
future_days: 7,
|
||||
history_lookback_days: 30,
|
||||
start_date: '',
|
||||
analyze_result: true // 保持分析功能开启,但UI上移除开关
|
||||
})
|
||||
|
||||
const filters = reactive({
|
||||
store_id: '',
|
||||
model_type: null
|
||||
})
|
||||
|
||||
const pagination = reactive({
|
||||
currentPage: 1,
|
||||
pageSize: 8
|
||||
})
|
||||
|
||||
const storeNameMap = computed(() => {
|
||||
return stores.value.reduce((acc, store) => {
|
||||
acc[store.store_id] = store.store_name
|
||||
return acc
|
||||
}, {})
|
||||
})
|
||||
|
||||
const modelsWithNames = computed(() => {
|
||||
return modelList.value.map(model => ({
|
||||
...model,
|
||||
store_name: storeNameMap.value[model.store_id] || model.store_id
|
||||
}))
|
||||
})
|
||||
|
||||
const filteredModelList = computed(() => {
|
||||
return modelsWithNames.value.filter(model => {
|
||||
const storeMatch = !filters.store_id || model.store_id === filters.store_id
|
||||
const modelTypeMatch = !filters.model_type || model.model_type === filters.model_type.id
|
||||
return storeMatch && modelTypeMatch
|
||||
})
|
||||
})
|
||||
|
||||
const paginatedModelList = computed(() => {
|
||||
const start = (pagination.currentPage - 1) * pagination.pageSize
|
||||
const end = start + pagination.pageSize
|
||||
return filteredModelList.value.slice(start, end)
|
||||
})
|
||||
|
||||
const handlePageChange = (page) => {
|
||||
pagination.currentPage = page
|
||||
}
|
||||
|
||||
const formatDateTime = (row, column, cellValue) => {
|
||||
if (!cellValue) return ''
|
||||
const date = new Date(cellValue)
|
||||
return date.toLocaleString('zh-CN', {
|
||||
year: 'numeric',
|
||||
month: '2-digit',
|
||||
day: '2-digit',
|
||||
hour: '2-digit',
|
||||
minute: '2-digit',
|
||||
second: '2-digit',
|
||||
hour12: false
|
||||
}).replace(/\//g, '-')
|
||||
}
|
||||
|
||||
const fetchModelTypes = async () => {
|
||||
try {
|
||||
const response = await axios.get('/api/model_types')
|
||||
if (response.data.status === 'success') {
|
||||
modelTypes.value = response.data.data
|
||||
}
|
||||
} catch (error) {
|
||||
ElMessage.error('获取模型类型失败')
|
||||
}
|
||||
}
|
||||
|
||||
const fetchModels = async () => {
|
||||
modelsLoading.value = true
|
||||
try {
|
||||
const response = await axios.get('/api/models', { params: { training_mode: 'store' } })
|
||||
if (response.data.status === 'success') {
|
||||
modelList.value = response.data.data
|
||||
} else {
|
||||
ElMessage.error('获取模型列表失败')
|
||||
}
|
||||
} catch (error) {
|
||||
ElMessage.error('获取模型列表失败')
|
||||
} finally {
|
||||
modelsLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
const fetchStores = async () => {
|
||||
try {
|
||||
const response = await axios.get('/api/stores')
|
||||
if (response.data.status === 'success') {
|
||||
stores.value = response.data.data
|
||||
}
|
||||
} catch (error) {
|
||||
ElMessage.error('获取店铺列表失败')
|
||||
}
|
||||
}
|
||||
|
||||
const startPrediction = async (model) => {
|
||||
predicting[model.model_id] = true
|
||||
try {
|
||||
const payload = {
|
||||
training_mode: 'store',
|
||||
store_id: model.store_id,
|
||||
model_type: model.model_type,
|
||||
version: model.version,
|
||||
future_days: form.future_days,
|
||||
history_lookback_days: form.history_lookback_days,
|
||||
start_date: form.start_date,
|
||||
analyze_result: form.analyze_result,
|
||||
}
|
||||
const response = await axios.post('/api/prediction', payload)
|
||||
if (response.data.status === 'success') {
|
||||
predictionResult.value = response.data.data
|
||||
ElMessage.success('预测完成!')
|
||||
dialogVisible.value = true
|
||||
await nextTick()
|
||||
renderChart()
|
||||
} else {
|
||||
ElMessage.error(response.data.error || '预测失败')
|
||||
}
|
||||
} catch (error) {
|
||||
ElMessage.error(error.response?.data?.error || '预测请求失败')
|
||||
} finally {
|
||||
predicting[model.model_id] = false
|
||||
}
|
||||
}
|
||||
|
||||
const renderChart = () => {
|
||||
if (!chartCanvas.value || !predictionResult.value) return
|
||||
if (chart) {
|
||||
chart.destroy()
|
||||
}
|
||||
|
||||
const formatDate = (date) => new Date(date).toISOString().split('T')[0];
|
||||
|
||||
const historyData = (predictionResult.value.history_data || []).map(p => ({ ...p, date: formatDate(p.date) }));
|
||||
const predictionData = (predictionResult.value.prediction_data || []).map(p => ({ ...p, date: formatDate(p.date) }));
|
||||
|
||||
if (historyData.length === 0 && predictionData.length === 0) {
|
||||
ElMessage.warning('没有可用于图表的数据。')
|
||||
return
|
||||
}
|
||||
|
||||
const allLabels = [...new Set([...historyData.map(p => p.date), ...predictionData.map(p => p.date)])].sort()
|
||||
const simplifiedLabels = allLabels.map(date => date.split('-').slice(1).join('/'));
|
||||
|
||||
const historyMap = new Map(historyData.map(p => [p.date, p.sales]))
|
||||
const predictionMap = new Map(predictionData.map(p => [p.date, p.predicted_sales]))
|
||||
|
||||
const alignedHistorySales = allLabels.map(label => historyMap.get(label) ?? null)
|
||||
const alignedPredictionSales = allLabels.map(label => predictionMap.get(label) ?? null)
|
||||
|
||||
if (historyData.length > 0 && predictionData.length > 0) {
|
||||
const lastHistoryDate = historyData[historyData.length - 1].date
|
||||
const lastHistoryValue = historyData[historyData.length - 1].sales
|
||||
if (!predictionMap.has(lastHistoryDate)) {
|
||||
alignedPredictionSales[allLabels.indexOf(lastHistoryDate)] = lastHistoryValue
|
||||
}
|
||||
}
|
||||
|
||||
let subtitleText = '';
|
||||
if (historyData.length > 0) {
|
||||
subtitleText += `历史数据: ${historyData[0].date} ~ ${historyData[historyData.length - 1].date}`;
|
||||
}
|
||||
if (predictionData.length > 0) {
|
||||
if (subtitleText) subtitleText += ' | ';
|
||||
subtitleText += `预测数据: ${predictionData[0].date} ~ ${predictionData[predictionData.length - 1].date}`;
|
||||
}
|
||||
|
||||
chart = new Chart(chartCanvas.value, {
|
||||
type: 'line',
|
||||
data: {
|
||||
labels: simplifiedLabels,
|
||||
datasets: [
|
||||
{
|
||||
label: '历史销量',
|
||||
data: alignedHistorySales,
|
||||
borderColor: '#67C23A',
|
||||
backgroundColor: 'rgba(103, 194, 58, 0.2)',
|
||||
tension: 0.4,
|
||||
fill: true,
|
||||
spanGaps: false,
|
||||
},
|
||||
{
|
||||
label: '预测销量',
|
||||
data: alignedPredictionSales,
|
||||
borderColor: '#409EFF',
|
||||
backgroundColor: 'rgba(64, 158, 255, 0.2)',
|
||||
tension: 0.4,
|
||||
fill: true,
|
||||
borderDash: [5, 5],
|
||||
}
|
||||
]
|
||||
},
|
||||
options: {
|
||||
responsive: true,
|
||||
maintainAspectRatio: false,
|
||||
plugins: {
|
||||
legend: {
|
||||
labels: {
|
||||
color: 'white' // 图例文字颜色
|
||||
}
|
||||
},
|
||||
title: {
|
||||
display: true,
|
||||
text: `${predictionResult.value.store_name} - 销量预测趋势图`,
|
||||
color: 'white',
|
||||
font: {
|
||||
size: 20,
|
||||
weight: 'bold',
|
||||
}
|
||||
},
|
||||
subtitle: {
|
||||
display: true,
|
||||
text: subtitleText,
|
||||
color: 'white',
|
||||
font: {
|
||||
size: 14,
|
||||
},
|
||||
padding: {
|
||||
bottom: 20
|
||||
}
|
||||
}
|
||||
},
|
||||
scales: {
|
||||
x: {
|
||||
title: {
|
||||
display: true,
|
||||
text: '日期',
|
||||
color: 'white'
|
||||
},
|
||||
ticks: {
|
||||
color: 'white' // X轴刻度文字颜色
|
||||
},
|
||||
grid: {
|
||||
display: false
|
||||
}
|
||||
},
|
||||
y: {
|
||||
title: {
|
||||
display: true,
|
||||
text: '销量',
|
||||
color: 'white'
|
||||
},
|
||||
ticks: {
|
||||
color: 'white' // Y轴刻度文字颜色
|
||||
},
|
||||
grid: {
|
||||
color: 'rgba(255, 255, 255, 0.2)', // 网格线颜色
|
||||
drawBorder: false,
|
||||
},
|
||||
beginAtZero: true
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
fetchStores()
|
||||
fetchModels()
|
||||
fetchModelTypes()
|
||||
const today = new Date()
|
||||
form.start_date = today.toISOString().split('T')[0]
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.prediction-view {
|
||||
padding: 20px;
|
||||
}
|
||||
.card-header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
}
|
||||
.controls-section, .model-list-section {
|
||||
margin-top: 20px;
|
||||
}
|
||||
.model-list-section h4 {
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
.prediction-chart {
|
||||
margin-top: 20px;
|
||||
}
|
||||
</style>
|
@ -26,7 +26,20 @@ export default defineConfig({
|
||||
'/api': {
|
||||
target: 'http://127.0.0.1:5000',
|
||||
changeOrigin: true,
|
||||
}
|
||||
},
|
||||
// 新增Swagger UI的代理规则
|
||||
'/swagger': {
|
||||
target: 'http://127.0.0.1:5000',
|
||||
changeOrigin: true,
|
||||
},
|
||||
'/apispec.json': {
|
||||
target: 'http://127.0.0.1:5000',
|
||||
changeOrigin: true,
|
||||
},
|
||||
'/flasgger_static': {
|
||||
target: 'http://127.0.0.1:5000',
|
||||
changeOrigin: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
})
|
180
Windows_快速启动.bat
180
Windows_快速启动.bat
@ -1,180 +0,0 @@
|
||||
@echo off
|
||||
chcp 65001 >nul
|
||||
echo ====================================
|
||||
echo 药店销售预测系统 - Windows 快速启动
|
||||
echo ====================================
|
||||
echo.
|
||||
|
||||
:: 检查Python
|
||||
echo [1/6] 检查Python环境...
|
||||
python --version >nul 2>&1
|
||||
if errorlevel 1 (
|
||||
echo ❌ 未找到Python,请先安装Python 3.8+
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
echo ✓ Python环境正常
|
||||
|
||||
:: 检查虚拟环境
|
||||
echo.
|
||||
echo [2/6] 检查虚拟环境...
|
||||
if not exist ".venv\Scripts\python.exe" (
|
||||
echo 🔄 创建虚拟环境...
|
||||
python -m venv .venv
|
||||
if errorlevel 1 (
|
||||
echo ❌ 虚拟环境创建失败
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
)
|
||||
echo ✓ 虚拟环境准备完成
|
||||
|
||||
:: 激活虚拟环境
|
||||
echo.
|
||||
echo [3/6] 激活虚拟环境...
|
||||
call .venv\Scripts\activate.bat
|
||||
if errorlevel 1 (
|
||||
echo ❌ 虚拟环境激活失败
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
echo ✓ 虚拟环境已激活
|
||||
|
||||
:: 安装依赖
|
||||
echo.
|
||||
echo [4/6] 检查Python依赖...
|
||||
pip show flask >nul 2>&1
|
||||
if errorlevel 1 (
|
||||
echo 🔄 安装Python依赖...
|
||||
pip install -r install\requirements.txt
|
||||
if errorlevel 1 (
|
||||
echo ❌ 依赖安装失败
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
)
|
||||
echo ✓ Python依赖已安装
|
||||
|
||||
:: 检查数据文件
|
||||
echo.
|
||||
echo [5/6] 检查数据文件...
|
||||
if not exist "pharmacy_sales_multi_store.csv" (
|
||||
echo 🔄 生成示例数据...
|
||||
python generate_multi_store_data.py
|
||||
if errorlevel 1 (
|
||||
echo ❌ 数据生成失败
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
)
|
||||
echo ✓ 数据文件准备完成
|
||||
|
||||
:: 初始化数据库
|
||||
echo.
|
||||
echo [6/6] 初始化数据库...
|
||||
if not exist "prediction_history.db" (
|
||||
echo 🔄 初始化数据库...
|
||||
python server\init_multi_store_db.py
|
||||
if errorlevel 1 (
|
||||
echo ❌ 数据库初始化失败
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
)
|
||||
echo ✓ 数据库准备完成
|
||||
|
||||
echo.
|
||||
echo ====================================
|
||||
echo ✅ 环境准备完成!
|
||||
echo ====================================
|
||||
echo.
|
||||
echo 接下来请选择启动方式:
|
||||
echo [1] 启动API服务器 (后端)
|
||||
echo [2] 启动前端开发服务器
|
||||
echo [3] 运行API测试
|
||||
echo [4] 查看项目状态
|
||||
echo [0] 退出
|
||||
echo.
|
||||
|
||||
:menu
|
||||
set /p choice="请选择 (0-4): "
|
||||
|
||||
if "%choice%"=="1" goto start_api
|
||||
if "%choice%"=="2" goto start_frontend
|
||||
if "%choice%"=="3" goto run_tests
|
||||
if "%choice%"=="4" goto show_status
|
||||
if "%choice%"=="0" goto end
|
||||
echo 无效选择,请重新输入
|
||||
goto menu
|
||||
|
||||
:start_api
|
||||
echo.
|
||||
echo 🚀 启动API服务器...
|
||||
echo 服务器将在 http://localhost:5000 启动
|
||||
echo API文档访问: http://localhost:5000/swagger
|
||||
echo.
|
||||
echo 按 Ctrl+C 停止服务器
|
||||
echo.
|
||||
cd server
|
||||
python api.py
|
||||
goto end
|
||||
|
||||
:start_frontend
|
||||
echo.
|
||||
echo 🚀 启动前端开发服务器...
|
||||
cd UI
|
||||
if not exist "node_modules" (
|
||||
echo 🔄 安装前端依赖...
|
||||
npm install
|
||||
if errorlevel 1 (
|
||||
echo ❌ 前端依赖安装失败
|
||||
pause
|
||||
goto menu
|
||||
)
|
||||
)
|
||||
echo 前端将在 http://localhost:5173 启动
|
||||
echo.
|
||||
npm run dev
|
||||
goto end
|
||||
|
||||
:run_tests
|
||||
echo.
|
||||
echo 🧪 运行API测试...
|
||||
python test_api_endpoints.py
|
||||
echo.
|
||||
pause
|
||||
goto menu
|
||||
|
||||
:show_status
|
||||
echo.
|
||||
echo 📊 项目状态检查...
|
||||
echo.
|
||||
echo === 文件检查 ===
|
||||
if exist "pharmacy_sales_multi_store.csv" (echo ✓ 多店铺数据文件) else (echo ❌ 多店铺数据文件缺失)
|
||||
if exist "prediction_history.db" (echo ✓ 预测历史数据库) else (echo ❌ 预测历史数据库缺失)
|
||||
if exist "server\api.py" (echo ✓ API服务器文件) else (echo ❌ API服务器文件缺失)
|
||||
if exist "UI\package.json" (echo ✓ 前端项目文件) else (echo ❌ 前端项目文件缺失)
|
||||
|
||||
echo.
|
||||
echo === 模型文件 ===
|
||||
if exist "saved_models" (
|
||||
echo 已保存的模型:
|
||||
dir saved_models\*.pth /b 2>nul || echo 暂无已训练的模型
|
||||
) else (
|
||||
echo ❌ 模型目录不存在
|
||||
)
|
||||
|
||||
echo.
|
||||
echo === 虚拟环境状态 ===
|
||||
python -c "import sys; print('Python版本:', sys.version)"
|
||||
python -c "import flask; print('Flask版本:', flask.__version__)" 2>nul || echo ❌ Flask未安装
|
||||
|
||||
echo.
|
||||
pause
|
||||
goto menu
|
||||
|
||||
:end
|
||||
echo.
|
||||
echo 感谢使用药店销售预测系统!
|
||||
echo.
|
||||
pause
|
82
data/old_5shops_50skus-据结构字典.md
Normal file
82
data/old_5shops_50skus-据结构字典.md
Normal file
@ -0,0 +1,82 @@
|
||||
|
||||
| 分类 | 字段名 | 数据类型 | 描述 | 来源 |
|
||||
| --- | --- | --- | --- | --- |
|
||||
| **标识符** | `subbh` | String | 店铺唯一标识 | 骨架 |
|
||||
| | `hh` | String | 商品唯一标识 | 骨架 |
|
||||
| | `kdrq` | Date | 开单日期 (主键之一) | 骨架 |
|
||||
| **核心指标** | `sales_quantity` | Float | 当日销售量 (无销售则为0) | 当日销售 |
|
||||
| | `return_quantity` | Float | 当日退货量 (无销售则为0) | 当日销售 |
|
||||
| | `net_sales_quantity` | Float | **当日净销售量 (目标变量)** | 当日销售 |
|
||||
| | `gross_profit_total` | Float | 当日毛利 (无销售则为0) | 当日销售 |
|
||||
| | `transaction_count` | Integer | 当日交易次数 (无销售则为0) | 当日销售 |
|
||||
| **日期特征** | `date` | Date | 日期 (冗余字段) | 时序计算 |
|
||||
| | `is_weekend` | Boolean | 是否为周末 (True/False) | 时序计算 |
|
||||
| | `day_of_week` | Integer | 一周中的第几天 (0=周一, 6=周日) | 时序计算 |
|
||||
| | `day_of_month` | Integer | 一月中的第几天 (1-31) | 时序计算 |
|
||||
| | `day_of_year` | Integer | 一年中的第几天 (1-366) | 时序计算 |
|
||||
| | `week_of_month` | Integer | 当月第几周 (1-5) | 时序计算 |
|
||||
| | `month` | Integer | 月份 (1-12) | 时序计算 |
|
||||
| | `quarter` | Integer | 季度 (1-4) | 时序计算 |
|
||||
| | `is_holiday` | Boolean | 是否为节假日 (True/False) | 时序计算 |
|
||||
| **生命周期特征** | `first_sale_date` | Date | SKU在店首次销售日期 | 生命周期 |
|
||||
| | `last_sale_date` | Date | SKU在店末次销售日期 | 生命周期 |
|
||||
| | `lifecycle_days` | Integer | SKU在店生命周期总天数 | 生命周期 |
|
||||
| | `sample_category` | String | 生命周期分类 (new/medium/old) | 生命周期 |
|
||||
| | `rolling_7d_valid` | Boolean | 7日滚动窗口是否有效 (距离首次销售>=7天) | 生命周期 |
|
||||
| | `rolling_15d_valid` | Boolean | 15日滚动窗口是否有效 | 生命周期 |
|
||||
| | `rolling_30d_valid` | Boolean | 30日滚动窗口是否有效 | 生命周期 |
|
||||
| | `rolling_90d_valid` | Boolean | 90日滚动窗口是否有效 | 生命周期 |
|
||||
| **滚动特征 (7天)** | `sales_quantity_rolling_mean_7d` | Float | 过去7日平均销售量 | 历史滚动 |
|
||||
| | `return_quantity_rolling_mean_7d` | Float | 过去7日平均退货量 | 历史滚动 |
|
||||
| | `net_sales_quantity_rolling_mean_7d`| Float | 过去7日平均净销量 | 历史滚动 |
|
||||
| | `sales_quantity_rolling_sum_7d` | Float | 过去7日总销售量 | 历史滚动 |
|
||||
| | `return_quantity_rolling_sum_7d` | Float | 过去7日总退货量 | 历史滚动 |
|
||||
| | `net_sales_quantity_rolling_sum_7d` | Float | 过去7日总净销量 | 历史滚动 |
|
||||
| **滚动特征 (15天)** | `sales_quantity_rolling_mean_15d` | Float | 过去15日平均销售量 | 历史滚动 |
|
||||
| | `return_quantity_rolling_mean_15d` | Float | 过去15日平均退货量 | 历史滚动 |
|
||||
| | `net_sales_quantity_rolling_mean_15d`| Float | 过去15日平均净销量 | 历史滚动 |
|
||||
| | `sales_quantity_rolling_sum_15d` | Float | 过去15日总销售量 | 历史滚动 |
|
||||
| | `return_quantity_rolling_sum_15d` | Float | 过去15日总退货量 | 历史滚动 |
|
||||
| | `net_sales_quantity_rolling_sum_15d` | Float | 过去15日总净销量 | 历史滚动 |
|
||||
| **滚动特征 (30天)** | `sales_quantity_rolling_mean_30d` | Float | 过去30日平均销售量 | 历史滚动 |
|
||||
| | `return_quantity_rolling_mean_30d` | Float | 过去30日平均退货量 | 历史滚动 |
|
||||
| | `net_sales_quantity_rolling_mean_30d`| Float | 过去30日平均净销量 | 历史滚动 |
|
||||
| | `sales_quantity_rolling_sum_30d` | Float | 过去30日总销售量 | 历史滚动 |
|
||||
| | `return_quantity_rolling_sum_30d` | Float | 过去30日总退货量 | 历史滚动 |
|
||||
| | `net_sales_quantity_rolling_sum_30d` | Float | 过去30日总净销量 | 历史滚动 |
|
||||
| **滚动特征 (90天)** | `sales_quantity_rolling_mean_90d` | Float | 过去90日平均销售量 | 历史滚动 |
|
||||
| | `return_quantity_rolling_mean_90d` | Float | 过去90日平均退货量 | 历史滚动 |
|
||||
| | `net_sales_quantity_rolling_mean_90d`| Float | 过去90日平均净销量 | 历史滚动 |
|
||||
| | `sales_quantity_rolling_sum_90d` | Float | 过去90日总销售量 | 历史滚动 |
|
||||
| | `return_quantity_rolling_sum_90d` | Float | 过去90日总退货量 | 历史滚动 |
|
||||
| | `net_sales_quantity_rolling_sum_90d` | Float | 过去90日总净销量 | 历史滚动 |
|
||||
| **滚动特征 (180天)** | `sales_quantity_rolling_mean_180d` | Float | 过去180日平均销售量 | 历史滚动 |
|
||||
| | `return_quantity_rolling_mean_180d` | Float | 过去180日平均退货量 | 历史滚动 |
|
||||
| | `net_sales_quantity_rolling_mean_180d`| Float | 过去180日平均净销量 | 历史滚动 |
|
||||
| | `sales_quantity_rolling_sum_180d` | Float | 过去180日总销售量 | 历史滚动 |
|
||||
| | `return_quantity_rolling_sum_180d` | Float | 过去180日总退货量 | 历史滚动 |
|
||||
| | `net_sales_quantity_rolling_sum_180d` | Float | 过去180日总净销量 | 历史滚动 |
|
||||
| **滚动特征 (365天)** | `sales_quantity_rolling_mean_365d` | Float | 过去365日平均销售量 | 历史滚动 |
|
||||
| | `return_quantity_rolling_mean_365d` | Float | 过去365日平均退货量 | 历史滚动 |
|
||||
| | `net_sales_quantity_rolling_mean_365d`| Float | 过去365日平均净销量 | 历史滚动 |
|
||||
| | `sales_quantity_rolling_sum_365d` | Float | 过去365日总销售量 | 历史滚动 |
|
||||
| | `return_quantity_rolling_sum_365d` | Float | 过去365日总退货量 | 历史滚动 |
|
||||
| | `net_sales_quantity_rolling_sum_365d` | Float | 过去365日总净销量 | 历史滚动 |
|
||||
| **店铺特征** | `province` | String | 店铺所在省份 | 店铺特征 |
|
||||
| | `city` | String | 店铺所在城市 | 店铺特征 |
|
||||
| | `district` | String | 店铺所在行政区 | 店铺特征 |
|
||||
| | `poi_residential_count` | Integer | 周边住宅区POI数量 | 店铺特征 |
|
||||
| | `poi_school_count` | Integer | 周边学校POI数量 | 店铺特征 |
|
||||
| | `poi_mall_count` | Integer | 周边购物中心POI数量 | 店铺特征 |
|
||||
| | `temperature_2m_max` | Float | 当日最高气温 | 店铺特征 |
|
||||
| | `temperature_2m_min` | Float | 当日最低气温 | 店铺特征 |
|
||||
| | `temperature_2m_mean`| Float | 当日平均气温 | 店铺特征 |
|
||||
| **商品特征** | `零售大类代码_encoded` | Integer | 零售大类代码的数字编码 | 商品特征 |
|
||||
| | `零售中类代码_encoded` | Integer | 零售中类代码的数字编码 | 商品特征 |
|
||||
| | `零售小类代码_encoded` | Integer | 零售小类代码的数字编码 | 商品特征 |
|
||||
| | `商品ABC分类_encoded` | Integer | 商品ABC分类的数字编码 | 商品特征 |
|
||||
| | `商品手册代码_encoded` | Integer | 商品手册代码的数字编码 | 商品特征 |
|
||||
| | `产地_encoded` | Integer | 产地的数字编码 | 商品特征 |
|
||||
| | `brand_encoded` | Integer | 品牌的数字编码 | 商品特征 |
|
||||
| | `packaging_quantity` | Float | 包装数量 (从规格中提取) | 商品特征 |
|
||||
| | `approval_type_encoded` | Integer | 批准文号类型的数字编码 | 商品特征 |
|
BIN
data/old_5shops_50skus.parquet
Normal file
BIN
data/old_5shops_50skus.parquet
Normal file
Binary file not shown.
0
docs/UI_PREDICTION_FEATURE_CHANGELOG.md
Normal file
0
docs/UI_PREDICTION_FEATURE_CHANGELOG.md
Normal file
101
feature_branch_workflow.md
Normal file
101
feature_branch_workflow.md
Normal file
@ -0,0 +1,101 @@
|
||||
# 功能分支开发与合并标准流程
|
||||
|
||||
本文档旨在说明一个标准、安全的功能开发流程,涵盖从创建分支到最终合并的完整步骤。
|
||||
|
||||
## 流程概述
|
||||
|
||||
1. **创建功能分支**:基于主开发分支(如 `lyf-dev`)在远程仓库创建一个新的功能分支(如 `lyf-dev-req0001`)。
|
||||
2. **同步到本地**:将远程的新分支同步到本地,并切换到该分支进行开发。
|
||||
3. **开发与提交**:在功能分支上进行代码开发,并频繁提交改动。
|
||||
4. **推送到远程**:定期将本地的提交推送到远程功能分支,用于备份和协作。
|
||||
5. **合并回主分支**:当功能开发和测试完成后,将功能分支合并回主开发分支。
|
||||
|
||||
---
|
||||
|
||||
## 详细操作步骤
|
||||
|
||||
### 第一步:同步并切换到功能分支
|
||||
|
||||
当远程仓库已经创建了新的功能分支后(例如 `lyf-dev-req0001`),本地需要执行以下命令来同步和切换。
|
||||
|
||||
1. **获取远程最新信息**:
|
||||
```bash
|
||||
git fetch
|
||||
```
|
||||
这个命令会拉取远程仓库的所有最新信息,包括新建的分支。
|
||||
|
||||
2. **创建并切换到本地分支**:
|
||||
```bash
|
||||
git checkout lyf-dev-req0001
|
||||
```
|
||||
Git 会自动检测到远程存在一个同名分支,并为您创建一个本地分支来跟踪它。
|
||||
|
||||
### 第二步:在功能分支上开发和提交
|
||||
|
||||
现在您可以在 `lyf-dev-req0001` 分支上安全地进行开发。
|
||||
|
||||
1. **进行代码修改**:添加、修改或删除文件以实现新功能。
|
||||
|
||||
2. **提交代码改动**:
|
||||
```bash
|
||||
# 添加所有修改过的文件到暂存区
|
||||
git add .
|
||||
|
||||
# 提交改动到本地仓库,并附上有意义的说明
|
||||
git commit -m "feat: 完成用户认证模块"
|
||||
```
|
||||
> **最佳实践**:保持提交的粒度小且描述清晰,方便代码审查和问题回溯。
|
||||
|
||||
### 第三步:推送功能分支到远程
|
||||
|
||||
为了备份代码和进行团队协作,需要将本地的提交推送到远程仓库。
|
||||
|
||||
```bash
|
||||
# 将当前分支 (lyf-dev-req0001) 的提交推送到远程同名分支
|
||||
git push origin lyf-dev-req0001
|
||||
```
|
||||
|
||||
### 第四步:合并功能到主开发分支 (`lyf-dev`)
|
||||
|
||||
当功能开发完毕并通过测试后,就可以准备将其合并回 `lyf-dev` 分支。
|
||||
|
||||
1. **切换到主开发分支**:
|
||||
```bash
|
||||
git checkout lyf-dev
|
||||
```
|
||||
|
||||
2. **确保主开发分支是最新版本**:
|
||||
在合并前,务必先拉取远程 `lyf-dev` 的最新代码,以减少冲突的可能性。
|
||||
```bash
|
||||
git pull origin lyf-dev
|
||||
```
|
||||
|
||||
3. **合并功能分支**:
|
||||
将 `lyf-dev-req0001` 的所有改动合并到当前的 `lyf-dev` 分支。
|
||||
```bash
|
||||
git merge lyf-dev-req0001
|
||||
```
|
||||
* **如果出现冲突 (Conflict)**:Git 会提示您哪些文件存在冲突。您需要手动打开这些文件,解决冲突部分,然后再次执行 `git add .` 和 `git commit` 来完成合并提交。
|
||||
* **如果没有冲突**:Git 会自动创建一个合并提交。
|
||||
|
||||
4. **将合并后的主分支推送到远程**:
|
||||
```bash
|
||||
git push origin lyf-dev
|
||||
```
|
||||
|
||||
### 第五步:清理(可选)
|
||||
|
||||
当功能分支确认不再需要后,可以删除它以保持仓库整洁。
|
||||
|
||||
1. **删除远程分支**:
|
||||
```bash
|
||||
git push origin --delete lyf-dev-req0001
|
||||
```
|
||||
|
||||
2. **删除本地分支**:
|
||||
```bash
|
||||
git branch -d lyf-dev-req0001
|
||||
```
|
||||
|
||||
---
|
||||
遵循以上流程可以确保您的开发工作流程清晰、安全且高效。
|
@ -1,5 +1,28 @@
|
||||
@echo off
|
||||
chcp 65001 >nul
|
||||
|
||||
REM 检查.venv目录是否存在
|
||||
if exist .venv (
|
||||
echo 虚拟环境已存在。
|
||||
) else (
|
||||
echo 正在创建虚拟环境...
|
||||
uv venv
|
||||
if errorlevel 1 (
|
||||
echo 创建虚拟环境失败。
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
echo.
|
||||
echo 虚拟环境已创建。请激活它后重新运行此脚本。
|
||||
echo - Windows (CMD): .\.venv\Scripts\activate
|
||||
echo - Windows (PowerShell): .\.venv\Scripts\Activate.ps1
|
||||
echo - Linux/macOS: source .venv/bin/activate
|
||||
echo.
|
||||
pause
|
||||
exit /b 0
|
||||
)
|
||||
|
||||
echo 正在安装药店销售预测系统API依赖...
|
||||
pip install flask==3.1.1 flask-cors==6.0.0 flasgger==0.9.7.1
|
||||
uv pip install flask==3.1.1 flask-cors==6.0.0 flasgger==0.9.7.1 -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
echo 依赖安装完成,现在可以运行 python api.py 启动API服务
|
||||
pause
|
||||
pause
|
@ -1,8 +1,33 @@
|
||||
@echo off
|
||||
chcp 65001 >nul
|
||||
|
||||
REM 检查.venv目录是否存在
|
||||
if exist .venv (
|
||||
echo 虚拟环境已存在。
|
||||
) else (
|
||||
echo 正在创建虚拟环境...
|
||||
uv venv
|
||||
if errorlevel 1 (
|
||||
echo 创建虚拟环境失败。
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
echo.
|
||||
echo 虚拟环境已创建。请激活它后重新运行此脚本。
|
||||
echo - Windows (CMD): .\.venv\Scripts\activate
|
||||
echo - Windows (PowerShell): .\.venv\Scripts\Activate.ps1
|
||||
echo - Linux/macOS: source .venv/bin/activate
|
||||
echo.
|
||||
pause
|
||||
exit /b 0
|
||||
)
|
||||
|
||||
echo.
|
||||
echo 药店销售预测系统 - 依赖库安装脚本
|
||||
echo ==================================
|
||||
echo.
|
||||
echo 虚拟环境已激活,准备安装依赖。
|
||||
echo.
|
||||
|
||||
echo 请选择要安装的版本:
|
||||
echo 1. CPU版本(适用于没有NVIDIA GPU的计算机)
|
||||
@ -14,23 +39,23 @@ set /p choice=请输入选项 (1/2/3):
|
||||
|
||||
if "%choice%"=="1" (
|
||||
echo 正在安装CPU版本依赖...
|
||||
pip install -r requirements.txt
|
||||
uv pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
) else if "%choice%"=="2" (
|
||||
echo 正在安装GPU版本(CUDA 12.1)依赖...
|
||||
echo 首先安装基础依赖...
|
||||
pip install -r requirements-gpu.txt --no-deps
|
||||
uv pip install -r requirements-gpu.txt --no-deps -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
echo 安装除PyTorch以外的其他依赖...
|
||||
pip install numpy==2.3.0 pandas==2.3.0 matplotlib==3.10.3 scikit-learn==1.7.0 tqdm==4.67.1 openpyxl==3.1.5
|
||||
uv pip install numpy==2.3.0 pandas==2.3.0 matplotlib==3.10.3 scikit-learn==1.7.0 tqdm==4.67.1 openpyxl==3.1.5 -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
echo 从PyTorch官方源安装CUDA 12.1版本的PyTorch...
|
||||
pip install torch==2.7.1 --index-url https://download.pytorch.org/whl/cu121
|
||||
uv pip install torch==2.7.1 --index-url https://download.pytorch.org/whl/cu121
|
||||
) else if "%choice%"=="3" (
|
||||
echo 正在安装GPU版本(CUDA 11.8)依赖...
|
||||
echo 首先安装基础依赖...
|
||||
pip install -r requirements-gpu-cu118.txt --no-deps
|
||||
uv pip install -r requirements-gpu-cu118.txt --no-deps -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
echo 安装除PyTorch以外的其他依赖...
|
||||
pip install numpy==2.3.0 pandas==2.3.0 matplotlib==3.10.3 scikit-learn==1.7.0 tqdm==4.67.1 openpyxl==3.1.5
|
||||
uv pip install numpy==2.3.0 pandas==2.3.0 matplotlib==3.10.3 scikit-learn==1.7.0 tqdm==4.67.1 openpyxl==3.1.5 -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
echo 从PyTorch官方源安装CUDA 11.8版本的PyTorch...
|
||||
pip install torch==2.7.1 --index-url https://download.pytorch.org/whl/cu118
|
||||
uv pip install torch==2.7.1 --index-url https://download.pytorch.org/whl/cu118
|
||||
) else (
|
||||
echo 无效的选项!请重新运行脚本并选择正确的选项。
|
||||
goto end
|
||||
|
@ -1,5 +1,27 @@
|
||||
@echo off
|
||||
chcp 65001 >nul
|
||||
|
||||
REM 检查.venv目录是否存在
|
||||
if exist .venv (
|
||||
echo 虚拟环境已存在。
|
||||
) else (
|
||||
echo 正在创建虚拟环境...
|
||||
uv venv
|
||||
if errorlevel 1 (
|
||||
echo 创建虚拟环境失败。
|
||||
pause
|
||||
exit /b 1
|
||||
)
|
||||
echo.
|
||||
echo 虚拟环境已创建。请激活它后重新运行此脚本。
|
||||
echo - Windows (CMD): .\.venv\Scripts\activate
|
||||
echo - Windows (PowerShell): .\.venv\Scripts\Activate.ps1
|
||||
echo - Linux/macOS: source .venv/bin/activate
|
||||
echo.
|
||||
pause
|
||||
exit /b 0
|
||||
)
|
||||
|
||||
echo 安装PyTorch GPU版本(通过官方源)
|
||||
echo ===================================
|
||||
echo.
|
||||
@ -14,15 +36,15 @@ set /p choice=请输入选项 (1/2):
|
||||
|
||||
if "%choice%"=="1" (
|
||||
echo 正在安装PyTorch CUDA 12.8版本...
|
||||
uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
|
||||
uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128 -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
)
|
||||
else if "%choice%"=="2" (
|
||||
echo 正在安装PyTorch CUDA 12.6版本...
|
||||
uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126
|
||||
uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126 -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
)
|
||||
else if "%choice%"=="3" (
|
||||
echo 正在安装PyTorch CUDA 11.8版本...
|
||||
uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
|
||||
uv pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
) else (
|
||||
echo 无效的选项!
|
||||
goto end
|
||||
|
353
lyf开发日志记录文档.md
Normal file
353
lyf开发日志记录文档.md
Normal file
@ -0,0 +1,353 @@
|
||||
# 开发日志记录
|
||||
|
||||
本文档记录了项目开发过程中的主要修改、问题修复和重要决策。
|
||||
|
||||
---
|
||||
|
||||
## 2025-07-13:早期后端修复与重构
|
||||
**开发者**: lyf
|
||||
|
||||
### 13:30 - 修复数据加载路径问题
|
||||
- **任务目标**: 解决模型训练时因数据文件路径错误导致的数据加载失败问题。
|
||||
- **核心问题**: `server/core/predictor.py` 中的 `PharmacyPredictor` 类初始化时,硬编码了错误的默认数据文件路径。
|
||||
- **修复方案**: 将默认数据路径更正为 `'data/timeseries_training_data_sample_10s50p.parquet'`,并同步更新了所有训练器。
|
||||
|
||||
### 14:00 - 数据流重构
|
||||
- **任务目标**: 解决因数据处理流程中断导致关键特征丢失,从而引发模型训练失败的根本问题。
|
||||
- **核心问题**: `predictor.py` 未将预处理好的数据向下传递,导致各训练器重复加载并错误处理数据。
|
||||
- **修复方案**: 重构了核心数据流,确保数据在 `predictor.py` 中被统一加载和预处理,然后作为一个DataFrame显式传递给所有下游的训练器函数。
|
||||
|
||||
---
|
||||
|
||||
## 2025-07-14:模型训练与并发问题集中攻坚
|
||||
**开发者**: lyf
|
||||
|
||||
### 10:16 - 修复训练器层 `KeyError`
|
||||
- **问题**: 所有模型训练均因 `KeyError: "['sales', 'price'] not in index"` 失败。
|
||||
- **分析**: 训练器硬编码的特征列表中包含了数据源中不存在的 `'price'` 列。
|
||||
- **修复**: 从所有四个训练器 (`mlstm`, `transformer`, `tcn`, `kan`) 的 `features` 列表中移除了对不存在的 `'price'` 列的依赖。
|
||||
|
||||
### 10:38 - 修复数据标准化层 `KeyError`
|
||||
- **问题**: 修复后出现新错误 `KeyError: "['sales'] not in index"`。
|
||||
- **分析**: `server/utils/multi_store_data_utils.py` 中的 `standardize_column_names` 函数列名映射错误,且缺少最终列选择机制。
|
||||
- **修复**: 修正了列名映射,并增加了列选择机制,确保函数返回的 `DataFrame` 结构统一且包含 `sales` 列。
|
||||
|
||||
### 11:04 - 修复JSON序列化失败问题
|
||||
- **问题**: 训练完成后,因 `Object of type float32 is not JSON serializable` 导致前后端通信失败。
|
||||
- **分析**: 训练产生的评估指标是NumPy的 `float32` 类型,无法被标准 `json` 库序列化。
|
||||
- **修复**: 在 `server/utils/training_process_manager.py` 中增加了 `convert_numpy_types` 辅助函数,在通过WebSocket或API返回数据前,将所有NumPy数值类型转换为Python原生类型,从根源上解决了所有序列化问题。
|
||||
|
||||
### 11:15 - 修复MAPE计算错误
|
||||
- **问题**: 训练日志显示 `MAPE: nan%` 并伴有 `RuntimeWarning: Mean of empty slice.`。
|
||||
- **分析**: 当测试集中的所有真实值都为0时,计算MAPE会导致对空数组求平均值。
|
||||
- **修复**: 在 `server/analysis/metrics.py` 中增加条件判断,若不存在非零真实值,则直接将MAPE设为0。
|
||||
|
||||
### 11:41 - 修复“按店铺训练”页面列表加载失败
|
||||
- **问题**: “选择店铺”的下拉列表为空。
|
||||
- **分析**: `standardize_column_names` 函数错误地移除了包括店铺元数据在内的非训练必需列。
|
||||
- **修复**: 将列筛选的逻辑从通用的 `standardize_column_names` 函数中移出,精确地应用到仅为模型训练准备数据的函数中。
|
||||
|
||||
### 13:00 - 修复“按店铺训练-所有药品”模式
|
||||
- **问题**: 选择“所有药品”训练时,因 `product_id` 被错误地处理为字符串 `"unknown"` 而失败。
|
||||
- **修复**: 在 `server/core/predictor.py` 中拦截 `"unknown"` ID,并将其意图正确地转换为“聚合此店铺的所有产品数据”。同时扩展了 `aggregate_multi_store_data` 函数,使其支持按店铺ID进行聚合。
|
||||
|
||||
### 14:19 - 修复并发训练中的稳定性问题
|
||||
- **问题**: 并发训练时出现 `API列表排序错误` 和 `WebSocket连接错误`。
|
||||
- **修复**:
|
||||
1. **排序**: 在 `api.py` 中为 `None` 类型的 `start_time` 提供了默认值,解决了 `TypeError`。
|
||||
2. **连接**: 在 `socketio.run()` 调用时增加了 `allow_unsafe_werkzeug=True` 参数,解决了调试模式下Socket.IO与Werkzeug的冲突。
|
||||
|
||||
### 15:30 - 根治模型训练中的维度不匹配问题
|
||||
- **问题**: 所有模型训练完成后,评估指标 `R²` 始终为0.0。
|
||||
- **根本原因**: `server/utils/data_utils.py` 的 `create_dataset` 函数在创建目标数据集 `dataY` 时,错误地保留了一个多余的维度。同时,模型文件 (`mlstm_model.py`, `transformer_model.py`) 的输出也存在维度问题。
|
||||
- **最终修复**:
|
||||
1. **数据层**: 在 `create_dataset` 中使用 `.flatten()` 修正了 `y` 标签的维度。
|
||||
2. **模型层**: 在所有模型的 `forward` 方法最后增加了 `.squeeze(-1)`,确保模型输出维度正确。
|
||||
3. **训练器层**: 撤销了所有为解决此问题而做的临时性维度调整,恢复了最直接的损失计算。
|
||||
|
||||
### 16:10 - 修复“全局模型训练-所有药品”模式
|
||||
- **问题**: 与“按店铺训练”类似,全局训练的“所有药品”模式也因 `product_id="unknown"` 而失败。
|
||||
- **修复**: 采用了与店铺训练完全相同的修复模式。在 `predictor.py` 中拦截 `"unknown"` 并将其意图转换为真正的全局聚合(`product_id=None`),并扩展 `aggregate_multi_store_data` 函数以支持此功能。
|
||||
|
||||
---
|
||||
|
||||
## 2025-07-15:端到端修复“按药品预测”图表功能
|
||||
**开发者**: lyf
|
||||
|
||||
### 10:00 - 阶段一:修复数据库写入失败 (`sqlite3.IntegrityError`)
|
||||
- **问题**: 后端日志显示 `datatype mismatch`。
|
||||
- **分析**: `save_prediction_result` 函数试图将复杂Python对象直接存入数据库。
|
||||
- **修复**: 在 `server/api.py` 中,执行数据库插入前,使用 `json.dumps()` 将复杂对象序列化为JSON字符串。
|
||||
|
||||
### 10:30 - 阶段二:修复API响应结构与前端不匹配
|
||||
- **问题**: 图表依然无法渲染。
|
||||
- **分析**: 前端期望 `history_data` 在顶层,而后端将其封装在 `data` 子对象中。
|
||||
- **修复**: 修改 `server/api.py` 的 `predict` 函数,将关键数据提升到响应的根级别。
|
||||
|
||||
### 11:00 - 阶段三:修复历史数据与预测数据时间不连续
|
||||
- **问题**: 图表数据在时间上完全脱节。
|
||||
- **分析**: 获取历史数据的逻辑总是取整个数据集的最后30条,而非预测起始日期之前的30条。
|
||||
- **修复**: 在 `server/api.py` 中增加了正确的日期筛选逻辑。
|
||||
|
||||
### 14:00 - 阶段四:重构数据源,根治数据不一致问题
|
||||
- **问题**: 历史数据(绿线)与预测数据(蓝线)的口径完全不同。
|
||||
- **根本原因**: API层独立加载**原始数据**画图,而预测器使用**聚合后数据**预测。
|
||||
- **修复 (重构)**:
|
||||
1. 修改 `server/predictors/model_predictor.py`,使其返回预测结果的同时,也返回其所使用的、口径一致的历史数据。
|
||||
2. 彻底删除了 `server/api.py` 中所有独立加载历史数据的冗余代码,确保了数据源的唯一性。
|
||||
|
||||
### 15:00 - 阶段五:修复图表X轴日期格式问题
|
||||
- **问题**: X轴显示为混乱的GMT格式时间戳。
|
||||
- **分析**: `history_data` 中的 `Timestamp` 对象未被正确格式化。
|
||||
- **修复**: 在 `server/api.py` 中,为 `history_data` 增加了 `.strftime('%Y-%m-%d')` 的格式化处理。
|
||||
|
||||
### 16:00 - 阶段六:修复模型“学不会”的根本原因 (超参数传递中断)
|
||||
- **问题**: 即便流程正确,所有模型的预测结果依然是无法学习的直线。
|
||||
- **根本原因**: `server/core/predictor.py` 在调用训练器时,**没有将 `sequence_length` 等关键超参数传递下去**,导致所有模型都使用了错误的默认值。
|
||||
- **修复**:
|
||||
1. 修改 `server/core/predictor.py`,在调用中加入超参数的传递。
|
||||
2. 修改所有四个训练器文件,使其能接收并使用这些参数。
|
||||
|
||||
---
|
||||
|
||||
## 2025-07-16:最终验证与项目总结
|
||||
**开发者**: lyf
|
||||
|
||||
### 10:00 - 阶段七:最终验证与结论
|
||||
- **问题**: 在修复所有代码问题后,对特定日期的预测结果依然是平线。
|
||||
- **分析**: 通过编写临时数据分析脚本 (`temp_check_parquet.py`) 最终确认,这是**数据本身**的问题。我们选择的预测日期在样本数据集中恰好处于一个“零销量”的空白期。
|
||||
- **最终结论**: 系统代码已完全修复。图表上显示的平线,是模型对“零销量”历史做出的**正确且符合逻辑**的反应。
|
||||
|
||||
### 11:45 - 项目总结与文档归档
|
||||
- **任务**: 根据用户要求,回顾整个调试过程,将所有问题、解决方案、优化思路和最终结论,按照日期和时间顺序,整理并更新到本开发日志中,形成一份高质量的技术档案。
|
||||
- **结果**: 本文档已更新完成。
|
||||
|
||||
|
||||
### 13:15 - 最终修复:根治模型标识符不一致问题
|
||||
- **问题**: 经过再次测试和日志分析,发现即便是修正后,店铺模型的 `model_identifier` 在训练时依然被错误地构建为 `01010023_store_01010023`。
|
||||
- **根本原因**: `server/core/predictor.py` 的 `train_model` 方法中,在 `training_mode == 'store'` 的分支下,构建 `model_identifier` 的逻辑存在冗余和错误。
|
||||
- **最终解决方案**: 删除了错误的拼接逻辑 `model_identifier = f"{store_id}_{product_id}"`,直接使用在之前步骤中已经被正确赋值为 `f"store_{store_id}"` 的 `product_id` 变量作为 `model_identifier`。这确保了从训练、保存到最终API查询,店铺模型的唯一标识符始终保持一致。
|
||||
|
||||
|
||||
### 13:30 - 最终修复(第二轮):根治模型保存路径错误
|
||||
- **问题**: 即便修复了标识符,模型版本依然无法加载。
|
||||
- **根本原因**: 通过分析训练日志,发现所有训练器(`transformer_trainer.py`, `mlstm_trainer.py`, `tcn_trainer.py`)中的 `save_checkpoint` 函数,都会强制在 `saved_models` 目录下创建一个 `checkpoints` 子目录,并将所有模型文件保存在其中。而负责查找模型的 `get_model_versions` 函数只在根目录查找,导致模型永远无法被发现。
|
||||
- **最终解决方案**: 逐一修改了所有相关训练器文件中的 `save_checkpoint` 函数,移除了创建和使用 `checkpoints` 子目录的逻辑,确保所有模型都直接保存在 `saved_models` 根目录下。
|
||||
- **结论**: 至此,模型保存的路径与查找的路径完全统一,从根本上解决了模型版本无法加载的问题。
|
||||
|
||||
|
||||
### 13:40 - 最终修复(第三轮):统一所有训练器的模型保存逻辑
|
||||
- **问题**: 在修复了 `transformer_trainer.py` 后,发现 `mlstm_trainer.py` 和 `tcn_trainer.py` 存在完全相同的路径和命名错误,导致问题依旧。
|
||||
- **根本原因**: `save_checkpoint` 函数在所有训练器中都被错误地实现,它们都强制创建了 `checkpoints` 子目录,并使用了错误的逻辑来拼接文件名。
|
||||
- **最终解决方案**:
|
||||
1. **逐一修复**: 逐一修改了 `transformer_trainer.py`, `mlstm_trainer.py`, 和 `tcn_trainer.py` 中的 `save_checkpoint` 函数。
|
||||
2. **路径修复**: 移除了创建和使用 `checkpoints` 子目录的逻辑,确保模型直接保存在 `model_dir` (即 `saved_models`) 的根目录下。
|
||||
3. **文件名修复**: 简化并修正了文件名的生成逻辑,直接使用 `product_id` 参数作为唯一标识符(该参数已由上游逻辑正确赋值为 `药品ID` 或 `store_{店铺ID}`),不再进行任何额外的、错误的拼接。
|
||||
- **结论**: 至此,所有训练器的模型保存逻辑完全统一,模型保存的路径和文件名与API的查找逻辑完全匹配,从根本上解决了模型版本无法加载的问题。
|
||||
|
||||
|
||||
---
|
||||
|
||||
## 2025-07-16 (续):端到端修复“店铺预测”图表功能
|
||||
**开发者**: lyf
|
||||
|
||||
### 15:30 - 最终修复(第四轮):打通店铺预测的数据流
|
||||
- **问题**: 在解决了模型加载问题后,“店铺预测”功能虽然可以成功执行,但前端图表依然空白,不显示历史数据和预测数据。
|
||||
- **根本原因**: 参数传递在调用链中出现断裂。
|
||||
1. `server/api.py` 在调用 `run_prediction` 时,没有传递 `training_mode`。
|
||||
2. `server/core/predictor.py` 在调用 `load_model_and_predict` 时,没有传递 `store_id` 和 `training_mode`。
|
||||
3. `server/predictors/model_predictor.py` 内部的数据加载逻辑,在处理店铺预测时,错误地使用了模型标识符(`store_{id}`)作为产品ID来过滤数据,导致无法加载到任何历史数据。
|
||||
- **最终解决方案 (三步修复)**:
|
||||
1. **修复 `model_predictor.py`**: 修改 `load_model_and_predict` 函数,使其能够根据 `training_mode` 参数智能地加载数据。当模式为 `'store'` 时,它会正确地聚合该店铺的所有销售数据作为历史数据,这与训练时的数据准备方式完全一致。
|
||||
2. **修复 `predictor.py`**: 修改 `predict` 方法,将 `store_id` 和 `training_mode` 参数正确地传递给底层的 `load_model_and_predict` 函数。
|
||||
3. **修复 `api.py`**: 修改 `predict` 路由和 `run_prediction` 辅助函数,确保 `training_mode` 参数在整个调用链中被完整传递。
|
||||
- **结论**: 通过以上修复,我们确保了从API接口到最底层数据加载器的参数传递是完整和正确的。现在,无论是药品预测还是店铺预测,系统都能够加载正确的历史数据用于图表绘制,彻底解决了图表显示空白的问题。
|
||||
|
||||
### 16:16 - 项目状态更新
|
||||
- **状态**: **所有已知问题已修复**。
|
||||
- **确认**: 用户已确认“现在药品和店铺预测流程通了。
|
||||
- **后续**: 将本次修复过程归档至本文档。
|
||||
|
||||
|
||||
---
|
||||
|
||||
### 2025年7月16日 18:38 - 全模型预测功能通用性修复
|
||||
|
||||
**问题现象**:
|
||||
在解决了 `Transformer` 模型的预测问题后,发现一个更深层次的系统性问题:在所有预测模式(按药品、按店铺、全局)中,只有 `Transformer` 算法可以成功预测并显示图表,而其他四种模型(`mLSTM`, `KAN`, `优化版KAN`, `TCN`)虽然能成功训练,但在预测时均会失败,并提示“没有可用于图表的数据”。
|
||||
|
||||
**根本原因深度分析**:
|
||||
这个问题的核心在于**模型配置的持久化不完整且不统一**。
|
||||
|
||||
1. **Transformer 的“幸存”**: `Transformer` 模型的实现恰好不依赖于那些在保存时被遗漏的特定超参数,因此它能“幸存”下来。
|
||||
2. **其他模型的“共性缺陷”**: 其他所有模型 (`mLSTM`, `TCN`, `KAN`) 在它们的构造函数中,都依赖于一些在训练时定义、但在保存到检查点文件 (`.pth`) 时**被遗漏的**关键结构性参数。
|
||||
* **mLSTM**: 缺少 `mlstm_layers`, `embed_dim`, `dense_dim` 等参数。
|
||||
* **TCN**: 缺少 `num_channels`, `kernel_size` 等参数。
|
||||
* **KAN**: 缺少 `hidden_sizes` 列表。
|
||||
3. **连锁失败**:
|
||||
* 当 `server/predictors/model_predictor.py` 尝试加载这些模型的检查点文件时,它从 `checkpoint['config']` 中找不到实例化模型所必需的全部参数。
|
||||
* 模型实例化失败,抛出 `KeyError` 或 `TypeError`。
|
||||
* 这个异常导致 `load_model_and_predict` 函数提前返回 `None`,最终导致返回给前端的响应中缺少 `history_data`,前端因此无法渲染图表。
|
||||
|
||||
**系统性、可扩展的解决方案**:
|
||||
为了彻底解决这个问题,并为未来平稳地加入新算法,我们对所有非 Transformer 的训练器进行了标准化的、彻底的修复。
|
||||
|
||||
1. **修复 `mlstm_trainer.py`**: 在 `config` 字典中补全了 `mlstm_layers`, `embed_dim`, `dense_dim` 等所有缺失的参数。
|
||||
2. **修复 `tcn_trainer.py`**: 在 `config` 字典中补全了 `num_channels`, `kernel_size` 等所有缺失的参数。
|
||||
3. **修复 `kan_trainer.py`**: 在 `config` 字典中补全了 `hidden_sizes` 列表。
|
||||
|
||||
**结果**:
|
||||
通过这次系统性的修复,我们确保了所有训练器在保存模型时,都会将完整的、可用于重新实例化模型的配置信息写入检查点文件。这从根本上解决了所有模型算法的预测失败问题,使得整个系统在处理不同算法时具有了通用性和健壮性。
|
||||
|
||||
|
||||
---
|
||||
|
||||
## 2025-07-17:系统性链路疏通与规范化
|
||||
**开发者**: lyf
|
||||
|
||||
### 15:00 - 创建技术文档与上手指南
|
||||
- **任务**: 为了便于新成员理解和未来维护,创建了两份核心技术文档。
|
||||
- **产出**:
|
||||
1. **`系统调用逻辑与核心代码分析.md`**: 一份深入代码细节的端到端调用链路分析文档,详细描述了从前端交互到后端处理,再到模型训练和预测的完整流程。
|
||||
2. **`项目快速上手指南.md`**: 一份面向新成员(特别是Java背景)的高层次指南,通过技术栈类比、架构分层图和清晰的开发流程,帮助新成员快速建立对项目的宏观理解。
|
||||
|
||||
### 16:00 - 修复 `mLSTM` 模型加载链路
|
||||
- **问题**: `mLSTM` 模型在预测时因参数名不一致而加载失败。
|
||||
- **分析**:
|
||||
- 第一次失败: 加载器需要 `num_layers`,但训练器保存的是 `mlstm_layers`。
|
||||
- 第二次失败: 加载器需要 `dropout`,但训练器保存的是 `dropout_rate`。
|
||||
- **修复**: 遵循“保存方决定命名”的原则,修改了 `server/predictors/model_predictor.py`,将加载时使用的参数名统一为 `mlstm_layers` 和 `dropout_rate`,与训练器保持一致。
|
||||
|
||||
### 16:45 - 修复 `mLSTM` 模型算法缺陷
|
||||
- **问题**: `mLSTM` 模型修复加载问题后,预测结果为一条无效的直线。
|
||||
- **根本原因**: `server/models/mlstm_model.py` 中的模型架构存在设计缺陷。其解码器逻辑错误地将输入序列的最后一个时间步复制多份作为预测,导致模型无法学习时间序列的变化。
|
||||
- **修复**: 重构了 `MLSTMTransformer` 类的 `forward` 方法,移除了有问题的解码器逻辑,改为直接使用编码器最终的隐藏状态通过一个线性层进行预测,从根本上修正了算法的实现。
|
||||
|
||||
### 17:00 - 修复 `TCN` 模型加载链路
|
||||
- **问题**: `TCN` 模型在预测加载时存在硬编码参数,是一个潜在的崩溃点。
|
||||
- **分析**: `server/predictors/model_predictor.py` 在创建 `TCNForecaster` 实例时,硬编码了 `kernel_size=3`,而没有从模型配置中读取。
|
||||
- **修复**: 修改了 `model_predictor.py`,使其从 `config['kernel_size']` 中动态读取该参数,确保了配置的完整性和一致性。
|
||||
|
||||
### 17:15 - 修复 `KAN` 模型版本发现问题
|
||||
- **问题**: `KAN` 和 `优化版KAN` 训练成功后,在预测页面无法找到任何模型版本。
|
||||
- **根本原因**: **保存**和**搜索**逻辑不匹配。`kan_trainer.py` 使用 `model_manager.py` 以 `..._product_...` 格式保存模型,而 `server/core/config.py` 中的 `get_model_versions` 函数却只按 `..._epoch_...` 的格式进行搜索。
|
||||
- **修复**: 扩展了 `config.py` 中的 `get_model_versions` 函数,使其能够兼容并搜索多种命名格式,包括 `KAN` 模型使用的 `..._product_...` 格式。
|
||||
|
||||
### 17:25 - 修复 `KAN` 模型文件路径生成问题
|
||||
- **问题**: 修复版本发现问题后,点击预测依然失败,提示“未找到模型文件”。
|
||||
- **根本原因**: 只修复了**版本发现**逻辑,但未同步修复**文件路径生成**逻辑。`config.py` 中的 `get_model_file_path` 函数在为 `KAN` 模型生成路径时,依然错误地使用了 `_epoch_` 格式。
|
||||
- **修复**: 修改了 `get_model_file_path` 函数,为 `kan` 和 `optimized_kan` 模型增加了特殊处理,确保在生成其文件路径时使用正确的 `_product_` 命名格式。
|
||||
|
||||
### 17:40 - 升级 `KAN` 训练器的版本管理功能
|
||||
- **问题**: `KAN` 模型只有一个静态的 `'v1'` 版本,与其他模型(有 `best`, `final_epoch_...` 等版本)不一致。
|
||||
- **根本原因**: `kan_trainer.py` 的实现逻辑过于简单,缺少在训练过程中动态评估并保存多个版本的功能,仅在最后硬编码保存为 `'v1'`。
|
||||
- **修复 (功能升级)**: 重构了 `server/trainers/kan_trainer.py`,为其增加了与其他训练器完全一致的动态版本管理功能。现在它可以在训练时自动追踪并保存性能最佳的 `best` 版本,并在训练结束后保存 `final_epoch_...` 版本。
|
||||
|
||||
### 17:58 - 最终结论
|
||||
- **状态**: **所有已知问题已修复**。
|
||||
- **成果**:
|
||||
1. 所有模型的 **“数据 -> 训练 -> 保存 -> 加载 -> 预测 -> 可视化”** 执行链路已全面打通和验证。
|
||||
2. 统一并修复了所有模型在配置持久化和加载过程中的参数不一致问题。
|
||||
3. 将所有模型的版本管理逻辑和工程实现标准完全对齐。
|
||||
4. 创建并完善了核心技术文档,固化了开发规范。
|
||||
- **项目状态**: 系统现在处于一个健壮、一致且可扩展的稳定状态。
|
||||
|
||||
---
|
||||
|
||||
## 2025-07-18: 系统性重构模型版本管理机制
|
||||
**开发者**: lyf
|
||||
|
||||
### 14:00 - 根治版本混乱与模型加载失败问题
|
||||
- **问题现象**: `KAN` 及其他算法在训练后,预测时出现版本号混乱(如出现裸数字 `1`、`3` 或 `best` 等无效版本)、版本重复、以及因版本不匹配导致的“模型文件未找到”的 `404` 错误。
|
||||
- **根本原因深度分析**:
|
||||
1. **逻辑分散**: 版本生成的逻辑分散在各个训练器 (`trainer`) 中,而版本发现的逻辑在 `config.py` 中,两者标准不一,充满冲突的正则表达式和硬编码规则。
|
||||
2. **命名不统一**: `KAN` 训练器使用 `model_manager` 保存,而其他训练器使用本地的 `save_checkpoint` 函数,导致了 `..._product_..._v1.pth` 和 `..._epoch_best.pth` 等多种不兼容的命名格式并存。
|
||||
3. **提取错误**: `config.py` 中的 `get_model_versions` 函数因其过于宽泛和冲突的匹配规则,会从文件名中错误地提取出无效的版本号,是导致前端下拉框内容混乱的直接原因。
|
||||
- **系统性重构解决方案**:
|
||||
1. **确立单一权威**: 将 [`server/utils/model_manager.py`](server/utils/model_manager.py:1) 确立为系统中唯一负责版本管理、模型命名和文件IO的组件。
|
||||
2. **实现自动版本控制**: 在 `ModelManager` 中增加了 `_get_next_version` 内部方法,使其能够自动扫描现有文件,并安全地生成下一个递增的、带 `v` 前缀的版本号(如 `v3`)。
|
||||
3. **统一所有训练器**: 全面重构了 `kan_trainer.py`, `mlstm_trainer.py`, `tcn_trainer.py`, 和 `transformer_trainer.py`。现在,所有训练器在保存最终模型时,都调用 `model_manager.save_model` 并且**不再自行决定版本号**,完全由 `ModelManager` 自动生成。对于训练过程中的最佳模型,则统一显式保存为 `best` 版本。
|
||||
4. **清理与加固**: 废弃并删除了 `config.py` 中所有旧的、有问题的版本管理函数,并重写了 `get_model_versions`,使其只使用严格的正则表达式来查找和解析符合新命名规范的模型版本。
|
||||
5. **优化API**: 更新了 `api.py`,使其完全与新的 `ModelManager` 对接,并改进了预测失败时的错误信息反馈。
|
||||
- **结论**: 通过这次重构,系统的版本管理机制从一个分散、混乱、充满硬编码的状态,升级为了一个集中的、统一的、自动化的健壮系统。所有已知相关的bug已被从根本上解决。
|
||||
|
||||
---
|
||||
|
||||
## 2025-07-18 (续): 实现“按店铺”AI闭环及连锁Bug修复
|
||||
**开发者**: lyf
|
||||
|
||||
### 15:00 - 架构升级:实现“按店铺”训练与预测功能
|
||||
- **任务目标**: 在现有“按药品”模式基础上,增加并打通“按店铺”维度的完整AI闭环。
|
||||
- **核心挑战**: 需要对数据处理、模型标识、训练流程和API调用进行系统性改造,以支持新的训练模式。
|
||||
- **解决方案 (四步重构)**:
|
||||
1. **升级 `ModelManager`**: 重新设计了模型命名规则,为店铺和全局模型提供了清晰、无歧义的标识(如 `transformer_store_S001_v1.pth`),并同步更新了解析逻辑。
|
||||
2. **修正核心预测器**: 修复了 `predictor.py` 中的关键逻辑缺陷,确保在店铺模式下,系统能生成并使用正确的 `model_identifier`(如 `store_S001`),并强制调用数据聚合函数。
|
||||
3. **适配API层**: 调整了 `api.py` 中的训练和预测接口,使其能够兼容和正确处理新的店铺模式请求。
|
||||
4. **统一所有训练器**: 对全部四个训练器文件进行了统一修改,确保它们在保存模型时,都正确地使用了新的 `model_identifier`。
|
||||
|
||||
### 15:30 - 连锁Bug修复第一环:解决店铺模型版本加载失败
|
||||
- **问题现象**: “按店铺预测”页面的模型版本下拉框为空。
|
||||
- **根本原因**: `api.py` 中负责获取店铺模型版本的接口 `get_store_model_versions_api` 仍在使用旧的、不兼容新命名规范的函数来查找模型。
|
||||
- **修复**: 重写了该接口,使其放弃旧函数,转而使用 `ModelManager` 来进行统一、可靠的模型查找。
|
||||
|
||||
### 15:40 - 连锁Bug修复第二环:解决店铺预测 `404` 失败
|
||||
- **问题现象**: 版本列表加载正常后,点击“开始预测”返回 `404` 错误。
|
||||
- **根本原因**: 后端预测接口 `predict()` 内部的执行函数 `load_model_and_predict` 存在一段过时的、手动的模型文件查找逻辑,它完全绕过了 `ModelManager`,并错误地构建了文件路径。
|
||||
- **修复 (联合重构)**:
|
||||
1. **改造 `model_predictor.py`**: 彻底移除了 `load_model_and_predict` 函数内部所有过时的文件查找代码,并修改其函数签名,使其直接接收一个明确的 `model_path` 参数。
|
||||
2. **改造 `api.py`**: 修改了 `predict` 接口,将在API层通过 `ModelManager` 找到的正确模型路径,一路传递到最底层的 `load_model_and_predict` 函数中,确保了调用链的逻辑一致性。
|
||||
|
||||
### 15:50 - 连锁Bug修复第三环:解决服务启动 `NameError`
|
||||
- **问题现象**: 在修复预测逻辑后,API服务无法启动,报错 `NameError: name 'Optional' is not defined`。
|
||||
- **根本原因**: 在修改 `model_predictor.py` 时,使用了 `Optional` 类型提示,但忘记从 `typing` 模块导入。
|
||||
- **修复**: 在 `server/predictors/model_predictor.py` 文件顶部添加了 `from typing import Optional`。
|
||||
- **最终结论**: 至此,所有与“按店铺”功能相关的架构升级和连锁bug均已修复。系统现在能够稳定、正确地处理两种维度的训练和预测任务,并且代码逻辑更加统一和健壮。
|
||||
|
||||
|
||||
---
|
||||
|
||||
## 2025-07-21:前后端联合调试与UI修复
|
||||
**开发者**: lyf
|
||||
|
||||
### 15:45 - 修复后端 `DataFrame` 序列化错误
|
||||
- **问题现象**: 在清理了历史模型并重新进行预测后,前端出现 `Object of type DataFrame is not JSON serializable` 错误。
|
||||
- **根本原因**: `server/predictors/model_predictor.py` 中的 `load_model_and_predict` 函数在返回结果时,为了兼容旧版接口而保留的 `'predictions'` 字段,其值依然是未经处理的 Pandas DataFrame (`predictions_df`)。
|
||||
- **修复方案**: 修改了该函数的返回字典,将 `'predictions'` 字段的值也更新为已经过 `.to_dict('records')` 方法处理的 `prediction_data_json`,确保了返回对象的所有部分都是JSON兼容的。
|
||||
|
||||
### 16:00 - 统一修复所有预测视图的图表渲染问题
|
||||
- **问题现象**: 在解决了后端的序列化问题后,所有三个预测视图(按药品、按店铺、全局)的图表均为空白,并且图表下方的日期副标题显示为未经格式化的原始JavaScript日期字符串。
|
||||
- **根本原因深度分析**:
|
||||
1. **数据访问路径不精确**: 前端代码直接从API响应的根对象 (`response.data`) 中获取数据,而最可靠的数据源位于 `response.data.data` 中。
|
||||
2. **日期对象处理不当**: 前端代码未能将从后端接收到的日期(无论是字符串还是由axios自动转换的Date对象)标准化为统一的字符串格式。这导致在使用 `Set` 对日期进行去重时,因对象引用不同而失败,最终图表上没有数据点。
|
||||
- **统一修复方案**:
|
||||
1. **逐一修改**: 逐一修改了 `ProductPredictionView.vue`, `StorePredictionView.vue`, 和 `GlobalPredictionView.vue` 三个文件。
|
||||
2. **修正数据访问**: 在 `startPrediction` 方法中,将API响应的核心数据 `response.data.data` 赋值给 `predictionResult`。
|
||||
3. **标准化日期**: 在 `renderChart` 方法的开头,增加了一个 `formatDate` 辅助函数,并在处理数据时立即调用它,将所有日期都统一转换为 `'YYYY-MM-DD'` 格式的字符串,从而一举解决了数据点丢失和标题格式错误的双重问题。
|
||||
- **最终结论**: 至此,所有预测视图的前后端数据链路和UI展示功能均已修复,系统功能恢复正常。
|
||||
|
||||
|
||||
---
|
||||
|
||||
## 2025-07-24:UI/UX 优化与后端逻辑统一
|
||||
**开发者**: Roo (AI Assistant) & lyf
|
||||
|
||||
### 14:30 - 统一预测图表颜色风格
|
||||
- **任务**: 根据用户反馈,将所有预测结果图表的文字颜色(包括标题、副标题、图例、坐标轴)统一修改为白色,以适应深色UI背景。
|
||||
- **实施**:
|
||||
- 逐一修改了 `ProductPredictionView.vue`, `StorePredictionView.vue`, 和 `GlobalPredictionView.vue`。
|
||||
- 在 `Chart.js` 的 `options` 配置中,将所有相关的 `color` 属性设置为 `'white'`,并将网格线颜色调整为半透明白色 `rgba(255, 255, 255, 0.2)`。
|
||||
|
||||
### 14:50 - 修复“按店铺预测”图表标题显示错误
|
||||
- **问题**: “按店铺预测”的图表标题显示为 `undefined`。
|
||||
- **根本原因**: 后端 `server/api.py` 在处理店铺预测时,没有查询并返回真实的店铺名称,导致前端无法获取该数据。
|
||||
- **修复方案 (后端优先)**:
|
||||
1. 在 `server/api.py` 中新增了 `get_store_name` 辅助函数,用于根据 `store_id` 查询店铺名称。
|
||||
2. 修改了 `/api/prediction` 接口,在 `training_mode` 为 `'store'` 时,调用新函数获取店铺名,并将其以 `store_name` 字段返回给前端。
|
||||
- **结论**: 通过统一后端逻辑,确保了数据源的正确性,从根本上解决了问题。
|
||||
|
||||
### 15:05 - 统一模型列表时间显示格式
|
||||
- **任务**: 根据用户要求,将所有预测页面模型列表中的“创建时间”从ISO格式统一为 `YYYY-MM-DD HH:MM:SS` 的24小时制格式。
|
||||
- **实施**:
|
||||
- 逐一修改了 `ProductPredictionView.vue`, `StorePredictionView.vue`, 和 `GlobalPredictionView.vue`。
|
||||
- 在 `<script setup>` 中添加了 `formatDateTime` 辅助函数。
|
||||
- 在 `<el-table-column>` 中使用 `:formatter="formatDateTime"` 属性来应用该格式化函数,实现了UI显示的统一。
|
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Binary file not shown.
@ -56,3 +56,6 @@ tzdata==2025.2
|
||||
werkzeug==3.1.3
|
||||
win32-setctime==1.2.0
|
||||
wsproto==1.2.0
|
||||
python-dateutil
|
||||
xgboost
|
||||
scikit-learn
|
||||
|
BIN
sales_trends.png
BIN
sales_trends.png
Binary file not shown.
Before Width: | Height: | Size: 348 KiB |
@ -35,7 +35,11 @@ def evaluate_model(y_true, y_pred):
|
||||
# 计算平均绝对百分比误差 (MAPE)
|
||||
# 避免除以零
|
||||
mask = y_true != 0
|
||||
mape = np.mean(np.abs((y_true[mask] - y_pred[mask]) / y_true[mask])) * 100
|
||||
if np.any(mask):
|
||||
mape = np.mean(np.abs((y_true[mask] - y_pred[mask]) / y_true[mask])) * 100
|
||||
else:
|
||||
# 如果所有真实值都为0,无法计算MAPE,返回0
|
||||
mape = 0.0
|
||||
|
||||
return {
|
||||
'mse': mse,
|
||||
|
907
server/api.py
907
server/api.py
File diff suppressed because it is too large
Load Diff
@ -10,6 +10,13 @@ import os
|
||||
import re
|
||||
import glob
|
||||
|
||||
# 项目根目录
|
||||
# __file__ 是当前文件 (config.py) 的路径
|
||||
# os.path.dirname(__file__) 是 server/core
|
||||
# os.path.join(..., '..') 是 server
|
||||
# os.path.join(..., '..', '..') 是项目根目录
|
||||
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
|
||||
|
||||
# 解决画图中文显示问题
|
||||
plt.rcParams['font.sans-serif'] = ['SimHei']
|
||||
plt.rcParams['axes.unicode_minus'] = False
|
||||
@ -26,8 +33,9 @@ def get_device():
|
||||
DEVICE = get_device()
|
||||
|
||||
# 数据相关配置
|
||||
DEFAULT_DATA_PATH = 'pharmacy_sales.xlsx'
|
||||
DEFAULT_MODEL_DIR = 'saved_models'
|
||||
# 使用 os.path.join 构造跨平台的路径
|
||||
DEFAULT_DATA_PATH = os.path.join(PROJECT_ROOT, 'data', 'timeseries_training_data_sample_10s50p.parquet')
|
||||
DEFAULT_MODEL_DIR = os.path.join(PROJECT_ROOT, 'saved_models')
|
||||
DEFAULT_FEATURES = ['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||
|
||||
# 时间序列参数
|
||||
@ -50,7 +58,9 @@ HIDDEN_SIZE = 64 # 隐藏层大小
|
||||
NUM_LAYERS = 2 # 层数
|
||||
|
||||
# 支持的模型类型
|
||||
SUPPORTED_MODELS = ['mlstm', 'kan', 'transformer', 'tcn', 'optimized_kan']
|
||||
# 支持的模型类型 (v2 - 动态加载)
|
||||
from models.model_registry import TRAINER_REGISTRY
|
||||
SUPPORTED_MODELS = list(TRAINER_REGISTRY.keys())
|
||||
|
||||
# 版本管理配置
|
||||
MODEL_VERSION_PREFIX = 'v' # 版本前缀
|
||||
@ -63,76 +73,30 @@ TRAINING_UPDATE_INTERVAL = 1 # 训练进度更新间隔(秒)
|
||||
# 创建模型保存目录
|
||||
os.makedirs(DEFAULT_MODEL_DIR, exist_ok=True)
|
||||
|
||||
def get_next_model_version(product_id: str, model_type: str) -> str:
|
||||
"""
|
||||
获取指定产品和模型类型的下一个版本号
|
||||
|
||||
Args:
|
||||
product_id: 产品ID
|
||||
model_type: 模型类型
|
||||
|
||||
Returns:
|
||||
下一个版本号,格式如 'v2', 'v3' 等
|
||||
"""
|
||||
# 新格式:带版本号的文件
|
||||
pattern_new = f"{model_type}_model_product_{product_id}_v*.pth"
|
||||
existing_files_new = glob.glob(os.path.join(DEFAULT_MODEL_DIR, pattern_new))
|
||||
|
||||
# 旧格式:不带版本号的文件(兼容性支持)
|
||||
pattern_old = f"{model_type}_model_product_{product_id}.pth"
|
||||
old_file_path = os.path.join(DEFAULT_MODEL_DIR, pattern_old)
|
||||
has_old_format = os.path.exists(old_file_path)
|
||||
|
||||
# 如果没有任何格式的文件,返回默认版本
|
||||
if not existing_files_new and not has_old_format:
|
||||
return DEFAULT_VERSION
|
||||
|
||||
# 提取新格式文件的版本号
|
||||
versions = []
|
||||
for file_path in existing_files_new:
|
||||
filename = os.path.basename(file_path)
|
||||
version_match = re.search(rf"_v(\d+)\.pth$", filename)
|
||||
if version_match:
|
||||
versions.append(int(version_match.group(1)))
|
||||
|
||||
# 如果存在旧格式文件,将其视为v1
|
||||
if has_old_format:
|
||||
versions.append(1)
|
||||
print(f"检测到旧格式模型文件: {old_file_path},将其视为版本v1")
|
||||
|
||||
if versions:
|
||||
next_version_num = max(versions) + 1
|
||||
return f"v{next_version_num}"
|
||||
else:
|
||||
return DEFAULT_VERSION
|
||||
|
||||
def get_model_file_path(product_id: str, model_type: str, version: str = None) -> str:
|
||||
def get_model_file_path(product_id: str, model_type: str, version: str) -> str:
|
||||
"""
|
||||
生成模型文件路径
|
||||
根据产品ID、模型类型和版本号,生成模型文件的准确路径。
|
||||
|
||||
Args:
|
||||
product_id: 产品ID
|
||||
product_id: 产品ID (纯数字)
|
||||
model_type: 模型类型
|
||||
version: 版本号,如果为None则获取下一个版本
|
||||
version: 版本字符串 (例如 'best', 'final_epoch_50', 'v1_legacy')
|
||||
|
||||
Returns:
|
||||
模型文件的完整路径
|
||||
"""
|
||||
if version is None:
|
||||
version = get_next_model_version(product_id, model_type)
|
||||
|
||||
# 特殊处理v1版本:检查是否存在旧格式文件
|
||||
if version == "v1":
|
||||
# 检查旧格式文件是否存在
|
||||
old_format_filename = f"{model_type}_model_product_{product_id}.pth"
|
||||
old_format_path = os.path.join(DEFAULT_MODEL_DIR, old_format_filename)
|
||||
|
||||
if os.path.exists(old_format_path):
|
||||
print(f"找到旧格式模型文件: {old_format_path},将其作为v1版本")
|
||||
return old_format_path
|
||||
|
||||
# 使用新格式文件名
|
||||
filename = f"{model_type}_model_product_{product_id}_{version}.pth"
|
||||
# 处理历史遗留的 "v1" 格式
|
||||
if version == "v1_legacy":
|
||||
filename = f"{model_type}_model_product_{product_id}.pth"
|
||||
return os.path.join(DEFAULT_MODEL_DIR, filename)
|
||||
|
||||
# 修正:直接使用唯一的product_id(它可能包含store_前缀)来构建文件名
|
||||
# 文件名示例: transformer_17002608_epoch_best.pth 或 transformer_store_01010023_epoch_best.pth
|
||||
# 针对 KAN 和 optimized_kan,使用 model_manager 的命名约定
|
||||
# 统一所有模型的命名格式
|
||||
filename = f"{model_type}_product_{product_id}_{version}.pth"
|
||||
# 修正:直接在根模型目录查找,不再使用checkpoints子目录
|
||||
return os.path.join(DEFAULT_MODEL_DIR, filename)
|
||||
|
||||
def get_model_versions(product_id: str, model_type: str) -> list:
|
||||
@ -140,54 +104,38 @@ def get_model_versions(product_id: str, model_type: str) -> list:
|
||||
获取指定产品和模型类型的所有版本
|
||||
|
||||
Args:
|
||||
product_id: 产品ID
|
||||
product_id: 产品ID (现在应该是纯数字ID)
|
||||
model_type: 模型类型
|
||||
|
||||
Returns:
|
||||
版本列表,按版本号排序
|
||||
"""
|
||||
# 新格式:带版本号的文件
|
||||
pattern_new = f"{model_type}_model_product_{product_id}_v*.pth"
|
||||
existing_files_new = glob.glob(os.path.join(DEFAULT_MODEL_DIR, pattern_new))
|
||||
# 统一使用新的命名约定进行搜索
|
||||
pattern = os.path.join(DEFAULT_MODEL_DIR, f"{model_type}_product_{product_id}_*.pth")
|
||||
existing_files = glob.glob(pattern)
|
||||
|
||||
# 旧格式:不带版本号的文件(兼容性支持)
|
||||
pattern_old = f"{model_type}_model_product_{product_id}.pth"
|
||||
old_file_path = os.path.join(DEFAULT_MODEL_DIR, pattern_old)
|
||||
has_old_format = os.path.exists(old_file_path)
|
||||
versions = set()
|
||||
|
||||
versions = []
|
||||
|
||||
# 处理新格式文件
|
||||
for file_path in existing_files_new:
|
||||
for file_path in existing_files:
|
||||
filename = os.path.basename(file_path)
|
||||
version_match = re.search(rf"_v(\d+)\.pth$", filename)
|
||||
if version_match:
|
||||
version_num = int(version_match.group(1))
|
||||
versions.append(f"v{version_num}")
|
||||
|
||||
# 如果存在旧格式文件,将其视为v1
|
||||
if has_old_format:
|
||||
if "v1" not in versions: # 避免重复添加
|
||||
versions.append("v1")
|
||||
print(f"检测到旧格式模型文件: {old_file_path},将其视为版本v1")
|
||||
|
||||
# 按版本号排序
|
||||
versions.sort(key=lambda v: int(v[1:]))
|
||||
return versions
|
||||
|
||||
def get_latest_model_version(product_id: str, model_type: str) -> str:
|
||||
"""
|
||||
获取指定产品和模型类型的最新版本
|
||||
|
||||
Args:
|
||||
product_id: 产品ID
|
||||
model_type: 模型类型
|
||||
|
||||
Returns:
|
||||
最新版本号,如果没有则返回None
|
||||
"""
|
||||
versions = get_model_versions(product_id, model_type)
|
||||
return versions[-1] if versions else None
|
||||
# 严格匹配 _v<number> 或 'best'
|
||||
match = re.search(r'_(v\d+|best)\.pth$', filename)
|
||||
if match:
|
||||
versions.add(match.group(1))
|
||||
|
||||
# 按数字版本降序排序,'best'始终在最前
|
||||
def sort_key(v):
|
||||
if v == 'best':
|
||||
return -1 # 'best' is always first
|
||||
if v.startswith('v'):
|
||||
return int(v[1:])
|
||||
return float('inf') # Should not happen
|
||||
|
||||
sorted_versions = sorted(list(versions), key=sort_key, reverse=True)
|
||||
|
||||
return sorted_versions
|
||||
|
||||
|
||||
def save_model_version_info(product_id: str, model_type: str, version: str, file_path: str, metrics: dict = None):
|
||||
"""
|
||||
|
@ -11,12 +11,13 @@ import time
|
||||
import matplotlib.pyplot as plt
|
||||
from datetime import datetime
|
||||
|
||||
from trainers import (
|
||||
train_product_model_with_mlstm,
|
||||
train_product_model_with_kan,
|
||||
train_product_model_with_tcn,
|
||||
train_product_model_with_transformer
|
||||
)
|
||||
# from trainers import (
|
||||
# train_product_model_with_mlstm,
|
||||
# train_product_model_with_kan,
|
||||
# train_product_model_with_tcn,
|
||||
# train_product_model_with_transformer
|
||||
# )
|
||||
# 上述导入已不再需要,因为我们现在通过模型注册表动态获取训练器
|
||||
from predictors.model_predictor import load_model_and_predict
|
||||
from utils.data_utils import prepare_data, prepare_sequences
|
||||
from utils.multi_store_data_utils import (
|
||||
@ -41,7 +42,7 @@ class PharmacyPredictor:
|
||||
"""
|
||||
# 设置默认数据路径为多店铺CSV文件
|
||||
if data_path is None:
|
||||
data_path = 'pharmacy_sales_multi_store.csv'
|
||||
data_path = DEFAULT_DATA_PATH
|
||||
|
||||
self.data_path = data_path
|
||||
self.model_dir = model_dir
|
||||
@ -117,30 +118,59 @@ class PharmacyPredictor:
|
||||
log_message(f"按产品训练模式: 产品 {product_id}, 数据量: {len(product_data)}")
|
||||
|
||||
elif training_mode == 'store':
|
||||
# 按店铺训练:使用特定店铺的特定产品数据
|
||||
# 按店铺训练
|
||||
if not store_id:
|
||||
log_message("店铺训练模式需要指定 store_id", 'error')
|
||||
return None
|
||||
try:
|
||||
product_data = get_store_product_sales_data(
|
||||
store_id=store_id,
|
||||
product_id=product_id,
|
||||
file_path=self.data_path
|
||||
)
|
||||
log_message(f"按店铺训练模式: 店铺 {store_id}, 产品 {product_id}, 数据量: {len(product_data)}")
|
||||
except Exception as e:
|
||||
log_message(f"获取店铺产品数据失败: {e}", 'error')
|
||||
return None
|
||||
|
||||
# 如果product_id是'unknown',则表示为店铺所有商品训练一个聚合模型
|
||||
if product_id == 'unknown':
|
||||
try:
|
||||
# 使用新的聚合函数,按店铺聚合
|
||||
product_data = aggregate_multi_store_data(
|
||||
store_id=store_id,
|
||||
aggregation_method=aggregation_method,
|
||||
file_path=self.data_path
|
||||
)
|
||||
log_message(f"按店铺聚合训练: 店铺 {store_id}, 聚合方法 {aggregation_method}, 数据量: {len(product_data)}")
|
||||
# 将product_id设置为'store_{store_id}',与API查找逻辑保持一致
|
||||
product_id = f"store_{store_id}"
|
||||
except Exception as e:
|
||||
log_message(f"聚合店铺 {store_id} 数据失败: {e}", 'error')
|
||||
return None
|
||||
else:
|
||||
# 为店铺的单个特定产品训练
|
||||
try:
|
||||
product_data = get_store_product_sales_data(
|
||||
store_id=store_id,
|
||||
product_id=product_id,
|
||||
file_path=self.data_path
|
||||
)
|
||||
log_message(f"按店铺-产品训练: 店铺 {store_id}, 产品 {product_id}, 数据量: {len(product_data)}")
|
||||
except Exception as e:
|
||||
log_message(f"获取店铺产品数据失败: {e}", 'error')
|
||||
return None
|
||||
|
||||
elif training_mode == 'global':
|
||||
# 全局训练:聚合所有店铺的产品数据
|
||||
try:
|
||||
product_data = aggregate_multi_store_data(
|
||||
product_id=product_id,
|
||||
aggregation_method=aggregation_method,
|
||||
file_path=self.data_path
|
||||
)
|
||||
log_message(f"全局训练模式: 产品 {product_id}, 聚合方法 {aggregation_method}, 数据量: {len(product_data)}")
|
||||
# 如果product_id是'unknown',则表示为全局所有商品训练一个聚合模型
|
||||
if product_id == 'unknown':
|
||||
product_data = aggregate_multi_store_data(
|
||||
product_id=None, # 传递None以触发真正的全局聚合
|
||||
aggregation_method=aggregation_method,
|
||||
file_path=self.data_path
|
||||
)
|
||||
log_message(f"全局训练模式: 所有产品, 聚合方法 {aggregation_method}, 数据量: {len(product_data)}")
|
||||
# 将product_id设置为一个有意义的标识符
|
||||
product_id = 'all_products'
|
||||
else:
|
||||
product_data = aggregate_multi_store_data(
|
||||
product_id=product_id,
|
||||
aggregation_method=aggregation_method,
|
||||
file_path=self.data_path
|
||||
)
|
||||
log_message(f"全局训练模式: 产品 {product_id}, 聚合方法 {aggregation_method}, 数据量: {len(product_data)}")
|
||||
except Exception as e:
|
||||
log_message(f"聚合全局数据失败: {e}", 'error')
|
||||
return None
|
||||
@ -148,77 +178,59 @@ class PharmacyPredictor:
|
||||
log_message(f"不支持的训练模式: {training_mode}", 'error')
|
||||
return None
|
||||
|
||||
# 根据训练模式构建模型标识符
|
||||
# 根据训练模式构建模型标识符 (v2 修正)
|
||||
if training_mode == 'store':
|
||||
model_identifier = f"{store_id}_{product_id}"
|
||||
# 店铺模型的标识符只应基于店铺ID
|
||||
model_identifier = f"store_{store_id}"
|
||||
elif training_mode == 'global':
|
||||
model_identifier = f"global_{product_id}_{aggregation_method}"
|
||||
else:
|
||||
# 全局模型的标识符不应依赖于单个product_id
|
||||
model_identifier = f"global_{aggregation_method}"
|
||||
else: # product mode
|
||||
model_identifier = product_id
|
||||
|
||||
# 调用相应的训练函数
|
||||
# 调用相应的训练函数 (重构为使用注册表)
|
||||
try:
|
||||
log_message(f"🤖 开始调用 {model_type} 训练器")
|
||||
if model_type == 'transformer':
|
||||
model_result, metrics, actual_version = train_product_model_with_transformer(
|
||||
product_id,
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
epochs=epochs,
|
||||
model_dir=self.model_dir,
|
||||
version=version,
|
||||
socketio=socketio,
|
||||
task_id=task_id,
|
||||
continue_training=continue_training
|
||||
)
|
||||
log_message(f"✅ {model_type} 训练器返回: metrics={type(metrics)}, version={actual_version}", 'success')
|
||||
elif model_type == 'mlstm':
|
||||
_, metrics, _, _ = train_product_model_with_mlstm(
|
||||
product_id,
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
epochs=epochs,
|
||||
model_dir=self.model_dir,
|
||||
socketio=socketio,
|
||||
task_id=task_id,
|
||||
progress_callback=progress_callback
|
||||
)
|
||||
elif model_type == 'kan':
|
||||
_, metrics = train_product_model_with_kan(
|
||||
product_id,
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
epochs=epochs,
|
||||
use_optimized=use_optimized,
|
||||
model_dir=self.model_dir
|
||||
)
|
||||
elif model_type == 'optimized_kan':
|
||||
_, metrics = train_product_model_with_kan(
|
||||
product_id,
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
epochs=epochs,
|
||||
use_optimized=True,
|
||||
model_dir=self.model_dir
|
||||
)
|
||||
elif model_type == 'tcn':
|
||||
_, metrics, _, _ = train_product_model_with_tcn(
|
||||
product_id,
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
epochs=epochs,
|
||||
model_dir=self.model_dir,
|
||||
socketio=socketio,
|
||||
task_id=task_id
|
||||
)
|
||||
from models.model_registry import get_trainer
|
||||
log_message(f"🤖 正在从注册表获取 '{model_type}' 训练器...")
|
||||
trainer_function = get_trainer(model_type)
|
||||
log_message(f"✅ 成功获取训练器: {trainer_function.__name__}")
|
||||
|
||||
# 准备通用参数
|
||||
trainer_args = {
|
||||
'product_id': product_id,
|
||||
'model_identifier': model_identifier,
|
||||
'product_df': product_data,
|
||||
'store_id': store_id,
|
||||
'training_mode': training_mode,
|
||||
'aggregation_method': aggregation_method,
|
||||
'epochs': epochs,
|
||||
'sequence_length': sequence_length,
|
||||
'forecast_horizon': forecast_horizon,
|
||||
'model_dir': self.model_dir,
|
||||
'socketio': socketio,
|
||||
'task_id': task_id,
|
||||
'progress_callback': progress_callback,
|
||||
'version': version,
|
||||
'continue_training': continue_training,
|
||||
'use_optimized': use_optimized # KAN模型需要
|
||||
}
|
||||
|
||||
# 动态调用训练函数 (v2 - 智能参数过滤)
|
||||
import inspect
|
||||
sig = inspect.signature(trainer_function)
|
||||
valid_args = {k: v for k, v in trainer_args.items() if k in sig.parameters}
|
||||
|
||||
log_message(f"🔍 准备调用 {trainer_function.__name__},有效参数: {list(valid_args.keys())}")
|
||||
|
||||
result = trainer_function(**valid_args)
|
||||
|
||||
# 根据返回值的数量解析metrics
|
||||
if isinstance(result, tuple) and len(result) >= 2:
|
||||
metrics = result[1] # 通常第二个返回值是metrics
|
||||
else:
|
||||
log_message(f"不支持的模型类型: {model_type}", 'error')
|
||||
return None
|
||||
log_message(f"⚠️ 训练器返回格式未知,无法直接提取metrics: {type(result)}", 'warning')
|
||||
metrics = None
|
||||
|
||||
|
||||
# 检查和打印返回的metrics
|
||||
log_message(f"📊 训练完成,检查返回的metrics: {metrics}")
|
||||
@ -262,21 +274,24 @@ class PharmacyPredictor:
|
||||
返回:
|
||||
预测结果和分析(如果analyze_result为True)
|
||||
"""
|
||||
# 根据训练模式构建模型标识符
|
||||
# 根据训练模式构建模型标识符 (v2 修正)
|
||||
if training_mode == 'store' and store_id:
|
||||
model_identifier = f"{store_id}_{product_id}"
|
||||
model_identifier = f"store_{store_id}"
|
||||
elif training_mode == 'global':
|
||||
model_identifier = f"global_{product_id}_{aggregation_method}"
|
||||
else:
|
||||
# 全局模型的标识符不应依赖于单个product_id
|
||||
model_identifier = f"global_{aggregation_method}"
|
||||
else: # product mode
|
||||
model_identifier = product_id
|
||||
|
||||
return load_model_and_predict(
|
||||
model_identifier,
|
||||
model_type,
|
||||
future_days=future_days,
|
||||
start_date=start_date,
|
||||
model_identifier,
|
||||
model_type,
|
||||
store_id=store_id,
|
||||
future_days=future_days,
|
||||
start_date=start_date,
|
||||
analyze_result=analyze_result,
|
||||
version=version
|
||||
version=version,
|
||||
training_mode=training_mode
|
||||
)
|
||||
|
||||
def train_optimized_kan_model(self, product_id, epochs=100, batch_size=32,
|
||||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
102
server/models/cnn_bilstm_attention.py
Normal file
102
server/models/cnn_bilstm_attention.py
Normal file
@ -0,0 +1,102 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
CNN-BiLSTM-Attention 模型定义,适配药店销售预测系统。
|
||||
原始代码来源: python机器学习回归全家桶
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# 注意:由于原始代码使用了 TensorFlow/Keras 的层,我们将在这里创建一个 PyTorch 的等效实现。
|
||||
# 这是一个更健壮、更符合现有系统架构的做法。
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""
|
||||
PyTorch 实现的注意力机制。
|
||||
"""
|
||||
def __init__(self, feature_dim, step_dim, bias=True, **kwargs):
|
||||
super(Attention, self).__init__(**kwargs)
|
||||
|
||||
self.supports_masking = True
|
||||
self.bias = bias
|
||||
self.feature_dim = feature_dim
|
||||
self.step_dim = step_dim
|
||||
self.features_dim = 0
|
||||
|
||||
weight = torch.zeros(feature_dim, 1)
|
||||
nn.init.xavier_uniform_(weight)
|
||||
self.weight = nn.Parameter(weight)
|
||||
|
||||
if bias:
|
||||
self.b = nn.Parameter(torch.zeros(step_dim))
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
feature_dim = self.feature_dim
|
||||
step_dim = self.step_dim
|
||||
|
||||
eij = torch.mm(
|
||||
x.contiguous().view(-1, feature_dim),
|
||||
self.weight
|
||||
).view(-1, step_dim)
|
||||
|
||||
if self.bias:
|
||||
eij = eij + self.b
|
||||
|
||||
eij = torch.tanh(eij)
|
||||
a = torch.exp(eij)
|
||||
|
||||
if mask is not None:
|
||||
a = a * mask
|
||||
|
||||
a = a / (torch.sum(a, 1, keepdim=True) + 1e-10)
|
||||
|
||||
weighted_input = x * torch.unsqueeze(a, -1)
|
||||
return torch.sum(weighted_input, 1)
|
||||
|
||||
|
||||
class CnnBiLstmAttention(nn.Module):
|
||||
"""
|
||||
CNN-BiLSTM-Attention 模型的 PyTorch 实现。
|
||||
"""
|
||||
def __init__(self, input_dim, output_dim, sequence_length, cnn_filters=64, cnn_kernel_size=1, lstm_units=128):
|
||||
super(CnnBiLstmAttention, self).__init__()
|
||||
self.sequence_length = sequence_length
|
||||
self.cnn_filters = cnn_filters
|
||||
self.lstm_units = lstm_units
|
||||
|
||||
# CNN 层
|
||||
self.conv1d = nn.Conv1d(in_channels=input_dim, out_channels=cnn_filters, kernel_size=cnn_kernel_size)
|
||||
self.relu = nn.ReLU()
|
||||
self.maxpool = nn.MaxPool1d(kernel_size=1)
|
||||
|
||||
# BiLSTM 层
|
||||
self.bilstm = nn.LSTM(input_size=cnn_filters, hidden_size=lstm_units, num_layers=1, batch_first=True, bidirectional=True)
|
||||
|
||||
# Attention 层
|
||||
self.attention = Attention(feature_dim=lstm_units * 2, step_dim=sequence_length)
|
||||
|
||||
# 全连接输出层
|
||||
self.dense = nn.Linear(lstm_units * 2, output_dim)
|
||||
|
||||
def forward(self, x):
|
||||
# 输入 x 的形状: (batch_size, sequence_length, input_dim)
|
||||
|
||||
# CNN 处理
|
||||
x = x.permute(0, 2, 1) # 转换为 (batch_size, input_dim, sequence_length) 以适应 Conv1d
|
||||
x = self.conv1d(x)
|
||||
x = self.relu(x)
|
||||
x = x.permute(0, 2, 1) # 转换回 (batch_size, sequence_length, cnn_filters)
|
||||
|
||||
# BiLSTM 处理
|
||||
lstm_out, _ = self.bilstm(x) # lstm_out 形状: (batch_size, sequence_length, lstm_units * 2)
|
||||
|
||||
# Attention 处理
|
||||
# 注意:这里的 Attention 实现可能需要根据具体任务微调
|
||||
# 一个简化的方法是直接使用 LSTM 的最终隐藏状态或输出
|
||||
# 这里我们先用一个简化的逻辑:直接展平 LSTM 输出
|
||||
attention_out = self.attention(lstm_out)
|
||||
|
||||
# 全连接层输出
|
||||
output = self.dense(attention_out)
|
||||
|
||||
return output
|
@ -1,5 +1,6 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Tuple
|
||||
from .transformer_model import TransformerEncoder, TransformerDecoder
|
||||
|
||||
# 定义mLSTM单元
|
||||
@ -48,8 +49,8 @@ class mLSTMCell(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
internal_state: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
||||
internal_state: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
||||
# 获取内部状态
|
||||
C, n, m = internal_state
|
||||
|
||||
@ -112,7 +113,7 @@ class mLSTMCell(nn.Module):
|
||||
|
||||
def init_hidden(
|
||||
self, batch_size: int, **kwargs
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
return (
|
||||
torch.zeros(batch_size, self.hidden_size, self.hidden_size, **kwargs),
|
||||
torch.zeros(batch_size, self.hidden_size, **kwargs),
|
||||
@ -237,4 +238,5 @@ class MLSTMTransformer(nn.Module):
|
||||
for decoder in self.decoders:
|
||||
decoder_outputs = decoder(decoder_outputs, encoder_outputs)
|
||||
|
||||
return self.output_layer(decoder_outputs)
|
||||
# 移除最后一个维度,使输出为 (B, H)
|
||||
return self.output_layer(decoder_outputs).squeeze(-1)
|
64
server/models/model_registry.py
Normal file
64
server/models/model_registry.py
Normal file
@ -0,0 +1,64 @@
|
||||
"""
|
||||
模型注册表
|
||||
用于解耦模型的调用和实现,支持插件式扩展新模型。
|
||||
"""
|
||||
|
||||
# 训练器注册表
|
||||
TRAINER_REGISTRY = {}
|
||||
|
||||
def register_trainer(name, func):
|
||||
"""
|
||||
注册一个模型训练器。
|
||||
|
||||
参数:
|
||||
name (str): 模型类型名称 (e.g., 'xgboost')
|
||||
func (function): 对应的训练函数
|
||||
"""
|
||||
if name in TRAINER_REGISTRY:
|
||||
print(f"警告: 模型训练器 '{name}' 已被覆盖注册。")
|
||||
TRAINER_REGISTRY[name] = func
|
||||
print(f"✅ 已注册训练器: {name}")
|
||||
|
||||
def get_trainer(name):
|
||||
"""
|
||||
根据模型类型名称获取一个已注册的训练器。
|
||||
"""
|
||||
if name not in TRAINER_REGISTRY:
|
||||
# 在打印可用训练器之前,确保它们已经被加载
|
||||
from trainers import discover_trainers
|
||||
discover_trainers()
|
||||
if name not in TRAINER_REGISTRY:
|
||||
raise ValueError(f"未注册的模型训练器: '{name}'. 可用: {list(TRAINER_REGISTRY.keys())}")
|
||||
return TRAINER_REGISTRY[name]
|
||||
|
||||
# --- 预测器注册表 ---
|
||||
|
||||
# 预测器函数需要一个统一的接口,例如:
|
||||
# def predictor_function(model, checkpoint, **kwargs): -> predictions
|
||||
|
||||
PREDICTOR_REGISTRY = {}
|
||||
|
||||
def register_predictor(name, func):
|
||||
"""
|
||||
注册一个模型预测器。
|
||||
"""
|
||||
if name in PREDICTOR_REGISTRY:
|
||||
print(f"警告: 模型预测器 '{name}' 已被覆盖注册。")
|
||||
PREDICTOR_REGISTRY[name] = func
|
||||
|
||||
def get_predictor(name):
|
||||
"""
|
||||
根据模型类型名称获取一个已注册的预测器。
|
||||
如果找不到特定预测器,可以返回一个默认的。
|
||||
"""
|
||||
return PREDICTOR_REGISTRY.get(name, PREDICTOR_REGISTRY.get('default'))
|
||||
|
||||
# 默认的PyTorch预测逻辑可以被注册为 'default'
|
||||
def register_default_predictors():
|
||||
from predictors.model_predictor import default_pytorch_predictor
|
||||
register_predictor('default', default_pytorch_predictor)
|
||||
# 如果其他PyTorch模型有特殊预测逻辑,也可以在这里注册
|
||||
# register_predictor('kan', kan_predictor_func)
|
||||
|
||||
# 注意:这个函数的调用时机很重要,需要在应用启动时执行一次。
|
||||
# 我们可以暂时在 model_predictor.py 导入注册表后调用它。
|
@ -1,5 +1,6 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Tuple
|
||||
|
||||
# 定义sLSTM单元
|
||||
class sLSTMCell(nn.Module):
|
||||
@ -27,9 +28,9 @@ class sLSTMCell(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
internal_state: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
) -> tuple[
|
||||
torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
|
||||
internal_state: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
) -> Tuple[
|
||||
torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
|
||||
]:
|
||||
# 解包内部状态
|
||||
h, c, n, m = internal_state # (batch_size, hidden_size)
|
||||
@ -78,7 +79,7 @@ class sLSTMCell(nn.Module):
|
||||
|
||||
def init_hidden(
|
||||
self, batch_size: int, **kwargs
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
return (
|
||||
torch.zeros(batch_size, self.hidden_size, **kwargs),
|
||||
torch.zeros(batch_size, self.hidden_size, **kwargs),
|
||||
|
@ -104,4 +104,5 @@ class TimeSeriesTransformer(nn.Module):
|
||||
for decoder in self.decoders:
|
||||
decoder_outputs = decoder(decoder_outputs, encoder_outputs)
|
||||
|
||||
return self.output_layer(decoder_outputs) # [batch_size, output_sequence_length, 1]
|
||||
# 移除最后一个维度,使输出为 (B, H)
|
||||
return self.output_layer(decoder_outputs).squeeze(-1)
|
Binary file not shown.
@ -10,6 +10,7 @@ from datetime import datetime, timedelta
|
||||
import matplotlib.pyplot as plt
|
||||
from sklearn.preprocessing import MinMaxScaler
|
||||
import sklearn.preprocessing._data # 添加这一行以支持MinMaxScaler的反序列化
|
||||
from typing import Optional
|
||||
|
||||
from models.transformer_model import TimeSeriesTransformer
|
||||
from models.slstm_model import sLSTM as ScalarLSTM
|
||||
@ -17,365 +18,174 @@ from models.mlstm_model import MLSTMTransformer as MatrixLSTM
|
||||
from models.kan_model import KANForecaster
|
||||
from models.tcn_model import TCNForecaster
|
||||
from models.optimized_kan_forecaster import OptimizedKANForecaster
|
||||
from models.cnn_bilstm_attention import CnnBiLstmAttention
|
||||
import xgboost as xgb
|
||||
|
||||
from analysis.trend_analysis import analyze_prediction_result
|
||||
from utils.visualization import plot_prediction_results
|
||||
from utils.multi_store_data_utils import get_store_product_sales_data, aggregate_multi_store_data
|
||||
from core.config import DEVICE, get_model_file_path
|
||||
from core.config import DEVICE, get_model_file_path, DEFAULT_DATA_PATH
|
||||
from models.model_registry import get_predictor, register_predictor
|
||||
|
||||
def load_model_and_predict(product_id, model_type, store_id=None, future_days=7, start_date=None, analyze_result=False, version=None):
|
||||
def default_pytorch_predictor(model, checkpoint, product_df, future_days, start_date, history_lookback_days):
|
||||
"""
|
||||
加载已训练的模型并进行预测
|
||||
|
||||
参数:
|
||||
product_id: 产品ID
|
||||
model_type: 模型类型 ('transformer', 'mlstm', 'kan', 'tcn', 'optimized_kan')
|
||||
store_id: 店铺ID,为None时使用全局模型
|
||||
future_days: 预测未来天数
|
||||
start_date: 预测起始日期,如果为None则使用最后一个已知日期
|
||||
analyze_result: 是否分析预测结果
|
||||
version: 模型版本,如果为None则使用最新版本
|
||||
|
||||
返回:
|
||||
预测结果和分析(如果analyze_result为True)
|
||||
默认的PyTorch模型预测逻辑,支持自动回归。
|
||||
"""
|
||||
try:
|
||||
# 确定模型文件路径(支持多店铺)
|
||||
model_path = None
|
||||
|
||||
if version:
|
||||
# 使用版本管理系统获取正确的文件路径
|
||||
model_path = get_model_file_path(product_id, model_type, version)
|
||||
config = checkpoint['config']
|
||||
scaler_X = checkpoint['scaler_X']
|
||||
scaler_y = checkpoint['scaler_y']
|
||||
features = config.get('features', ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'])
|
||||
sequence_length = config['sequence_length']
|
||||
|
||||
if start_date:
|
||||
start_date_dt = pd.to_datetime(start_date)
|
||||
prediction_input_df = product_df[product_df['date'] < start_date_dt].tail(sequence_length)
|
||||
else:
|
||||
prediction_input_df = product_df.tail(sequence_length)
|
||||
start_date_dt = product_df['date'].iloc[-1] + timedelta(days=1)
|
||||
|
||||
if len(prediction_input_df) < sequence_length:
|
||||
raise ValueError(f"预测所需的历史数据不足。需要 {sequence_length} 天, 但只有 {len(prediction_input_df)} 天。")
|
||||
|
||||
history_for_chart_df = product_df[product_df['date'] < start_date_dt].tail(history_lookback_days)
|
||||
|
||||
all_predictions = []
|
||||
current_sequence_df = prediction_input_df.copy()
|
||||
|
||||
for _ in range(future_days):
|
||||
X_current_scaled = scaler_X.transform(current_sequence_df[features].values)
|
||||
# **核心改进**: 智能判断模型类型并调用相应的预测方法
|
||||
if isinstance(model, xgb.Booster):
|
||||
# XGBoost 模型预测路径
|
||||
X_input_reshaped = X_current_scaled.reshape(1, -1)
|
||||
d_input = xgb.DMatrix(X_input_reshaped)
|
||||
# **关键修复**: 使用 best_iteration 进行预测,以匹配早停策略
|
||||
y_pred_scaled = model.predict(d_input, iteration_range=(0, model.best_iteration))
|
||||
next_step_pred_scaled = y_pred_scaled.reshape(1, -1)
|
||||
else:
|
||||
# 根据store_id确定搜索目录
|
||||
if store_id:
|
||||
# 查找特定店铺的模型
|
||||
possible_dirs = [
|
||||
os.path.join('saved_models', model_type, store_id),
|
||||
os.path.join('models', model_type, store_id)
|
||||
]
|
||||
else:
|
||||
# 查找全局模型
|
||||
possible_dirs = [
|
||||
os.path.join('saved_models', model_type, 'global'),
|
||||
os.path.join('models', model_type, 'global'),
|
||||
os.path.join('saved_models', model_type), # 后向兼容
|
||||
'saved_models' # 最基本的目录
|
||||
]
|
||||
|
||||
# 文件名模式
|
||||
model_suffix = '_optimized' if model_type == 'optimized_kan' else ''
|
||||
file_model_type = 'kan' if model_type == 'optimized_kan' else model_type
|
||||
|
||||
possible_names = [
|
||||
f"{product_id}_{model_type}_v1_model.pt", # 新多店铺格式
|
||||
f"{product_id}_{model_type}_v1_global_model.pt", # 全局模型格式
|
||||
f"{product_id}_{model_type}_v1.pth", # 旧版本格式
|
||||
f"{file_model_type}{model_suffix}_model_product_{product_id}.pth", # 原始格式
|
||||
f"{model_type}_model_product_{product_id}.pth" # 简化格式
|
||||
]
|
||||
|
||||
# 搜索模型文件
|
||||
for dir_path in possible_dirs:
|
||||
if not os.path.exists(dir_path):
|
||||
continue
|
||||
for name in possible_names:
|
||||
test_path = os.path.join(dir_path, name)
|
||||
if os.path.exists(test_path):
|
||||
model_path = test_path
|
||||
break
|
||||
if model_path:
|
||||
break
|
||||
|
||||
if not model_path:
|
||||
scope_msg = f"店铺 {store_id}" if store_id else "全局"
|
||||
print(f"找不到产品 {product_id} 的 {model_type} 模型文件 ({scope_msg})")
|
||||
print(f"搜索目录: {possible_dirs}")
|
||||
return None
|
||||
|
||||
print(f"尝试加载模型文件: {model_path}")
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
print(f"模型文件 {model_path} 不存在")
|
||||
return None
|
||||
|
||||
# 加载销售数据(支持多店铺)
|
||||
try:
|
||||
if store_id:
|
||||
# 加载特定店铺的数据
|
||||
product_df = get_store_product_sales_data(
|
||||
store_id,
|
||||
product_id,
|
||||
'pharmacy_sales_multi_store.csv'
|
||||
)
|
||||
store_name = product_df['store_name'].iloc[0] if 'store_name' in product_df.columns else f"店铺{store_id}"
|
||||
prediction_scope = f"店铺 '{store_name}' ({store_id})"
|
||||
else:
|
||||
# 聚合所有店铺的数据进行预测
|
||||
product_df = aggregate_multi_store_data(
|
||||
product_id,
|
||||
aggregation_method='sum',
|
||||
file_path='pharmacy_sales_multi_store.csv'
|
||||
)
|
||||
prediction_scope = "全部店铺(聚合数据)"
|
||||
except Exception as e:
|
||||
print(f"多店铺数据加载失败,尝试使用原始数据格式: {e}")
|
||||
# 后向兼容:尝试加载原始数据格式
|
||||
try:
|
||||
df = pd.read_excel('pharmacy_sales.xlsx')
|
||||
product_df = df[df['product_id'] == product_id].sort_values('date')
|
||||
if store_id:
|
||||
print(f"警告:原始数据不支持店铺过滤,将使用所有数据预测")
|
||||
prediction_scope = "默认数据"
|
||||
except Exception as e2:
|
||||
print(f"加载产品数据失败: {str(e2)}")
|
||||
return None
|
||||
|
||||
if product_df.empty:
|
||||
print(f"产品 {product_id} 没有销售数据")
|
||||
return None
|
||||
|
||||
product_name = product_df['product_name'].iloc[0]
|
||||
print(f"使用 {model_type} 模型预测产品 '{product_name}' (ID: {product_id}) 的未来 {future_days} 天销量")
|
||||
print(f"预测范围: {prediction_scope}")
|
||||
|
||||
# 添加安全的全局变量以支持MinMaxScaler的反序列化
|
||||
try:
|
||||
torch.serialization.add_safe_globals([sklearn.preprocessing._data.MinMaxScaler])
|
||||
except Exception as e:
|
||||
print(f"添加安全全局变量失败,但这可能不影响模型加载: {str(e)}")
|
||||
|
||||
# 加载模型和配置
|
||||
try:
|
||||
# 首先尝试使用weights_only=False加载
|
||||
try:
|
||||
print("尝试使用 weights_only=False 加载模型")
|
||||
checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False)
|
||||
except Exception as e:
|
||||
print(f"使用weights_only=False加载失败: {str(e)}")
|
||||
print("尝试使用默认参数加载模型")
|
||||
checkpoint = torch.load(model_path, map_location=DEVICE)
|
||||
|
||||
print(f"模型加载成功,检查checkpoint类型: {type(checkpoint)}")
|
||||
if isinstance(checkpoint, dict):
|
||||
print(f"checkpoint包含的键: {list(checkpoint.keys())}")
|
||||
else:
|
||||
print(f"checkpoint不是字典类型,而是: {type(checkpoint)}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"加载模型失败: {str(e)}")
|
||||
return None
|
||||
|
||||
# 检查并获取配置
|
||||
if 'config' not in checkpoint:
|
||||
print("模型文件中没有配置信息")
|
||||
return None
|
||||
|
||||
config = checkpoint['config']
|
||||
print(f"模型配置: {config}")
|
||||
|
||||
# 检查并获取缩放器
|
||||
if 'scaler_X' not in checkpoint or 'scaler_y' not in checkpoint:
|
||||
print("模型文件中没有缩放器信息")
|
||||
return None
|
||||
|
||||
scaler_X = checkpoint['scaler_X']
|
||||
scaler_y = checkpoint['scaler_y']
|
||||
|
||||
# 创建模型实例
|
||||
try:
|
||||
if model_type == 'transformer':
|
||||
model = TimeSeriesTransformer(
|
||||
num_features=config['input_dim'],
|
||||
d_model=config['hidden_size'],
|
||||
nhead=config['num_heads'],
|
||||
num_encoder_layers=config['num_layers'],
|
||||
dim_feedforward=config['hidden_size'] * 2,
|
||||
dropout=config['dropout'],
|
||||
output_sequence_length=config['output_dim'],
|
||||
seq_length=config['sequence_length'],
|
||||
batch_size=32
|
||||
).to(DEVICE)
|
||||
elif model_type == 'slstm':
|
||||
model = ScalarLSTM(
|
||||
input_dim=config['input_dim'],
|
||||
hidden_dim=config['hidden_size'],
|
||||
output_dim=config['output_dim'],
|
||||
num_layers=config['num_layers'],
|
||||
dropout=config['dropout']
|
||||
).to(DEVICE)
|
||||
elif model_type == 'mlstm':
|
||||
# 获取配置参数,如果不存在则使用默认值
|
||||
embed_dim = config.get('embed_dim', 32)
|
||||
dense_dim = config.get('dense_dim', 32)
|
||||
num_heads = config.get('num_heads', 4)
|
||||
num_blocks = config.get('num_blocks', 3)
|
||||
|
||||
model = MatrixLSTM(
|
||||
num_features=config['input_dim'],
|
||||
hidden_size=config['hidden_size'],
|
||||
mlstm_layers=config['num_layers'],
|
||||
embed_dim=embed_dim,
|
||||
dense_dim=dense_dim,
|
||||
num_heads=num_heads,
|
||||
dropout_rate=config['dropout'],
|
||||
num_blocks=num_blocks,
|
||||
output_sequence_length=config['output_dim']
|
||||
).to(DEVICE)
|
||||
elif model_type == 'kan':
|
||||
model = KANForecaster(
|
||||
input_features=config['input_dim'],
|
||||
hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']],
|
||||
output_sequence_length=config['output_dim']
|
||||
).to(DEVICE)
|
||||
elif model_type == 'optimized_kan':
|
||||
model = OptimizedKANForecaster(
|
||||
input_features=config['input_dim'],
|
||||
hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']],
|
||||
output_sequence_length=config['output_dim']
|
||||
).to(DEVICE)
|
||||
elif model_type == 'tcn':
|
||||
model = TCNForecaster(
|
||||
num_features=config['input_dim'],
|
||||
output_sequence_length=config['output_dim'],
|
||||
num_channels=[config['hidden_size']] * config['num_layers'],
|
||||
kernel_size=3,
|
||||
dropout=config['dropout']
|
||||
).to(DEVICE)
|
||||
else:
|
||||
print(f"不支持的模型类型: {model_type}")
|
||||
return None
|
||||
|
||||
print(f"模型实例创建成功: {type(model)}")
|
||||
except Exception as e:
|
||||
print(f"创建模型实例失败: {str(e)}")
|
||||
return None
|
||||
|
||||
# 加载模型参数
|
||||
try:
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
model.eval()
|
||||
print("模型参数加载成功")
|
||||
except Exception as e:
|
||||
print(f"加载模型参数失败: {str(e)}")
|
||||
return None
|
||||
|
||||
# 准备输入数据
|
||||
try:
|
||||
features = ['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||
sequence_length = config['sequence_length']
|
||||
|
||||
# 获取最近的sequence_length天数据作为输入
|
||||
recent_data = product_df.iloc[-sequence_length:].copy()
|
||||
|
||||
# 如果指定了起始日期,则使用该日期之后的数据
|
||||
if start_date:
|
||||
if isinstance(start_date, str):
|
||||
start_date = datetime.strptime(start_date, '%Y-%m-%d')
|
||||
recent_data = product_df[product_df['date'] >= start_date].iloc[:sequence_length].copy()
|
||||
if len(recent_data) < sequence_length:
|
||||
print(f"警告: 从指定日期 {start_date} 开始的数据少于所需的 {sequence_length} 天")
|
||||
# 补充数据
|
||||
missing_days = sequence_length - len(recent_data)
|
||||
additional_data = product_df[product_df['date'] < start_date].iloc[-missing_days:].copy()
|
||||
recent_data = pd.concat([additional_data, recent_data]).reset_index(drop=True)
|
||||
|
||||
print(f"输入数据准备完成,形状: {recent_data.shape}")
|
||||
except Exception as e:
|
||||
print(f"准备输入数据失败: {str(e)}")
|
||||
return None
|
||||
|
||||
# 归一化输入数据
|
||||
try:
|
||||
X = recent_data[features].values
|
||||
X_scaled = scaler_X.transform(X)
|
||||
|
||||
# 转换为模型输入格式
|
||||
X_input = torch.tensor(X_scaled.reshape(1, sequence_length, -1), dtype=torch.float32).to(DEVICE)
|
||||
print(f"输入张量准备完成,形状: {X_input.shape}")
|
||||
except Exception as e:
|
||||
print(f"归一化输入数据失败: {str(e)}")
|
||||
return None
|
||||
|
||||
# 预测
|
||||
try:
|
||||
# 默认 PyTorch 模型预测路径
|
||||
X_input = torch.tensor(X_current_scaled.reshape(1, sequence_length, -1), dtype=torch.float32).to(DEVICE)
|
||||
with torch.no_grad():
|
||||
y_pred_scaled = model(X_input).cpu().numpy()
|
||||
print(f"原始预测输出形状: {y_pred_scaled.shape}")
|
||||
|
||||
# 处理TCN、Transformer、mLSTM和KAN模型的输出,确保形状正确
|
||||
if model_type in ['tcn', 'transformer', 'mlstm', 'kan', 'optimized_kan'] and len(y_pred_scaled.shape) == 3:
|
||||
y_pred_scaled = y_pred_scaled.squeeze(-1)
|
||||
print(f"处理后的预测输出形状: {y_pred_scaled.shape}")
|
||||
|
||||
# 反归一化预测结果
|
||||
y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).flatten()
|
||||
print(f"反归一化后的预测结果: {y_pred}")
|
||||
|
||||
# 生成预测日期
|
||||
last_date = recent_data['date'].iloc[-1]
|
||||
pred_dates = [(last_date + timedelta(days=i+1)) for i in range(len(y_pred))]
|
||||
print(f"预测日期: {pred_dates}")
|
||||
except Exception as e:
|
||||
print(f"执行预测失败: {str(e)}")
|
||||
return None
|
||||
|
||||
# 创建预测结果DataFrame
|
||||
next_step_pred_scaled = y_pred_scaled[0, 0].reshape(1, -1)
|
||||
next_step_pred_unscaled = float(max(0, scaler_y.inverse_transform(next_step_pred_scaled)[0][0]))
|
||||
|
||||
next_date = current_sequence_df['date'].iloc[-1] + timedelta(days=1)
|
||||
all_predictions.append({'date': next_date, 'predicted_sales': next_step_pred_unscaled})
|
||||
|
||||
new_row = {'date': next_date, 'sales': next_step_pred_unscaled, 'weekday': next_date.weekday(), 'month': next_date.month, 'is_holiday': 0, 'is_weekend': 1 if next_date.weekday() >= 5 else 0, 'is_promotion': 0, 'temperature': current_sequence_df['temperature'].iloc[-1]}
|
||||
new_row_df = pd.DataFrame([new_row])
|
||||
current_sequence_df = pd.concat([current_sequence_df.iloc[1:], new_row_df], ignore_index=True)
|
||||
|
||||
predictions_df = pd.DataFrame(all_predictions)
|
||||
return predictions_df, history_for_chart_df, prediction_input_df
|
||||
|
||||
# 注册默认的PyTorch预测器
|
||||
register_predictor('default', default_pytorch_predictor)
|
||||
# 将增强后的默认预测器也注册给xgboost
|
||||
register_predictor('xgboost', default_pytorch_predictor)
|
||||
# 将新模型也注册给默认预测器
|
||||
register_predictor('cnn_bilstm_attention', default_pytorch_predictor)
|
||||
|
||||
|
||||
def load_model_and_predict(model_path: str, product_id: str, model_type: str, store_id: Optional[str] = None, future_days: int = 7, start_date: Optional[str] = None, analyze_result: bool = False, version: Optional[str] = None, training_mode: str = 'product', history_lookback_days: int = 30):
|
||||
"""
|
||||
加载已训练的模型并进行预测 (v4版 - 插件式架构)
|
||||
"""
|
||||
try:
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(f"模型文件 {model_path} 不存在")
|
||||
|
||||
# --- 数据加载部分保持不变 ---
|
||||
from utils.multi_store_data_utils import aggregate_multi_store_data
|
||||
if training_mode == 'store' and store_id:
|
||||
from utils.multi_store_data_utils import load_multi_store_data
|
||||
store_df_for_name = load_multi_store_data(store_id=store_id)
|
||||
product_name = store_df_for_name['store_name'].iloc[0] if not store_df_for_name.empty else f"店铺 {store_id}"
|
||||
product_df = aggregate_multi_store_data(store_id=store_id, aggregation_method='sum', file_path=DEFAULT_DATA_PATH)
|
||||
elif training_mode == 'global':
|
||||
product_df = aggregate_multi_store_data(aggregation_method='sum', file_path=DEFAULT_DATA_PATH)
|
||||
product_name = "全局销售数据"
|
||||
else:
|
||||
product_df = aggregate_multi_store_data(product_id=product_id, aggregation_method='sum', file_path=DEFAULT_DATA_PATH)
|
||||
product_name = product_df['product_name'].iloc[0] if not product_df.empty else product_id
|
||||
|
||||
if product_df.empty:
|
||||
raise ValueError(f"产品 {product_id} 或店铺 {store_id} 没有销售数据")
|
||||
|
||||
# --- 模型加载与实例化 (重构) ---
|
||||
try:
|
||||
predictions_df = pd.DataFrame({
|
||||
'date': pred_dates,
|
||||
'sales': y_pred # 使用sales字段名而不是predicted_sales,以便与历史数据兼容
|
||||
})
|
||||
print(f"预测结果DataFrame创建成功,形状: {predictions_df.shape}")
|
||||
except Exception as e:
|
||||
print(f"创建预测结果DataFrame失败: {str(e)}")
|
||||
return None
|
||||
torch.serialization.add_safe_globals([sklearn.preprocessing._data.MinMaxScaler])
|
||||
except Exception: pass
|
||||
|
||||
# 绘制预测结果
|
||||
try:
|
||||
plt.figure(figsize=(12, 6))
|
||||
plt.plot(product_df['date'], product_df['sales'], 'b-', label='历史销量')
|
||||
plt.plot(predictions_df['date'], predictions_df['sales'], 'r--', label='预测销量')
|
||||
plt.title(f'{product_name} - {model_type}模型销量预测')
|
||||
plt.xlabel('日期')
|
||||
plt.ylabel('销量')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
plt.xticks(rotation=45)
|
||||
plt.tight_layout()
|
||||
checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False)
|
||||
config = checkpoint.get('config', {})
|
||||
loaded_model_type = config.get('model_type', model_type) # 优先使用模型内保存的类型
|
||||
|
||||
# 根据模型类型决定如何获取模型实例
|
||||
if loaded_model_type == 'xgboost':
|
||||
# 对于XGBoost, 模型对象直接保存在'model_state_dict'键中
|
||||
model = checkpoint['model_state_dict']
|
||||
else:
|
||||
# 对于PyTorch模型, 需要重新构建实例并加载state_dict
|
||||
if loaded_model_type == 'transformer':
|
||||
model = TimeSeriesTransformer(num_features=config['input_dim'], d_model=config['hidden_size'], nhead=config['num_heads'], num_encoder_layers=config['num_layers'], dim_feedforward=config['hidden_size'] * 2, dropout=config['dropout'], output_sequence_length=config['output_dim'], seq_length=config['sequence_length'], batch_size=32).to(DEVICE)
|
||||
elif loaded_model_type == 'mlstm':
|
||||
model = MatrixLSTM(num_features=config['input_dim'], hidden_size=config['hidden_size'], mlstm_layers=config['mlstm_layers'], embed_dim=config.get('embed_dim', 32), dense_dim=config.get('dense_dim', 32), num_heads=config.get('num_heads', 4), dropout_rate=config['dropout_rate'], num_blocks=config.get('num_blocks', 3), output_sequence_length=config['output_dim']).to(DEVICE)
|
||||
elif loaded_model_type == 'kan':
|
||||
model = KANForecaster(input_features=config['input_dim'], hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']], output_sequence_length=config['output_dim']).to(DEVICE)
|
||||
elif loaded_model_type == 'optimized_kan':
|
||||
model = OptimizedKANForecaster(input_features=config['input_dim'], hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']], output_sequence_length=config['output_dim']).to(DEVICE)
|
||||
elif loaded_model_type == 'tcn':
|
||||
model = TCNForecaster(num_features=config['input_dim'], output_sequence_length=config['output_dim'], num_channels=[config['hidden_size']] * config['num_layers'], kernel_size=config['kernel_size'], dropout=config['dropout']).to(DEVICE)
|
||||
elif loaded_model_type == 'cnn_bilstm_attention':
|
||||
model = CnnBiLstmAttention(
|
||||
input_dim=config['input_dim'],
|
||||
output_dim=config['output_dim'],
|
||||
sequence_length=config['sequence_length']
|
||||
).to(DEVICE)
|
||||
else:
|
||||
raise ValueError(f"不支持的模型类型: {loaded_model_type}")
|
||||
|
||||
# 保存图像
|
||||
plt.savefig(f'{product_id}_{model_type}_prediction.png')
|
||||
plt.close()
|
||||
|
||||
print(f"预测结果已保存到 {product_id}_{model_type}_prediction.png")
|
||||
except Exception as e:
|
||||
print(f"绘制预测结果图表失败: {str(e)}")
|
||||
# 这个错误不影响主要功能,继续执行
|
||||
|
||||
# 分析预测结果
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
model.eval()
|
||||
|
||||
# --- 动态调用预测器 ---
|
||||
predictor_function = get_predictor(loaded_model_type)
|
||||
if not predictor_function:
|
||||
raise ValueError(f"找不到模型类型 '{loaded_model_type}' 的预测器实现")
|
||||
|
||||
predictions_df, history_for_chart_df, prediction_input_df = predictor_function(
|
||||
model=model,
|
||||
checkpoint=checkpoint,
|
||||
product_df=product_df,
|
||||
future_days=future_days,
|
||||
start_date=start_date,
|
||||
history_lookback_days=history_lookback_days
|
||||
)
|
||||
|
||||
# --- 分析与返回部分保持不变 ---
|
||||
analysis = None
|
||||
if analyze_result:
|
||||
try:
|
||||
analysis = analyze_prediction_result(product_id, model_type, y_pred, X)
|
||||
print("\n预测结果分析:")
|
||||
if analysis and 'explanation' in analysis:
|
||||
print(analysis['explanation'])
|
||||
else:
|
||||
print("分析结果不包含explanation字段")
|
||||
analysis = analyze_prediction_result(product_id, loaded_model_type, predictions_df['predicted_sales'].values, prediction_input_df[config.get('features')].values)
|
||||
except Exception as e:
|
||||
print(f"分析预测结果失败: {str(e)}")
|
||||
# 分析失败不影响主要功能,继续执行
|
||||
|
||||
|
||||
history_data_json = history_for_chart_df.to_dict('records') if not history_for_chart_df.empty else []
|
||||
prediction_data_json = predictions_df.to_dict('records') if not predictions_df.empty else []
|
||||
|
||||
return {
|
||||
'product_id': product_id,
|
||||
'product_name': product_name,
|
||||
'model_type': model_type,
|
||||
'predictions': predictions_df,
|
||||
'model_type': loaded_model_type,
|
||||
'predictions': prediction_data_json,
|
||||
'prediction_data': prediction_data_json,
|
||||
'history_data': history_data_json,
|
||||
'analysis': analysis
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"预测过程中出现未捕获的异常: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return None
|
||||
return None
|
992
server/swagger.json
Normal file
992
server/swagger.json
Normal file
@ -0,0 +1,992 @@
|
||||
{
|
||||
"openapi": "3.0.0",
|
||||
"info": {
|
||||
"title": "药店销售预测系统API",
|
||||
"description": "用于药店销售预测的RESTful API",
|
||||
"version": "1.0.0",
|
||||
"contact": {
|
||||
"name": "API开发团队",
|
||||
"email": "support@example.com"
|
||||
}
|
||||
},
|
||||
"tags": [
|
||||
{
|
||||
"name": "数据管理",
|
||||
"description": "数据上传和查询相关接口"
|
||||
},
|
||||
{
|
||||
"name": "模型训练",
|
||||
"description": "模型训练相关接口"
|
||||
},
|
||||
{
|
||||
"name": "模型预测",
|
||||
"description": "预测销售数据相关接口"
|
||||
},
|
||||
{
|
||||
"name": "模型管理",
|
||||
"description": "模型查询、导出和删除接口"
|
||||
}
|
||||
],
|
||||
"paths": {
|
||||
"/api/products": {
|
||||
"get": {
|
||||
"tags": ["数据管理"],
|
||||
"summary": "获取所有产品列表",
|
||||
"description": "返回系统中所有产品的ID和名称",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "成功获取产品列表",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"status": {"type": "string", "example": "success"},
|
||||
"data": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"product_id": {"type": "string", "example": "P001"},
|
||||
"product_name": {"type": "string", "example": "产品A"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"500": {"description": "服务器内部错误"}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/products/{product_id}": {
|
||||
"get": {
|
||||
"tags": ["数据管理"],
|
||||
"summary": "获取单个产品详情",
|
||||
"description": "返回指定产品ID的详细信息",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "product_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": {"type": "string"},
|
||||
"description": "产品ID,例如P001"
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "成功获取产品详情",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"status": {"type": "string", "example": "success"},
|
||||
"data": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"product_id": {"type": "string", "example": "P001"},
|
||||
"product_name": {"type": "string", "example": "产品A"},
|
||||
"data_points": {"type": "integer", "example": 365},
|
||||
"date_range": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"start": {"type": "string", "example": "2023-01-01"},
|
||||
"end": {"type": "string", "example": "2023-12-31"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"404": {"description": "产品不存在"},
|
||||
"500": {"description": "服务器内部错误"}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/products/{product_id}/sales": {
|
||||
"get": {
|
||||
"tags": ["数据管理"],
|
||||
"summary": "获取产品销售数据",
|
||||
"description": "返回指定产品在特定日期范围内的销售数据",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "product_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": {"type": "string"},
|
||||
"description": "产品ID,例如P001"
|
||||
},
|
||||
{
|
||||
"name": "start_date",
|
||||
"in": "query",
|
||||
"schema": {"type": "string"},
|
||||
"description": "开始日期,格式为YYYY-MM-DD"
|
||||
},
|
||||
{
|
||||
"name": "end_date",
|
||||
"in": "query",
|
||||
"schema": {"type": "string"},
|
||||
"description": "结束日期,格式为YYYY-MM-DD"
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "成功获取销售数据",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"status": {"type": "string", "example": "success"},
|
||||
"data": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"date": {"type": "string", "example": "2023-12-01"},
|
||||
"sales": {"type": "integer", "example": 150}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"404": {"description": "产品不存在"},
|
||||
"500": {"description": "服务器内部错误"}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/data/upload": {
|
||||
"post": {
|
||||
"tags": ["数据管理"],
|
||||
"summary": "上传销售数据",
|
||||
"description": "上传新的销售数据文件(Excel格式)",
|
||||
"requestBody": {
|
||||
"content": {
|
||||
"multipart/form-data": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file": {
|
||||
"type": "string",
|
||||
"format": "binary"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "数据上传成功",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"status": {"type": "string", "example": "success"},
|
||||
"message": {"type": "string", "example": "数据上传成功"},
|
||||
"data": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"products": {"type": "integer", "example": 10},
|
||||
"rows": {"type": "integer", "example": 3650}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"400": {"description": "请求错误"},
|
||||
"500": {"description": "服务器内部错误"}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/training": {
|
||||
"get": {
|
||||
"tags": ["模型训练"],
|
||||
"summary": "获取所有训练任务列表",
|
||||
"description": "返回所有正在进行、已完成或失败的训练任务",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "成功获取任务列表",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"status": {"type": "string", "example": "success"},
|
||||
"data": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"task_id": {"type": "string", "example": "uuid-1234"},
|
||||
"product_id": {"type": "string", "example": "P001"},
|
||||
"model_type": {"type": "string", "example": "mlstm"},
|
||||
"status": {"type": "string", "example": "completed"},
|
||||
"start_time": {"type": "string", "example": "2023-12-25T10:00:00Z"},
|
||||
"metrics": {"type": "object", "example": {"R2": 0.95, "RMSE": 5.5}},
|
||||
"error": {"type": "string", "nullable": true},
|
||||
"model_path": {"type": "string", "example": "/path/to/model.pth"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"post": {
|
||||
"tags": ["模型训练"],
|
||||
"summary": "启动模型训练任务",
|
||||
"description": "为指定产品启动一个新的模型训练任务",
|
||||
"requestBody": {
|
||||
"required": true,
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"product_id": {"type": "string", "example": "P001"},
|
||||
"model_type": {"type": "string", "enum": ["mlstm", "transformer", "kan", "optimized_kan", "tcn", "xgboost"]},
|
||||
"store_id": {"type": "string", "example": "S001"},
|
||||
"epochs": {"type": "integer", "default": 50}
|
||||
},
|
||||
"required": ["product_id", "model_type"]
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "训练任务已启动",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"message": {"type": "string", "example": "模型训练已开始"},
|
||||
"task_id": {"type": "string", "example": "new-uuid-5678"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"400": {"description": "请求错误"}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/training/{task_id}": {
|
||||
"get": {
|
||||
"tags": ["模型训练"],
|
||||
"summary": "查询训练任务状态",
|
||||
"description": "获取特定训练任务的当前状态和详情",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "task_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": {"type": "string"},
|
||||
"description": "训练任务ID"
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "成功获取任务状态",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"status": {"type": "string", "example": "success"},
|
||||
"data": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"product_id": {"type": "string", "example": "P001"},
|
||||
"model_type": {"type": "string", "example": "mlstm"},
|
||||
"status": {"type": "string", "example": "running"},
|
||||
"progress": {"type": "number", "example": 50.5},
|
||||
"created_at": {"type": "string", "example": "2023-12-25T10:00:00Z"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"404": {"description": "任务不存在"},
|
||||
"500": {"description": "服务器内部错误"}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/prediction": {
|
||||
"post": {
|
||||
"tags": ["模型预测"],
|
||||
"summary": "使用模型进行预测",
|
||||
"description": "使用指定模型预测未来销售数据",
|
||||
"requestBody": {
|
||||
"required": true,
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"product_id": {"type": "string"},
|
||||
"model_type": {"type": "string", "enum": ["mlstm", "transformer", "kan", "optimized_kan", "tcn"]},
|
||||
"store_id": {"type": "string"},
|
||||
"version": {"type": "string"},
|
||||
"future_days": {"type": "integer"},
|
||||
"include_visualization": {"type": "boolean"},
|
||||
"start_date": {"type": "string"}
|
||||
},
|
||||
"required": ["product_id", "model_type"]
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "预测成功",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"status": {"type": "string", "example": "success"},
|
||||
"data": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"product_id": {"type": "string", "example": "P001"},
|
||||
"product_name": {"type": "string", "example": "产品A"},
|
||||
"model_type": {"type": "string", "example": "mlstm"},
|
||||
"predictions": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"date": {"type": "string", "example": "2024-01-01"},
|
||||
"predicted_sales": {"type": "integer", "example": 100}
|
||||
}
|
||||
}
|
||||
},
|
||||
"visualization": {"type": "string", "example": "base64-encoded-image-string"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"400": {"description": "请求错误"},
|
||||
"404": {"description": "产品或模型不存在"},
|
||||
"500": {"description": "服务器内部错误"}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/prediction/compare": {
|
||||
"post": {
|
||||
"tags": ["模型预测"],
|
||||
"summary": "比较不同模型预测结果",
|
||||
"description": "比较不同模型对同一产品的预测结果",
|
||||
"requestBody": {
|
||||
"required": true,
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"product_id": {"type": "string"},
|
||||
"model_types": {"type": "array", "items": {"type": "string"}},
|
||||
"versions": {"type": "array", "items": {"type": "string"}},
|
||||
"include_visualization": {"type": "boolean"}
|
||||
},
|
||||
"required": ["product_id", "model_types"]
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "比较成功",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"status": {"type": "string", "example": "success"},
|
||||
"data": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"product_id": {"type": "string", "example": "P001"},
|
||||
"comparison": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"date": {"type": "string", "example": "2024-01-01"},
|
||||
"mlstm": {"type": "integer", "example": 100},
|
||||
"transformer": {"type": "integer", "example": 102}
|
||||
}
|
||||
}
|
||||
},
|
||||
"visualization": {"type": "string", "example": "base64-encoded-image-string"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"400": {"description": "请求错误"},
|
||||
"404": {"description": "产品或模型不存在"},
|
||||
"500": {"description": "服务器内部错误"}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/prediction/history": {
|
||||
"get": {
|
||||
"tags": ["模型预测"],
|
||||
"summary": "获取历史预测记录",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "获取成功",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"status": {"type": "string", "example": "success"},
|
||||
"data": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prediction_id": {"type": "string", "example": "pred-uuid-1"},
|
||||
"product_id": {"type": "string", "example": "P001"},
|
||||
"model_type": {"type": "string", "example": "mlstm"},
|
||||
"created_at": {"type": "string", "example": "2023-12-20T11:00:00Z"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/prediction/history/{prediction_id}": {
|
||||
"get": {
|
||||
"tags": ["模型预测"],
|
||||
"summary": "获取特定预测记录的详情",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "prediction_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": {"type": "string"}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "获取成功",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"status": {"type": "string", "example": "success"},
|
||||
"data": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prediction_id": {"type": "string", "example": "pred-uuid-1"},
|
||||
"product_id": {"type": "string", "example": "P001"},
|
||||
"model_type": {"type": "string", "example": "mlstm"},
|
||||
"predictions": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"date": {"type": "string", "example": "2023-12-21"},
|
||||
"predicted_sales": {"type": "integer", "example": 110}
|
||||
}
|
||||
}
|
||||
},
|
||||
"analysis": {"type": "object", "example": {"trend": "upward"}}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"404": {"description": "记录不存在"}
|
||||
}
|
||||
},
|
||||
"delete": {
|
||||
"tags": ["模型预测"],
|
||||
"summary": "删除预测记录",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "prediction_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": {"type": "string"}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "删除成功",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"status": {"type": "string", "example": "success"},
|
||||
"message": {"type": "string", "example": "预测记录已删除"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"404": {"description": "记录不存在"}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/models": {
|
||||
"get": {
|
||||
"tags": ["模型管理"],
|
||||
"summary": "获取模型列表",
|
||||
"parameters": [
|
||||
{"name": "product_id", "in": "query", "schema": {"type": "string"}},
|
||||
{"name": "model_type", "in": "query", "schema": {"type": "string"}}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "获取成功",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"status": {"type": "string", "example": "success"},
|
||||
"data": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"model_id": {"type": "string", "example": "P001_mlstm_v1"},
|
||||
"product_id": {"type": "string", "example": "P001"},
|
||||
"model_type": {"type": "string", "example": "mlstm"},
|
||||
"version": {"type": "string", "example": "v1"},
|
||||
"created_at": {"type": "string", "example": "2023-12-15T09:00:00Z"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/models/{model_id}": {
|
||||
"get": {
|
||||
"tags": ["模型管理"],
|
||||
"summary": "获取模型详情",
|
||||
"parameters": [
|
||||
{"name": "model_id", "in": "path", "required": true, "schema": {"type": "string"}}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "获取成功",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"status": {"type": "string", "example": "success"},
|
||||
"data": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"model_id": {"type": "string", "example": "P001_mlstm_v1"},
|
||||
"product_id": {"type": "string", "example": "P001"},
|
||||
"model_type": {"type": "string", "example": "mlstm"},
|
||||
"version": {"type": "string", "example": "v1"},
|
||||
"metrics": {"type": "object", "example": {"R2": 0.95, "RMSE": 5.5}}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"404": {"description": "模型不存在"}
|
||||
}
|
||||
},
|
||||
"delete": {
|
||||
"tags": ["模型管理"],
|
||||
"summary": "删除模型",
|
||||
"parameters": [
|
||||
{"name": "model_id", "in": "path", "required": true, "schema": {"type": "string"}}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "删除成功",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"status": {"type": "string", "example": "success"},
|
||||
"message": {"type": "string", "example": "模型已删除"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"404": {"description": "模型不存在"}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/models/{model_id}/export": {
|
||||
"get": {
|
||||
"tags": ["模型管理"],
|
||||
"summary": "导出模型",
|
||||
"parameters": [
|
||||
{"name": "model_id", "in": "path", "required": true, "schema": {"type": "string"}}
|
||||
],
|
||||
"responses": {
|
||||
"200": {"description": "模型文件下载"},
|
||||
"404": {"description": "模型不存在"}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/model_types": {
|
||||
"get": {
|
||||
"tags": ["模型管理"],
|
||||
"summary": "获取系统支持的所有模型类型",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "获取成功",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"status": {"type": "string", "example": "success"},
|
||||
"data": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {"type": "string", "example": "mlstm"},
|
||||
"name": {"type": "string", "example": "mLSTM"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/models/{product_id}/{model_type}/versions": {
|
||||
"get": {
|
||||
"tags": ["模型管理"],
|
||||
"summary": "获取模型版本列表",
|
||||
"parameters": [
|
||||
{"name": "product_id", "in": "path", "required": true, "schema": {"type": "string"}},
|
||||
{"name": "model_type", "in": "path", "required": true, "schema": {"type": "string"}}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "获取成功",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"status": {"type": "string", "example": "success"},
|
||||
"data": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"product_id": {"type": "string", "example": "P001"},
|
||||
"model_type": {"type": "string", "example": "mlstm"},
|
||||
"versions": {"type": "array", "items": {"type": "string"}, "example": ["v1", "v2"]},
|
||||
"latest_version": {"type": "string", "example": "v2"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/stores": {
|
||||
"get": {
|
||||
"tags": ["数据管理"],
|
||||
"summary": "获取所有店铺列表",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "获取成功",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"status": {"type": "string", "example": "success"},
|
||||
"data": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"store_id": {"type": "string", "example": "S001"},
|
||||
"store_name": {"type": "string", "example": "第一分店"}
|
||||
}
|
||||
}
|
||||
},
|
||||
"count": {"type": "integer", "example": 2}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"post": {
|
||||
"tags": ["数据管理"],
|
||||
"summary": "创建新店铺",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "创建成功",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"status": {"type": "string", "example": "success"},
|
||||
"message": {"type": "string", "example": "店铺创建成功"},
|
||||
"data": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"store_id": {"type": "string", "example": "S003"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/stores/{store_id}": {
|
||||
"get": {
|
||||
"tags": ["数据管理"],
|
||||
"summary": "获取单个店铺信息",
|
||||
"parameters": [
|
||||
{"name": "store_id", "in": "path", "required": true, "schema": {"type": "string"}}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "获取成功",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"status": {"type": "string", "example": "success"},
|
||||
"data": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"store_id": {"type": "string", "example": "S001"},
|
||||
"store_name": {"type": "string", "example": "第一分店"},
|
||||
"location": {"type": "string", "example": "市中心"},
|
||||
"size": {"type": "number", "example": 120.5},
|
||||
"type": {"type": "string", "example": "旗舰店"},
|
||||
"opening_date": {"type": "string", "example": "2022-01-01"},
|
||||
"status": {"type": "string", "example": "active"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"404": {"description": "店铺不存在"}
|
||||
}
|
||||
},
|
||||
"put": {
|
||||
"tags": ["数据管理"],
|
||||
"summary": "更新店铺信息",
|
||||
"parameters": [
|
||||
{"name": "store_id", "in": "path", "required": true, "schema": {"type": "string"}}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "更新成功",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"status": {"type": "string", "example": "success"},
|
||||
"message": {"type": "string", "example": "店铺更新成功"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"404": {"description": "店铺不存在"}
|
||||
}
|
||||
},
|
||||
"delete": {
|
||||
"tags": ["数据管理"],
|
||||
"summary": "删除店铺",
|
||||
"parameters": [
|
||||
{"name": "store_id", "in": "path", "required": true, "schema": {"type": "string"}}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "删除成功",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"status": {"type": "string", "example": "success"},
|
||||
"message": {"type": "string", "example": "店铺删除成功"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"404": {"description": "店铺不存在"}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/stores/{store_id}/products": {
|
||||
"get": {
|
||||
"tags": ["数据管理"],
|
||||
"summary": "获取店铺的产品列表",
|
||||
"parameters": [
|
||||
{"name": "store_id", "in": "path", "required": true, "schema": {"type": "string"}}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "获取成功",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"status": {"type": "string", "example": "success"},
|
||||
"data": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"product_id": {"type": "string", "example": "P001"},
|
||||
"product_name": {"type": "string", "example": "产品A"}
|
||||
}
|
||||
}
|
||||
},
|
||||
"count": {"type": "integer", "example": 1}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/stores/{store_id}/statistics": {
|
||||
"get": {
|
||||
"tags": ["数据管理"],
|
||||
"summary": "获取店铺销售统计信息",
|
||||
"parameters": [
|
||||
{"name": "store_id", "in": "path", "required": true, "schema": {"type": "string"}}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "获取成功",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"status": {"type": "string", "example": "success"},
|
||||
"data": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"total_sales": {"type": "number", "example": 150000.0},
|
||||
"total_quantity": {"type": "integer", "example": 7500},
|
||||
"products_count": {"type": "integer", "example": 50}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/sales/data": {
|
||||
"get": {
|
||||
"tags": ["数据管理"],
|
||||
"summary": "获取销售数据列表",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "获取成功",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"status": {"type": "string", "example": "success"},
|
||||
"data": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"date": {"type": "string", "example": "2023-12-01"},
|
||||
"store_id": {"type": "string", "example": "S001"},
|
||||
"product_id": {"type": "string", "example": "P001"},
|
||||
"sales": {"type": "integer", "example": 150},
|
||||
"price": {"type": "number", "example": 25.5}
|
||||
}
|
||||
}
|
||||
},
|
||||
"total": {"type": "integer", "example": 100},
|
||||
"page": {"type": "integer", "example": 1},
|
||||
"page_size": {"type": "integer", "example": 1}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -2,18 +2,44 @@
|
||||
药店销售预测系统 - 模型训练模块
|
||||
"""
|
||||
|
||||
from .mlstm_trainer import train_product_model_with_mlstm
|
||||
from .kan_trainer import train_product_model_with_kan
|
||||
from .tcn_trainer import train_product_model_with_tcn
|
||||
from .transformer_trainer import train_product_model_with_transformer
|
||||
import os
|
||||
import glob
|
||||
import importlib
|
||||
|
||||
# 默认训练函数
|
||||
from .mlstm_trainer import train_product_model_with_mlstm as train_product_model
|
||||
_TRAINERS_LOADED = False
|
||||
|
||||
def discover_trainers():
|
||||
"""
|
||||
自动发现并加载所有训练器插件。
|
||||
使用一个标志位确保这个过程只执行一次。
|
||||
"""
|
||||
global _TRAINERS_LOADED
|
||||
if _TRAINERS_LOADED:
|
||||
return
|
||||
|
||||
print("🚀 开始发现并加载训练器插件...")
|
||||
|
||||
package_dir = os.path.dirname(__file__)
|
||||
module_name = __name__
|
||||
|
||||
trainer_files = glob.glob(os.path.join(package_dir, "*_trainer.py"))
|
||||
|
||||
for f in trainer_files:
|
||||
base_name = os.path.basename(f)
|
||||
if base_name.startswith('__'):
|
||||
continue
|
||||
|
||||
module_stem = base_name.replace('.py', '')
|
||||
|
||||
try:
|
||||
# 动态导入模块以触发自注册
|
||||
importlib.import_module(f".{module_stem}", package=module_name)
|
||||
except ImportError as e:
|
||||
print(f"⚠️ 加载训练器 {module_stem} 失败: {e}")
|
||||
|
||||
_TRAINERS_LOADED = True
|
||||
print("✅ 所有训练器插件加载完成。")
|
||||
|
||||
# 在包被首次导入时,自动执行发现过程
|
||||
discover_trainers()
|
||||
|
||||
__all__ = [
|
||||
'train_product_model',
|
||||
'train_product_model_with_mlstm',
|
||||
'train_product_model_with_kan',
|
||||
'train_product_model_with_tcn',
|
||||
'train_product_model_with_transformer'
|
||||
]
|
||||
|
186
server/trainers/cnn_bilstm_attention_trainer.py
Normal file
186
server/trainers/cnn_bilstm_attention_trainer.py
Normal file
@ -0,0 +1,186 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
CNN-BiLSTM-Attention 模型训练器
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
import numpy as np
|
||||
import time
|
||||
import copy
|
||||
|
||||
from models.model_registry import register_trainer
|
||||
from utils.model_manager import model_manager
|
||||
from analysis.metrics import evaluate_model
|
||||
from utils.data_utils import create_dataset
|
||||
from sklearn.preprocessing import MinMaxScaler
|
||||
from utils.visualization import plot_loss_curve # 导入绘图函数
|
||||
|
||||
# 导入新创建的模型
|
||||
from models.cnn_bilstm_attention import CnnBiLstmAttention
|
||||
|
||||
def train_with_cnn_bilstm_attention(product_id, model_identifier, product_df, store_id, training_mode, aggregation_method, epochs, sequence_length, forecast_horizon, model_dir, **kwargs):
|
||||
"""
|
||||
使用 CNN-BiLSTM-Attention 模型进行训练,并实现早停和最佳模型保存。
|
||||
"""
|
||||
print(f"🚀 CNN-BiLSTM-Attention 训练器启动: model_identifier='{model_identifier}'")
|
||||
start_time = time.time()
|
||||
|
||||
# --- 1. 数据准备 ---
|
||||
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||
product_name = product_df['product_name'].iloc[0] if 'product_name' in product_df.columns else model_identifier
|
||||
|
||||
X = product_df[features].values
|
||||
y = product_df[['sales']].values
|
||||
|
||||
scaler_X = MinMaxScaler(feature_range=(0, 1))
|
||||
scaler_y = MinMaxScaler(feature_range=(0, 1))
|
||||
|
||||
X_scaled = scaler_X.fit_transform(X)
|
||||
y_scaled = scaler_y.fit_transform(y)
|
||||
|
||||
train_size = int(len(X_scaled) * 0.8)
|
||||
X_train_raw, X_test_raw = X_scaled[:train_size], X_scaled[train_size:]
|
||||
y_train_raw, y_test_raw = y_scaled[:train_size], y_scaled[train_size:]
|
||||
|
||||
trainX, trainY = create_dataset(X_train_raw, y_train_raw, sequence_length, forecast_horizon)
|
||||
testX, testY = create_dataset(X_test_raw, y_test_raw, sequence_length, forecast_horizon)
|
||||
|
||||
trainX = torch.from_numpy(trainX).float()
|
||||
trainY = torch.from_numpy(trainY).float()
|
||||
testX = torch.from_numpy(testX).float()
|
||||
testY = torch.from_numpy(testY).float()
|
||||
|
||||
# --- 2. 实例化模型和优化器 ---
|
||||
input_dim = trainX.shape[2]
|
||||
|
||||
model = CnnBiLstmAttention(
|
||||
input_dim=input_dim,
|
||||
output_dim=forecast_horizon,
|
||||
sequence_length=sequence_length
|
||||
)
|
||||
|
||||
optimizer = optim.Adam(model.parameters(), lr=kwargs.get('learning_rate', 0.001))
|
||||
criterion = torch.nn.MSELoss()
|
||||
|
||||
# --- 3. 训练循环与早停 ---
|
||||
print("开始训练 CNN-BiLSTM-Attention 模型 (含早停)...")
|
||||
loss_history = {'train': [], 'val': []}
|
||||
best_val_loss = float('inf')
|
||||
best_model_state = None
|
||||
patience = kwargs.get('patience', 15)
|
||||
patience_counter = 0
|
||||
|
||||
for epoch in range(epochs):
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
|
||||
outputs = model(trainX)
|
||||
train_loss = criterion(outputs, trainY.squeeze(-1))
|
||||
|
||||
train_loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# 验证
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
val_outputs = model(testX)
|
||||
val_loss = criterion(val_outputs, testY.squeeze(-1))
|
||||
|
||||
loss_history['train'].append(train_loss.item())
|
||||
loss_history['val'].append(val_loss.item())
|
||||
|
||||
if (epoch + 1) % 10 == 0:
|
||||
print(f'Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss.item():.4f}, Val Loss: {val_loss.item():.4f}')
|
||||
|
||||
# 早停逻辑
|
||||
if val_loss.item() < best_val_loss:
|
||||
best_val_loss = val_loss.item()
|
||||
best_model_state = copy.deepcopy(model.state_dict())
|
||||
patience_counter = 0
|
||||
print(f"✨ 新的最佳模型! Epoch: {epoch+1}, Val Loss: {best_val_loss:.4f}")
|
||||
else:
|
||||
patience_counter += 1
|
||||
if patience_counter >= patience:
|
||||
print(f"🚫 早停触发! 在 epoch {epoch+1} 停止。")
|
||||
break
|
||||
|
||||
training_time = time.time() - start_time
|
||||
print(f"模型训练完成,耗时: {training_time:.2f}秒")
|
||||
|
||||
# --- 4. 使用最佳模型进行评估 ---
|
||||
if best_model_state:
|
||||
model.load_state_dict(best_model_state)
|
||||
print("最佳模型已加载用于最终评估。")
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
test_pred_scaled = model(testX)
|
||||
|
||||
test_pred_unscaled = scaler_y.inverse_transform(test_pred_scaled.numpy())
|
||||
test_true_unscaled = scaler_y.inverse_transform(testY.squeeze(-1).numpy())
|
||||
|
||||
metrics = evaluate_model(test_true_unscaled.flatten(), test_pred_unscaled.flatten())
|
||||
metrics['training_time'] = training_time
|
||||
metrics['best_val_loss'] = best_val_loss
|
||||
metrics['stopped_epoch'] = epoch + 1
|
||||
|
||||
print("\n最佳模型评估指标:")
|
||||
print(f"MSE: {metrics['mse']:.4f}, RMSE: {metrics['rmse']:.4f}, MAE: {metrics['mae']:.4f}, R²: {metrics['r2']:.4f}, MAPE: {metrics['mape']:.2f}%")
|
||||
|
||||
# 绘制损失曲线
|
||||
loss_curve_path = plot_loss_curve(
|
||||
loss_history['train'],
|
||||
loss_history['val'],
|
||||
product_name,
|
||||
'cnn_bilstm_attention',
|
||||
model_dir=model_dir
|
||||
)
|
||||
print(f"📈 损失曲线已保存到: {loss_curve_path}")
|
||||
|
||||
# --- 5. 模型保存 ---
|
||||
model_data = {
|
||||
'model_state_dict': best_model_state, # 保存最佳模型的状态
|
||||
'scaler_X': scaler_X,
|
||||
'scaler_y': scaler_y,
|
||||
'config': {
|
||||
'model_type': 'cnn_bilstm_attention',
|
||||
'input_dim': input_dim,
|
||||
'output_dim': forecast_horizon,
|
||||
'sequence_length': sequence_length,
|
||||
'features': features
|
||||
},
|
||||
'metrics': metrics,
|
||||
'loss_history': loss_history, # 保存损失历史
|
||||
'loss_curve_path': loss_curve_path # 添加损失图路径
|
||||
}
|
||||
|
||||
# 保存最终版本模型
|
||||
final_model_path, final_version = model_manager.save_model(
|
||||
model_data=model_data,
|
||||
product_id=product_id,
|
||||
model_type='cnn_bilstm_attention',
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
product_name=product_name
|
||||
)
|
||||
print(f"✅ CNN-BiLSTM-Attention 最终模型已保存,版本: {final_version}")
|
||||
|
||||
# 保存最佳版本模型
|
||||
best_model_path, best_version = model_manager.save_model(
|
||||
model_data=model_data,
|
||||
product_id=product_id,
|
||||
model_type='cnn_bilstm_attention',
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
product_name=product_name,
|
||||
version='best' # 明确指定版本为 'best'
|
||||
)
|
||||
print(f"✅ CNN-BiLSTM-Attention 最佳模型已保存,版本: {best_version}")
|
||||
|
||||
return model, metrics, final_version, final_model_path
|
||||
|
||||
# --- 关键步骤: 将训练器注册到系统中 ---
|
||||
register_trainer('cnn_bilstm_attention', train_with_cnn_bilstm_attention)
|
@ -21,7 +21,7 @@ from utils.visualization import plot_loss_curve
|
||||
from analysis.metrics import evaluate_model
|
||||
from core.config import DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON
|
||||
|
||||
def train_product_model_with_kan(product_id, store_id=None, training_mode='product', aggregation_method='sum', epochs=50, use_optimized=False, model_dir=DEFAULT_MODEL_DIR):
|
||||
def train_product_model_with_kan(product_id, model_identifier, product_df=None, store_id=None, training_mode='product', aggregation_method='sum', epochs=50, sequence_length=LOOK_BACK, forecast_horizon=FORECAST_HORIZON, use_optimized=False, model_dir=DEFAULT_MODEL_DIR):
|
||||
"""
|
||||
使用KAN模型训练产品销售预测模型
|
||||
|
||||
@ -35,46 +35,55 @@ def train_product_model_with_kan(product_id, store_id=None, training_mode='produ
|
||||
model: 训练好的模型
|
||||
metrics: 模型评估指标
|
||||
"""
|
||||
# 根据训练模式加载数据
|
||||
from utils.multi_store_data_utils import load_multi_store_data, get_store_product_sales_data, aggregate_multi_store_data
|
||||
|
||||
try:
|
||||
# 如果没有传入product_df,则根据训练模式加载数据
|
||||
if product_df is None:
|
||||
from utils.multi_store_data_utils import load_multi_store_data, get_store_product_sales_data, aggregate_multi_store_data
|
||||
|
||||
try:
|
||||
if training_mode == 'store' and store_id:
|
||||
# 加载特定店铺的数据
|
||||
product_df = get_store_product_sales_data(
|
||||
store_id,
|
||||
product_id,
|
||||
'pharmacy_sales_multi_store.csv'
|
||||
)
|
||||
training_scope = f"店铺 {store_id}"
|
||||
elif training_mode == 'global':
|
||||
# 聚合所有店铺的数据
|
||||
product_df = aggregate_multi_store_data(
|
||||
product_id,
|
||||
aggregation_method=aggregation_method,
|
||||
file_path='pharmacy_sales_multi_store.csv'
|
||||
)
|
||||
training_scope = f"全局聚合({aggregation_method})"
|
||||
else:
|
||||
# 默认:加载所有店铺的产品数据
|
||||
product_df = load_multi_store_data('pharmacy_sales_multi_store.csv', product_id=product_id)
|
||||
training_scope = "所有店铺"
|
||||
except Exception as e:
|
||||
print(f"多店铺数据加载失败: {e}")
|
||||
# 后备方案:尝试原始数据
|
||||
df = pd.read_excel('pharmacy_sales.xlsx')
|
||||
product_df = df[df['product_id'] == product_id].sort_values('date')
|
||||
training_scope = "原始数据"
|
||||
else:
|
||||
# 如果传入了product_df,直接使用
|
||||
if training_mode == 'store' and store_id:
|
||||
# 加载特定店铺的数据
|
||||
product_df = get_store_product_sales_data(
|
||||
store_id,
|
||||
product_id,
|
||||
'pharmacy_sales_multi_store.csv'
|
||||
)
|
||||
training_scope = f"店铺 {store_id}"
|
||||
elif training_mode == 'global':
|
||||
# 聚合所有店铺的数据
|
||||
product_df = aggregate_multi_store_data(
|
||||
product_id,
|
||||
aggregation_method=aggregation_method,
|
||||
file_path='pharmacy_sales_multi_store.csv'
|
||||
)
|
||||
training_scope = f"全局聚合({aggregation_method})"
|
||||
else:
|
||||
# 默认:加载所有店铺的产品数据
|
||||
product_df = load_multi_store_data('pharmacy_sales_multi_store.csv', product_id=product_id)
|
||||
training_scope = "所有店铺"
|
||||
except Exception as e:
|
||||
print(f"多店铺数据加载失败: {e}")
|
||||
# 后备方案:尝试原始数据
|
||||
df = pd.read_excel('pharmacy_sales.xlsx')
|
||||
product_df = df[df['product_id'] == product_id].sort_values('date')
|
||||
training_scope = "原始数据"
|
||||
|
||||
if product_df.empty:
|
||||
raise ValueError(f"产品 {product_id} 没有可用的销售数据")
|
||||
|
||||
# 数据量检查
|
||||
min_required_samples = LOOK_BACK + FORECAST_HORIZON
|
||||
min_required_samples = sequence_length + forecast_horizon
|
||||
if len(product_df) < min_required_samples:
|
||||
error_msg = (
|
||||
f"❌ 训练数据不足错误\n"
|
||||
f"当前配置需要: {min_required_samples} 天数据 (LOOK_BACK={LOOK_BACK} + FORECAST_HORIZON={FORECAST_HORIZON})\n"
|
||||
f"当前配置需要: {min_required_samples} 天数据 (LOOK_BACK={sequence_length} + FORECAST_HORIZON={forecast_horizon})\n"
|
||||
f"实际数据量: {len(product_df)} 天\n"
|
||||
f"产品ID: {product_id}, 训练模式: {training_mode}\n"
|
||||
f"建议解决方案:\n"
|
||||
@ -95,7 +104,7 @@ def train_product_model_with_kan(product_id, store_id=None, training_mode='produ
|
||||
print(f"模型将保存到目录: {model_dir}")
|
||||
|
||||
# 创建特征和目标变量
|
||||
features = ['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||
|
||||
# 预处理数据
|
||||
X = product_df[features].values
|
||||
@ -114,8 +123,8 @@ def train_product_model_with_kan(product_id, store_id=None, training_mode='produ
|
||||
y_train, y_test = y_scaled[:train_size], y_scaled[train_size:]
|
||||
|
||||
# 创建时间序列数据
|
||||
trainX, trainY = create_dataset(X_train, y_train, LOOK_BACK, FORECAST_HORIZON)
|
||||
testX, testY = create_dataset(X_test, y_test, LOOK_BACK, FORECAST_HORIZON)
|
||||
trainX, trainY = create_dataset(X_train, y_train, sequence_length, forecast_horizon)
|
||||
testX, testY = create_dataset(X_test, y_test, sequence_length, forecast_horizon)
|
||||
|
||||
# 转换为PyTorch的Tensor
|
||||
trainX_tensor = torch.Tensor(trainX)
|
||||
@ -133,7 +142,7 @@ def train_product_model_with_kan(product_id, store_id=None, training_mode='produ
|
||||
|
||||
# 初始化KAN模型
|
||||
input_dim = X_train.shape[1]
|
||||
output_dim = FORECAST_HORIZON
|
||||
output_dim = forecast_horizon
|
||||
hidden_size = 64
|
||||
|
||||
if use_optimized:
|
||||
@ -159,6 +168,7 @@ def train_product_model_with_kan(product_id, store_id=None, training_mode='produ
|
||||
train_losses = []
|
||||
test_losses = []
|
||||
start_time = time.time()
|
||||
best_loss = float('inf')
|
||||
|
||||
for epoch in range(epochs):
|
||||
model.train()
|
||||
@ -216,6 +226,43 @@ def train_product_model_with_kan(product_id, store_id=None, training_mode='produ
|
||||
|
||||
test_loss = test_loss / len(test_loader)
|
||||
test_losses.append(test_loss)
|
||||
|
||||
# 检查是否为最佳模型
|
||||
model_type_name = 'optimized_kan' if use_optimized else 'kan'
|
||||
if test_loss < best_loss:
|
||||
best_loss = test_loss
|
||||
print(f"🎉 新的最佳模型发现在 epoch {epoch+1},测试损失: {test_loss:.4f}")
|
||||
|
||||
# 为保存最佳模型准备数据
|
||||
best_model_data = {
|
||||
'model_state_dict': model.state_dict(),
|
||||
'scaler_X': scaler_X,
|
||||
'scaler_y': scaler_y,
|
||||
'config': {
|
||||
'input_dim': input_dim,
|
||||
'output_dim': output_dim,
|
||||
'hidden_size': hidden_size,
|
||||
'hidden_sizes': [hidden_size, hidden_size * 2, hidden_size],
|
||||
'sequence_length': sequence_length,
|
||||
'forecast_horizon': forecast_horizon,
|
||||
'model_type': model_type_name,
|
||||
'use_optimized': use_optimized
|
||||
},
|
||||
'epoch': epoch + 1
|
||||
}
|
||||
|
||||
# 使用模型管理器保存 'best' 版本
|
||||
from utils.model_manager import model_manager
|
||||
model_manager.save_model(
|
||||
model_data=best_model_data,
|
||||
product_id=model_identifier, # 修正:使用唯一的标识符
|
||||
model_type=model_type_name,
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
product_name=product_name,
|
||||
version='best' # 显式覆盖版本为'best'
|
||||
)
|
||||
|
||||
if (epoch + 1) % 10 == 0:
|
||||
print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")
|
||||
@ -273,9 +320,9 @@ def train_product_model_with_kan(product_id, store_id=None, training_mode='produ
|
||||
'input_dim': input_dim,
|
||||
'output_dim': output_dim,
|
||||
'hidden_size': hidden_size,
|
||||
'hidden_sizes': [hidden_size, hidden_size*2, hidden_size],
|
||||
'sequence_length': LOOK_BACK,
|
||||
'forecast_horizon': FORECAST_HORIZON,
|
||||
'hidden_sizes': [hidden_size, hidden_size * 2, hidden_size],
|
||||
'sequence_length': sequence_length,
|
||||
'forecast_horizon': forecast_horizon,
|
||||
'model_type': model_type_name,
|
||||
'use_optimized': use_optimized
|
||||
},
|
||||
@ -288,15 +335,23 @@ def train_product_model_with_kan(product_id, store_id=None, training_mode='produ
|
||||
'loss_curve_path': loss_curve_path
|
||||
}
|
||||
|
||||
model_path = model_manager.save_model(
|
||||
# 保存最终模型,让 model_manager 自动处理版本号
|
||||
final_model_path, final_version = model_manager.save_model(
|
||||
model_data=model_data,
|
||||
product_id=product_id,
|
||||
product_id=model_identifier, # 修正:使用唯一的标识符
|
||||
model_type=model_type_name,
|
||||
version='v1', # KAN训练器默认使用v1
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
product_name=product_name
|
||||
# 注意:此处不传递version参数,由管理器自动生成
|
||||
)
|
||||
|
||||
return model, metrics
|
||||
print(f"最终模型已保存,版本: {final_version}, 路径: {final_model_path}")
|
||||
|
||||
return model, metrics
|
||||
|
||||
# --- 将此训练器注册到系统中 ---
|
||||
from models.model_registry import register_trainer
|
||||
register_trainer('kan', train_product_model_with_kan)
|
||||
register_trainer('optimized_kan', train_product_model_with_kan)
|
@ -20,102 +20,30 @@ from utils.multi_store_data_utils import get_store_product_sales_data, aggregate
|
||||
from utils.visualization import plot_loss_curve
|
||||
from analysis.metrics import evaluate_model
|
||||
from core.config import (
|
||||
DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON,
|
||||
get_next_model_version, get_model_file_path, get_latest_model_version
|
||||
DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON
|
||||
)
|
||||
from utils.training_progress import progress_manager
|
||||
|
||||
def save_checkpoint(checkpoint_data: dict, epoch_or_label, product_id: str,
|
||||
model_type: str, model_dir: str, store_id=None,
|
||||
training_mode: str = 'product', aggregation_method=None):
|
||||
"""
|
||||
保存训练检查点
|
||||
|
||||
Args:
|
||||
checkpoint_data: 检查点数据
|
||||
epoch_or_label: epoch编号或标签(如'best')
|
||||
product_id: 产品ID
|
||||
model_type: 模型类型
|
||||
model_dir: 模型保存目录
|
||||
store_id: 店铺ID
|
||||
training_mode: 训练模式
|
||||
aggregation_method: 聚合方法
|
||||
"""
|
||||
# 创建检查点目录
|
||||
checkpoint_dir = os.path.join(model_dir, 'checkpoints')
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
# 生成检查点文件名
|
||||
if training_mode == 'store' and store_id:
|
||||
filename = f"{model_type}_store_{store_id}_{product_id}_epoch_{epoch_or_label}.pth"
|
||||
elif training_mode == 'global' and aggregation_method:
|
||||
filename = f"{model_type}_global_{product_id}_{aggregation_method}_epoch_{epoch_or_label}.pth"
|
||||
else:
|
||||
filename = f"{model_type}_product_{product_id}_epoch_{epoch_or_label}.pth"
|
||||
|
||||
checkpoint_path = os.path.join(checkpoint_dir, filename)
|
||||
|
||||
# 保存检查点
|
||||
torch.save(checkpoint_data, checkpoint_path)
|
||||
print(f"[mLSTM] 检查点已保存: {checkpoint_path}", flush=True)
|
||||
|
||||
return checkpoint_path
|
||||
|
||||
|
||||
def load_checkpoint(product_id: str, model_type: str, epoch_or_label,
|
||||
model_dir: str, store_id=None, training_mode: str = 'product',
|
||||
aggregation_method=None):
|
||||
"""
|
||||
加载训练检查点
|
||||
|
||||
Args:
|
||||
product_id: 产品ID
|
||||
model_type: 模型类型
|
||||
epoch_or_label: epoch编号或标签
|
||||
model_dir: 模型保存目录
|
||||
store_id: 店铺ID
|
||||
training_mode: 训练模式
|
||||
aggregation_method: 聚合方法
|
||||
|
||||
Returns:
|
||||
checkpoint_data: 检查点数据,如果未找到返回None
|
||||
"""
|
||||
checkpoint_dir = os.path.join(model_dir, 'checkpoints')
|
||||
|
||||
# 生成检查点文件名
|
||||
if training_mode == 'store' and store_id:
|
||||
filename = f"{model_type}_store_{store_id}_{product_id}_epoch_{epoch_or_label}.pth"
|
||||
elif training_mode == 'global' and aggregation_method:
|
||||
filename = f"{model_type}_global_{product_id}_{aggregation_method}_epoch_{epoch_or_label}.pth"
|
||||
else:
|
||||
filename = f"{model_type}_product_{product_id}_epoch_{epoch_or_label}.pth"
|
||||
|
||||
checkpoint_path = os.path.join(checkpoint_dir, filename)
|
||||
|
||||
if os.path.exists(checkpoint_path):
|
||||
try:
|
||||
checkpoint_data = torch.load(checkpoint_path, map_location=DEVICE)
|
||||
print(f"[mLSTM] 检查点已加载: {checkpoint_path}", flush=True)
|
||||
return checkpoint_data
|
||||
except Exception as e:
|
||||
print(f"[mLSTM] 加载检查点失败: {e}", flush=True)
|
||||
return None
|
||||
else:
|
||||
print(f"[mLSTM] 检查点文件不存在: {checkpoint_path}", flush=True)
|
||||
return None
|
||||
from utils.model_manager import model_manager
|
||||
|
||||
def train_product_model_with_mlstm(
|
||||
product_id,
|
||||
product_id,
|
||||
model_identifier,
|
||||
product_df,
|
||||
store_id=None,
|
||||
training_mode='product',
|
||||
aggregation_method='sum',
|
||||
epochs=50,
|
||||
model_dir=DEFAULT_MODEL_DIR,
|
||||
epochs=50,
|
||||
sequence_length=LOOK_BACK,
|
||||
forecast_horizon=FORECAST_HORIZON,
|
||||
model_dir=DEFAULT_MODEL_DIR,
|
||||
version=None,
|
||||
socketio=None,
|
||||
task_id=None,
|
||||
continue_training=False,
|
||||
progress_callback=None
|
||||
progress_callback=None,
|
||||
patience=10,
|
||||
learning_rate=0.001,
|
||||
clip_norm=1.0
|
||||
):
|
||||
"""
|
||||
使用mLSTM训练产品销售预测模型
|
||||
@ -169,19 +97,10 @@ def train_product_model_with_mlstm(
|
||||
|
||||
emit_progress("开始mLSTM模型训练...")
|
||||
|
||||
# 根据训练模式加载数据
|
||||
from utils.multi_store_data_utils import load_multi_store_data, get_store_product_sales_data, aggregate_multi_store_data
|
||||
|
||||
# 确定版本号
|
||||
if version is None:
|
||||
if continue_training:
|
||||
version = get_latest_model_version(product_id, 'mlstm')
|
||||
if version is None:
|
||||
version = get_next_model_version(product_id, 'mlstm')
|
||||
else:
|
||||
version = get_next_model_version(product_id, 'mlstm')
|
||||
|
||||
emit_progress(f"开始训练 mLSTM 模型版本 {version}")
|
||||
emit_progress(f"开始训练 mLSTM 模型")
|
||||
if version:
|
||||
emit_progress(f"使用指定版本: {version}")
|
||||
|
||||
# 初始化训练进度管理器(如果还未初始化)
|
||||
if socketio and task_id:
|
||||
@ -204,42 +123,21 @@ def train_product_model_with_mlstm(
|
||||
print(f"[mLSTM] 任务 {task_id}: 使用现有进度管理器", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[mLSTM] 任务 {task_id}: 进度管理器初始化失败: {e}", flush=True)
|
||||
|
||||
# 根据训练模式加载数据
|
||||
try:
|
||||
if training_mode == 'store' and store_id:
|
||||
# 加载特定店铺的数据
|
||||
product_df = get_store_product_sales_data(
|
||||
store_id,
|
||||
product_id,
|
||||
'pharmacy_sales_multi_store.csv'
|
||||
)
|
||||
training_scope = f"店铺 {store_id}"
|
||||
elif training_mode == 'global':
|
||||
# 聚合所有店铺的数据
|
||||
product_df = aggregate_multi_store_data(
|
||||
product_id,
|
||||
aggregation_method=aggregation_method,
|
||||
file_path='pharmacy_sales_multi_store.csv'
|
||||
)
|
||||
training_scope = f"全局聚合({aggregation_method})"
|
||||
else:
|
||||
# 默认:加载所有店铺的产品数据
|
||||
product_df = load_multi_store_data('pharmacy_sales_multi_store.csv', product_id=product_id)
|
||||
training_scope = "所有店铺"
|
||||
except Exception as e:
|
||||
print(f"多店铺数据加载失败: {e}")
|
||||
# 后备方案:尝试原始数据
|
||||
df = pd.read_excel('pharmacy_sales.xlsx')
|
||||
product_df = df[df['product_id'] == product_id].sort_values(by='date')
|
||||
training_scope = "原始数据"
|
||||
|
||||
# 数据现在由调用方传入,不再在此处加载
|
||||
if training_mode == 'store' and store_id:
|
||||
training_scope = f"店铺 {store_id}"
|
||||
elif training_mode == 'global':
|
||||
training_scope = f"全局聚合({aggregation_method})"
|
||||
else:
|
||||
training_scope = "所有店铺"
|
||||
|
||||
# 数据量检查
|
||||
min_required_samples = LOOK_BACK + FORECAST_HORIZON
|
||||
min_required_samples = sequence_length + forecast_horizon
|
||||
if len(product_df) < min_required_samples:
|
||||
error_msg = (
|
||||
f"❌ 训练数据不足错误\n"
|
||||
f"当前配置需要: {min_required_samples} 天数据 (LOOK_BACK={LOOK_BACK} + FORECAST_HORIZON={FORECAST_HORIZON})\n"
|
||||
f"当前配置需要: {min_required_samples} 天数据 (LOOK_BACK={sequence_length} + FORECAST_HORIZON={forecast_horizon})\n"
|
||||
f"实际数据量: {len(product_df)} 天\n"
|
||||
f"产品ID: {product_id}, 训练模式: {training_mode}\n"
|
||||
f"建议解决方案:\n"
|
||||
@ -255,7 +153,7 @@ def train_product_model_with_mlstm(
|
||||
|
||||
print(f"[mLSTM] 使用mLSTM模型训练产品 '{product_name}' (ID: {product_id}) 的销售预测模型", flush=True)
|
||||
print(f"[mLSTM] 训练范围: {training_scope}", flush=True)
|
||||
print(f"[mLSTM] 版本: {version}", flush=True)
|
||||
# print(f"[mLSTM] 版本: {version}", flush=True) # Version is now handled by model_manager
|
||||
print(f"[mLSTM] 使用设备: {DEVICE}", flush=True)
|
||||
print(f"[mLSTM] 模型将保存到目录: {model_dir}", flush=True)
|
||||
print(f"[mLSTM] 数据量: {len(product_df)} 条记录", flush=True)
|
||||
@ -263,7 +161,7 @@ def train_product_model_with_mlstm(
|
||||
emit_progress(f"训练产品: {product_name} (ID: {product_id}) - {training_scope}")
|
||||
|
||||
# 创建特征和目标变量
|
||||
features = ['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||
|
||||
print(f"[mLSTM] 开始数据预处理,特征: {features}", flush=True)
|
||||
|
||||
@ -289,8 +187,8 @@ def train_product_model_with_mlstm(
|
||||
y_train, y_test = y_scaled[:train_size], y_scaled[train_size:]
|
||||
|
||||
# 创建时间序列数据
|
||||
trainX, trainY = create_dataset(X_train, y_train, LOOK_BACK, FORECAST_HORIZON)
|
||||
testX, testY = create_dataset(X_test, y_test, LOOK_BACK, FORECAST_HORIZON)
|
||||
trainX, trainY = create_dataset(X_train, y_train, sequence_length, forecast_horizon)
|
||||
testX, testY = create_dataset(X_test, y_test, sequence_length, forecast_horizon)
|
||||
|
||||
# 转换为PyTorch的Tensor
|
||||
trainX_tensor = torch.Tensor(trainX)
|
||||
@ -315,7 +213,7 @@ def train_product_model_with_mlstm(
|
||||
|
||||
# 初始化mLSTM结合Transformer模型
|
||||
input_dim = X_train.shape[1]
|
||||
output_dim = FORECAST_HORIZON
|
||||
output_dim = forecast_horizon
|
||||
hidden_size = 128
|
||||
num_heads = 4
|
||||
dropout_rate = 0.1
|
||||
@ -344,23 +242,16 @@ def train_product_model_with_mlstm(
|
||||
|
||||
# 如果是继续训练,加载现有模型
|
||||
if continue_training and version != 'v1':
|
||||
try:
|
||||
existing_model_path = get_model_file_path(product_id, 'mlstm', version)
|
||||
if os.path.exists(existing_model_path):
|
||||
checkpoint = torch.load(existing_model_path, map_location=DEVICE)
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
print(f"加载现有模型: {existing_model_path}")
|
||||
emit_progress(f"加载现有模型版本 {version} 进行继续训练")
|
||||
except Exception as e:
|
||||
print(f"无法加载现有模型,将重新开始训练: {e}")
|
||||
emit_progress("无法加载现有模型,重新开始训练")
|
||||
# TODO: Implement continue_training logic with the new model_manager
|
||||
pass
|
||||
|
||||
# 将模型移动到设备上
|
||||
model = model.to(DEVICE)
|
||||
|
||||
criterion = nn.MSELoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||
|
||||
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=patience // 2, factor=0.5)
|
||||
|
||||
emit_progress("数据预处理完成,开始模型训练...", progress=10)
|
||||
|
||||
# 训练模型
|
||||
@ -371,8 +262,9 @@ def train_product_model_with_mlstm(
|
||||
# 配置检查点保存
|
||||
checkpoint_interval = max(1, epochs // 10) # 每10%进度保存一次,最少每1个epoch
|
||||
best_loss = float('inf')
|
||||
epochs_no_improve = 0
|
||||
|
||||
emit_progress(f"开始训练 - 总epoch: {epochs}, 检查点间隔: {checkpoint_interval}")
|
||||
emit_progress(f"开始训练 - 总epoch: {epochs}, 检查点间隔: {checkpoint_interval}, 耐心值: {patience}")
|
||||
|
||||
for epoch in range(epochs):
|
||||
emit_progress(f"开始训练 Epoch {epoch+1}/{epochs}")
|
||||
@ -384,9 +276,6 @@ def train_product_model_with_mlstm(
|
||||
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
|
||||
|
||||
# 确保目标张量有正确的形状
|
||||
if y_batch.dim() == 2:
|
||||
y_batch = y_batch.unsqueeze(-1)
|
||||
|
||||
# 前向传播
|
||||
outputs = model(X_batch)
|
||||
loss = criterion(outputs, y_batch)
|
||||
@ -394,6 +283,8 @@ def train_product_model_with_mlstm(
|
||||
# 反向传播和优化
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
if clip_norm:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_norm)
|
||||
optimizer.step()
|
||||
|
||||
epoch_loss += loss.item()
|
||||
@ -409,10 +300,6 @@ def train_product_model_with_mlstm(
|
||||
for batch_idx, (X_batch, y_batch) in enumerate(test_loader):
|
||||
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
|
||||
|
||||
# 确保目标张量有正确的形状
|
||||
if y_batch.dim() == 2:
|
||||
y_batch = y_batch.unsqueeze(-1)
|
||||
|
||||
outputs = model(X_batch)
|
||||
loss = criterion(outputs, y_batch)
|
||||
test_loss += loss.item()
|
||||
@ -420,6 +307,9 @@ def train_product_model_with_mlstm(
|
||||
test_loss = test_loss / len(test_loader)
|
||||
test_losses.append(test_loss)
|
||||
|
||||
# 更新学习率
|
||||
scheduler.step(test_loss)
|
||||
|
||||
# 计算总体训练进度
|
||||
epoch_progress = ((epoch + 1) / epochs) * 90 + 10 # 10-100% 范围
|
||||
|
||||
@ -452,12 +342,13 @@ def train_product_model_with_mlstm(
|
||||
'output_dim': output_dim,
|
||||
'hidden_size': hidden_size,
|
||||
'num_heads': num_heads,
|
||||
'dropout': dropout_rate,
|
||||
'dropout_rate': dropout_rate,
|
||||
'num_blocks': num_blocks,
|
||||
'embed_dim': embed_dim,
|
||||
'dense_dim': dense_dim,
|
||||
'sequence_length': LOOK_BACK,
|
||||
'forecast_horizon': FORECAST_HORIZON,
|
||||
'mlstm_layers': 2, # 确保这个参数被保存
|
||||
'sequence_length': sequence_length,
|
||||
'forecast_horizon': forecast_horizon,
|
||||
'model_type': 'mlstm'
|
||||
},
|
||||
'training_info': {
|
||||
@ -471,21 +362,31 @@ def train_product_model_with_mlstm(
|
||||
}
|
||||
}
|
||||
|
||||
# 保存检查点
|
||||
save_checkpoint(checkpoint_data, epoch + 1, product_id, 'mlstm',
|
||||
model_dir, store_id, training_mode, aggregation_method)
|
||||
|
||||
# 如果是最佳模型,额外保存一份
|
||||
if test_loss < best_loss:
|
||||
best_loss = test_loss
|
||||
save_checkpoint(checkpoint_data, 'best', product_id, 'mlstm',
|
||||
model_dir, store_id, training_mode, aggregation_method)
|
||||
model_manager.save_model(
|
||||
model_data=checkpoint_data,
|
||||
product_id=model_identifier, # 修正:使用唯一的标识符
|
||||
model_type='mlstm',
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
product_name=product_name,
|
||||
version='best'
|
||||
)
|
||||
emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})")
|
||||
|
||||
emit_progress(f"💾 保存训练检查点 epoch_{epoch+1}")
|
||||
epochs_no_improve = 0
|
||||
else:
|
||||
epochs_no_improve += 1
|
||||
|
||||
if (epoch + 1) % 10 == 0:
|
||||
print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}", flush=True)
|
||||
|
||||
# 提前停止逻辑
|
||||
if epochs_no_improve >= patience:
|
||||
emit_progress(f"连续 {patience} 个epoch测试损失未改善,提前停止训练。")
|
||||
break
|
||||
|
||||
# 计算训练时间
|
||||
training_time = time.time() - start_time
|
||||
@ -527,19 +428,15 @@ def train_product_model_with_mlstm(
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
test_pred = model(testX_tensor.to(DEVICE)).cpu().numpy()
|
||||
|
||||
# 处理输出形状
|
||||
if len(test_pred.shape) == 3:
|
||||
test_pred = test_pred.squeeze(-1)
|
||||
test_true = testY
|
||||
|
||||
# 反归一化预测结果和真实值
|
||||
test_pred_inv = scaler_y.inverse_transform(test_pred.reshape(-1, 1)).flatten()
|
||||
test_true_inv = scaler_y.inverse_transform(testY.reshape(-1, 1)).flatten()
|
||||
test_pred_inv = scaler_y.inverse_transform(test_pred)
|
||||
test_true_inv = scaler_y.inverse_transform(test_true)
|
||||
|
||||
# 计算评估指标
|
||||
metrics = evaluate_model(test_true_inv, test_pred_inv)
|
||||
metrics['training_time'] = training_time
|
||||
metrics['version'] = version
|
||||
|
||||
# 打印评估指标
|
||||
print("\n模型评估指标:")
|
||||
@ -568,12 +465,13 @@ def train_product_model_with_mlstm(
|
||||
'output_dim': output_dim,
|
||||
'hidden_size': hidden_size,
|
||||
'num_heads': num_heads,
|
||||
'dropout': dropout_rate,
|
||||
'dropout_rate': dropout_rate,
|
||||
'num_blocks': num_blocks,
|
||||
'embed_dim': embed_dim,
|
||||
'dense_dim': dense_dim,
|
||||
'sequence_length': LOOK_BACK,
|
||||
'forecast_horizon': FORECAST_HORIZON,
|
||||
'mlstm_layers': 2, # 确保这个参数被保存
|
||||
'sequence_length': sequence_length,
|
||||
'forecast_horizon': forecast_horizon,
|
||||
'model_type': 'mlstm'
|
||||
},
|
||||
'metrics': metrics,
|
||||
@ -590,10 +488,15 @@ def train_product_model_with_mlstm(
|
||||
}
|
||||
}
|
||||
|
||||
# 保存最终模型(使用epoch标识)
|
||||
final_model_path = save_checkpoint(
|
||||
final_model_data, f"final_epoch_{epochs}", product_id, 'mlstm',
|
||||
model_dir, store_id, training_mode, aggregation_method
|
||||
# 保存最终模型,让 model_manager 自动处理版本号
|
||||
final_model_path, final_version = model_manager.save_model(
|
||||
model_data=final_model_data,
|
||||
product_id=model_identifier, # 修正:使用唯一的标识符
|
||||
model_type='mlstm',
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
product_name=product_name
|
||||
)
|
||||
|
||||
# 发送训练完成消息
|
||||
@ -605,9 +508,14 @@ def train_product_model_with_mlstm(
|
||||
'mape': metrics['mape'],
|
||||
'training_time': training_time,
|
||||
'final_epoch': epochs,
|
||||
'model_path': final_model_path
|
||||
'model_path': final_model_path,
|
||||
'version': final_version
|
||||
}
|
||||
|
||||
emit_progress(f"✅ mLSTM模型训练完成!最终epoch: {epochs} 已保存", progress=100, metrics=final_metrics)
|
||||
emit_progress(f"✅ mLSTM模型训练完成!版本 {final_version} 已保存", progress=100, metrics=final_metrics)
|
||||
|
||||
return model, metrics, epochs, final_model_path
|
||||
return model, metrics, epochs, final_model_path
|
||||
|
||||
# --- 将此训练器注册到系统中 ---
|
||||
from models.model_registry import register_trainer
|
||||
register_trainer('mlstm', train_product_model_with_mlstm)
|
@ -20,49 +20,18 @@ from utils.visualization import plot_loss_curve
|
||||
from analysis.metrics import evaluate_model
|
||||
from core.config import DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON
|
||||
from utils.training_progress import progress_manager
|
||||
|
||||
def save_checkpoint(checkpoint_data: dict, epoch_or_label, product_id: str,
|
||||
model_type: str, model_dir: str, store_id=None,
|
||||
training_mode: str = 'product', aggregation_method=None):
|
||||
"""
|
||||
保存训练检查点
|
||||
|
||||
Args:
|
||||
checkpoint_data: 检查点数据
|
||||
epoch_or_label: epoch编号或标签(如'best')
|
||||
product_id: 产品ID
|
||||
model_type: 模型类型
|
||||
model_dir: 模型保存目录
|
||||
store_id: 店铺ID
|
||||
training_mode: 训练模式
|
||||
aggregation_method: 聚合方法
|
||||
"""
|
||||
# 创建检查点目录
|
||||
checkpoint_dir = os.path.join(model_dir, 'checkpoints')
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
# 生成检查点文件名
|
||||
if training_mode == 'store' and store_id:
|
||||
filename = f"{model_type}_store_{store_id}_{product_id}_epoch_{epoch_or_label}.pth"
|
||||
elif training_mode == 'global' and aggregation_method:
|
||||
filename = f"{model_type}_global_{product_id}_{aggregation_method}_epoch_{epoch_or_label}.pth"
|
||||
else:
|
||||
filename = f"{model_type}_product_{product_id}_epoch_{epoch_or_label}.pth"
|
||||
|
||||
checkpoint_path = os.path.join(checkpoint_dir, filename)
|
||||
|
||||
# 保存检查点
|
||||
torch.save(checkpoint_data, checkpoint_path)
|
||||
print(f"[TCN] 检查点已保存: {checkpoint_path}", flush=True)
|
||||
|
||||
return checkpoint_path
|
||||
from utils.model_manager import model_manager
|
||||
|
||||
def train_product_model_with_tcn(
|
||||
product_id,
|
||||
product_id,
|
||||
model_identifier,
|
||||
product_df=None,
|
||||
store_id=None,
|
||||
training_mode='product',
|
||||
aggregation_method='sum',
|
||||
epochs=50,
|
||||
epochs=50,
|
||||
sequence_length=LOOK_BACK,
|
||||
forecast_horizon=FORECAST_HORIZON,
|
||||
model_dir=DEFAULT_MODEL_DIR,
|
||||
version=None,
|
||||
socketio=None,
|
||||
@ -71,21 +40,6 @@ def train_product_model_with_tcn(
|
||||
):
|
||||
"""
|
||||
使用TCN模型训练产品销售预测模型
|
||||
|
||||
参数:
|
||||
product_id: 产品ID
|
||||
epochs: 训练轮次
|
||||
model_dir: 模型保存目录,默认使用配置中的DEFAULT_MODEL_DIR
|
||||
version: 指定版本号,如果为None则自动生成
|
||||
socketio: WebSocket对象,用于实时反馈
|
||||
task_id: 训练任务ID
|
||||
continue_training: 是否继续训练现有模型
|
||||
|
||||
返回:
|
||||
model: 训练好的模型
|
||||
metrics: 模型评估指标
|
||||
version: 实际使用的版本号
|
||||
model_path: 模型文件路径
|
||||
"""
|
||||
|
||||
def emit_progress(message, progress=None, metrics=None):
|
||||
@ -102,64 +56,28 @@ def train_product_model_with_tcn(
|
||||
data['metrics'] = metrics
|
||||
socketio.emit('training_progress', data, namespace='/training')
|
||||
|
||||
# 确定版本号
|
||||
if version is None:
|
||||
from core.config import get_latest_model_version, get_next_model_version
|
||||
if continue_training:
|
||||
version = get_latest_model_version(product_id, 'tcn')
|
||||
if version is None:
|
||||
version = get_next_model_version(product_id, 'tcn')
|
||||
else:
|
||||
version = get_next_model_version(product_id, 'tcn')
|
||||
emit_progress(f"开始训练 TCN 模型")
|
||||
|
||||
emit_progress(f"开始训练 TCN 模型版本 {version}")
|
||||
|
||||
# 根据训练模式加载数据
|
||||
from utils.multi_store_data_utils import load_multi_store_data, get_store_product_sales_data, aggregate_multi_store_data
|
||||
|
||||
try:
|
||||
if training_mode == 'store' and store_id:
|
||||
# 加载特定店铺的数据
|
||||
product_df = get_store_product_sales_data(
|
||||
store_id,
|
||||
product_id,
|
||||
'pharmacy_sales_multi_store.csv'
|
||||
)
|
||||
training_scope = f"店铺 {store_id}"
|
||||
elif training_mode == 'global':
|
||||
# 聚合所有店铺的数据
|
||||
product_df = aggregate_multi_store_data(
|
||||
product_id,
|
||||
aggregation_method=aggregation_method,
|
||||
file_path='pharmacy_sales_multi_store.csv'
|
||||
)
|
||||
training_scope = f"全局聚合({aggregation_method})"
|
||||
else:
|
||||
# 默认:加载所有店铺的产品数据
|
||||
product_df = load_multi_store_data('pharmacy_sales_multi_store.csv', product_id=product_id)
|
||||
training_scope = "所有店铺"
|
||||
except Exception as e:
|
||||
print(f"多店铺数据加载失败: {e}")
|
||||
# 后备方案:尝试原始数据
|
||||
df = pd.read_excel('pharmacy_sales.xlsx')
|
||||
product_df = df[df['product_id'] == product_id].sort_values('date')
|
||||
training_scope = "原始数据"
|
||||
if product_df is None:
|
||||
from utils.multi_store_data_utils import aggregate_multi_store_data
|
||||
product_df = aggregate_multi_store_data(
|
||||
product_id=product_id,
|
||||
aggregation_method=aggregation_method
|
||||
)
|
||||
training_scope = f"全局聚合({aggregation_method})"
|
||||
else:
|
||||
training_scope = "所有店铺"
|
||||
|
||||
if product_df.empty:
|
||||
raise ValueError(f"产品 {product_id} 没有可用的销售数据")
|
||||
|
||||
# 数据量检查
|
||||
min_required_samples = LOOK_BACK + FORECAST_HORIZON
|
||||
min_required_samples = sequence_length + forecast_horizon
|
||||
if len(product_df) < min_required_samples:
|
||||
error_msg = (
|
||||
f"❌ 训练数据不足错误\n"
|
||||
f"当前配置需要: {min_required_samples} 天数据 (LOOK_BACK={LOOK_BACK} + FORECAST_HORIZON={FORECAST_HORIZON})\n"
|
||||
f"当前配置需要: {min_required_samples} 天数据 (LOOK_BACK={sequence_length} + FORECAST_HORIZON={forecast_horizon})\n"
|
||||
f"实际数据量: {len(product_df)} 天\n"
|
||||
f"产品ID: {product_id}, 训练模式: {training_mode}\n"
|
||||
f"建议解决方案:\n"
|
||||
f"1. 生成更多数据: uv run generate_multi_store_data.py\n"
|
||||
f"2. 调整配置参数: 减小 LOOK_BACK 或 FORECAST_HORIZON\n"
|
||||
f"3. 使用全局训练模式聚合更多数据"
|
||||
)
|
||||
print(error_msg)
|
||||
emit_progress(f"训练失败:数据不足 ({len(product_df)}/{min_required_samples} 天)")
|
||||
@ -170,48 +88,39 @@ def train_product_model_with_tcn(
|
||||
|
||||
print(f"使用TCN模型训练产品 '{product_name}' (ID: {product_id}) 的销售预测模型")
|
||||
print(f"训练范围: {training_scope}")
|
||||
print(f"版本: {version}")
|
||||
print(f"使用设备: {DEVICE}")
|
||||
print(f"模型将保存到目录: {model_dir}")
|
||||
|
||||
emit_progress(f"训练产品: {product_name} (ID: {product_id})")
|
||||
|
||||
# 创建特征和目标变量
|
||||
features = ['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||
|
||||
# 预处理数据
|
||||
X = product_df[features].values
|
||||
y = product_df[['sales']].values # 保持为二维数组
|
||||
y = product_df[['sales']].values
|
||||
|
||||
# 设置数据预处理阶段
|
||||
progress_manager.set_stage("data_preprocessing", 0)
|
||||
emit_progress("数据预处理中...")
|
||||
|
||||
# 归一化数据
|
||||
scaler_X = MinMaxScaler(feature_range=(0, 1))
|
||||
scaler_y = MinMaxScaler(feature_range=(0, 1))
|
||||
|
||||
X_scaled = scaler_X.fit_transform(X)
|
||||
y_scaled = scaler_y.fit_transform(y)
|
||||
|
||||
# 划分训练集和测试集(80% 训练,20% 测试)
|
||||
train_size = int(len(X_scaled) * 0.8)
|
||||
X_train, X_test = X_scaled[:train_size], X_scaled[train_size:]
|
||||
y_train, y_test = y_scaled[:train_size], y_scaled[train_size:]
|
||||
|
||||
progress_manager.set_stage("data_preprocessing", 50)
|
||||
|
||||
# 创建时间序列数据
|
||||
trainX, trainY = create_dataset(X_train, y_train, LOOK_BACK, FORECAST_HORIZON)
|
||||
testX, testY = create_dataset(X_test, y_test, LOOK_BACK, FORECAST_HORIZON)
|
||||
trainX, trainY = create_dataset(X_train, y_train, sequence_length, forecast_horizon)
|
||||
testX, testY = create_dataset(X_test, y_test, sequence_length, forecast_horizon)
|
||||
|
||||
# 转换为PyTorch的Tensor
|
||||
trainX_tensor = torch.Tensor(trainX)
|
||||
trainY_tensor = torch.Tensor(trainY)
|
||||
testX_tensor = torch.Tensor(testX)
|
||||
testY_tensor = torch.Tensor(testY)
|
||||
|
||||
# 创建数据加载器
|
||||
train_dataset = PharmacyDataset(trainX_tensor, trainY_tensor)
|
||||
test_dataset = PharmacyDataset(testX_tensor, testY_tensor)
|
||||
|
||||
@ -219,7 +128,6 @@ def train_product_model_with_tcn(
|
||||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
||||
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
|
||||
|
||||
# 更新进度管理器的批次信息
|
||||
total_batches = len(train_loader)
|
||||
total_samples = len(train_dataset)
|
||||
progress_manager.total_batches_per_epoch = total_batches
|
||||
@ -228,9 +136,8 @@ def train_product_model_with_tcn(
|
||||
|
||||
progress_manager.set_stage("data_preprocessing", 100)
|
||||
|
||||
# 初始化TCN模型
|
||||
input_dim = X_train.shape[1]
|
||||
output_dim = FORECAST_HORIZON
|
||||
output_dim = forecast_horizon
|
||||
hidden_size = 64
|
||||
num_layers = 3
|
||||
kernel_size = 3
|
||||
@ -244,21 +151,8 @@ def train_product_model_with_tcn(
|
||||
dropout=dropout_rate
|
||||
)
|
||||
|
||||
# 如果是继续训练,加载现有模型
|
||||
if continue_training and version != 'v1':
|
||||
try:
|
||||
from core.config import get_model_file_path
|
||||
existing_model_path = get_model_file_path(product_id, 'tcn', version)
|
||||
if os.path.exists(existing_model_path):
|
||||
checkpoint = torch.load(existing_model_path, map_location=DEVICE)
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
print(f"加载现有模型: {existing_model_path}")
|
||||
emit_progress(f"加载现有模型版本 {version} 进行继续训练")
|
||||
except Exception as e:
|
||||
print(f"无法加载现有模型,将重新开始训练: {e}")
|
||||
emit_progress("无法加载现有模型,重新开始训练")
|
||||
# TODO: Implement continue_training logic with the new model_manager
|
||||
|
||||
# 将模型移动到设备上
|
||||
model = model.to(DEVICE)
|
||||
|
||||
criterion = nn.MSELoss()
|
||||
@ -266,20 +160,17 @@ def train_product_model_with_tcn(
|
||||
|
||||
emit_progress("开始模型训练...")
|
||||
|
||||
# 训练模型
|
||||
train_losses = []
|
||||
test_losses = []
|
||||
start_time = time.time()
|
||||
|
||||
# 配置检查点保存
|
||||
checkpoint_interval = max(1, epochs // 10) # 每10%进度保存一次,最少每1个epoch
|
||||
checkpoint_interval = max(1, epochs // 10)
|
||||
best_loss = float('inf')
|
||||
|
||||
progress_manager.set_stage("model_training", 0)
|
||||
emit_progress(f"开始训练 - 总epoch: {epochs}, 检查点间隔: {checkpoint_interval}")
|
||||
|
||||
for epoch in range(epochs):
|
||||
# 开始新的轮次
|
||||
progress_manager.start_epoch(epoch)
|
||||
|
||||
model.train()
|
||||
@ -288,43 +179,34 @@ def train_product_model_with_tcn(
|
||||
for batch_idx, (X_batch, y_batch) in enumerate(train_loader):
|
||||
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
|
||||
|
||||
# 确保目标张量有正确的形状 (batch_size, forecast_horizon, 1)
|
||||
if y_batch.dim() == 2:
|
||||
y_batch = y_batch.unsqueeze(-1)
|
||||
|
||||
# 前向传播
|
||||
outputs = model(X_batch)
|
||||
|
||||
# 确保输出和目标形状匹配
|
||||
loss = criterion(outputs, y_batch)
|
||||
|
||||
# 反向传播和优化
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
epoch_loss += loss.item()
|
||||
|
||||
# 更新批次进度(每10个批次更新一次)
|
||||
if batch_idx % 10 == 0 or batch_idx == len(train_loader) - 1:
|
||||
current_lr = optimizer.param_groups[0]['lr']
|
||||
progress_manager.update_batch(batch_idx, loss.item(), current_lr)
|
||||
|
||||
# 计算训练损失
|
||||
train_loss = epoch_loss / len(train_loader)
|
||||
train_losses.append(train_loss)
|
||||
|
||||
# 设置验证阶段
|
||||
progress_manager.set_stage("validation", 0)
|
||||
|
||||
# 在测试集上评估
|
||||
model.eval()
|
||||
test_loss = 0
|
||||
with torch.no_grad():
|
||||
for batch_idx, (X_batch, y_batch) in enumerate(test_loader):
|
||||
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
|
||||
|
||||
# 确保目标张量有正确的形状
|
||||
if y_batch.dim() == 2:
|
||||
y_batch = y_batch.unsqueeze(-1)
|
||||
|
||||
@ -332,7 +214,6 @@ def train_product_model_with_tcn(
|
||||
loss = criterion(outputs, y_batch)
|
||||
test_loss += loss.item()
|
||||
|
||||
# 更新验证进度
|
||||
if batch_idx % 5 == 0 or batch_idx == len(test_loader) - 1:
|
||||
val_progress = (batch_idx / len(test_loader)) * 100
|
||||
progress_manager.set_stage("validation", val_progress)
|
||||
@ -340,10 +221,8 @@ def train_product_model_with_tcn(
|
||||
test_loss = test_loss / len(test_loader)
|
||||
test_losses.append(test_loss)
|
||||
|
||||
# 完成当前轮次
|
||||
progress_manager.finish_epoch(train_loss, test_loss)
|
||||
|
||||
# 发送训练进度(保持与旧系统的兼容性)
|
||||
if (epoch + 1) % 5 == 0 or epoch == epochs - 1:
|
||||
progress = ((epoch + 1) / epochs) * 100
|
||||
current_metrics = {
|
||||
@ -355,7 +234,6 @@ def train_product_model_with_tcn(
|
||||
emit_progress(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}",
|
||||
progress=progress, metrics=current_metrics)
|
||||
|
||||
# 定期保存检查点
|
||||
if (epoch + 1) % checkpoint_interval == 0 or epoch == epochs - 1:
|
||||
checkpoint_data = {
|
||||
'epoch': epoch + 1,
|
||||
@ -372,10 +250,11 @@ def train_product_model_with_tcn(
|
||||
'output_dim': output_dim,
|
||||
'hidden_size': hidden_size,
|
||||
'num_layers': num_layers,
|
||||
'num_channels': [hidden_size] * num_layers,
|
||||
'dropout': dropout_rate,
|
||||
'kernel_size': kernel_size,
|
||||
'sequence_length': LOOK_BACK,
|
||||
'forecast_horizon': FORECAST_HORIZON,
|
||||
'sequence_length': sequence_length,
|
||||
'forecast_horizon': forecast_horizon,
|
||||
'model_type': 'tcn'
|
||||
},
|
||||
'training_info': {
|
||||
@ -388,30 +267,28 @@ def train_product_model_with_tcn(
|
||||
}
|
||||
}
|
||||
|
||||
# 保存检查点
|
||||
save_checkpoint(checkpoint_data, epoch + 1, product_id, 'tcn',
|
||||
model_dir, store_id, training_mode, aggregation_method)
|
||||
|
||||
# 如果是最佳模型,额外保存一份
|
||||
if test_loss < best_loss:
|
||||
best_loss = test_loss
|
||||
save_checkpoint(checkpoint_data, 'best', product_id, 'tcn',
|
||||
model_dir, store_id, training_mode, aggregation_method)
|
||||
model_manager.save_model(
|
||||
model_data=checkpoint_data,
|
||||
product_id=model_identifier, # 修正:使用唯一的标识符
|
||||
model_type='tcn',
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
product_name=product_name,
|
||||
version='best'
|
||||
)
|
||||
emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})")
|
||||
|
||||
emit_progress(f"💾 保存训练检查点 epoch_{epoch+1}")
|
||||
|
||||
if (epoch + 1) % 10 == 0:
|
||||
print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")
|
||||
|
||||
# 计算训练时间
|
||||
training_time = time.time() - start_time
|
||||
|
||||
# 设置模型保存阶段
|
||||
progress_manager.set_stage("model_saving", 0)
|
||||
emit_progress("训练完成,正在保存模型...")
|
||||
|
||||
# 绘制损失曲线并保存到模型目录
|
||||
loss_curve_path = plot_loss_curve(
|
||||
train_losses,
|
||||
test_losses,
|
||||
@ -421,23 +298,17 @@ def train_product_model_with_tcn(
|
||||
)
|
||||
print(f"损失曲线已保存到: {loss_curve_path}")
|
||||
|
||||
# 评估模型
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
# 确保测试数据的形状正确
|
||||
test_pred = model(testX_tensor.to(DEVICE))
|
||||
# 将输出转换为二维数组 [samples, forecast_horizon]
|
||||
test_pred = test_pred.squeeze(-1).cpu().numpy()
|
||||
|
||||
# 反归一化预测结果和真实值
|
||||
test_pred_inv = scaler_y.inverse_transform(test_pred.reshape(-1, 1)).flatten()
|
||||
test_true_inv = scaler_y.inverse_transform(testY.reshape(-1, 1)).flatten()
|
||||
|
||||
# 计算评估指标
|
||||
metrics = evaluate_model(test_true_inv, test_pred_inv)
|
||||
metrics['training_time'] = training_time
|
||||
|
||||
# 打印评估指标
|
||||
print("\n模型评估指标:")
|
||||
print(f"MSE: {metrics['mse']:.4f}")
|
||||
print(f"RMSE: {metrics['rmse']:.4f}")
|
||||
@ -446,9 +317,8 @@ def train_product_model_with_tcn(
|
||||
print(f"MAPE: {metrics['mape']:.2f}%")
|
||||
print(f"训练时间: {training_time:.2f}秒")
|
||||
|
||||
# 保存最终训练完成的模型(基于最终epoch)
|
||||
final_model_data = {
|
||||
'epoch': epochs, # 最终epoch
|
||||
'epoch': epochs,
|
||||
'model_state_dict': model.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'train_loss': train_losses[-1],
|
||||
@ -462,10 +332,11 @@ def train_product_model_with_tcn(
|
||||
'output_dim': output_dim,
|
||||
'hidden_size': hidden_size,
|
||||
'num_layers': num_layers,
|
||||
'num_channels': [hidden_size] * num_layers,
|
||||
'dropout': dropout_rate,
|
||||
'kernel_size': kernel_size,
|
||||
'sequence_length': LOOK_BACK,
|
||||
'forecast_horizon': FORECAST_HORIZON,
|
||||
'sequence_length': sequence_length,
|
||||
'forecast_horizon': forecast_horizon,
|
||||
'model_type': 'tcn'
|
||||
},
|
||||
'metrics': metrics,
|
||||
@ -483,10 +354,14 @@ def train_product_model_with_tcn(
|
||||
|
||||
progress_manager.set_stage("model_saving", 50)
|
||||
|
||||
# 保存最终模型(使用epoch标识)
|
||||
final_model_path = save_checkpoint(
|
||||
final_model_data, f"final_epoch_{epochs}", product_id, 'tcn',
|
||||
model_dir, store_id, training_mode, aggregation_method
|
||||
final_model_path, final_version = model_manager.save_model(
|
||||
model_data=final_model_data,
|
||||
product_id=model_identifier, # 修正:使用唯一的标识符
|
||||
model_type='tcn',
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
product_name=product_name
|
||||
)
|
||||
|
||||
progress_manager.set_stage("model_saving", 100)
|
||||
@ -498,9 +373,14 @@ def train_product_model_with_tcn(
|
||||
'r2': metrics['r2'],
|
||||
'mape': metrics['mape'],
|
||||
'training_time': training_time,
|
||||
'final_epoch': epochs
|
||||
'final_epoch': epochs,
|
||||
'version': final_version
|
||||
}
|
||||
|
||||
emit_progress(f"模型训练完成!最终epoch: {epochs}", progress=100, metrics=final_metrics)
|
||||
emit_progress(f"模型训练完成!版本 {final_version} 已保存", progress=100, metrics=final_metrics)
|
||||
|
||||
return model, metrics, epochs, final_model_path
|
||||
return model, metrics, epochs, final_model_path
|
||||
|
||||
# --- 将此训练器注册到系统中 ---
|
||||
from models.model_registry import register_trainer
|
||||
register_trainer('tcn', train_product_model_with_tcn)
|
@ -21,79 +21,34 @@ from utils.multi_store_data_utils import get_store_product_sales_data, aggregate
|
||||
from utils.visualization import plot_loss_curve
|
||||
from analysis.metrics import evaluate_model
|
||||
from core.config import (
|
||||
DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON,
|
||||
get_next_model_version, get_model_file_path, get_latest_model_version
|
||||
DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON
|
||||
)
|
||||
from utils.training_progress import progress_manager
|
||||
from utils.model_manager import model_manager
|
||||
|
||||
def save_checkpoint(checkpoint_data: dict, epoch_or_label, product_id: str,
|
||||
model_type: str, model_dir: str, store_id=None,
|
||||
training_mode: str = 'product', aggregation_method=None):
|
||||
"""
|
||||
保存训练检查点
|
||||
|
||||
Args:
|
||||
checkpoint_data: 检查点数据
|
||||
epoch_or_label: epoch编号或标签(如'best')
|
||||
product_id: 产品ID
|
||||
model_type: 模型类型
|
||||
model_dir: 模型保存目录
|
||||
store_id: 店铺ID
|
||||
training_mode: 训练模式
|
||||
aggregation_method: 聚合方法
|
||||
"""
|
||||
# 创建检查点目录
|
||||
checkpoint_dir = os.path.join(model_dir, 'checkpoints')
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
# 生成检查点文件名
|
||||
if training_mode == 'store' and store_id:
|
||||
filename = f"{model_type}_store_{store_id}_{product_id}_epoch_{epoch_or_label}.pth"
|
||||
elif training_mode == 'global' and aggregation_method:
|
||||
filename = f"{model_type}_global_{product_id}_{aggregation_method}_epoch_{epoch_or_label}.pth"
|
||||
else:
|
||||
filename = f"{model_type}_product_{product_id}_epoch_{epoch_or_label}.pth"
|
||||
|
||||
checkpoint_path = os.path.join(checkpoint_dir, filename)
|
||||
|
||||
# 保存检查点
|
||||
torch.save(checkpoint_data, checkpoint_path)
|
||||
print(f"[Transformer] 检查点已保存: {checkpoint_path}", flush=True)
|
||||
|
||||
return checkpoint_path
|
||||
|
||||
def train_product_model_with_transformer(
|
||||
product_id,
|
||||
store_id=None,
|
||||
training_mode='product',
|
||||
aggregation_method='sum',
|
||||
epochs=50,
|
||||
product_id,
|
||||
model_identifier,
|
||||
product_df=None,
|
||||
store_id=None,
|
||||
training_mode='product',
|
||||
aggregation_method='sum',
|
||||
epochs=50,
|
||||
sequence_length=LOOK_BACK,
|
||||
forecast_horizon=FORECAST_HORIZON,
|
||||
model_dir=DEFAULT_MODEL_DIR,
|
||||
version=None,
|
||||
socketio=None,
|
||||
task_id=None,
|
||||
continue_training=False
|
||||
continue_training=False,
|
||||
patience=10,
|
||||
learning_rate=0.001,
|
||||
clip_norm=1.0
|
||||
):
|
||||
"""
|
||||
使用Transformer模型训练产品销售预测模型
|
||||
|
||||
参数:
|
||||
product_id: 产品ID
|
||||
epochs: 训练轮次
|
||||
model_dir: 模型保存目录,默认使用配置中的DEFAULT_MODEL_DIR
|
||||
version: 指定版本号,如果为None则自动生成
|
||||
socketio: WebSocket对象,用于实时反馈
|
||||
task_id: 训练任务ID
|
||||
continue_training: 是否继续训练现有模型
|
||||
|
||||
返回:
|
||||
model: 训练好的模型
|
||||
metrics: 模型评估指标
|
||||
version: 实际使用的版本号
|
||||
"""
|
||||
|
||||
# WebSocket进度反馈函数
|
||||
def emit_progress(message, progress=None, metrics=None):
|
||||
"""发送训练进度到前端"""
|
||||
if socketio and task_id:
|
||||
@ -108,18 +63,15 @@ def train_product_model_with_transformer(
|
||||
data['metrics'] = metrics
|
||||
socketio.emit('training_progress', data, namespace='/training')
|
||||
print(f"[{time.strftime('%H:%M:%S')}] {message}", flush=True)
|
||||
# 强制刷新输出缓冲区
|
||||
import sys
|
||||
sys.stdout.flush()
|
||||
sys.stderr.flush()
|
||||
|
||||
emit_progress("开始Transformer模型训练...")
|
||||
|
||||
# 获取训练进度管理器实例
|
||||
try:
|
||||
from utils.training_progress import progress_manager
|
||||
except ImportError:
|
||||
# 如果无法导入,创建一个空的管理器以避免错误
|
||||
class DummyProgressManager:
|
||||
def set_stage(self, *args, **kwargs): pass
|
||||
def start_training(self, *args, **kwargs): pass
|
||||
@ -129,52 +81,26 @@ def train_product_model_with_transformer(
|
||||
def finish_training(self, *args, **kwargs): pass
|
||||
progress_manager = DummyProgressManager()
|
||||
|
||||
# 根据训练模式加载数据
|
||||
from utils.multi_store_data_utils import load_multi_store_data
|
||||
|
||||
try:
|
||||
if training_mode == 'store' and store_id:
|
||||
# 加载特定店铺的数据
|
||||
product_df = get_store_product_sales_data(
|
||||
store_id,
|
||||
product_id,
|
||||
'pharmacy_sales_multi_store.csv'
|
||||
)
|
||||
training_scope = f"店铺 {store_id}"
|
||||
elif training_mode == 'global':
|
||||
# 聚合所有店铺的数据
|
||||
product_df = aggregate_multi_store_data(
|
||||
product_id,
|
||||
aggregation_method=aggregation_method,
|
||||
file_path='pharmacy_sales_multi_store.csv'
|
||||
)
|
||||
training_scope = f"全局聚合({aggregation_method})"
|
||||
else:
|
||||
# 默认:加载所有店铺的产品数据
|
||||
product_df = load_multi_store_data('pharmacy_sales_multi_store.csv', product_id=product_id)
|
||||
training_scope = "所有店铺"
|
||||
except Exception as e:
|
||||
print(f"多店铺数据加载失败: {e}")
|
||||
# 后备方案:尝试原始数据
|
||||
df = pd.read_excel('pharmacy_sales.xlsx')
|
||||
product_df = df[df['product_id'] == product_id].sort_values('date')
|
||||
training_scope = "原始数据"
|
||||
if product_df is None:
|
||||
from utils.multi_store_data_utils import aggregate_multi_store_data
|
||||
product_df = aggregate_multi_store_data(
|
||||
product_id=product_id,
|
||||
aggregation_method=aggregation_method
|
||||
)
|
||||
training_scope = f"全局聚合({aggregation_method})"
|
||||
else:
|
||||
training_scope = "所有店铺"
|
||||
|
||||
if product_df.empty:
|
||||
raise ValueError(f"产品 {product_id} 没有可用的销售数据")
|
||||
|
||||
# 数据量检查
|
||||
min_required_samples = LOOK_BACK + FORECAST_HORIZON
|
||||
min_required_samples = sequence_length + forecast_horizon
|
||||
if len(product_df) < min_required_samples:
|
||||
error_msg = (
|
||||
f"❌ 训练数据不足错误\n"
|
||||
f"当前配置需要: {min_required_samples} 天数据 (LOOK_BACK={LOOK_BACK} + FORECAST_HORIZON={FORECAST_HORIZON})\n"
|
||||
f"当前配置需要: {min_required_samples} 天数据 (LOOK_BACK={sequence_length} + FORECAST_HORIZON={forecast_horizon})\n"
|
||||
f"实际数据量: {len(product_df)} 天\n"
|
||||
f"产品ID: {product_id}, 训练模式: {training_mode}\n"
|
||||
f"建议解决方案:\n"
|
||||
f"1. 生成更多数据: uv run generate_multi_store_data.py\n"
|
||||
f"2. 调整配置参数: 减小 LOOK_BACK 或 FORECAST_HORIZON\n"
|
||||
f"3. 使用全局训练模式聚合更多数据"
|
||||
)
|
||||
print(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
@ -186,18 +112,14 @@ def train_product_model_with_transformer(
|
||||
print(f"[Device] 使用设备: {DEVICE}", flush=True)
|
||||
print(f"[Model] 模型将保存到目录: {model_dir}", flush=True)
|
||||
|
||||
# 创建特征和目标变量
|
||||
features = ['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||
|
||||
# 设置数据预处理阶段
|
||||
progress_manager.set_stage("data_preprocessing", 0)
|
||||
emit_progress("数据预处理中...")
|
||||
|
||||
# 预处理数据
|
||||
X = product_df[features].values
|
||||
y = product_df[['sales']].values # 保持为二维数组
|
||||
y = product_df[['sales']].values
|
||||
|
||||
# 归一化数据
|
||||
scaler_X = MinMaxScaler(feature_range=(0, 1))
|
||||
scaler_y = MinMaxScaler(feature_range=(0, 1))
|
||||
|
||||
@ -206,24 +128,20 @@ def train_product_model_with_transformer(
|
||||
|
||||
progress_manager.set_stage("data_preprocessing", 40)
|
||||
|
||||
# 划分训练集和测试集(80% 训练,20% 测试)
|
||||
train_size = int(len(X_scaled) * 0.8)
|
||||
X_train, X_test = X_scaled[:train_size], X_scaled[train_size:]
|
||||
y_train, y_test = y_scaled[:train_size], y_scaled[train_size:]
|
||||
|
||||
# 创建时间序列数据
|
||||
trainX, trainY = create_dataset(X_train, y_train, LOOK_BACK, FORECAST_HORIZON)
|
||||
testX, testY = create_dataset(X_test, y_test, LOOK_BACK, FORECAST_HORIZON)
|
||||
trainX, trainY = create_dataset(X_train, y_train, sequence_length, forecast_horizon)
|
||||
testX, testY = create_dataset(X_test, y_test, sequence_length, forecast_horizon)
|
||||
|
||||
progress_manager.set_stage("data_preprocessing", 70)
|
||||
|
||||
# 转换为PyTorch的Tensor
|
||||
trainX_tensor = torch.Tensor(trainX)
|
||||
trainY_tensor = torch.Tensor(trainY)
|
||||
testX_tensor = torch.Tensor(testX)
|
||||
testY_tensor = torch.Tensor(testY)
|
||||
|
||||
# 创建数据加载器
|
||||
train_dataset = PharmacyDataset(trainX_tensor, trainY_tensor)
|
||||
test_dataset = PharmacyDataset(testX_tensor, testY_tensor)
|
||||
|
||||
@ -231,7 +149,6 @@ def train_product_model_with_transformer(
|
||||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
||||
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
|
||||
|
||||
# 更新进度管理器的批次信息
|
||||
total_batches = len(train_loader)
|
||||
total_samples = len(train_dataset)
|
||||
progress_manager.total_batches_per_epoch = total_batches
|
||||
@ -241,9 +158,8 @@ def train_product_model_with_transformer(
|
||||
progress_manager.set_stage("data_preprocessing", 100)
|
||||
emit_progress("数据预处理完成,开始模型训练...")
|
||||
|
||||
# 初始化Transformer模型
|
||||
input_dim = X_train.shape[1]
|
||||
output_dim = FORECAST_HORIZON
|
||||
output_dim = forecast_horizon
|
||||
hidden_size = 64
|
||||
num_heads = 4
|
||||
dropout_rate = 0.1
|
||||
@ -257,30 +173,28 @@ def train_product_model_with_transformer(
|
||||
dim_feedforward=hidden_size * 2,
|
||||
dropout=dropout_rate,
|
||||
output_sequence_length=output_dim,
|
||||
seq_length=LOOK_BACK,
|
||||
seq_length=sequence_length,
|
||||
batch_size=batch_size
|
||||
)
|
||||
|
||||
# 将模型移动到设备上
|
||||
model = model.to(DEVICE)
|
||||
|
||||
criterion = nn.MSELoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=patience // 2, factor=0.5)
|
||||
|
||||
# 训练模型
|
||||
train_losses = []
|
||||
test_losses = []
|
||||
start_time = time.time()
|
||||
|
||||
# 配置检查点保存
|
||||
checkpoint_interval = max(1, epochs // 10) # 每10%进度保存一次,最少每1个epoch
|
||||
checkpoint_interval = max(1, epochs // 10)
|
||||
best_loss = float('inf')
|
||||
epochs_no_improve = 0
|
||||
|
||||
progress_manager.set_stage("model_training", 0)
|
||||
emit_progress(f"开始训练 - 总epoch: {epochs}, 检查点间隔: {checkpoint_interval}")
|
||||
emit_progress(f"开始训练 - 总epoch: {epochs}, 检查点间隔: {checkpoint_interval}, 耐心值: {patience}")
|
||||
|
||||
for epoch in range(epochs):
|
||||
# 开始新的轮次
|
||||
progress_manager.start_epoch(epoch)
|
||||
|
||||
model.train()
|
||||
@ -289,49 +203,36 @@ def train_product_model_with_transformer(
|
||||
for batch_idx, (X_batch, y_batch) in enumerate(train_loader):
|
||||
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
|
||||
|
||||
# 确保目标张量有正确的形状
|
||||
if y_batch.dim() == 2:
|
||||
y_batch = y_batch.unsqueeze(-1)
|
||||
|
||||
# 前向传播
|
||||
outputs = model(X_batch)
|
||||
loss = criterion(outputs, y_batch)
|
||||
|
||||
# 反向传播和优化
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
if clip_norm:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_norm)
|
||||
optimizer.step()
|
||||
|
||||
epoch_loss += loss.item()
|
||||
|
||||
# 更新批次进度
|
||||
if batch_idx % 5 == 0 or batch_idx == len(train_loader) - 1:
|
||||
current_lr = optimizer.param_groups[0]['lr']
|
||||
progress_manager.update_batch(batch_idx, loss.item(), current_lr)
|
||||
|
||||
# 计算训练损失
|
||||
train_loss = epoch_loss / len(train_loader)
|
||||
train_losses.append(train_loss)
|
||||
|
||||
# 设置验证阶段
|
||||
progress_manager.set_stage("validation", 0)
|
||||
|
||||
# 在测试集上评估
|
||||
model.eval()
|
||||
test_loss = 0
|
||||
with torch.no_grad():
|
||||
for batch_idx, (X_batch, y_batch) in enumerate(test_loader):
|
||||
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
|
||||
|
||||
# 确保目标张量有正确的形状
|
||||
if y_batch.dim() == 2:
|
||||
y_batch = y_batch.unsqueeze(-1)
|
||||
|
||||
outputs = model(X_batch)
|
||||
loss = criterion(outputs, y_batch)
|
||||
test_loss += loss.item()
|
||||
|
||||
# 更新验证进度
|
||||
if batch_idx % 3 == 0 or batch_idx == len(test_loader) - 1:
|
||||
val_progress = (batch_idx / len(test_loader)) * 100
|
||||
progress_manager.set_stage("validation", val_progress)
|
||||
@ -339,10 +240,10 @@ def train_product_model_with_transformer(
|
||||
test_loss = test_loss / len(test_loader)
|
||||
test_losses.append(test_loss)
|
||||
|
||||
# 完成当前轮次
|
||||
scheduler.step(test_loss)
|
||||
|
||||
progress_manager.finish_epoch(train_loss, test_loss)
|
||||
|
||||
# 发送训练进度
|
||||
if (epoch + 1) % 5 == 0 or epoch == epochs - 1:
|
||||
progress = ((epoch + 1) / epochs) * 100
|
||||
current_metrics = {
|
||||
@ -354,7 +255,6 @@ def train_product_model_with_transformer(
|
||||
emit_progress(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}",
|
||||
progress=progress, metrics=current_metrics)
|
||||
|
||||
# 定期保存检查点
|
||||
if (epoch + 1) % checkpoint_interval == 0 or epoch == epochs - 1:
|
||||
checkpoint_data = {
|
||||
'epoch': epoch + 1,
|
||||
@ -373,8 +273,8 @@ def train_product_model_with_transformer(
|
||||
'num_heads': num_heads,
|
||||
'dropout': dropout_rate,
|
||||
'num_layers': num_layers,
|
||||
'sequence_length': LOOK_BACK,
|
||||
'forecast_horizon': FORECAST_HORIZON,
|
||||
'sequence_length': sequence_length,
|
||||
'forecast_horizon': forecast_horizon,
|
||||
'model_type': 'transformer'
|
||||
},
|
||||
'training_info': {
|
||||
@ -387,30 +287,35 @@ def train_product_model_with_transformer(
|
||||
}
|
||||
}
|
||||
|
||||
# 保存检查点
|
||||
save_checkpoint(checkpoint_data, epoch + 1, product_id, 'transformer',
|
||||
model_dir, store_id, training_mode, aggregation_method)
|
||||
|
||||
# 如果是最佳模型,额外保存一份
|
||||
if test_loss < best_loss:
|
||||
best_loss = test_loss
|
||||
save_checkpoint(checkpoint_data, 'best', product_id, 'transformer',
|
||||
model_dir, store_id, training_mode, aggregation_method)
|
||||
model_manager.save_model(
|
||||
model_data=checkpoint_data,
|
||||
product_id=model_identifier, # 修正:使用唯一的标识符
|
||||
model_type='transformer',
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
product_name=product_name,
|
||||
version='best'
|
||||
)
|
||||
emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})")
|
||||
|
||||
emit_progress(f"💾 保存训练检查点 epoch_{epoch+1}")
|
||||
epochs_no_improve = 0
|
||||
else:
|
||||
epochs_no_improve += 1
|
||||
|
||||
if (epoch + 1) % 10 == 0:
|
||||
print(f"📊 Epoch {epoch+1}/{epochs}, 训练损失: {train_loss:.4f}, 测试损失: {test_loss:.4f}", flush=True)
|
||||
|
||||
if epochs_no_improve >= patience:
|
||||
emit_progress(f"连续 {patience} 个epoch测试损失未改善,提前停止训练。")
|
||||
break
|
||||
|
||||
# 计算训练时间
|
||||
training_time = time.time() - start_time
|
||||
|
||||
# 设置模型保存阶段
|
||||
progress_manager.set_stage("model_saving", 0)
|
||||
emit_progress("训练完成,正在保存模型...")
|
||||
|
||||
# 绘制损失曲线并保存到模型目录
|
||||
loss_curve_path = plot_loss_curve(
|
||||
train_losses,
|
||||
test_losses,
|
||||
@ -420,24 +325,17 @@ def train_product_model_with_transformer(
|
||||
)
|
||||
print(f"📈 损失曲线已保存到: {loss_curve_path}", flush=True)
|
||||
|
||||
# 评估模型
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
test_pred = model(testX_tensor.to(DEVICE)).cpu().numpy()
|
||||
|
||||
# 处理输出形状
|
||||
if len(test_pred.shape) == 3:
|
||||
test_pred = test_pred.squeeze(-1)
|
||||
test_true = testY
|
||||
|
||||
# 反归一化预测结果和真实值
|
||||
test_pred_inv = scaler_y.inverse_transform(test_pred.reshape(-1, 1)).flatten()
|
||||
test_true_inv = scaler_y.inverse_transform(testY.reshape(-1, 1)).flatten()
|
||||
test_pred_inv = scaler_y.inverse_transform(test_pred)
|
||||
test_true_inv = scaler_y.inverse_transform(test_true)
|
||||
|
||||
# 计算评估指标
|
||||
metrics = evaluate_model(test_true_inv, test_pred_inv)
|
||||
metrics['training_time'] = training_time
|
||||
|
||||
# 打印评估指标
|
||||
print(f"\n📊 模型评估指标:", flush=True)
|
||||
print(f" MSE: {metrics['mse']:.4f}", flush=True)
|
||||
print(f" RMSE: {metrics['rmse']:.4f}", flush=True)
|
||||
@ -446,9 +344,8 @@ def train_product_model_with_transformer(
|
||||
print(f" MAPE: {metrics['mape']:.2f}%", flush=True)
|
||||
print(f" ⏱️ 训练时间: {training_time:.2f}秒", flush=True)
|
||||
|
||||
# 保存最终训练完成的模型(基于最终epoch)
|
||||
final_model_data = {
|
||||
'epoch': epochs, # 最终epoch
|
||||
'epoch': epochs,
|
||||
'model_state_dict': model.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'train_loss': train_losses[-1],
|
||||
@ -464,8 +361,8 @@ def train_product_model_with_transformer(
|
||||
'num_heads': num_heads,
|
||||
'dropout': dropout_rate,
|
||||
'num_layers': num_layers,
|
||||
'sequence_length': LOOK_BACK,
|
||||
'forecast_horizon': FORECAST_HORIZON,
|
||||
'sequence_length': sequence_length,
|
||||
'forecast_horizon': forecast_horizon,
|
||||
'model_type': 'transformer'
|
||||
},
|
||||
'metrics': metrics,
|
||||
@ -483,10 +380,14 @@ def train_product_model_with_transformer(
|
||||
|
||||
progress_manager.set_stage("model_saving", 50)
|
||||
|
||||
# 保存最终模型(使用epoch标识)
|
||||
final_model_path = save_checkpoint(
|
||||
final_model_data, f"final_epoch_{epochs}", product_id, 'transformer',
|
||||
model_dir, store_id, training_mode, aggregation_method
|
||||
final_model_path, final_version = model_manager.save_model(
|
||||
model_data=final_model_data,
|
||||
product_id=model_identifier, # 修正:使用唯一的标识符
|
||||
model_type='transformer',
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
product_name=product_name
|
||||
)
|
||||
|
||||
progress_manager.set_stage("model_saving", 100)
|
||||
@ -494,7 +395,6 @@ def train_product_model_with_transformer(
|
||||
|
||||
print(f"💾 模型已保存到 {final_model_path}", flush=True)
|
||||
|
||||
# 准备最终返回的指标
|
||||
final_metrics = {
|
||||
'mse': metrics['mse'],
|
||||
'rmse': metrics['rmse'],
|
||||
@ -502,7 +402,12 @@ def train_product_model_with_transformer(
|
||||
'r2': metrics['r2'],
|
||||
'mape': metrics['mape'],
|
||||
'training_time': training_time,
|
||||
'final_epoch': epochs
|
||||
'final_epoch': epochs,
|
||||
'version': final_version
|
||||
}
|
||||
|
||||
return model, final_metrics, epochs
|
||||
return model, final_metrics, epochs
|
||||
|
||||
# --- 将此训练器注册到系统中 ---
|
||||
from models.model_registry import register_trainer
|
||||
register_trainer('transformer', train_product_model_with_transformer)
|
167
server/trainers/xgboost_trainer.py
Normal file
167
server/trainers/xgboost_trainer.py
Normal file
@ -0,0 +1,167 @@
|
||||
"""
|
||||
药店销售预测系统 - XGBoost 模型训练器 (插件式)
|
||||
"""
|
||||
|
||||
import time
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import xgboost as xgb
|
||||
from sklearn.preprocessing import MinMaxScaler
|
||||
from xgboost.callback import EarlyStopping
|
||||
|
||||
# 导入核心工具
|
||||
from utils.data_utils import create_dataset
|
||||
from analysis.metrics import evaluate_model
|
||||
from utils.model_manager import model_manager
|
||||
from models.model_registry import register_trainer
|
||||
from utils.visualization import plot_loss_curve # 导入绘图函数
|
||||
|
||||
def train_product_model_with_xgboost(product_id, model_identifier, product_df, store_id, training_mode, aggregation_method, epochs, sequence_length, forecast_horizon, model_dir, **kwargs):
|
||||
"""
|
||||
使用 XGBoost 模型训练产品销售预测模型。
|
||||
此函数签名与其他训练器保持一致,以兼容注册表调用。
|
||||
"""
|
||||
print(f"🚀 XGBoost训练器启动: model_identifier='{model_identifier}'")
|
||||
|
||||
# --- 1. 数据准备和验证 ---
|
||||
if product_df.empty:
|
||||
raise ValueError(f"产品 {product_id} 没有可用的销售数据")
|
||||
|
||||
min_required_samples = sequence_length + forecast_horizon
|
||||
if len(product_df) < min_required_samples:
|
||||
error_msg = (f"数据不足: 需要 {min_required_samples} 条, 实际 {len(product_df)} 条。")
|
||||
raise ValueError(error_msg)
|
||||
|
||||
product_df = product_df.sort_values('date')
|
||||
product_name = product_df['product_name'].iloc[0] if 'product_name' in product_df.columns else model_identifier
|
||||
|
||||
# --- 2. 数据预处理和适配 ---
|
||||
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||
|
||||
X = product_df[features].values
|
||||
y = product_df[['sales']].values
|
||||
|
||||
scaler_X = MinMaxScaler(feature_range=(0, 1))
|
||||
scaler_y = MinMaxScaler(feature_range=(0, 1))
|
||||
|
||||
X_scaled = scaler_X.fit_transform(X)
|
||||
y_scaled = scaler_y.fit_transform(y)
|
||||
|
||||
train_size = int(len(X_scaled) * 0.8)
|
||||
X_train_raw, X_test_raw = X_scaled[:train_size], X_scaled[train_size:]
|
||||
y_train_raw, y_test_raw = y_scaled[:train_size], y_scaled[train_size:]
|
||||
|
||||
trainX, trainY = create_dataset(X_train_raw, y_train_raw, sequence_length, forecast_horizon)
|
||||
testX, testY = create_dataset(X_test_raw, y_test_raw, sequence_length, forecast_horizon)
|
||||
|
||||
# **关键适配步骤**: XGBoost 需要二维输入
|
||||
trainX = trainX.reshape(trainX.shape[0], -1)
|
||||
testX = testX.reshape(testX.shape[0], -1)
|
||||
|
||||
# **关键适配**: 转换为 XGBoost 核心 DMatrix 格式,以使用稳定的 xgb.train API
|
||||
dtrain = xgb.DMatrix(trainX, label=trainY)
|
||||
dtest = xgb.DMatrix(testX, label=testY)
|
||||
|
||||
# --- 3. 模型训练 (使用核心 xgb.train API) ---
|
||||
xgb_params = {
|
||||
'learning_rate': kwargs.get('learning_rate', 0.08),
|
||||
'subsample': kwargs.get('subsample', 0.75),
|
||||
'colsample_bytree': kwargs.get('colsample_bytree', 1),
|
||||
'max_depth': kwargs.get('max_depth', 7),
|
||||
'gamma': kwargs.get('gamma', 0),
|
||||
'objective': 'reg:squarederror',
|
||||
'eval_metric': 'rmse', # eval_metric 在这里是原生支持的
|
||||
'n_jobs': -1
|
||||
}
|
||||
n_estimators = kwargs.get('n_estimators', 500)
|
||||
|
||||
print("开始训练XGBoost模型 (使用核心xgb.train API)...")
|
||||
start_time = time.time()
|
||||
|
||||
evals_result = {}
|
||||
model = xgb.train(
|
||||
params=xgb_params,
|
||||
dtrain=dtrain,
|
||||
num_boost_round=n_estimators,
|
||||
evals=[(dtrain, 'train'), (dtest, 'test')],
|
||||
early_stopping_rounds=50, # early_stopping_rounds 在这里是原生支持的
|
||||
evals_result=evals_result,
|
||||
verbose_eval=False
|
||||
)
|
||||
|
||||
training_time = time.time() - start_time
|
||||
print(f"XGBoost模型训练完成,耗时: {training_time:.2f}秒")
|
||||
|
||||
# --- 4. 模型评估与可视化 ---
|
||||
# 使用 model.best_iteration 获取最佳轮次的预测结果
|
||||
test_pred = model.predict(dtest, iteration_range=(0, model.best_iteration))
|
||||
|
||||
test_pred_inv = scaler_y.inverse_transform(test_pred.reshape(-1, forecast_horizon))
|
||||
test_true_inv = scaler_y.inverse_transform(testY.reshape(-1, forecast_horizon))
|
||||
|
||||
metrics = evaluate_model(test_true_inv.flatten(), test_pred_inv.flatten())
|
||||
metrics['training_time'] = training_time
|
||||
metrics['best_iteration'] = model.best_iteration
|
||||
|
||||
print("\n模型评估指标:")
|
||||
print(f"MSE: {metrics['mse']:.4f}, RMSE: {metrics['rmse']:.4f}, MAE: {metrics['mae']:.4f}, R²: {metrics['r2']:.4f}, MAPE: {metrics['mape']:.2f}%")
|
||||
|
||||
# 提取损失并绘制曲线
|
||||
train_losses = evals_result['train']['rmse']
|
||||
test_losses = evals_result['test']['rmse']
|
||||
loss_curve_path = plot_loss_curve(
|
||||
train_losses,
|
||||
test_losses,
|
||||
product_name,
|
||||
'xgboost',
|
||||
model_dir=model_dir
|
||||
)
|
||||
print(f"📈 损失曲线已保存到: {loss_curve_path}")
|
||||
|
||||
# --- 5. 模型保存 (借道 utils.model_manager) ---
|
||||
model_data = {
|
||||
'model_state_dict': model, # 直接保存模型对象
|
||||
'scaler_X': scaler_X,
|
||||
'scaler_y': scaler_y,
|
||||
'config': {
|
||||
'sequence_length': sequence_length,
|
||||
'forecast_horizon': forecast_horizon,
|
||||
'model_type': 'xgboost',
|
||||
'features': features,
|
||||
'xgb_params': xgb_params
|
||||
},
|
||||
'metrics': metrics,
|
||||
'loss_history': evals_result,
|
||||
'loss_curve_path': loss_curve_path # 添加损失图路径
|
||||
}
|
||||
|
||||
# 保存最终版本模型
|
||||
final_model_path, final_version = model_manager.save_model(
|
||||
model_data=model_data,
|
||||
product_id=product_id,
|
||||
model_type='xgboost',
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
product_name=product_name
|
||||
)
|
||||
print(f"✅ XGBoost最终模型已通过统一管理器保存,版本: {final_version}, 路径: {final_model_path}")
|
||||
|
||||
# 保存最佳版本模型
|
||||
best_model_path, best_version = model_manager.save_model(
|
||||
model_data=model_data,
|
||||
product_id=product_id,
|
||||
model_type='xgboost',
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
product_name=product_name,
|
||||
version='best' # 明确指定版本为 'best'
|
||||
)
|
||||
print(f"✅ XGBoost最佳模型已通过统一管理器保存,版本: {best_version}, 路径: {best_model_path}")
|
||||
|
||||
# 返回值遵循统一格式
|
||||
return model, metrics, final_version, final_model_path
|
||||
|
||||
# --- 将此训练器注册到系统中 ---
|
||||
register_trainer('xgboost', train_product_model_with_xgboost)
|
@ -41,7 +41,7 @@ def create_dataset(datasetX, datasetY, look_back=1, predict_steps=1):
|
||||
x = datasetX[i:(i + look_back)]
|
||||
dataX.append(x)
|
||||
y = datasetY[(i + look_back):(i + look_back + predict_steps)]
|
||||
dataY.append(y)
|
||||
dataY.append(y.flatten())
|
||||
return np.array(dataX), np.array(dataY)
|
||||
|
||||
def prepare_data(product_data, sequence_length=30, forecast_horizon=7):
|
||||
|
@ -8,10 +8,12 @@ import json
|
||||
import torch
|
||||
import glob
|
||||
from datetime import datetime
|
||||
import re
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
from core.config import DEFAULT_MODEL_DIR
|
||||
|
||||
|
||||
|
||||
class ModelManager:
|
||||
"""统一模型管理器"""
|
||||
|
||||
@ -24,56 +26,91 @@ class ModelManager:
|
||||
if not os.path.exists(self.model_dir):
|
||||
os.makedirs(self.model_dir)
|
||||
|
||||
def generate_model_filename(self,
|
||||
product_id: str,
|
||||
model_type: str,
|
||||
def _get_next_version(self, model_type: str, product_id: Optional[str] = None, store_id: Optional[str] = None, training_mode: str = 'product', aggregation_method: Optional[str] = None) -> int:
|
||||
"""获取下一个模型版本号 (纯数字)"""
|
||||
search_pattern = self.generate_model_filename(
|
||||
model_type=model_type,
|
||||
version='v*',
|
||||
product_id=product_id,
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method
|
||||
)
|
||||
|
||||
full_search_path = os.path.join(self.model_dir, search_pattern)
|
||||
existing_files = glob.glob(full_search_path)
|
||||
|
||||
max_version = 0
|
||||
for f in existing_files:
|
||||
match = re.search(r'_v(\d+)\.pth$', os.path.basename(f))
|
||||
if match:
|
||||
max_version = max(max_version, int(match.group(1)))
|
||||
|
||||
return max_version + 1
|
||||
|
||||
def generate_model_filename(self,
|
||||
model_type: str,
|
||||
version: str,
|
||||
store_id: Optional[str] = None,
|
||||
training_mode: str = 'product',
|
||||
product_id: Optional[str] = None,
|
||||
store_id: Optional[str] = None,
|
||||
aggregation_method: Optional[str] = None) -> str:
|
||||
"""
|
||||
生成统一的模型文件名
|
||||
|
||||
格式规范:
|
||||
|
||||
格式规范 (v2):
|
||||
- 产品模式: {model_type}_product_{product_id}_{version}.pth
|
||||
- 店铺模式: {model_type}_store_{store_id}_{product_id}_{version}.pth
|
||||
- 全局模式: {model_type}_global_{product_id}_{aggregation_method}_{version}.pth
|
||||
- 店铺模式: {model_type}_store_{store_id}_{version}.pth
|
||||
- 全局模式: {model_type}_global_{aggregation_method}_{version}.pth
|
||||
"""
|
||||
if training_mode == 'store' and store_id:
|
||||
return f"{model_type}_store_{store_id}_{product_id}_{version}.pth"
|
||||
return f"{model_type}_store_{store_id}_{version}.pth"
|
||||
elif training_mode == 'global' and aggregation_method:
|
||||
return f"{model_type}_global_{product_id}_{aggregation_method}_{version}.pth"
|
||||
else:
|
||||
# 默认产品模式
|
||||
return f"{model_type}_global_{aggregation_method}_{version}.pth"
|
||||
elif training_mode == 'product' and product_id:
|
||||
return f"{model_type}_product_{product_id}_{version}.pth"
|
||||
else:
|
||||
# 提供一个后备或抛出错误,以避免生成无效文件名
|
||||
raise ValueError(f"无法为训练模式 '{training_mode}' 生成有效的文件名,缺少必需的ID。")
|
||||
|
||||
def save_model(self,
|
||||
def save_model(self,
|
||||
model_data: dict,
|
||||
product_id: str,
|
||||
model_type: str,
|
||||
version: str,
|
||||
model_type: str,
|
||||
store_id: Optional[str] = None,
|
||||
training_mode: str = 'product',
|
||||
aggregation_method: Optional[str] = None,
|
||||
product_name: Optional[str] = None) -> str:
|
||||
product_name: Optional[str] = None,
|
||||
version: Optional[str] = None) -> Tuple[str, str]:
|
||||
"""
|
||||
保存模型到统一位置
|
||||
保存模型到统一位置,并自动管理版本。
|
||||
|
||||
参数:
|
||||
model_data: 包含模型状态和配置的字典
|
||||
product_id: 产品ID
|
||||
model_type: 模型类型
|
||||
version: 版本号
|
||||
store_id: 店铺ID (可选)
|
||||
training_mode: 训练模式
|
||||
aggregation_method: 聚合方法 (可选)
|
||||
product_name: 产品名称 (可选)
|
||||
...
|
||||
version: (可选) 如果提供,则覆盖自动版本控制 (如 'best')。
|
||||
|
||||
返回:
|
||||
模型文件路径
|
||||
(模型文件路径, 使用的版本号)
|
||||
"""
|
||||
if version is None:
|
||||
next_version_num = self._get_next_version(
|
||||
model_type=model_type,
|
||||
product_id=product_id,
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method
|
||||
)
|
||||
version_str = f"v{next_version_num}"
|
||||
else:
|
||||
version_str = version
|
||||
|
||||
filename = self.generate_model_filename(
|
||||
product_id, model_type, version, store_id, training_mode, aggregation_method
|
||||
model_type=model_type,
|
||||
version=version_str,
|
||||
training_mode=training_mode,
|
||||
product_id=product_id,
|
||||
store_id=store_id,
|
||||
aggregation_method=aggregation_method
|
||||
)
|
||||
|
||||
# 统一保存到根目录,避免复杂的子目录结构
|
||||
@ -86,7 +123,7 @@ class ModelManager:
|
||||
'product_id': product_id,
|
||||
'product_name': product_name or product_id,
|
||||
'model_type': model_type,
|
||||
'version': version,
|
||||
'version': version_str,
|
||||
'store_id': store_id,
|
||||
'training_mode': training_mode,
|
||||
'aggregation_method': aggregation_method,
|
||||
@ -99,7 +136,7 @@ class ModelManager:
|
||||
torch.save(enhanced_model_data, model_path)
|
||||
|
||||
print(f"模型已保存: {model_path}")
|
||||
return model_path
|
||||
return model_path, version_str
|
||||
|
||||
def list_models(self,
|
||||
product_id: Optional[str] = None,
|
||||
@ -180,9 +217,9 @@ class ModelManager:
|
||||
model_info['modified_at'] = datetime.fromtimestamp(
|
||||
os.path.getmtime(model_file)
|
||||
).isoformat()
|
||||
|
||||
models.append(model_info)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理模型文件失败 {model_file}: {e}")
|
||||
continue
|
||||
@ -228,127 +265,58 @@ class ModelManager:
|
||||
|
||||
def parse_model_filename(self, filename: str) -> Optional[Dict]:
|
||||
"""
|
||||
解析模型文件名,提取模型信息
|
||||
|
||||
解析模型文件名,提取模型信息 (v2版)
|
||||
|
||||
支持的格式:
|
||||
- {model_type}_product_{product_id}_{version}.pth
|
||||
- {model_type}_store_{store_id}_{product_id}_{version}.pth
|
||||
- {model_type}_global_{product_id}_{aggregation_method}_{version}.pth
|
||||
- 旧格式兼容
|
||||
- 产品: {model_type}_product_{product_id}_{version}.pth
|
||||
- 店铺: {model_type}_store_{store_id}_{version}.pth
|
||||
- 全局: {model_type}_global_{aggregation_method}_{version}.pth
|
||||
"""
|
||||
if not filename.endswith('.pth'):
|
||||
return None
|
||||
|
||||
|
||||
base_name = filename.replace('.pth', '')
|
||||
|
||||
parts = base_name.split('_')
|
||||
|
||||
if len(parts) < 3:
|
||||
return None # 格式不符合基本要求
|
||||
|
||||
# **核心修复**: 采用更健壮的、从后往前的解析逻辑,以支持带下划线的模型名称
|
||||
try:
|
||||
# 新格式解析
|
||||
if '_product_' in base_name:
|
||||
# 产品模式: model_type_product_product_id_version
|
||||
parts = base_name.split('_product_')
|
||||
model_type = parts[0]
|
||||
rest = parts[1]
|
||||
|
||||
# 分离产品ID和版本
|
||||
if '_v' in rest:
|
||||
last_v_index = rest.rfind('_v')
|
||||
product_id = rest[:last_v_index]
|
||||
version = rest[last_v_index+1:]
|
||||
else:
|
||||
product_id = rest
|
||||
version = 'v1'
|
||||
|
||||
version = parts[-1]
|
||||
identifier = parts[-2]
|
||||
mode_candidate = parts[-3]
|
||||
|
||||
if mode_candidate == 'product':
|
||||
model_type = '_'.join(parts[:-3])
|
||||
return {
|
||||
'model_type': model_type,
|
||||
'product_id': product_id,
|
||||
'version': version,
|
||||
'training_mode': 'product',
|
||||
'store_id': None,
|
||||
'aggregation_method': None
|
||||
'product_id': identifier,
|
||||
'version': version,
|
||||
}
|
||||
|
||||
elif '_store_' in base_name:
|
||||
# 店铺模式: model_type_store_store_id_product_id_version
|
||||
parts = base_name.split('_store_')
|
||||
model_type = parts[0]
|
||||
rest = parts[1]
|
||||
|
||||
# 分离店铺ID、产品ID和版本
|
||||
rest_parts = rest.split('_')
|
||||
if len(rest_parts) >= 3:
|
||||
store_id = rest_parts[0]
|
||||
if rest_parts[-1].startswith('v'):
|
||||
# 最后一部分是版本号
|
||||
version = rest_parts[-1]
|
||||
product_id = '_'.join(rest_parts[1:-1])
|
||||
else:
|
||||
version = 'v1'
|
||||
product_id = '_'.join(rest_parts[1:])
|
||||
|
||||
return {
|
||||
'model_type': model_type,
|
||||
'product_id': product_id,
|
||||
'version': version,
|
||||
'training_mode': 'store',
|
||||
'store_id': store_id,
|
||||
'aggregation_method': None
|
||||
}
|
||||
|
||||
elif '_global_' in base_name:
|
||||
# 全局模式: model_type_global_product_id_aggregation_method_version
|
||||
parts = base_name.split('_global_')
|
||||
model_type = parts[0]
|
||||
rest = parts[1]
|
||||
|
||||
rest_parts = rest.split('_')
|
||||
if len(rest_parts) >= 3:
|
||||
if rest_parts[-1].startswith('v'):
|
||||
# 最后一部分是版本号
|
||||
version = rest_parts[-1]
|
||||
aggregation_method = rest_parts[-2]
|
||||
product_id = '_'.join(rest_parts[:-2])
|
||||
else:
|
||||
version = 'v1'
|
||||
aggregation_method = rest_parts[-1]
|
||||
product_id = '_'.join(rest_parts[:-1])
|
||||
|
||||
return {
|
||||
'model_type': model_type,
|
||||
'product_id': product_id,
|
||||
'version': version,
|
||||
'training_mode': 'global',
|
||||
'store_id': None,
|
||||
'aggregation_method': aggregation_method
|
||||
}
|
||||
|
||||
# 兼容旧格式
|
||||
else:
|
||||
# 尝试解析其他格式
|
||||
if 'model_product_' in base_name:
|
||||
parts = base_name.split('_model_product_')
|
||||
model_type = parts[0]
|
||||
product_part = parts[1]
|
||||
|
||||
if '_v' in product_part:
|
||||
last_v_index = product_part.rfind('_v')
|
||||
product_id = product_part[:last_v_index]
|
||||
version = product_part[last_v_index+1:]
|
||||
else:
|
||||
product_id = product_part
|
||||
version = 'v1'
|
||||
|
||||
return {
|
||||
'model_type': model_type,
|
||||
'product_id': product_id,
|
||||
'version': version,
|
||||
'training_mode': 'product',
|
||||
'store_id': None,
|
||||
'aggregation_method': None
|
||||
}
|
||||
|
||||
elif mode_candidate == 'store':
|
||||
model_type = '_'.join(parts[:-3])
|
||||
return {
|
||||
'model_type': model_type,
|
||||
'training_mode': 'store',
|
||||
'store_id': identifier,
|
||||
'version': version,
|
||||
}
|
||||
elif mode_candidate == 'global':
|
||||
model_type = '_'.join(parts[:-3])
|
||||
return {
|
||||
'model_type': model_type,
|
||||
'training_mode': 'global',
|
||||
'aggregation_method': identifier,
|
||||
'version': version,
|
||||
}
|
||||
except IndexError:
|
||||
# 如果文件名部分不够,则解析失败
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f"解析文件名失败 {filename}: {e}")
|
||||
|
||||
|
||||
return None
|
||||
|
||||
def delete_model(self, model_file: str) -> bool:
|
||||
|
@ -8,8 +8,9 @@ import numpy as np
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, List, Tuple, Dict, Any
|
||||
from core.config import DEFAULT_DATA_PATH
|
||||
|
||||
def load_multi_store_data(file_path: str = 'pharmacy_sales_multi_store.csv',
|
||||
def load_multi_store_data(file_path: str = None,
|
||||
store_id: Optional[str] = None,
|
||||
product_id: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
@ -18,7 +19,7 @@ def load_multi_store_data(file_path: str = 'pharmacy_sales_multi_store.csv',
|
||||
加载多店铺销售数据,支持按店铺、产品、时间范围过滤
|
||||
|
||||
参数:
|
||||
file_path: 数据文件路径
|
||||
file_path: 数据文件路径 (支持 .csv, .xlsx, .parquet)。如果为None,则使用config中定义的默认路径。
|
||||
store_id: 店铺ID,为None时返回所有店铺数据
|
||||
product_id: 产品ID,为None时返回所有产品数据
|
||||
start_date: 开始日期 (YYYY-MM-DD)
|
||||
@ -28,43 +29,27 @@ def load_multi_store_data(file_path: str = 'pharmacy_sales_multi_store.csv',
|
||||
DataFrame: 过滤后的销售数据
|
||||
"""
|
||||
|
||||
# 尝试多个可能的文件路径
|
||||
possible_paths = [
|
||||
file_path,
|
||||
f'../{file_path}',
|
||||
f'server/{file_path}',
|
||||
'pharmacy_sales_multi_store.csv',
|
||||
'../pharmacy_sales_multi_store.csv',
|
||||
'pharmacy_sales.xlsx', # 后向兼容原始文件
|
||||
'../pharmacy_sales.xlsx'
|
||||
]
|
||||
|
||||
df = None
|
||||
for path in possible_paths:
|
||||
try:
|
||||
if path.endswith('.csv'):
|
||||
df = pd.read_csv(path)
|
||||
elif path.endswith('.xlsx'):
|
||||
df = pd.read_excel(path)
|
||||
# 为原始Excel文件添加默认店铺信息
|
||||
if 'store_id' not in df.columns:
|
||||
df['store_id'] = 'S001'
|
||||
df['store_name'] = '默认店铺'
|
||||
df['store_location'] = '未知位置'
|
||||
df['store_type'] = 'standard'
|
||||
|
||||
if df is not None:
|
||||
print(f"成功加载数据文件: {path}")
|
||||
break
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
if df is None:
|
||||
raise FileNotFoundError(f"无法找到数据文件,尝试的路径: {possible_paths}")
|
||||
|
||||
# 确保date列是datetime类型
|
||||
if 'date' in df.columns:
|
||||
df['date'] = pd.to_datetime(df['date'])
|
||||
# 如果未提供文件路径,则使用配置文件中的默认路径
|
||||
if file_path is None:
|
||||
file_path = DEFAULT_DATA_PATH
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"数据文件不存在: {file_path}")
|
||||
|
||||
try:
|
||||
if file_path.endswith('.csv'):
|
||||
df = pd.read_csv(file_path)
|
||||
elif file_path.endswith('.xlsx'):
|
||||
df = pd.read_excel(file_path)
|
||||
elif file_path.endswith('.parquet'):
|
||||
df = pd.read_parquet(file_path)
|
||||
else:
|
||||
raise ValueError(f"不支持的文件格式: {file_path}")
|
||||
|
||||
print(f"成功加载数据文件: {file_path}")
|
||||
except Exception as e:
|
||||
print(f"加载文件 {file_path} 失败: {e}")
|
||||
raise
|
||||
|
||||
# 按店铺过滤
|
||||
if store_id:
|
||||
@ -76,28 +61,38 @@ def load_multi_store_data(file_path: str = 'pharmacy_sales_multi_store.csv',
|
||||
df = df[df['product_id'] == product_id].copy()
|
||||
print(f"按产品过滤: {product_id}, 剩余记录数: {len(df)}")
|
||||
|
||||
# 按时间范围过滤
|
||||
# 标准化列名和数据类型
|
||||
df = standardize_column_names(df)
|
||||
|
||||
# 在标准化之后进行时间范围过滤
|
||||
if start_date:
|
||||
start_date = pd.to_datetime(start_date)
|
||||
df = df[df['date'] >= start_date].copy()
|
||||
print(f"开始日期过滤: {start_date}, 剩余记录数: {len(df)}")
|
||||
|
||||
try:
|
||||
start_date_dt = pd.to_datetime(start_date)
|
||||
# 确保比较是在datetime对象之间
|
||||
if 'date' in df.columns:
|
||||
df = df[df['date'] >= start_date_dt].copy()
|
||||
print(f"开始日期过滤: {start_date_dt}, 剩余记录数: {len(df)}")
|
||||
except (ValueError, TypeError):
|
||||
print(f"警告: 无效的开始日期格式 '{start_date}',已忽略。")
|
||||
|
||||
if end_date:
|
||||
end_date = pd.to_datetime(end_date)
|
||||
df = df[df['date'] <= end_date].copy()
|
||||
print(f"结束日期过滤: {end_date}, 剩余记录数: {len(df)}")
|
||||
try:
|
||||
end_date_dt = pd.to_datetime(end_date)
|
||||
# 确保比较是在datetime对象之间
|
||||
if 'date' in df.columns:
|
||||
df = df[df['date'] <= end_date_dt].copy()
|
||||
print(f"结束日期过滤: {end_date_dt}, 剩余记录数: {len(df)}")
|
||||
except (ValueError, TypeError):
|
||||
print(f"警告: 无效的结束日期格式 '{end_date}',已忽略。")
|
||||
|
||||
if len(df) == 0:
|
||||
print("警告: 过滤后没有数据")
|
||||
|
||||
# 标准化列名以匹配训练代码期望的格式
|
||||
df = standardize_column_names(df)
|
||||
|
||||
return df
|
||||
|
||||
def standardize_column_names(df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
标准化列名以匹配训练代码期望的格式
|
||||
标准化列名以匹配训练代码和API期望的格式
|
||||
|
||||
参数:
|
||||
df: 原始DataFrame
|
||||
@ -107,55 +102,67 @@ def standardize_column_names(df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
df = df.copy()
|
||||
|
||||
# 列名映射:新列名 -> 原列名
|
||||
column_mapping = {
|
||||
'sales': 'quantity_sold', # 销售数量
|
||||
'price': 'unit_price', # 单价
|
||||
'weekday': 'day_of_week' # 星期几
|
||||
# 定义列名映射并强制重命名
|
||||
rename_map = {
|
||||
'sales_quantity': 'sales', # 修复:匹配原始列名
|
||||
'temperature_2m_mean': 'temperature', # 新增:处理温度列
|
||||
'dayofweek': 'weekday' # 修复:匹配原始列名
|
||||
}
|
||||
df.rename(columns={k: v for k, v in rename_map.items() if k in df.columns}, inplace=True)
|
||||
|
||||
# 应用列名映射
|
||||
for new_name, old_name in column_mapping.items():
|
||||
if old_name in df.columns and new_name not in df.columns:
|
||||
df[new_name] = df[old_name]
|
||||
|
||||
# 创建缺失的特征列
|
||||
# 确保date列是datetime类型
|
||||
if 'date' in df.columns:
|
||||
df['date'] = pd.to_datetime(df['date'])
|
||||
|
||||
# 创建数值型的weekday (0=Monday, 6=Sunday)
|
||||
if 'weekday' not in df.columns:
|
||||
df['weekday'] = df['date'].dt.dayofweek
|
||||
elif df['weekday'].dtype == 'object':
|
||||
# 如果weekday是字符串,转换为数值
|
||||
weekday_map = {
|
||||
'Monday': 0, 'Tuesday': 1, 'Wednesday': 2, 'Thursday': 3,
|
||||
'Friday': 4, 'Saturday': 5, 'Sunday': 6
|
||||
}
|
||||
df['weekday'] = df['weekday'].map(weekday_map).fillna(df['date'].dt.dayofweek)
|
||||
|
||||
# 添加月份信息
|
||||
if 'month' not in df.columns:
|
||||
df['month'] = df['date'].dt.month
|
||||
df['date'] = pd.to_datetime(df['date'], errors='coerce')
|
||||
df.dropna(subset=['date'], inplace=True) # 移除无法解析的日期行
|
||||
else:
|
||||
# 如果没有date列,无法继续,返回空DataFrame
|
||||
return pd.DataFrame()
|
||||
|
||||
# 计算 sales_amount
|
||||
# 由于没有price列,sales_amount的计算逻辑需要调整或移除
|
||||
# 这里我们注释掉它,因为原始数据中已有sales_amount
|
||||
# if 'sales_amount' not in df.columns and 'sales' in df.columns and 'price' in df.columns:
|
||||
# # 先确保sales和price是数字
|
||||
# df['sales'] = pd.to_numeric(df['sales'], errors='coerce')
|
||||
# df['price'] = pd.to_numeric(df['price'], errors='coerce')
|
||||
# df['sales_amount'] = df['sales'] * df['price']
|
||||
|
||||
# 创建缺失的特征列
|
||||
if 'weekday' not in df.columns:
|
||||
df['weekday'] = df['date'].dt.dayofweek
|
||||
|
||||
# 添加缺失的布尔特征列(如果不存在则设为默认值)
|
||||
if 'month' not in df.columns:
|
||||
df['month'] = df['date'].dt.month
|
||||
|
||||
# 添加缺失的元数据列
|
||||
meta_columns = {
|
||||
'store_name': 'Unknown Store',
|
||||
'store_location': 'Unknown Location',
|
||||
'store_type': 'Unknown',
|
||||
'product_name': 'Unknown Product',
|
||||
'product_category': 'Unknown Category'
|
||||
}
|
||||
for col, default in meta_columns.items():
|
||||
if col not in df.columns:
|
||||
df[col] = default
|
||||
|
||||
# 添加缺失的布尔特征列
|
||||
default_features = {
|
||||
'is_holiday': False, # 是否节假日
|
||||
'is_weekend': None, # 是否周末(从weekday计算)
|
||||
'is_promotion': False, # 是否促销
|
||||
'temperature': 20.0 # 默认温度
|
||||
'is_holiday': False,
|
||||
'is_weekend': None,
|
||||
'is_promotion': False,
|
||||
'temperature': 20.0
|
||||
}
|
||||
|
||||
for feature, default_value in default_features.items():
|
||||
if feature not in df.columns:
|
||||
if feature == 'is_weekend' and 'weekday' in df.columns:
|
||||
# 周末:周六(5)和周日(6)
|
||||
if feature == 'is_weekend':
|
||||
df['is_weekend'] = df['weekday'].isin([5, 6])
|
||||
else:
|
||||
df[feature] = default_value
|
||||
|
||||
# 确保数值类型正确
|
||||
numeric_columns = ['sales', 'price', 'weekday', 'month', 'temperature']
|
||||
numeric_columns = ['sales', 'sales_amount', 'weekday', 'month', 'temperature']
|
||||
for col in numeric_columns:
|
||||
if col in df.columns:
|
||||
df[col] = pd.to_numeric(df[col], errors='coerce')
|
||||
@ -166,11 +173,11 @@ def standardize_column_names(df: pd.DataFrame) -> pd.DataFrame:
|
||||
if col in df.columns:
|
||||
df[col] = df[col].astype(bool)
|
||||
|
||||
print(f"数据标准化完成,可用特征列: {[col for col in ['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'] if col in df.columns]}")
|
||||
print(f"数据标准化完成,可用特征列: {[col for col in ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'] if col in df.columns]}")
|
||||
|
||||
return df
|
||||
|
||||
def get_available_stores(file_path: str = 'pharmacy_sales_multi_store.csv') -> List[Dict[str, Any]]:
|
||||
def get_available_stores(file_path: str = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取可用的店铺列表
|
||||
|
||||
@ -183,15 +190,31 @@ def get_available_stores(file_path: str = 'pharmacy_sales_multi_store.csv') -> L
|
||||
try:
|
||||
df = load_multi_store_data(file_path)
|
||||
|
||||
# 获取唯一店铺信息
|
||||
stores = df[['store_id', 'store_name', 'store_location', 'store_type']].drop_duplicates()
|
||||
if 'store_id' not in df.columns:
|
||||
print("数据文件中缺少 'store_id' 列")
|
||||
return []
|
||||
|
||||
# 智能地获取店铺信息,即使某些列缺失
|
||||
store_info = []
|
||||
|
||||
return stores.to_dict('records')
|
||||
# 使用drop_duplicates获取唯一的店铺组合
|
||||
stores_df = df.drop_duplicates(subset=['store_id'])
|
||||
|
||||
for _, row in stores_df.iterrows():
|
||||
store_info.append({
|
||||
'store_id': row['store_id'],
|
||||
'store_name': row.get('store_name', f"店铺 {row['store_id']}"),
|
||||
'location': row.get('store_location', '未知位置'),
|
||||
'type': row.get('store_type', '标准'),
|
||||
'opening_date': row.get('opening_date', '未知'),
|
||||
})
|
||||
|
||||
return store_info
|
||||
except Exception as e:
|
||||
print(f"获取店铺列表失败: {e}")
|
||||
return []
|
||||
|
||||
def get_available_products(file_path: str = 'pharmacy_sales_multi_store.csv',
|
||||
def get_available_products(file_path: str = None,
|
||||
store_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取可用的产品列表
|
||||
@ -222,7 +245,7 @@ def get_available_products(file_path: str = 'pharmacy_sales_multi_store.csv',
|
||||
|
||||
def get_store_product_sales_data(store_id: str,
|
||||
product_id: str,
|
||||
file_path: str = 'pharmacy_sales_multi_store.csv') -> pd.DataFrame:
|
||||
file_path: str = None) -> pd.DataFrame:
|
||||
"""
|
||||
获取特定店铺和产品的销售数据,用于模型训练
|
||||
|
||||
@ -252,27 +275,53 @@ def get_store_product_sales_data(store_id: str,
|
||||
print(f"警告: 数据标准化后仍缺少列 {missing_columns}")
|
||||
raise ValueError(f"无法获取完整的特征数据,缺少列: {missing_columns}")
|
||||
|
||||
return df
|
||||
# 定义模型训练所需的所有列(特征 + 目标)
|
||||
final_columns = [
|
||||
'date', 'sales', 'product_id', 'product_name', 'store_id', 'store_name',
|
||||
'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'
|
||||
]
|
||||
|
||||
# 筛选出DataFrame中实际存在的列
|
||||
existing_columns = [col for col in final_columns if col in df.columns]
|
||||
|
||||
# 返回只包含这些必需列的DataFrame
|
||||
return df[existing_columns]
|
||||
|
||||
def aggregate_multi_store_data(product_id: str,
|
||||
aggregation_method: str = 'sum',
|
||||
file_path: str = 'pharmacy_sales_multi_store.csv') -> pd.DataFrame:
|
||||
def aggregate_multi_store_data(product_id: Optional[str] = None,
|
||||
store_id: Optional[str] = None,
|
||||
aggregation_method: str = 'sum',
|
||||
file_path: str = None) -> pd.DataFrame:
|
||||
"""
|
||||
聚合多个店铺的销售数据,用于全局模型训练
|
||||
聚合销售数据,可按产品(全局)或按店铺(所有产品)
|
||||
|
||||
参数:
|
||||
file_path: 数据文件路径
|
||||
product_id: 产品ID
|
||||
product_id: 产品ID (用于全局模型)
|
||||
store_id: 店铺ID (用于店铺聚合模型)
|
||||
aggregation_method: 聚合方法 ('sum', 'mean', 'median')
|
||||
|
||||
返回:
|
||||
DataFrame: 聚合后的销售数据
|
||||
"""
|
||||
# 加载所有店铺的产品数据
|
||||
df = load_multi_store_data(file_path, product_id=product_id)
|
||||
|
||||
if len(df) == 0:
|
||||
raise ValueError(f"没有找到产品 {product_id} 的销售数据")
|
||||
# 根据是全局聚合、店铺聚合还是真正全局聚合来加载数据
|
||||
if store_id:
|
||||
# 店铺聚合:加载该店铺的所有数据
|
||||
df = load_multi_store_data(file_path, store_id=store_id)
|
||||
if len(df) == 0:
|
||||
raise ValueError(f"没有找到店铺 {store_id} 的销售数据")
|
||||
grouping_entity = f"店铺 {store_id}"
|
||||
elif product_id:
|
||||
# 按产品聚合:加载该产品在所有店铺的数据
|
||||
df = load_multi_store_data(file_path, product_id=product_id)
|
||||
if len(df) == 0:
|
||||
raise ValueError(f"没有找到产品 {product_id} 的销售数据")
|
||||
grouping_entity = f"产品 {product_id}"
|
||||
else:
|
||||
# 真正全局聚合:加载所有数据
|
||||
df = load_multi_store_data(file_path)
|
||||
if len(df) == 0:
|
||||
raise ValueError("数据文件为空,无法进行全局聚合")
|
||||
grouping_entity = "所有产品"
|
||||
|
||||
# 按日期聚合(使用标准化后的列名)
|
||||
agg_dict = {}
|
||||
@ -317,9 +366,19 @@ def aggregate_multi_store_data(product_id: str,
|
||||
aggregated_df = aggregated_df.sort_values('date').copy()
|
||||
aggregated_df = standardize_column_names(aggregated_df)
|
||||
|
||||
return aggregated_df
|
||||
# 定义模型训练所需的所有列(特征 + 目标)
|
||||
final_columns = [
|
||||
'date', 'sales', 'product_id', 'product_name', 'store_id', 'store_name',
|
||||
'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'
|
||||
]
|
||||
|
||||
# 筛选出DataFrame中实际存在的列
|
||||
existing_columns = [col for col in final_columns if col in aggregated_df.columns]
|
||||
|
||||
# 返回只包含这些必需列的DataFrame
|
||||
return aggregated_df[existing_columns]
|
||||
|
||||
def get_sales_statistics(file_path: str = 'pharmacy_sales_multi_store.csv',
|
||||
def get_sales_statistics(file_path: str = None,
|
||||
store_id: Optional[str] = None,
|
||||
product_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
@ -358,8 +417,8 @@ def get_sales_statistics(file_path: str = 'pharmacy_sales_multi_store.csv',
|
||||
return {'error': str(e)}
|
||||
|
||||
# 向后兼容的函数
|
||||
def load_data(file_path='pharmacy_sales.xlsx', store_id=None):
|
||||
def load_data(file_path=None, store_id=None):
|
||||
"""
|
||||
向后兼容的数据加载函数
|
||||
"""
|
||||
return load_multi_store_data(file_path, store_id=store_id)
|
||||
return load_multi_store_data(file_path, store_id=store_id)
|
||||
|
@ -24,6 +24,17 @@ server_dir = os.path.dirname(current_dir)
|
||||
sys.path.append(server_dir)
|
||||
|
||||
from utils.logging_config import setup_api_logging, get_training_logger, log_training_progress
|
||||
import numpy as np
|
||||
|
||||
def convert_numpy_types(obj):
|
||||
"""递归地将字典/列表中的NumPy类型转换为Python原生类型"""
|
||||
if isinstance(obj, dict):
|
||||
return {k: convert_numpy_types(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return [convert_numpy_types(i) for i in obj]
|
||||
elif isinstance(obj, np.generic):
|
||||
return obj.item()
|
||||
return obj
|
||||
|
||||
@dataclass
|
||||
class TrainingTask:
|
||||
@ -325,13 +336,16 @@ class TrainingProcessManager:
|
||||
|
||||
task_id = task_data['task_id']
|
||||
|
||||
# 立即对从队列中取出的数据进行类型转换
|
||||
serializable_task_data = convert_numpy_types(task_data)
|
||||
|
||||
with self.lock:
|
||||
if task_id in self.tasks:
|
||||
# 更新任务状态
|
||||
for key, value in task_data.items():
|
||||
# 使用转换后的数据更新任务状态
|
||||
for key, value in serializable_task_data.items():
|
||||
setattr(self.tasks[task_id], key, value)
|
||||
|
||||
# WebSocket通知 - 根据action类型发送不同的事件
|
||||
# WebSocket通知 - 使用已转换的数据
|
||||
if self.websocket_callback:
|
||||
try:
|
||||
if action == 'complete':
|
||||
@ -341,21 +355,21 @@ class TrainingProcessManager:
|
||||
'action': 'completed',
|
||||
'status': 'completed',
|
||||
'progress': 100,
|
||||
'message': task_data.get('message', '训练完成'),
|
||||
'metrics': task_data.get('metrics'),
|
||||
'end_time': task_data.get('end_time'),
|
||||
'product_id': task_data.get('product_id'),
|
||||
'model_type': task_data.get('model_type')
|
||||
'message': serializable_task_data.get('message', '训练完成'),
|
||||
'metrics': serializable_task_data.get('metrics'),
|
||||
'end_time': serializable_task_data.get('end_time'),
|
||||
'product_id': serializable_task_data.get('product_id'),
|
||||
'model_type': serializable_task_data.get('model_type')
|
||||
})
|
||||
# 额外发送一个完成事件,确保前端能收到
|
||||
self.websocket_callback('training_completed', {
|
||||
'task_id': task_id,
|
||||
'status': 'completed',
|
||||
'progress': 100,
|
||||
'message': task_data.get('message', '训练完成'),
|
||||
'metrics': task_data.get('metrics'),
|
||||
'product_id': task_data.get('product_id'),
|
||||
'model_type': task_data.get('model_type')
|
||||
'message': serializable_task_data.get('message', '训练完成'),
|
||||
'metrics': serializable_task_data.get('metrics'),
|
||||
'product_id': serializable_task_data.get('product_id'),
|
||||
'model_type': serializable_task_data.get('model_type')
|
||||
})
|
||||
elif action == 'error':
|
||||
# 训练失败
|
||||
@ -364,22 +378,22 @@ class TrainingProcessManager:
|
||||
'action': 'failed',
|
||||
'status': 'failed',
|
||||
'progress': 0,
|
||||
'message': task_data.get('message', '训练失败'),
|
||||
'error': task_data.get('error'),
|
||||
'product_id': task_data.get('product_id'),
|
||||
'model_type': task_data.get('model_type')
|
||||
'message': serializable_task_data.get('message', '训练失败'),
|
||||
'error': serializable_task_data.get('error'),
|
||||
'product_id': serializable_task_data.get('product_id'),
|
||||
'model_type': serializable_task_data.get('model_type')
|
||||
})
|
||||
else:
|
||||
# 状态更新
|
||||
self.websocket_callback('training_update', {
|
||||
'task_id': task_id,
|
||||
'action': action,
|
||||
'status': task_data.get('status'),
|
||||
'progress': task_data.get('progress', 0),
|
||||
'message': task_data.get('message', ''),
|
||||
'metrics': task_data.get('metrics'),
|
||||
'product_id': task_data.get('product_id'),
|
||||
'model_type': task_data.get('model_type')
|
||||
'status': serializable_task_data.get('status'),
|
||||
'progress': serializable_task_data.get('progress', 0),
|
||||
'message': serializable_task_data.get('message', ''),
|
||||
'metrics': serializable_task_data.get('metrics'),
|
||||
'product_id': serializable_task_data.get('product_id'),
|
||||
'model_type': serializable_task_data.get('model_type')
|
||||
})
|
||||
except Exception as e:
|
||||
self.logger.error(f"WebSocket通知失败: {e}")
|
||||
@ -441,7 +455,9 @@ class TrainingProcessManager:
|
||||
# WebSocket通知进度更新
|
||||
if self.websocket_callback and 'progress' in progress_data:
|
||||
try:
|
||||
self.websocket_callback('training_progress', progress_data)
|
||||
# 在发送前确保所有数据类型都是JSON可序列化的
|
||||
serializable_data = convert_numpy_types(progress_data)
|
||||
self.websocket_callback('training_progress', serializable_data)
|
||||
except Exception as e:
|
||||
self.logger.error(f"进度WebSocket通知失败: {e}")
|
||||
|
||||
|
788
xz修改记录日志和启动依赖.md
Normal file
788
xz修改记录日志和启动依赖.md
Normal file
@ -0,0 +1,788 @@
|
||||
### 根目录启动
|
||||
`uv pip install loguru numpy pandas torch matplotlib flask flask_cors flask_socketio flasgger scikit-learn tqdm pytorch_tcn pyarrow xgboost`
|
||||
|
||||
### UI
|
||||
`npm install` `npm run dev`
|
||||
|
||||
|
||||
|
||||
# “预测分析”模块UI重构修改记录
|
||||
|
||||
**任务目标**: 将原有的、通过下拉菜单切换模式的单一预测页面,重构为通过左侧子导航切换模式的多页面布局,使其UI结构与“模型训练”模块保持一致。
|
||||
|
||||
|
||||
### 后端修复 (2025-07-13)
|
||||
|
||||
**任务目标**: 解决模型训练时因数据文件路径错误导致的数据加载失败问题。
|
||||
|
||||
- **核心问题**: `server/core/predictor.py` 中的 `PharmacyPredictor` 类初始化时,硬编码了错误的默认数据文件路径 (`'pharmacy_sales_multi_store.csv'`)。
|
||||
- **修复方案**:
|
||||
1. 修改 `server/core/predictor.py`,将默认数据路径更正为 `'data/timeseries_training_data_sample_10s50p.parquet'`。
|
||||
2. 同步更新了 `server/trainers/mlstm_trainer.py` 中所有对数据加载函数的调用,确保使用正确的文件路径。
|
||||
- **结果**: 彻底解决了在独立训练进程中数据加载失败的问题。
|
||||
|
||||
---
|
||||
### 后端修复 (2025-07-13) - 数据流重构
|
||||
|
||||
**任务目标**: 解决因数据处理流程中断导致 `sales` 和 `price` 关键特征丢失,从而引发模型训练失败的根本问题。
|
||||
|
||||
- **核心问题**:
|
||||
1. `server/core/predictor.py` 中的 `train_model` 方法在调用训练器(如 `train_product_model_with_mlstm`)时,没有将预处理好的数据传递过去。
|
||||
2. `server/trainers/mlstm_trainer.py` 因此被迫重新加载和处理数据,但其使用的数据标准化函数 `standardize_column_names` 存在逻辑缺陷,导致关键列丢失。
|
||||
|
||||
- **修复方案 (数据流重构)**:
|
||||
1. **修改 `server/trainers/mlstm_trainer.py`**:
|
||||
- 重构 `train_product_model_with_mlstm` 函数,使其能够接收一个预处理好的 DataFrame (`product_df`) 作为参数。
|
||||
- 移除了函数内部所有的数据加载和重复处理逻辑。
|
||||
2. **修改 `server/core/predictor.py`**:
|
||||
- 在 `train_model` 方法中,将已经加载并处理好的 `product_data` 作为参数,显式传递给 `train_product_model_with_mlstm` 函数。
|
||||
3. **修改 `server/utils/multi_store_data_utils.py`**:
|
||||
- 在 `standardize_column_names` 函数中,使用 Pandas 的 `rename` 方法强制进行列名转换,确保 `quantity_sold` 和 `unit_price` 被可靠地重命名为 `sales` 和 `price`。
|
||||
|
||||
- **结果**: 彻底修复了数据处理流程,确保数据只被加载和标准化一次,并被正确传递,从根本上解决了模型训练失败的问题。
|
||||
---
|
||||
|
||||
### 第一次重构 (多页面、双栏布局)
|
||||
|
||||
- **新增文件**:
|
||||
- `UI/src/views/prediction/ProductPredictionView.vue`
|
||||
- `UI/src/views/prediction/StorePredictionView.vue`
|
||||
- `UI/src/views/prediction/GlobalPredictionView.vue`
|
||||
- **修改文件**:
|
||||
- `UI/src/router/index.js`: 添加了指向新页面的路由。
|
||||
- `UI/src/App.vue`: 将“预测分析”修改为包含三个子菜单的父菜单。
|
||||
|
||||
---
|
||||
|
||||
### 第二次重构 (基于用户反馈的单页面布局)
|
||||
|
||||
**任务目标**: 统一三个预测子页面的布局,采用旧的单页面预测样式,并将导航功能与页面内容解耦。
|
||||
|
||||
- **修改文件**:
|
||||
- **`UI/src/views/prediction/ProductPredictionView.vue`**:
|
||||
- **内容**: 使用 `UI/src/views/NewPredictionView.vue` 的布局进行替换。
|
||||
- **逻辑**: 移除了“模型训练方式”选择器,并将该页面的预测模式硬编码为 `product`。
|
||||
- **`UI/src/views/prediction/StorePredictionView.vue`**:
|
||||
- **内容**: 使用 `UI/src/views/NewPredictionView.vue` 的布局进行替换。
|
||||
- **逻辑**: 移除了“模型训练方式”选择器,并将该页面的预测模式硬编码为 `store`。
|
||||
- **`UI/src/views/prediction/GlobalPredictionView.vue`**:
|
||||
- **内容**: 使用 `UI/src/views/NewPredictionView.vue` 的布局进行替换。
|
||||
- **逻辑**: 移除了“模型训练方式”及特定目标选择器,并将该页面的预测模式硬编码为 `global`。
|
||||
|
||||
---
|
||||
|
||||
**总结**: 通过两次重构,最终实现了使用左侧导航栏切换预测模式,同时右侧内容区域保持统一、简洁的单页面布局,完全符合用户的最终要求。
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
---
|
||||
**按药品训练修改**
|
||||
**日期**: 2025-07-14
|
||||
**文件**: `server/trainers/mlstm_trainer.py`
|
||||
**问题**: 模型训练因 `KeyError: "['sales', 'price'] not in index"` 失败。
|
||||
**分析**:
|
||||
1. `'price'` 列在提供的数据中不存在,导致 `KeyError`。
|
||||
2. `'sales'` 列作为历史输入(自回归特征)对于模型训练是必要的。
|
||||
**解决方案**: 从 `mlstm_trainer` 的特征列表中移除了不存在的 `'price'` 列,保留了 `'sales'` 列用于自回归。
|
||||
|
||||
---
|
||||
**日期**: 2025-07-14 (补充)
|
||||
**文件**:
|
||||
* `server/trainers/transformer_trainer.py`
|
||||
* `server/trainers/tcn_trainer.py`
|
||||
* `server/trainers/kan_trainer.py`
|
||||
**问题**: 预防性修复。这些文件存在与 `mlstm_trainer.py` 相同的 `KeyError` 隐患。
|
||||
**分析**: 经过检查,这些训练器与 `mlstm_trainer` 共享相同的数据处理逻辑,其硬编码的特征列表中都包含了不存在的 `'price'` 列。
|
||||
**解决方案**: 统一从所有相关训练器的特征列表中移除了 `'price'` 列,以确保所有模型训练的健壮性。
|
||||
|
||||
---
|
||||
**日期**: 2025-07-14 (深度修复)
|
||||
**文件**: `server/utils/multi_store_data_utils.py`
|
||||
**问题**: 追踪 `KeyError: "['sales'] not in index"` 时,发现数据标准化流程存在多个问题。
|
||||
**分析**:
|
||||
1. 通过 `uv run` 读取了 `.parquet` 数据文件,确认了原始列名。
|
||||
2. 发现 `standardize_column_names` 函数中的重命名映射与原始列名不匹配 (例如 `quantity_sold` vs `sales_quantity`)。
|
||||
3. 确认了原始数据中没有 `price` 列,但代码中存在对它的依赖。
|
||||
4. 函数缺乏一个明确的返回列选择机制,导致 `sales` 列在数据准备阶段被意外丢弃。
|
||||
**解决方案**:
|
||||
1. 修正了 `rename_map` 以正确匹配原始数据列名 (`sales_quantity` -> `sales`, `temperature_2m_mean` -> `temperature`, `dayofweek` -> `weekday`)。
|
||||
2. 移除了对不存在的 `price` 列的依赖。
|
||||
3. 在函数末尾添加了逻辑,确保返回的 `DataFrame` 包含所有模型训练所需的标准列(特征 + 目标),保证了数据流的稳定性。
|
||||
4. 原始数据列名:['date', 'store_id', 'product_id', 'sales_quantity', 'sales_amount', 'gross_profit', 'customer_traffic', 'store_name', 'city', 'product_name', 'manufacturer', 'category_l1', 'category_l2', 'category_l3', 'abc_category', 'temperature_2m_mean', 'temperature_2m_max', 'temperature_2m_min', 'year', 'month', 'day', 'dayofweek', 'dayofyear', 'weekofyear', 'is_weekend', 'sl_lag_7', 'sl_lag_14', 'sl_rolling_mean_7', 'sl_rolling_std_7', 'sl_rolling_mean_14', 'sl_rolling_std_14']
|
||||
|
||||
---
|
||||
**日期**: 2025-07-14 10:16
|
||||
**主题**: 修复模型训练中的 `KeyError` 及数据流问题 (详细版)
|
||||
|
||||
### 阶段一:修复训练器层 `KeyError`
|
||||
|
||||
* **问题**: 模型训练因 `KeyError: "['sales', 'price'] not in index"` 失败。
|
||||
* **分析**: 训练器硬编码的特征列表中包含了数据源中不存在的 `'price'` 列。
|
||||
* **涉及文件**:
|
||||
* `server/trainers/mlstm_trainer.py`
|
||||
* `server/trainers/transformer_trainer.py`
|
||||
* `server/trainers/tcn_trainer.py`
|
||||
* `server/trainers/kan_trainer.py`
|
||||
* **修改详情**:
|
||||
* **位置**: 每个训练器文件中的 `features` 列表定义处。
|
||||
* **操作**: 修改。
|
||||
* **内容**:
|
||||
```diff
|
||||
- features = ['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||
+ features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||
```
|
||||
* **原因**: 移除对不存在的 `'price'` 列的依赖,解决 `KeyError`。
|
||||
|
||||
### 阶段二:修复数据标准化层
|
||||
|
||||
* **问题**: 修复后出现新错误 `KeyError: "['sales'] not in index"`,表明数据标准化流程存在缺陷。
|
||||
* **分析**: 通过 `uv run` 读取 `.parquet` 文件确认,`standardize_column_names` 函数中的列名映射错误,且缺少最终列选择机制。
|
||||
* **涉及文件**: `server/utils/multi_store_data_utils.py`
|
||||
* **修改详情**:
|
||||
1. **位置**: `standardize_column_names` 函数, `rename_map` 字典。
|
||||
* **操作**: 修改。
|
||||
* **内容**:
|
||||
```diff
|
||||
- rename_map = { 'quantity_sold': 'sales', 'unit_price': 'price', 'day_of_week': 'weekday' }
|
||||
+ rename_map = { 'sales_quantity': 'sales', 'temperature_2m_mean': 'temperature', 'dayofweek': 'weekday' }
|
||||
```
|
||||
* **原因**: 修正键名以匹配数据源的真实列名 (`sales_quantity`, `temperature_2m_mean`, `dayofweek`)。
|
||||
2. **位置**: `standardize_column_names` 函数, `sales_amount` 计算部分。
|
||||
* **操作**: 修改 (注释)。
|
||||
* **内容**:
|
||||
```diff
|
||||
- if 'sales_amount' not in df.columns and 'sales' in df.columns and 'price' in df.columns:
|
||||
- df['sales_amount'] = df['sales'] * df['price']
|
||||
+ # 由于没有price列,sales_amount的计算逻辑需要调整或移除
|
||||
+ # if 'sales_amount' not in df.columns and 'sales' in df.columns and 'price' in df.columns:
|
||||
+ # df['sales_amount'] = df['sales'] * df['price']
|
||||
```
|
||||
* **原因**: 避免因缺少 `'price'` 列而导致潜在错误。
|
||||
3. **位置**: `standardize_column_names` 函数, `numeric_columns` 列表。
|
||||
* **操作**: 删除。
|
||||
* **内容**:
|
||||
```diff
|
||||
- numeric_columns = ['sales', 'price', 'sales_amount', 'weekday', 'month', 'temperature']
|
||||
+ numeric_columns = ['sales', 'sales_amount', 'weekday', 'month', 'temperature']
|
||||
```
|
||||
* **原因**: 从数值类型转换列表中移除不存在的 `'price'` 列。
|
||||
4. **位置**: `standardize_column_names` 函数, `return` 语句前。
|
||||
* **操作**: 增加。
|
||||
* **内容**:
|
||||
```diff
|
||||
+ # 定义模型训练所需的所有列(特征 + 目标)
|
||||
+ final_columns = [
|
||||
+ 'date', 'sales', 'product_id', 'product_name', 'store_id', 'store_name',
|
||||
+ 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'
|
||||
+ ]
|
||||
+ # 筛选出DataFrame中实际存在的列
|
||||
+ existing_columns = [col for col in final_columns if col in df.columns]
|
||||
+ # 返回只包含这些必需列的DataFrame
|
||||
+ return df[existing_columns]
|
||||
```
|
||||
* **原因**: 增加列选择机制,确保函数返回的 `DataFrame` 结构统一且包含 `sales` 列,从根源上解决 `KeyError: "['sales'] not in index"`。
|
||||
|
||||
### 阶段三:修复数据流分发层
|
||||
|
||||
* **问题**: `predictor.py` 未将处理好的数据统一传递给所有训练器。
|
||||
* **分析**: `train_model` 方法中,只有 `mlstm` 的调用传递了 `product_df`,其他模型则没有,导致它们重新加载未处理的数据。
|
||||
* **涉及文件**: `server/core/predictor.py`
|
||||
* **修改详情**:
|
||||
* **位置**: `train_model` 方法中对 `train_product_model_with_transformer`, `_tcn`, `_kan` 的调用处。
|
||||
* **操作**: 增加。
|
||||
* **内容**: 在函数调用中增加了 `product_df=product_data` 参数。
|
||||
```diff
|
||||
- model_result, metrics, actual_version = train_product_model_with_transformer(product_id, ...)
|
||||
+ model_result, metrics, actual_version = train_product_model_with_transformer(product_id=product_id, product_df=product_data, ...)
|
||||
```
|
||||
*(对 `tcn` 和 `kan` 的调用也做了类似修改)*
|
||||
* **原因**: 统一数据流,确保所有训练器都使用经过正确预处理的、包含完整信息的 `DataFrame`。
|
||||
|
||||
### 阶段四:适配训练器以接收数据
|
||||
|
||||
* **问题**: `transformer`, `tcn`, `kan` 训练器需要能接收上游传来的数据。
|
||||
* **分析**: 需要修改这三个训练器的函数签名和内部逻辑,使其在接收到 `product_df` 时跳过数据加载。
|
||||
* **涉及文件**: `server/trainers/transformer_trainer.py`, `tcn_trainer.py`, `kan_trainer.py`
|
||||
* **修改详情**:
|
||||
1. **位置**: 每个训练器主函数的定义处。
|
||||
* **操作**: 增加。
|
||||
* **内容**: 在函数参数中增加了 `product_df=None`。
|
||||
```diff
|
||||
- def train_product_model_with_transformer(product_id, ...)
|
||||
+ def train_product_model_with_transformer(product_id, product_df=None, ...)
|
||||
```
|
||||
2. **位置**: 每个训练器内部的数据加载逻辑处。
|
||||
* **操作**: 增加。
|
||||
* **内容**: 增加了 `if product_df is None:` 的判断逻辑,只有在未接收到数据时才执行内部加载。
|
||||
```diff
|
||||
+ if product_df is None:
|
||||
- # 根据训练模式加载数据
|
||||
- from utils.multi_store_data_utils import load_multi_store_data
|
||||
- ...
|
||||
+ # [原有的数据加载逻辑]
|
||||
+ else:
|
||||
+ # 如果传入了product_df,直接使用
|
||||
+ ...
|
||||
```
|
||||
* **原因**: 完成数据流修复的最后一环,使训练器能够灵活地接收外部数据或自行加载,彻底解决问题。
|
||||
|
||||
---
|
||||
**日期**: 2025-07-14 10:38
|
||||
**主题**: 修复因NumPy类型导致的JSON序列化失败问题
|
||||
|
||||
### 阶段五:修复前后端通信层
|
||||
|
||||
* **问题**: 模型训练成功后,后端向前端发送包含训练指标(metrics)的WebSocket消息或API响应时失败,导致前端状态无法更新为“已完成”。
|
||||
* **日志错误**: `Object of type float32 is not JSON serializable`
|
||||
* **分析**: 训练过程产生的评估指标(如 `mse`, `rmse`)是NumPy的 `float32` 类型。Python标准的 `json` 库无法直接序列化这种类型,导致在通过WebSocket或HTTP API发送数据时出错。
|
||||
* **涉及文件**: `server/utils/training_process_manager.py`
|
||||
* **修改详情**:
|
||||
1. **位置**: 文件顶部。
|
||||
* **操作**: 增加。
|
||||
* **内容**:
|
||||
```python
|
||||
import numpy as np
|
||||
|
||||
def convert_numpy_types(obj):
|
||||
"""递归地将字典/列表中的NumPy类型转换为Python原生类型"""
|
||||
if isinstance(obj, dict):
|
||||
return {k: convert_numpy_types(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return [convert_numpy_types(i) for i in obj]
|
||||
elif isinstance(obj, np.generic):
|
||||
return obj.item()
|
||||
return obj
|
||||
```
|
||||
* **原因**: 添加一个通用的辅助函数,用于将包含NumPy类型的数据结构转换为JSON兼容的格式。
|
||||
2. **位置**: `_monitor_results` 方法内部,调用 `self.websocket_callback` 之前。
|
||||
* **操作**: 增加。
|
||||
* **内容**:
|
||||
```diff
|
||||
+ serializable_task_data = convert_numpy_types(task_data)
|
||||
- self.websocket_callback('training_update', { ... 'metrics': task_data.get('metrics'), ... })
|
||||
+ self.websocket_callback('training_update', { ... 'metrics': serializable_task_data.get('metrics'), ... })
|
||||
```
|
||||
* **原因**: 在通过WebSocket发送数据之前,调用 `convert_numpy_types` 函数对包含训练结果的 `task_data` 进行处理,确保所有 `float32` 等类型都被转换为Python原生的 `float`,从而解决序列化错误。
|
||||
|
||||
**总结**: 通过在数据发送前进行类型转换,彻底解决了前后端通信中的序列化问题,确保了训练状态能够被正确地更新到前端。
|
||||
|
||||
---
|
||||
**日期**: 2025-07-14 11:04
|
||||
**主题**: 根治JSON序列化问题
|
||||
|
||||
### 阶段六:修复API层序列化错误
|
||||
|
||||
* **问题**: 在修复WebSocket的序列化问题后,发现直接轮询 `GET /api/training` 接口时,仍然出现 `Object of type float32 is not JSON serializable` 错误。
|
||||
* **分析**: 上一阶段的修复只转换了准备通过WebSocket发送的数据,但没有转换**存放在 `TrainingProcessManager` 内部 `self.tasks` 字典中的数据**。因此,当API通过 `get_all_tasks()` 方法读取这个字典时,获取到的仍然是包含NumPy类型的原始数据,导致 `jsonify` 失败。
|
||||
* **涉及文件**: `server/utils/training_process_manager.py`
|
||||
* **修改详情**:
|
||||
* **位置**: `_monitor_results` 方法,从 `result_queue` 获取数据之后。
|
||||
* **操作**: 调整逻辑。
|
||||
* **内容**:
|
||||
```diff
|
||||
- with self.lock:
|
||||
- # ... 更新 self.tasks ...
|
||||
- if self.websocket_callback:
|
||||
- serializable_task_data = convert_numpy_types(task_data)
|
||||
- # ... 使用 serializable_task_data 发送消息 ...
|
||||
+ # 立即对从队列中取出的数据进行类型转换
|
||||
+ serializable_task_data = convert_numpy_types(task_data)
|
||||
+ with self.lock:
|
||||
+ # 使用转换后的数据更新任务状态
|
||||
+ for key, value in serializable_task_data.items():
|
||||
+ setattr(self.tasks[task_id], key, value)
|
||||
+ # WebSocket通知 - 使用已转换的数据
|
||||
+ if self.websocket_callback:
|
||||
+ # ... 使用 serializable_task_data 发送消息 ...
|
||||
```
|
||||
* **原因**: 将类型转换的步骤提前,确保存入 `self.tasks` 的数据已经是JSON兼容的。这样,无论是通过WebSocket推送还是通过API查询,获取到的都是安全的数据,从根源上解决了所有序列化问题。
|
||||
|
||||
**最终总结**: 至此,所有已知的数据流和数据类型问题均已解决。
|
||||
|
||||
---
|
||||
**日期**: 2025-07-14 11:15
|
||||
**主题**: 修复模型评估中的MAPE计算错误
|
||||
|
||||
### 阶段七:修复评估指标计算
|
||||
|
||||
* **问题**: 训练 `transformer` 模型时,日志显示 `MAPE: nan%` 并伴有 `RuntimeWarning: Mean of empty slice.`。
|
||||
* **分析**: `MAPE` (平均绝对百分比误差) 的计算涉及除以真实值。当测试集中的所有真实销量(`y_true`)都为0时,用于避免除零错误的 `mask` 会导致一个空数组被传递给 `np.mean()`,从而产生 `nan` 和运行时警告。
|
||||
* **涉及文件**: `server/analysis/metrics.py`
|
||||
* **修改详情**:
|
||||
* **位置**: `evaluate_model` 函数中计算 `mape` 的部分。
|
||||
* **操作**: 增加条件判断。
|
||||
* **内容**:
|
||||
```diff
|
||||
- mask = y_true != 0
|
||||
- mape = np.mean(np.abs((y_true[mask] - y_pred[mask]) / y_true[mask])) * 100
|
||||
+ mask = y_true != 0
|
||||
+ if np.any(mask):
|
||||
+ mape = np.mean(np.abs((y_true[mask] - y_pred[mask]) / y_true[mask])) * 100
|
||||
+ else:
|
||||
+ # 如果所有真实值都为0,无法计算MAPE,返回0
|
||||
+ mape = 0.0
|
||||
```
|
||||
* **原因**: 在计算MAPE之前,先检查是否存在任何非零的真实值。如果不存在,则直接将MAPE设为0,避免了对空数组求平均值,从而解决了 `nan` 和 `RuntimeWarning` 的问题。
|
||||
|
||||
## 2025-07-14 11:41:修复“按店铺训练”页面店铺列表加载失败问题
|
||||
|
||||
**问题描述:**
|
||||
在“模型训练” -> “按店铺训练”页面中,“选择店铺”的下拉列表为空,无法加载任何店铺信息。
|
||||
|
||||
**根本原因:**
|
||||
位于 `server/utils/multi_store_data_utils.py` 的 `standardize_column_names` 函数在标准化数据后,错误地移除了包括店铺元数据在内的非训练必需列。这导致调用该函数的 `get_available_stores` 函数无法获取到完整的店铺信息,最终返回一个空列表。
|
||||
|
||||
**解决方案:**
|
||||
本着最小改动和保持代码清晰的原则,我进行了以下重构:
|
||||
|
||||
1. **净化 `standardize_column_names` 函数**:移除了其中所有与列筛选相关的代码,使其只专注于数据标准化这一核心职责。
|
||||
2. **精确应用筛选逻辑**:将列筛选的逻辑精确地移动到了 `get_store_product_sales_data` 和 `aggregate_multi_store_data` 这两个为模型训练准备数据的函数中。这确保了只有在需要为模型准备数据时,才会执行列筛选。
|
||||
3. **增强 `get_available_stores` 函数**:由于 `load_multi_store_data` 现在可以返回所有列,`get_available_stores` 将能够正常工作。同时,我增强了其代码的健壮性,以优雅地处理数据文件中可能存在的列缺失问题。
|
||||
|
||||
**代码变更:**
|
||||
- **文件:** `server/utils/multi_store_data_utils.py`
|
||||
- **主要改动:**
|
||||
- 从 `standardize_column_names` 中移除列筛选逻辑。
|
||||
- 在 `get_store_product_sales_data` 和 `aggregate_multi_store_data` 中添加列筛选逻辑。
|
||||
- 重写 `get_available_stores` 以更健壮地处理数据。
|
||||
|
||||
---
|
||||
**日期**: 2025-07-14 13:00
|
||||
**主题**: 修复“按店铺训练-所有药品”模式下的训练失败问题
|
||||
|
||||
### 问题描述
|
||||
在“模型训练” -> “按店铺训练”页面,当选择“所有药品”进行训练时,后端日志显示 `获取店铺产品数据失败: 没有找到店铺 [store_id] 产品 unknown 的销售数据`,导致训练任务失败。
|
||||
|
||||
### 根本原因
|
||||
1. **API层**: `server/api.py` 在处理来自前端的训练请求时,如果 `product_id` 为 `null`(对应“所有药品”选项),会执行 `product_id or "unknown"`,错误地将产品ID设置为字符串 `"unknown"`。
|
||||
2. **预测器层**: `server/core/predictor.py` 中的 `train_model` 方法接收到无效的 `product_id="unknown"` 后,尝试使用它来获取数据,但数据源中不存在ID为“unknown”的产品,导致数据加载失败。
|
||||
3. **数据工具层**: `server/utils/multi_store_data_utils.py` 中的 `aggregate_multi_store_data` 函数只支持按产品ID进行全局聚合,不支持按店铺ID聚合其下所有产品的数据。
|
||||
|
||||
### 解决方案 (保留"unknown"字符串)
|
||||
为了在不改变API层行为的前提下解决问题,采用了在下游处理这个特殊值的策略:
|
||||
|
||||
1. **修改 `server/core/predictor.py`**:
|
||||
* **位置**: `train_model` 方法。
|
||||
* **操作**: 增加了对 `product_id == 'unknown'` 的特殊处理逻辑。
|
||||
* **内容**:
|
||||
```python
|
||||
# 如果product_id是'unknown',则表示为店铺所有商品训练一个聚合模型
|
||||
if product_id == 'unknown':
|
||||
try:
|
||||
# 使用聚合函数,按店铺聚合
|
||||
product_data = aggregate_multi_store_data(
|
||||
store_id=store_id,
|
||||
aggregation_method=aggregation_method,
|
||||
file_path=self.data_path
|
||||
)
|
||||
# 将product_id设置为店铺ID,以便模型保存时使用有意义的标识
|
||||
product_id = store_id
|
||||
except Exception as e:
|
||||
# ... 错误处理 ...
|
||||
else:
|
||||
# ... 原有的按单个产品获取数据的逻辑 ...
|
||||
```
|
||||
* **原因**: 在预测器层面拦截无效的 `"unknown"` ID,并将其意图正确地转换为“聚合此店铺的所有产品数据”。同时,将 `product_id` 重新赋值为 `store_id`,确保了后续模型保存时能使用一个唯一且有意义的名称(如 `store_01010023_mlstm_v1.pth`)。
|
||||
|
||||
2. **修改 `server/utils/multi_store_data_utils.py`**:
|
||||
* **位置**: `aggregate_multi_store_data` 函数。
|
||||
* **操作**: 重构函数签名和内部逻辑。
|
||||
* **内容**:
|
||||
```python
|
||||
def aggregate_multi_store_data(product_id: Optional[str] = None,
|
||||
store_id: Optional[str] = None,
|
||||
aggregation_method: str = 'sum',
|
||||
...)
|
||||
# ...
|
||||
if store_id:
|
||||
# 店铺聚合:加载该店铺的所有数据
|
||||
df = load_multi_store_data(file_path, store_id=store_id)
|
||||
# ...
|
||||
elif product_id:
|
||||
# 全局聚合:加载该产品的所有数据
|
||||
df = load_multi_store_data(file_path, product_id=product_id)
|
||||
# ...
|
||||
else:
|
||||
raise ValueError("必须提供 product_id 或 store_id")
|
||||
```
|
||||
* **原因**: 扩展了数据聚合函数的功能,使其能够根据传入的 `store_id` 参数,加载并聚合特定店铺的所有销售数据,为店铺级别的综合模型训练提供了数据基础。
|
||||
|
||||
**最终结果**: 通过这两处修改,系统现在可以正确处理“按店铺-所有药品”的训练请求。它会聚合该店铺所有产品的销售数据,训练一个综合模型,并以店铺ID为标识来保存该模型,彻底解决了该功能点的训练失败问题。
|
||||
|
||||
---
|
||||
**日期**: 2025-07-14 14:19
|
||||
**主题**: 修复并发训练中的稳定性和日志错误
|
||||
|
||||
### 阶段八:修复并发训练中的多个错误
|
||||
|
||||
* **问题**: 在并发执行多个训练任务时,系统出现 `JSON序列化错误`、`API列表排序错误` 和 `WebSocket连接错误`。
|
||||
* **分析**:
|
||||
1. **`Object of type float32 is not JSON serializable`**: `training_process_manager.py` 在通过WebSocket发送**中途**的训练进度时,没有对包含NumPy `float32` 类型的 `metrics` 数据进行序列化。
|
||||
2. **`'<' not supported between instances of 'str' and 'NoneType'`**: `api.py` 在获取训练任务列表时,对 `start_time` 进行排序,但未处理某些任务的 `start_time` 可能为 `None` 的情况,导致 `TypeError`。
|
||||
3. **`AssertionError: write() before start_response`**: `api.py` 中,当以 `debug=True` 模式运行时,Flask内置的Werkzeug服务器的调试器与Socket.IO的连接管理机制发生冲突。
|
||||
* **解决方案**:
|
||||
1. **文件**: `server/utils/training_process_manager.py`
|
||||
* **位置**: `_monitor_progress` 方法。
|
||||
* **操作**: 在发送 `training_progress` 事件前,调用 `convert_numpy_types` 函数对 `progress_data` 进行完全序列化。
|
||||
* **原因**: 确保所有通过WebSocket发送的数据(包括中途进度)都是JSON兼容的,彻底解决序列化问题。
|
||||
2. **文件**: `server/api.py`
|
||||
* **位置**: `get_all_training_tasks` 函数。
|
||||
* **操作**: 修改 `sorted` 函数的 `key`,使用 `lambda x: x.get('start_time') or '1970-01-01 00:00:00'`。
|
||||
* **原因**: 为 `None` 类型的 `start_time` 提供一个有效的默认值,使其可以和字符串类型的日期进行安全比较,解决了排序错误。
|
||||
3. **文件**: `server/api.py`
|
||||
* **位置**: `socketio.run()` 调用处。
|
||||
* **操作**: 增加 `allow_unsafe_werkzeug=True if args.debug else False` 参数。
|
||||
* **原因**: 这是 `Flask-SocketIO` 官方推荐的解决方案,用于在调试模式下协调Werkzeug与Socket.IO的事件循环,避免底层WSGI错误。
|
||||
|
||||
**最终结果**: 通过这三项修复,系统的并发稳定性和健壮性得到显著提升,解决了在高并发训练场景下出现的各类错误。
|
||||
|
||||
---
|
||||
**日期**: 2025-07-14 14:48
|
||||
**主题**: 修复模型评估指标计算错误并优化训练过程
|
||||
|
||||
### 阶段九:修复模型评估与训练优化
|
||||
|
||||
* **问题**: 所有模型训练完成后,评估指标 `R²` 始终为0.0,`MAPE` 始终为0.00%,这表明模型评估或训练过程存在严重问题。
|
||||
* **分析**:
|
||||
1. **核心错误**: 在 `mlstm_trainer.py` 和 `transformer_trainer.py` 中,计算损失函数时,模型输出 `outputs` 的维度是 `(batch_size, forecast_horizon)`,而目标 `y_batch` 的维度被错误地通过 `unsqueeze(-1)` 修改为 `(batch_size, forecast_horizon, 1)`。这种维度不匹配导致损失计算错误,模型无法正确学习。
|
||||
2. **优化缺失**: 训练过程中缺少学习率调度、梯度裁剪和提前停止等关键的优化策略,影响了训练效率和稳定性。
|
||||
* **解决方案**:
|
||||
1. **修复维度不匹配 (关键修复)**:
|
||||
* **文件**: `server/trainers/mlstm_trainer.py`, `server/trainers/transformer_trainer.py`
|
||||
* **位置**: 训练和验证循环中的损失计算部分。
|
||||
* **操作**: 移除了对 `y_batch` 的 `unsqueeze(-1)` 操作,确保 `outputs` 和 `y_batch` 维度一致。
|
||||
```diff
|
||||
- loss = criterion(outputs, y_batch.unsqueeze(-1))
|
||||
+ loss = criterion(outputs, y_batch.squeeze(-1) if y_batch.dim() == 3 else y_batch)
|
||||
```
|
||||
* **原因**: 修正损失函数的输入,使模型能够根据正确的误差进行学习,从而解决评估指标恒为0的问题。
|
||||
2. **增加训练优化策略**:
|
||||
* **文件**: `server/trainers/mlstm_trainer.py`, `server/trainers/transformer_trainer.py`
|
||||
* **操作**: 在两个训练器中增加了以下功能:
|
||||
* **学习率调度器**: 引入 `torch.optim.lr_scheduler.ReduceLROnPlateau`,当测试损失停滞时自动降低学习率。
|
||||
* **梯度裁剪**: 在优化器更新前,使用 `torch.nn.utils.clip_grad_norm_` 对梯度进行裁剪,防止梯度爆炸。
|
||||
* **提前停止**: 增加了 `patience` 参数,当测试损失连续多个epoch未改善时,提前终止训练,防止过拟合。
|
||||
* **原因**: 引入这些业界标准的优化技术,可以显著提高训练过程的稳定性、收敛速度和最终的模型性能。
|
||||
|
||||
**最终结果**: 通过修复核心的逻辑错误并引入多项优化措施,模型现在不仅能够正确学习,而且训练过程更加健壮和高效。
|
||||
|
||||
---
|
||||
**日期**: 2025-07-14 15:20
|
||||
**主题**: 根治模型维度错误并统一数据流 (完整调试过程)
|
||||
|
||||
### 阶段九:错误的修复尝试 (记录备查)
|
||||
|
||||
* **问题**: 所有模型训练完成后,评估指标 `R²` 始终为0.0,`MAPE` 始终为0.00%。
|
||||
* **初步分析**: 怀疑损失函数计算时,`outputs` 和 `y_batch` 维度不匹配。
|
||||
* **错误的假设**: 当时错误地认为是 `y_batch` 的维度有问题,而 `outputs` 的维度是正确的。
|
||||
* **错误的修复**:
|
||||
* **文件**: `server/trainers/mlstm_trainer.py`, `server/trainers/transformer_trainer.py`
|
||||
* **操作**: 尝试在训练器层面使用 `squeeze` 调整 `y_batch` 的维度来匹配 `outputs`。
|
||||
```diff
|
||||
- loss = criterion(outputs, y_batch)
|
||||
+ loss = criterion(outputs, y_batch.squeeze(-1) if y_batch.dim() == 3 else y_batch)
|
||||
```
|
||||
* **结果**: 此修改导致了新的运行时错误 `UserWarning: Using a target size (torch.Size([32, 3])) that is different to the input size (torch.Size([32, 3, 1]))`,证明了修复方向错误,但帮助定位了问题的真正根源。
|
||||
|
||||
### 阶段十:根治维度不匹配问题
|
||||
|
||||
* **问题**: 深入分析阶段九的错误后,确认了问题的根源。
|
||||
* **根本原因**: `server/models/mlstm_model.py` 中的 `MLSTMTransformer` 模型,其 `forward` 方法的最后一层输出了一个多余的维度,导致其输出形状为 `(B, H, 1)`,而并非期望的 `(B, H)`。
|
||||
* **正确的解决方案 (端到端维度一致性)**:
|
||||
1. **修复模型层 (治本)**:
|
||||
* **文件**: `server/models/mlstm_model.py`
|
||||
* **位置**: `MLSTMTransformer` 的 `forward` 方法。
|
||||
* **操作**: 在 `output_layer` 之后增加 `.squeeze(-1)`,将模型输出的维度从 `(B, H, 1)` 修正为 `(B, H)`。
|
||||
```diff
|
||||
- return self.output_layer(decoder_outputs)
|
||||
+ return self.output_layer(decoder_outputs).squeeze(-1)
|
||||
```
|
||||
2. **净化训练器层 (治标)**:
|
||||
* **文件**: `server/trainers/mlstm_trainer.py`, `server/trainers/transformer_trainer.py`
|
||||
* **操作**: 撤销了阶段九的错误修改,恢复为最直接的损失计算 `loss = criterion(outputs, y_batch)`。
|
||||
3. **优化评估逻辑**:
|
||||
* **文件**: `server/trainers/mlstm_trainer.py`, `server/trainers/transformer_trainer.py`
|
||||
* **操作**: 简化了模型评估部分的反归一化逻辑,使其更清晰、更直接地处理 `(样本数, 预测步长)` 形状的数据。
|
||||
```diff
|
||||
- test_pred_inv = scaler_y.inverse_transform(test_pred.reshape(-1, 1)).flatten()
|
||||
- test_true_inv = scaler_y.inverse_transform(testY.reshape(-1, 1)).flatten()
|
||||
+ test_pred_inv = scaler_y.inverse_transform(test_pred)
|
||||
+ test_true_inv = scaler_y.inverse_transform(test_true)
|
||||
```
|
||||
|
||||
**最终结果**: 通过记录整个调试过程,我们不仅修复了问题,还理解了其根本原因。通过在模型源头修正维度,并在整个数据流中保持维度一致性,彻底解决了训练失败的问题。代码现在更简洁、健壮,并遵循了良好的设计实践。
|
||||
|
||||
---
|
||||
**日期**: 2025-07-14 15:30
|
||||
**主题**: 根治模型维度错误并统一数据流 (完整调试过程)
|
||||
|
||||
### 阶段九:错误的修复尝试 (记录备查)
|
||||
|
||||
* **问题**: 所有模型训练完成后,评估指标 `R²` 始终为0.0,`MAPE` 始终为0.00%。
|
||||
* **初步分析**: 怀疑损失函数计算时,`outputs` 和 `y_batch` 维度不匹配。
|
||||
* **错误的假设**: 当时错误地认为是 `y_batch` 的维度有问题,而 `outputs` 的维度是正确的。
|
||||
* **错误的修复**:
|
||||
* **文件**: `server/trainers/mlstm_trainer.py`, `server/trainers/transformer_trainer.py`
|
||||
* **操作**: 尝试在训练器层面使用 `squeeze` 调整 `y_batch` 的维度来匹配 `outputs`。
|
||||
```diff
|
||||
- loss = criterion(outputs, y_batch)
|
||||
+ loss = criterion(outputs, y_batch.squeeze(-1) if y_batch.dim() == 3 else y_batch)
|
||||
```
|
||||
* **结果**: 此修改导致了新的运行时错误 `UserWarning: Using a target size (torch.Size([32, 3])) that is different to the input size (torch.Size([32, 3, 1]))`,证明了修复方向错误,但帮助定位了问题的真正根源。
|
||||
|
||||
### 阶段十:根治维度不匹配问题
|
||||
|
||||
* **问题**: 深入分析阶段九的错误后,确认了问题的根源在于模型输出维度。
|
||||
* **根本原因**: `server/models/mlstm_model.py` 中的 `MLSTMTransformer` 模型,其 `forward` 方法的最后一层输出了一个多余的维度,导致其输出形状为 `(B, H, 1)`,而并非期望的 `(B, H)`。
|
||||
* **正确的解决方案 (端到端维度一致性)**:
|
||||
1. **修复模型层 (治本)**:
|
||||
* **文件**: `server/models/mlstm_model.py`
|
||||
* **位置**: `MLSTMTransformer` 的 `forward` 方法。
|
||||
* **操作**: 在 `output_layer` 之后增加 `.squeeze(-1)`,将模型输出的维度从 `(B, H, 1)` 修正为 `(B, H)`。
|
||||
2. **净化训练器层 (治标)**:
|
||||
* **文件**: `server/trainers/mlstm_trainer.py`, `server/trainers/transformer_trainer.py`
|
||||
* **操作**: 撤销了阶段九的错误修改,恢复为最直接的损失计算 `loss = criterion(outputs, y_batch)`。
|
||||
3. **优化评估逻辑**:
|
||||
* **文件**: `server/trainers/mlstm_trainer.py`, `server/trainers/transformer_trainer.py`
|
||||
* **操作**: 简化了模型评估部分的反归一化逻辑,使其更清晰、更直接地处理 `(样本数, 预测步长)` 形状的数据。
|
||||
|
||||
### 阶段十一:最终修复与逻辑统一
|
||||
|
||||
* **问题**: 在应用阶段十的修复后,训练仍然失败。mLSTM出现维度反转错误 (`target size (B, H, 1)` vs `input size (B, H)`),而Transformer则出现评估错误 (`'numpy.ndarray' object has no attribute 'numpy'`)。
|
||||
* **分析**:
|
||||
1. **维度反转根源**: 问题的最终根源在 `server/utils/data_utils.py` 的 `create_dataset` 函数。它在创建目标数据集 `dataY` 时,错误地保留了一个多余的维度,导致 `y_batch` 的形状变为 `(B, H, 1)`。
|
||||
2. **评估Bug**: 在 `mlstm_trainer.py` 和 `transformer_trainer.py` 的评估部分,代码 `test_true = testY.numpy()` 是错误的,因为 `testY` 已经是Numpy数组。
|
||||
* **最终解决方案 (端到端修复)**:
|
||||
1. **修复数据加载层 (治本)**:
|
||||
* **文件**: `server/utils/data_utils.py`
|
||||
* **位置**: `create_dataset` 函数。
|
||||
* **操作**: 修改 `dataY.append(y)` 为 `dataY.append(y.flatten())`,从源头上确保 `y` 标签的维度是正确的 `(B, H)`。
|
||||
2. **修复训练器评估层**:
|
||||
* **文件**: `server/trainers/mlstm_trainer.py`, `server/trainers/transformer_trainer.py`
|
||||
* **位置**: 模型评估部分。
|
||||
* **操作**: 修正 `test_true = testY.numpy()` 为 `test_true = testY`,解决了属性错误。
|
||||
|
||||
**最终结果**: 通过记录并分析整个调试过程(阶段九到十一),我们最终定位并修复了从数据加载、模型设计到训练器评估的整个流程中的维度不一致问题。代码现在更加简洁、健壮,并遵循了端到端维度一致的良好设计实践。
|
||||
|
||||
---
|
||||
**日期**: 2025-07-14 15:34
|
||||
**主题**: 扩展维度修复至Transformer模型
|
||||
|
||||
### 阶段十二:统一所有模型的输出维度
|
||||
|
||||
* **问题**: 在修复 `mLSTM` 模型后,`Transformer` 模型的训练仍然因为完全相同的维度不匹配问题而失败。
|
||||
* **分析**: `server/models/transformer_model.py` 中的 `TimeSeriesTransformer` 类也存在与 `mLSTM` 相同的设计缺陷,其 `forward` 方法的输出维度为 `(B, H, 1)` 而非 `(B, H)`。
|
||||
* **解决方案**:
|
||||
1. **修复Transformer模型层**:
|
||||
* **文件**: `server/models/transformer_model.py`
|
||||
* **位置**: `TimeSeriesTransformer` 的 `forward` 方法。
|
||||
* **操作**: 在 `output_layer` 之后增加 `.squeeze(-1)`,将模型输出的维度从 `(B, H, 1)` 修正为 `(B, H)`。
|
||||
```diff
|
||||
- return self.output_layer(decoder_outputs)
|
||||
+ return self.output_layer(decoder_outputs).squeeze(-1)
|
||||
```
|
||||
|
||||
**最终结果**: 通过将维度修复方案应用到所有相关的模型文件,我们确保了整个系统的模型层都遵循了统一的、正确的输出维度标准。至此,所有已知的维度相关问题均已从根源上解决。
|
||||
|
||||
---
|
||||
**日期**: 2025-07-14 16:10
|
||||
**主题**: 修复“全局模型训练-所有药品”模式下的训练失败问题
|
||||
|
||||
### 问题描述
|
||||
在“全局模型训练”页面,当选择“所有药品”进行训练时,后端日志显示 `聚合全局数据失败: 没有找到产品 unknown 的销售数据`,导致训练任务失败。
|
||||
|
||||
### 根本原因
|
||||
1. **API层 (`server/api.py`)**: 在处理全局训练请求时,如果前端未提供 `product_id`(对应“所有药品”选项),API层会执行 `product_id or "unknown"`,错误地将产品ID设置为字符串 `"unknown"`。
|
||||
2. **预测器层 (`server/core/predictor.py`)**: `train_model` 方法接收到无效的 `product_id="unknown"` 后,在 `training_mode='global'` 分支下,直接将其传递给数据聚合函数。
|
||||
3. **数据工具层 (`server/utils/multi_store_data_utils.py`)**: `aggregate_multi_store_data` 函数缺少处理“真正”全局聚合(即不按任何特定产品或店铺过滤)的逻辑,当收到 `product_id="unknown"` 时,它会尝试按一个不存在的产品进行过滤,最终导致失败。
|
||||
|
||||
### 解决方案 (遵循现有设计模式)
|
||||
为了在不影响现有功能的前提下修复此问题,采用了与历史修复类似的、在中间层进行逻辑适配的策略。
|
||||
|
||||
1. **修改 `server/utils/multi_store_data_utils.py`**:
|
||||
* **位置**: `aggregate_multi_store_data` 函数。
|
||||
* **操作**: 扩展了函数功能。
|
||||
* **内容**: 增加了新的逻辑分支。当 `product_id` 和 `store_id` 参数都为 `None` 时,函数现在会加载**所有**数据进行聚合,以支持真正的全局模型训练。
|
||||
```python
|
||||
# ...
|
||||
elif product_id:
|
||||
# 按产品聚合...
|
||||
else:
|
||||
# 真正全局聚合:加载所有数据
|
||||
df = load_multi_store_data(file_path)
|
||||
if len(df) == 0:
|
||||
raise ValueError("数据文件为空,无法进行全局聚合")
|
||||
grouping_entity = "所有产品"
|
||||
```
|
||||
* **原因**: 使数据聚合函数的功能更加完整和健壮,能够服务于真正的全局训练场景,同时不影响其原有的按店铺或按产品的聚合功能。
|
||||
|
||||
2. **修改 `server/core/predictor.py`**:
|
||||
* **位置**: `train_model` 方法,`training_mode == 'global'` 的逻辑分支内。
|
||||
* **操作**: 增加了对 `product_id == 'unknown'` 的特殊处理。
|
||||
* **内容**:
|
||||
```python
|
||||
if product_id == 'unknown':
|
||||
product_data = aggregate_multi_store_data(
|
||||
product_id=None, # 传递None以触发真正的全局聚合
|
||||
# ...
|
||||
)
|
||||
# 将product_id设置为一个有意义的标识符
|
||||
product_id = 'all_products'
|
||||
else:
|
||||
# ...原有的按单个产品聚合的逻辑...
|
||||
```
|
||||
* **原因**: 在核心预测器层面拦截无效的 `"unknown"` ID,并将其正确地解释为“聚合所有产品数据”的意图。通过向聚合函数传递 `product_id=None` 来调用新增强的全局聚合功能,并用一个有意义的标识符 `all_products` 来命名模型,确保了后续流程的正确执行。
|
||||
|
||||
**最终结果**: 通过这两处修改,系统现在可以正确处理“全局模型-所有药品”的训练请求,聚合所有产品的销售数据来训练一个通用的全局模型,彻底解决了该功能点的训练失败问题。
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
---
|
||||
**日期**: 2025-07-14
|
||||
**主题**: UI导航栏重构
|
||||
|
||||
### 描述
|
||||
根据用户请求,对左侧功能导航栏进行了调整。
|
||||
|
||||
### 主要改动
|
||||
1. **删除“数据管理”**:
|
||||
* 从 `UI/src/App.vue` 的导航菜单中移除了“数据管理”项。
|
||||
* 从 `UI/src/router/index.js` 中删除了对应的 `/data` 路由。
|
||||
* 删除了视图文件 `UI/src/views/DataView.vue`。
|
||||
|
||||
2. **提升“店铺管理”**:
|
||||
* 将“店铺管理”菜单项在 `UI/src/App.vue` 中的位置提升,以填补原“数据管理”的位置,使其在导航中更加突出。
|
||||
|
||||
### 涉及文件
|
||||
* `UI/src/App.vue`
|
||||
* `UI/src/router/index.js`
|
||||
* `UI/src/views/DataView.vue` (已删除)
|
||||
|
||||
|
||||
|
||||
|
||||
**按药品模型预测**
|
||||
---
|
||||
**日期**: 2025-07-14
|
||||
**主题**: 修复导航菜单高亮问题
|
||||
|
||||
### 描述
|
||||
修复了首次进入或刷新页面时,左侧导航菜单项与当前路由不匹配导致不高亮的问题。
|
||||
|
||||
### 主要改动
|
||||
* **文件**: `UI/src/App.vue`
|
||||
* **修改**:
|
||||
1. 引入 `useRoute` 和 `computed`。
|
||||
2. 创建了一个计算属性 `activeMenu`,其值动态地等于当前路由的路径 (`route.path`)。
|
||||
3. 将 `el-menu` 组件的 `:default-active` 属性绑定到 `activeMenu`。
|
||||
|
||||
### 结果
|
||||
确保了导航菜单的高亮状态始终与当前页面的URL保持同步。
|
||||
|
||||
---
|
||||
**日期**: 2025-07-15
|
||||
**主题**: 修复硬编码文件路径问题,提高项目可移植性
|
||||
|
||||
### 问题描述
|
||||
项目在从一台计算机迁移到另一台时,由于数据文件路径被硬编码在代码中,导致程序无法找到数据文件而运行失败。
|
||||
|
||||
### 根本原因
|
||||
多个Python文件(`predictor.py`, `multi_store_data_utils.py`)中直接写入了相对路径 `'data/timeseries_training_data_sample_10s50p.parquet'` 作为默认值。这种方式在不同运行环境下(如从根目录运行 vs 从子目录运行)会产生路径解析错误。
|
||||
|
||||
### 解决方案:集中配置,统一管理
|
||||
1. **修改 `server/core/config.py` (核心)**:
|
||||
* 动态计算并定义了一个全局变量 `PROJECT_ROOT`,它始终指向项目的根目录。
|
||||
* 基于 `PROJECT_ROOT`,使用 `os.path.join` 创建了一个跨平台的、绝对的默认数据路径 `DEFAULT_DATA_PATH` 和模型保存路径 `DEFAULT_MODEL_DIR`。
|
||||
* 这确保了无论从哪个位置执行代码,路径总能被正确解析。
|
||||
|
||||
2. **修改 `server/utils/multi_store_data_utils.py`**:
|
||||
* 从 `server/core/config` 导入 `DEFAULT_DATA_PATH`。
|
||||
* 将所有数据加载函数的 `file_path` 参数的默认值从硬编码的字符串改为 `None`。
|
||||
* 在函数内部,如果 `file_path` 为 `None`,则自动使用导入的 `DEFAULT_DATA_PATH`。
|
||||
* 移除了原有的、复杂的、为了猜测正确路径而编写的冗余代码。
|
||||
|
||||
3. **修改 `server/core/predictor.py`**:
|
||||
* 同样从 `server/core/config` 导入 `DEFAULT_DATA_PATH`。
|
||||
* 在初始化 `PharmacyPredictor` 时,如果未提供数据路径,则使用导入的 `DEFAULT_DATA_PATH` 作为默认值。
|
||||
|
||||
### 最终结果
|
||||
通过将数据源路径集中到唯一的配置文件中进行管理,彻底解决了因硬编码路径导致的可移植性问题。项目现在可以在任何环境下可靠地运行。
|
||||
|
||||
---
|
||||
### 未来如何修改数据源(例如,连接到服务器数据库)
|
||||
|
||||
本次重构为将来更换数据源打下了坚实的基础。操作非常简单:
|
||||
|
||||
1. **定位配置文件**: 打开 `server/core/config.py` 文件。
|
||||
|
||||
2. **修改数据源定义**:
|
||||
* **当前 (文件)**:
|
||||
```python
|
||||
DEFAULT_DATA_PATH = os.path.join(PROJECT_ROOT, 'data', 'timeseries_training_data_sample_10s50p.parquet')
|
||||
```
|
||||
* **未来 (数据库示例)**:
|
||||
您可以将这行替换为数据库连接字符串,或者添加新的数据库配置变量。例如:
|
||||
```python
|
||||
# 注释掉或删除旧的文件路径配置
|
||||
# DEFAULT_DATA_PATH = ...
|
||||
|
||||
# 新增数据库连接配置
|
||||
DATABASE_URL = "postgresql://user:password@your_server_ip:5432/your_database_name"
|
||||
```
|
||||
|
||||
3. **修改数据加载逻辑**:
|
||||
* **定位数据加载函数**: 打开 `server/utils/multi_store_data_utils.py`。
|
||||
* **修改 `load_multi_store_data` 函数**:
|
||||
* 引入数据库连接库(如 `sqlalchemy` 或 `psycopg2`)。
|
||||
* 修改函数逻辑,使其使用 `config.py` 中的 `DATABASE_URL` 来连接数据库,并执行SQL查询来获取数据,而不是读取文件。
|
||||
* **示例**:
|
||||
```python
|
||||
from sqlalchemy import create_engine
|
||||
from core.config import DATABASE_URL # 导入新的数据库配置
|
||||
|
||||
def load_multi_store_data(...):
|
||||
# ...
|
||||
engine = create_engine(DATABASE_URL)
|
||||
query = "SELECT * FROM sales_data" # 根据需要构建查询
|
||||
df = pd.read_sql(query, engine)
|
||||
# ... 后续处理逻辑保持不变 ...
|
||||
```
|
||||
|
||||
通过以上步骤,您就可以在不改动项目其他任何部分的情况下,轻松地将数据源从本地文件切换到服务器数据库。
|
||||
|
||||
---
|
||||
**日期**: 2025-07-15
|
||||
**主题**: 修复“按药品预测”功能并增强图表展示
|
||||
**开发者**: lyf
|
||||
|
||||
### 问题描述
|
||||
“预测分析” -> “按药品预测”页面无法正常使用。前端API调用地址错误,且图表渲染逻辑与后端返回的数据结构不匹配。
|
||||
|
||||
### 解决方案
|
||||
对 `UI/src/views/prediction/ProductPredictionView.vue` 文件进行了以下修复和增强:
|
||||
|
||||
1. **API端点修复**:
|
||||
* **位置**: `startPrediction` 函数。
|
||||
* **操作**: 将API请求地址从错误的 `/api/predict` 修正为正确的 `/api/prediction`。
|
||||
|
||||
2. **数据处理对齐**:
|
||||
* **位置**: `startPrediction` 和 `renderChart` 函数。
|
||||
* **操作**: 修改了数据接收逻辑,使其能够正确处理后端返回的 `history_data` 和 `prediction_data` 字段。
|
||||
|
||||
3. **图表功能增强**:
|
||||
* **位置**: `renderChart` 函数。
|
||||
* **操作**: 重构了图表渲染逻辑,现在可以同时展示历史销量(绿色实线)和预测销量(蓝色虚线),为用户提供更直观的对比分析。
|
||||
|
||||
4. **错误提示优化**:
|
||||
* **位置**: `startPrediction` 函数的 `catch` 块。
|
||||
* **操作**: 改进了错误处理,现在可以从响应中提取并显示来自后端的更具体的错误信息。
|
||||
|
||||
### 最终结果
|
||||
“按药品预测”功能已与后端成功对接,可以正常使用,并且提供了更丰富、更健壮的可视化体验。
|
222
xz新模型添加流程.md
Normal file
222
xz新模型添加流程.md
Normal file
@ -0,0 +1,222 @@
|
||||
# 如何向系统添加新模型
|
||||
|
||||
本指南详细说明了如何向本预测系统添加一个全新的预测模型。系统采用灵活的插件式架构,集成新模型的过程非常模块化,主要围绕 **模型(Model)**、**训练器(Trainer)** 和 **预测器(Predictor)** 这三个核心组件进行。
|
||||
|
||||
## 核心理念
|
||||
|
||||
系统的核心是 `models/model_registry.py`,它维护了两个独立的注册表:一个用于训练函数,另一个用于预测函数。添加新模型的本质就是:
|
||||
|
||||
1. **定义模型**:创建模型的架构。
|
||||
2. **创建训练器**:编写一个函数来训练这个模型,并将其注册到训练器注册表。
|
||||
3. **集成预测器**:确保系统知道如何加载模型并用它来预测,然后将预测逻辑注册到预测器注册表。
|
||||
|
||||
---
|
||||
|
||||
## 第 1 步:定义模型架构
|
||||
|
||||
首先,您需要在 `ShopTRAINING/server/models/` 目录下创建一个新的 Python 文件来定义您的模型。
|
||||
|
||||
**示例:创建 `ShopTRAINING/server/models/my_new_model.py`**
|
||||
|
||||
如果您的新模型是基于 PyTorch 的,它应该是一个继承自 `torch.nn.Module` 的类。
|
||||
|
||||
```python
|
||||
# file: ShopTRAINING/server/models/my_new_model.py
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class MyNewModel(nn.Module):
|
||||
def __init__(self, input_features, hidden_size, output_sequence_length):
|
||||
"""
|
||||
定义模型的层和结构。
|
||||
"""
|
||||
super(MyNewModel, self).__init__()
|
||||
self.layer1 = nn.Linear(input_features, hidden_size)
|
||||
self.relu = nn.ReLU()
|
||||
self.layer2 = nn.Linear(hidden_size, output_sequence_length)
|
||||
# ... 可添加更复杂的结构
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
定义数据通过模型的前向传播路径。
|
||||
x 的形状通常是 (batch_size, sequence_length, num_features)
|
||||
"""
|
||||
# 确保输入是正确的形状
|
||||
# 例如,对于简单的线性层,可能需要展平
|
||||
batch_size, seq_len, features = x.shape
|
||||
x = x.view(batch_size * seq_len, features) # 展平
|
||||
|
||||
out = self.layer1(x)
|
||||
out = self.relu(out)
|
||||
out = self.layer2(out)
|
||||
|
||||
# 恢复形状以匹配输出
|
||||
out = out.view(batch_size, seq_len, -1)
|
||||
# 通常我们只关心序列的最后一个预测
|
||||
return out[:, -1, :]
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 第 2 步:创建模型训练器
|
||||
|
||||
接下来,在 `ShopTRAINING/server/trainers/` 目录下创建一个新的训练器文件。这个文件负责模型的整个训练、评估和保存流程。
|
||||
|
||||
**示例:创建 `ShopTRAINING/server/trainers/my_new_model_trainer.py`**
|
||||
|
||||
这个训练函数需要遵循系统中其他训练器(如 `xgboost_trainer.py`)的统一函数签名,并使用 `@register_trainer` 装饰器或在文件末尾调用 `register_trainer` 函数。
|
||||
|
||||
```python
|
||||
# file: ShopTRAINING/server/trainers/my_new_model_trainer.py
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
from models.model_registry import register_trainer
|
||||
from utils.model_manager import model_manager
|
||||
from analysis.metrics import evaluate_model
|
||||
from models.my_new_model import MyNewModel # 导入您的新模型
|
||||
|
||||
# 遵循系统的标准函数签名
|
||||
def train_with_mynewmodel(product_id, model_identifier, product_df, store_id, training_mode, aggregation_method, epochs, sequence_length, forecast_horizon, model_dir, **kwargs):
|
||||
print(f"🚀 MyNewModel 训练器启动: model_identifier='{model_identifier}'")
|
||||
|
||||
# --- 1. 数据准备 ---
|
||||
# (此处省略了数据加载、标准化和创建数据集的详细代码,
|
||||
# 您可以参考 xgboost_trainer.py 或其他训练器中的实现)
|
||||
# ...
|
||||
# 假设您已准备好 trainX, trainY, testX, testY, scaler_y 等变量
|
||||
# trainX = ...
|
||||
# trainY = ...
|
||||
# testX = ...
|
||||
# testY = ...
|
||||
# scaler_y = ...
|
||||
# features = [...]
|
||||
|
||||
# --- 2. 实例化模型和优化器 ---
|
||||
input_dim = trainX.shape[2] # 获取特征数量
|
||||
hidden_size = 64 # 示例超参数
|
||||
|
||||
model = MyNewModel(
|
||||
input_features=input_dim,
|
||||
hidden_size=hidden_size,
|
||||
output_sequence_length=forecast_horizon
|
||||
)
|
||||
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||
criterion = torch.nn.MSELoss()
|
||||
|
||||
# --- 3. 训练循环 ---
|
||||
print("开始训练 MyNewModel...")
|
||||
for epoch in range(epochs):
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
outputs = model(trainX)
|
||||
loss = criterion(outputs, trainY)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
if (epoch + 1) % 10 == 0:
|
||||
print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')
|
||||
|
||||
# --- 4. 模型评估 ---
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
test_pred_scaled = model(testX)
|
||||
|
||||
# 反标准化并计算指标
|
||||
# ... (参考其他训练器)
|
||||
metrics = {'rmse': 0.0, 'mae': 0.0, 'r2': 0.0, 'mape': 0.0} # 示例
|
||||
|
||||
# --- 5. 模型保存 ---
|
||||
model_data = {
|
||||
'model_state_dict': model.state_dict(),
|
||||
'scaler_X': None, # 替换为您的 scaler_X
|
||||
'scaler_y': scaler_y,
|
||||
'config': {
|
||||
'model_type': 'mynewmodel', # **关键**: 使用唯一的模型名称
|
||||
'input_dim': input_dim,
|
||||
'hidden_size': hidden_size,
|
||||
'sequence_length': sequence_length,
|
||||
'forecast_horizon': forecast_horizon,
|
||||
'features': features
|
||||
},
|
||||
'metrics': metrics
|
||||
}
|
||||
|
||||
model_manager.save_model(
|
||||
model_data=model_data,
|
||||
product_id=product_id,
|
||||
model_type='mynewmodel', # **关键**: 再次确认模型名称
|
||||
# ... 其他参数
|
||||
)
|
||||
|
||||
print("✅ MyNewModel 模型训练并保存完成!")
|
||||
return model, metrics, "v1", "path/to/model" # 返回值遵循统一格式
|
||||
|
||||
# --- 关键步骤: 将训练器注册到系统中 ---
|
||||
register_trainer('mynewmodel', train_with_mynewmodel)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 第 3 步:集成模型预测器
|
||||
|
||||
最后,您需要让系统知道如何加载和使用您的新模型进行预测。这需要在 `ShopTRAINING/server/predictors/model_predictor.py` 中进行两处修改。
|
||||
|
||||
**文件: `ShopTRAINING/server/predictors/model_predictor.py`**
|
||||
|
||||
1. **让系统知道如何构建您的模型实例**
|
||||
|
||||
在 `load_model_and_predict` 函数中,有一个 `if/elif` 结构用于根据模型类型实例化不同的模型。您需要为 `MyNewModel` 添加一个新的分支。
|
||||
|
||||
```python
|
||||
# 在 model_predictor.py 中
|
||||
|
||||
# 首先,导入您的新模型类
|
||||
from models.my_new_model import MyNewModel
|
||||
|
||||
# ... 在 load_model_and_predict 函数内部 ...
|
||||
|
||||
# ... 其他模型的 elif 分支 ...
|
||||
elif loaded_model_type == 'tcn':
|
||||
model = TCNForecaster(...)
|
||||
|
||||
# vvv 添加这个新的分支 vvv
|
||||
elif loaded_model_type == 'mynewmodel':
|
||||
model = MyNewModel(
|
||||
input_features=config['input_dim'],
|
||||
hidden_size=config['hidden_size'],
|
||||
output_sequence_length=config['forecast_horizon']
|
||||
).to(DEVICE)
|
||||
# ^^^ 添加结束 ^^^
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的模型类型: {loaded_model_type}")
|
||||
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
model.eval()
|
||||
```
|
||||
|
||||
2. **注册预测逻辑**
|
||||
|
||||
如果您的模型是一个标准的 PyTorch 模型,并且其预测逻辑与现有的模型(如 Transformer, KAN)相同,您可以直接复用 `default_pytorch_predictor`。只需在文件末尾添加一行注册代码即可。
|
||||
|
||||
```python
|
||||
# 在 model_predictor.py 文件末尾
|
||||
|
||||
# ...
|
||||
# 将增强后的默认预测器也注册给xgboost
|
||||
register_predictor('xgboost', default_pytorch_predictor)
|
||||
|
||||
# vvv 添加这行代码 vvv
|
||||
# 让 'mynewmodel' 也使用通用的 PyTorch 预测器
|
||||
register_predictor('mynewmodel', default_pytorch_predictor)
|
||||
# ^^^ 添加结束 ^^^
|
||||
```
|
||||
|
||||
如果您的模型需要特殊的预测逻辑(例如,像 XGBoost 那样有不同的输入格式或调用方式),您可以复制 `default_pytorch_predictor` 创建一个新函数,修改其内部逻辑,然后将新函数注册给 `'mynewmodel'`。
|
||||
|
||||
---
|
||||
|
||||
## 总结
|
||||
|
||||
完成以上三个步骤后,您的新模型 `MyNewModel` 就已完全集成到系统中了。系统会自动在 `trainers` 目录中发现您的新训练器。当您通过 API 或界面选择 `mynewmodel` 进行训练和预测时,系统将自动调用您刚刚编写和注册的所有相应逻辑。
|
679
xz系统API注解和swagger访问.md
Normal file
679
xz系统API注解和swagger访问.md
Normal file
@ -0,0 +1,679 @@
|
||||
## 访问Swagger UI:
|
||||
## 服务启动后,在您的浏览器中打开以下地址即可查看交互式API文档:
|
||||
## http://localhost:5173/swagger/
|
||||
## 文件路径:C:\WorkSpace\ShopTRAINING\server\swagger.json
|
||||
|
||||
# 药店销售预测系统 API 文档
|
||||
|
||||
**版本:** 1.0.0
|
||||
**联系方式:** [API开发团队](mailto:support@example.com)
|
||||
|
||||
本文档详细描述了用于药店销售预测的RESTful API。
|
||||
|
||||
---
|
||||
|
||||
## 数据管理
|
||||
|
||||
数据上传和查询相关接口。
|
||||
|
||||
### `GET /api/products`
|
||||
|
||||
**摘要:** 获取所有产品列表
|
||||
|
||||
**描述:** 返回系统中所有产品的ID和名称。
|
||||
|
||||
**响应:**
|
||||
* `200`: 成功获取产品列表。
|
||||
```json
|
||||
{
|
||||
"status": "success",
|
||||
"data": [
|
||||
{
|
||||
"product_id": "P001",
|
||||
"product_name": "产品A"
|
||||
},
|
||||
{
|
||||
"product_id": "P002",
|
||||
"product_name": "产品B"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
* `500`: 服务器内部错误。
|
||||
|
||||
---
|
||||
|
||||
### `GET /api/products/{product_id}`
|
||||
|
||||
**摘要:** 获取单个产品详情
|
||||
|
||||
**描述:** 返回指定产品ID的详细信息。
|
||||
|
||||
**路径参数:**
|
||||
* `product_id` (string, required): 产品ID,例如P001。
|
||||
|
||||
**响应:**
|
||||
* `200`: 成功获取产品详情。
|
||||
```json
|
||||
{
|
||||
"status": "success",
|
||||
"data": {
|
||||
"product_id": "P001",
|
||||
"product_name": "产品A",
|
||||
"data_points": 365,
|
||||
"date_range": {
|
||||
"start": "2023-01-01",
|
||||
"end": "2023-12-31"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
* `404`: 产品不存在。
|
||||
* `500`: 服务器内部错误。
|
||||
|
||||
---
|
||||
|
||||
### `GET /api/products/{product_id}/sales`
|
||||
|
||||
**摘要:** 获取产品销售数据
|
||||
|
||||
**描述:** 返回指定产品在特定日期范围内的销售数据。
|
||||
|
||||
**路径参数:**
|
||||
* `product_id` (string, required): 产品ID,例如P001。
|
||||
|
||||
**查询参数:**
|
||||
* `start_date` (string): 开始日期,格式为YYYY-MM-DD。
|
||||
* `end_date` (string): 结束日期,格式为YYYY-MM-DD。
|
||||
|
||||
**响应:**
|
||||
* `200`: 成功获取销售数据。
|
||||
```json
|
||||
{
|
||||
"status": "success",
|
||||
"data": [
|
||||
{
|
||||
"date": "2023-12-01",
|
||||
"sales": 150
|
||||
},
|
||||
{
|
||||
"date": "2023-12-02",
|
||||
"sales": 155
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
* `404`: 产品不存在。
|
||||
* `500`: 服务器内部错误。
|
||||
|
||||
---
|
||||
|
||||
### `POST /api/data/upload`
|
||||
|
||||
**摘要:** 上传销售数据
|
||||
|
||||
**描述:** 上传新的销售数据文件(Excel格式)。
|
||||
|
||||
**请求体:** `multipart/form-data`
|
||||
* `file` (binary, required): Excel文件(.xlsx),包含销售数据。
|
||||
|
||||
**响应:**
|
||||
* `200`: 数据上传成功。
|
||||
```json
|
||||
{
|
||||
"status": "success",
|
||||
"message": "数据上传成功",
|
||||
"data": {
|
||||
"products": 10,
|
||||
"rows": 3650
|
||||
}
|
||||
}
|
||||
```
|
||||
* `400`: 请求错误。
|
||||
* `500`: 服务器内部错误。
|
||||
|
||||
---
|
||||
|
||||
### `GET /api/stores`
|
||||
|
||||
**摘要:** 获取所有店铺列表
|
||||
|
||||
**响应:**
|
||||
* `200`: 获取成功。
|
||||
```json
|
||||
{
|
||||
"status": "success",
|
||||
"data": [
|
||||
{
|
||||
"store_id": "S001",
|
||||
"store_name": "第一分店"
|
||||
},
|
||||
{
|
||||
"store_id": "S002",
|
||||
"store_name": "第二分店"
|
||||
}
|
||||
],
|
||||
"count": 2
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### `POST /api/stores`
|
||||
|
||||
**摘要:** 创建新店铺
|
||||
|
||||
**响应:**
|
||||
* `200`: 创建成功。
|
||||
```json
|
||||
{
|
||||
"status": "success",
|
||||
"message": "店铺创建成功",
|
||||
"data": {
|
||||
"store_id": "S003"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### `GET /api/stores/{store_id}`
|
||||
|
||||
**摘要:** 获取单个店铺信息
|
||||
|
||||
**路径参数:**
|
||||
* `store_id` (string, required): 店铺ID。
|
||||
|
||||
**响应:**
|
||||
* `200`: 获取成功。
|
||||
```json
|
||||
{
|
||||
"status": "success",
|
||||
"data": {
|
||||
"store_id": "S001",
|
||||
"store_name": "第一分店",
|
||||
"location": "市中心",
|
||||
"size": 120.5,
|
||||
"type": "旗舰店",
|
||||
"opening_date": "2022-01-01",
|
||||
"status": "active"
|
||||
}
|
||||
}
|
||||
```
|
||||
* `404`: 店铺不存在。
|
||||
|
||||
---
|
||||
|
||||
### `PUT /api/stores/{store_id}`
|
||||
|
||||
**摘要:** 更新店铺信息
|
||||
|
||||
**路径参数:**
|
||||
* `store_id` (string, required): 店铺ID。
|
||||
|
||||
**响应:**
|
||||
* `200`: 更新成功。
|
||||
```json
|
||||
{
|
||||
"status": "success",
|
||||
"message": "店铺更新成功"
|
||||
}
|
||||
```
|
||||
* `404`: 店铺不存在。
|
||||
|
||||
---
|
||||
|
||||
### `DELETE /api/stores/{store_id}`
|
||||
|
||||
**摘要:** 删除店铺
|
||||
|
||||
**路径参数:**
|
||||
* `store_id` (string, required): 店铺ID。
|
||||
|
||||
**响应:**
|
||||
* `200`: 删除成功。
|
||||
```json
|
||||
{
|
||||
"status": "success",
|
||||
"message": "店铺删除成功"
|
||||
}
|
||||
```
|
||||
* `404`: 店铺不存在。
|
||||
|
||||
---
|
||||
|
||||
### `GET /api/stores/{store_id}/products`
|
||||
|
||||
**摘要:** 获取店铺的产品列表
|
||||
|
||||
**路径参数:**
|
||||
* `store_id` (string, required): 店铺ID。
|
||||
|
||||
**响应:**
|
||||
* `200`: 获取成功。
|
||||
```json
|
||||
{
|
||||
"status": "success",
|
||||
"data": [
|
||||
{
|
||||
"product_id": "P001",
|
||||
"product_name": "产品A"
|
||||
}
|
||||
],
|
||||
"count": 1
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### `GET /api/stores/{store_id}/statistics`
|
||||
|
||||
**摘要:** 获取店铺销售统计信息
|
||||
|
||||
**路径参数:**
|
||||
* `store_id` (string, required): 店铺ID。
|
||||
|
||||
**响应:**
|
||||
* `200`: 获取成功。
|
||||
```json
|
||||
{
|
||||
"status": "success",
|
||||
"data": {
|
||||
"total_sales": 150000.0,
|
||||
"total_quantity": 7500,
|
||||
"products_count": 50
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### `GET /api/sales/data`
|
||||
|
||||
**摘要:** 获取销售数据列表
|
||||
|
||||
**响应:**
|
||||
* `200`: 获取成功。
|
||||
```json
|
||||
{
|
||||
"status": "success",
|
||||
"data": [
|
||||
{
|
||||
"date": "2023-12-01",
|
||||
"store_id": "S001",
|
||||
"product_id": "P001",
|
||||
"sales": 150,
|
||||
"price": 25.5
|
||||
}
|
||||
],
|
||||
"total": 100,
|
||||
"page": 1,
|
||||
"page_size": 1
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 模型训练
|
||||
|
||||
模型训练相关接口。
|
||||
|
||||
### `GET /api/training`
|
||||
|
||||
**摘要:** 获取所有训练任务列表
|
||||
|
||||
**描述:** 返回所有正在进行、已完成或失败的训练任务。
|
||||
|
||||
**响应:**
|
||||
* `200`: 成功获取任务列表。
|
||||
```json
|
||||
{
|
||||
"status": "success",
|
||||
"data": [
|
||||
{
|
||||
"task_id": "uuid-1234",
|
||||
"product_id": "P001",
|
||||
"model_type": "mlstm",
|
||||
"status": "completed",
|
||||
"start_time": "2023-12-25T10:00:00Z",
|
||||
"metrics": {"R2": 0.95, "RMSE": 5.5},
|
||||
"error": null,
|
||||
"model_path": "/path/to/model.pth"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### `POST /api/training`
|
||||
|
||||
**摘要:** 启动模型训练任务
|
||||
|
||||
**描述:** 为指定产品启动一个新的模型训练任务。
|
||||
|
||||
**请求体:** `application/json`
|
||||
```json
|
||||
{
|
||||
"product_id": "P001",
|
||||
"model_type": "mlstm",
|
||||
"store_id": "S001",
|
||||
"epochs": 50
|
||||
}
|
||||
```
|
||||
|
||||
**响应:**
|
||||
* `200`: 训练任务已启动。
|
||||
```json
|
||||
{
|
||||
"message": "模型训练已开始",
|
||||
"task_id": "new-uuid-5678"
|
||||
}
|
||||
```
|
||||
* `400`: 请求错误。
|
||||
|
||||
---
|
||||
|
||||
### `GET /api/training/{task_id}`
|
||||
|
||||
**摘要:** 查询训练任务状态
|
||||
|
||||
**描述:** 获取特定训练任务的当前状态和详情。
|
||||
|
||||
**路径参数:**
|
||||
* `task_id` (string, required): 训练任务ID。
|
||||
|
||||
**响应:**
|
||||
* `200`: 成功获取任务状态。
|
||||
```json
|
||||
{
|
||||
"status": "success",
|
||||
"data": {
|
||||
"product_id": "P001",
|
||||
"model_type": "mlstm",
|
||||
"status": "running",
|
||||
"progress": 50.5,
|
||||
"created_at": "2023-12-25T10:00:00Z"
|
||||
}
|
||||
}
|
||||
```
|
||||
* `404`: 任务不存在。
|
||||
* `500`: 服务器内部错误。
|
||||
|
||||
---
|
||||
|
||||
## 模型预测
|
||||
|
||||
预测销售数据相关接口。
|
||||
|
||||
### `POST /api/prediction`
|
||||
|
||||
**摘要:** 使用模型进行预测
|
||||
|
||||
**描述:** 使用指定模型预测未来销售数据。
|
||||
|
||||
**请求体:** `application/json`
|
||||
```json
|
||||
{
|
||||
"product_id": "string",
|
||||
"model_type": "mlstm",
|
||||
"store_id": "string",
|
||||
"version": "string",
|
||||
"future_days": 7,
|
||||
"include_visualization": true,
|
||||
"start_date": "2024-01-01"
|
||||
}
|
||||
```
|
||||
|
||||
**响应:**
|
||||
* `200`: 预测成功。
|
||||
```json
|
||||
{
|
||||
"status": "success",
|
||||
"data": {
|
||||
"product_id": "P001",
|
||||
"product_name": "产品A",
|
||||
"model_type": "mlstm",
|
||||
"predictions": [
|
||||
{"date": "2024-01-01", "predicted_sales": 100},
|
||||
{"date": "2024-01-02", "predicted_sales": 105}
|
||||
],
|
||||
"visualization": "base64-encoded-image-string"
|
||||
}
|
||||
}
|
||||
```
|
||||
* `400`: 请求错误。
|
||||
* `404`: 产品或模型不存在。
|
||||
* `500`: 服务器内部错误。
|
||||
|
||||
---
|
||||
|
||||
### `POST /api/prediction/compare`
|
||||
|
||||
**摘要:** 比较不同模型预测结果
|
||||
|
||||
**描述:** 比较不同模型对同一产品的预测结果。
|
||||
|
||||
**请求体:** `application/json`
|
||||
```json
|
||||
{
|
||||
"product_id": "string",
|
||||
"model_types": ["mlstm", "transformer"],
|
||||
"versions": ["v1", "v2"],
|
||||
"include_visualization": true
|
||||
}
|
||||
```
|
||||
|
||||
**响应:**
|
||||
* `200`: 比较成功。
|
||||
```json
|
||||
{
|
||||
"status": "success",
|
||||
"data": {
|
||||
"product_id": "P001",
|
||||
"comparison": [
|
||||
{"date": "2024-01-01", "mlstm": 100, "transformer": 102},
|
||||
{"date": "2024-01-02", "mlstm": 105, "transformer": 106}
|
||||
],
|
||||
"visualization": "base64-encoded-image-string"
|
||||
}
|
||||
}
|
||||
```
|
||||
* `400`: 请求错误。
|
||||
* `404`: 产品或模型不存在。
|
||||
* `500`: 服务器内部错误。
|
||||
|
||||
---
|
||||
|
||||
### `GET /api/prediction/history`
|
||||
|
||||
**摘要:** 获取历史预测记录
|
||||
|
||||
**响应:**
|
||||
* `200`: 获取成功。
|
||||
```json
|
||||
{
|
||||
"status": "success",
|
||||
"data": [
|
||||
{
|
||||
"prediction_id": "pred-uuid-1",
|
||||
"product_id": "P001",
|
||||
"model_type": "mlstm",
|
||||
"created_at": "2023-12-20T11:00:00Z"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### `GET /api/prediction/history/{prediction_id}`
|
||||
|
||||
**摘要:** 获取特定预测记录的详情
|
||||
|
||||
**路径参数:**
|
||||
* `prediction_id` (string, required): 预测记录ID。
|
||||
|
||||
**响应:**
|
||||
* `200`: 获取成功。
|
||||
```json
|
||||
{
|
||||
"status": "success",
|
||||
"data": {
|
||||
"prediction_id": "pred-uuid-1",
|
||||
"product_id": "P001",
|
||||
"model_type": "mlstm",
|
||||
"predictions": [{"date": "2023-12-21", "predicted_sales": 110}],
|
||||
"analysis": {"trend": "upward"}
|
||||
}
|
||||
}
|
||||
```
|
||||
* `404`: 记录不存在。
|
||||
|
||||
---
|
||||
|
||||
### `DELETE /api/prediction/history/{prediction_id}`
|
||||
|
||||
**摘要:** 删除预测记录
|
||||
|
||||
**路径参数:**
|
||||
* `prediction_id` (string, required): 预测记录ID。
|
||||
|
||||
**响应:**
|
||||
* `200`: 删除成功。
|
||||
```json
|
||||
{
|
||||
"status": "success",
|
||||
"message": "预测记录已删除"
|
||||
}
|
||||
```
|
||||
* `404`: 记录不存在。
|
||||
|
||||
---
|
||||
|
||||
## 模型管理
|
||||
|
||||
模型查询、导出和删除接口。
|
||||
|
||||
### `GET /api/models`
|
||||
|
||||
**摘要:** 获取模型列表
|
||||
|
||||
**查询参数:**
|
||||
* `product_id` (string): 按产品ID筛选。
|
||||
* `model_type` (string): 按模型类型筛选。
|
||||
|
||||
**响应:**
|
||||
* `200`: 获取成功。
|
||||
```json
|
||||
{
|
||||
"status": "success",
|
||||
"data": [
|
||||
{
|
||||
"model_id": "P001_mlstm_v1",
|
||||
"product_id": "P001",
|
||||
"model_type": "mlstm",
|
||||
"version": "v1",
|
||||
"created_at": "2023-12-15T09:00:00Z"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### `GET /api/models/{model_id}`
|
||||
|
||||
**摘要:** 获取模型详情
|
||||
|
||||
**路径参数:**
|
||||
* `model_id` (string, required): 模型ID。
|
||||
|
||||
**响应:**
|
||||
* `200`: 获取成功。
|
||||
```json
|
||||
{
|
||||
"status": "success",
|
||||
"data": {
|
||||
"model_id": "P001_mlstm_v1",
|
||||
"product_id": "P001",
|
||||
"model_type": "mlstm",
|
||||
"version": "v1",
|
||||
"metrics": {"R2": 0.95, "RMSE": 5.5}
|
||||
}
|
||||
}
|
||||
```
|
||||
* `404`: 模型不存在。
|
||||
|
||||
---
|
||||
|
||||
### `DELETE /api/models/{model_id}`
|
||||
|
||||
**摘要:** 删除模型
|
||||
|
||||
**路径参数:**
|
||||
* `model_id` (string, required): 模型ID。
|
||||
|
||||
**响应:**
|
||||
* `200`: 删除成功。
|
||||
```json
|
||||
{
|
||||
"status": "success",
|
||||
"message": "模型已删除"
|
||||
}
|
||||
```
|
||||
* `404`: 模型不存在。
|
||||
|
||||
---
|
||||
|
||||
### `GET /api/models/{model_id}/export`
|
||||
|
||||
**摘要:** 导出模型
|
||||
|
||||
**路径参数:**
|
||||
* `model_id` (string, required): 模型ID。
|
||||
|
||||
**响应:**
|
||||
* `200`: 模型文件下载 (二进制流)。
|
||||
* `404`: 模型不存在。
|
||||
|
||||
---
|
||||
|
||||
### `GET /api/model_types`
|
||||
|
||||
**摘要:** 获取系统支持的所有模型类型
|
||||
|
||||
**响应:**
|
||||
* `200`: 获取成功。
|
||||
```json
|
||||
{
|
||||
"status": "success",
|
||||
"data": [
|
||||
{"id": "mlstm", "name": "mLSTM"},
|
||||
{"id": "transformer", "name": "Transformer"}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### `GET /api/models/{product_id}/{model_type}/versions`
|
||||
|
||||
**摘要:** 获取模型版本列表
|
||||
|
||||
**路径参数:**
|
||||
* `product_id` (string, required): 产品ID。
|
||||
* `model_type` (string, required): 模型类型。
|
||||
|
||||
**响应:**
|
||||
* `200`: 获取成功。
|
||||
```json
|
||||
{
|
||||
"status": "success",
|
||||
"data": {
|
||||
"product_id": "P001",
|
||||
"model_type": "mlstm",
|
||||
"versions": ["v1", "v2"],
|
||||
"latest_version": "v2"
|
||||
}
|
||||
}
|
@ -1,41 +0,0 @@
|
||||
@echo off
|
||||
chcp 65001 >nul 2>&1
|
||||
echo 🚀 启动药店销售预测系统API服务器 (WebSocket修复版)
|
||||
echo.
|
||||
|
||||
:: 设置编码环境变量
|
||||
set PYTHONIOENCODING=utf-8
|
||||
set PYTHONLEGACYWINDOWSSTDIO=0
|
||||
|
||||
:: 显示当前配置
|
||||
echo 📋 当前环境配置:
|
||||
echo 编码: UTF-8
|
||||
echo 路径: %CD%
|
||||
echo Python: uv管理
|
||||
echo.
|
||||
|
||||
:: 检查依赖
|
||||
echo 🔍 检查Python依赖...
|
||||
uv list --quiet >nul 2>&1
|
||||
if errorlevel 1 (
|
||||
echo ⚠️ UV环境未配置,正在初始化...
|
||||
uv sync
|
||||
)
|
||||
|
||||
echo ✅ 依赖检查完成
|
||||
echo.
|
||||
|
||||
:: 启动API服务器
|
||||
echo 🌐 启动API服务器 (WebSocket支持)...
|
||||
echo 💡 访问地址: http://localhost:5000
|
||||
echo 🔗 WebSocket端点: ws://localhost:5000/socket.io
|
||||
echo.
|
||||
echo 📝 启动日志:
|
||||
echo ----------------------------------------
|
||||
|
||||
uv run server/api.py --host 0.0.0.0 --port 5000
|
||||
|
||||
echo.
|
||||
echo ----------------------------------------
|
||||
echo 🛑 API服务器已停止
|
||||
pause
|
11
启动API服务器.bat
11
启动API服务器.bat
@ -1,11 +0,0 @@
|
||||
@echo off
|
||||
chcp 65001 >nul 2>&1
|
||||
set PYTHONIOENCODING=utf-8
|
||||
set PYTHONLEGACYWINDOWSSTDIO=0
|
||||
cd /d %~dp0
|
||||
echo 🚀 启动药店销售预测系统API服务器...
|
||||
echo 📝 编码设置: UTF-8
|
||||
echo 🌐 服务地址: http://127.0.0.1:5000
|
||||
echo.
|
||||
uv run server/api.py
|
||||
pause
|
30
导出依赖配置.bat
30
导出依赖配置.bat
@ -1,30 +0,0 @@
|
||||
@echo off
|
||||
chcp 65001 >nul 2>&1
|
||||
echo 📦 导出UV依赖配置
|
||||
echo.
|
||||
|
||||
:: 设置编码
|
||||
set PYTHONIOENCODING=utf-8
|
||||
|
||||
echo 📋 导出requirements.txt格式...
|
||||
uv export --format requirements-txt > requirements-exported.txt
|
||||
|
||||
echo 📋 导出依赖树状图...
|
||||
uv tree > dependency-tree.txt
|
||||
|
||||
echo 📋 显示当前已安装的包...
|
||||
uv list > installed-packages.txt
|
||||
|
||||
echo 📋 显示uv配置...
|
||||
uv config list > uv-config.txt
|
||||
|
||||
echo.
|
||||
echo ✅ 依赖配置导出完成!
|
||||
echo.
|
||||
echo 📁 生成的文件:
|
||||
echo - requirements-exported.txt (标准requirements格式)
|
||||
echo - dependency-tree.txt (依赖关系树)
|
||||
echo - installed-packages.txt (已安装包列表)
|
||||
echo - uv-config.txt (UV配置信息)
|
||||
echo.
|
||||
pause
|
43
快速安装依赖.bat
43
快速安装依赖.bat
@ -1,43 +0,0 @@
|
||||
@echo off
|
||||
chcp 65001 >nul 2>&1
|
||||
echo 🚀 药店销售预测系统 - 快速安装依赖
|
||||
echo.
|
||||
|
||||
:: 设置编码环境变量
|
||||
set PYTHONIOENCODING=utf-8
|
||||
set PYTHONLEGACYWINDOWSSTDIO=0
|
||||
|
||||
echo 📁 配置UV缓存目录...
|
||||
uv config set cache-dir ".uv_cache"
|
||||
|
||||
echo 🌐 配置镜像源...
|
||||
uv config set global.index-url "https://pypi.tuna.tsinghua.edu.cn/simple"
|
||||
|
||||
echo.
|
||||
echo 📦 安装核心依赖包...
|
||||
echo.
|
||||
|
||||
:: 分批安装,避免超时
|
||||
echo 1/4 安装基础数据处理包...
|
||||
uv add numpy pandas openpyxl
|
||||
|
||||
echo 2/4 安装机器学习包...
|
||||
uv add scikit-learn matplotlib tqdm
|
||||
|
||||
echo 3/4 安装Web框架包...
|
||||
uv add flask flask-cors flask-socketio flasgger werkzeug
|
||||
|
||||
echo 4/4 安装深度学习框架...
|
||||
uv add torch torchvision --index-url https://download.pytorch.org/whl/cpu
|
||||
|
||||
echo.
|
||||
echo ✅ 核心依赖安装完成!
|
||||
echo.
|
||||
echo 🔍 检查安装状态...
|
||||
uv list
|
||||
|
||||
echo.
|
||||
echo 🎉 依赖安装完成!可以启动系统了
|
||||
echo 💡 启动命令: uv run server/api.py
|
||||
echo.
|
||||
pause
|
150
新需求开发流程.md
Normal file
150
新需求开发流程.md
Normal file
@ -0,0 +1,150 @@
|
||||
# 新需求开发标准流程
|
||||
|
||||
本文档旨在提供一个标准、安全、高效的新功能开发工作流,涵盖从创建功能分支到最终合并回主开发分支的完整步骤,并融入日常开发的最佳实践。
|
||||
|
||||
## 核心开发理念
|
||||
|
||||
- **主分支保护**: `lyf-dev` 是团队的主开发分支,应始终保持稳定和可部署状态。所有新功能开发都必须在独立的功能分支中进行。
|
||||
- **功能分支**: 每个新需求(如 `req0001`)都对应一个功能分支(如 `lyf-dev-req0001`)。分支命名应清晰、有意义。
|
||||
- **小步快跑**: 频繁提交(Commit)、频繁推送(Push)、频繁与主线同步(`rebase` 或 `merge`)。这能有效减少后期合并的难度和风险。
|
||||
- **清晰的历史**: 保持 Git 提交历史的可读性,方便代码审查(Code Review)和问题追溯。
|
||||
|
||||
---
|
||||
|
||||
## 每日工作第一步:同步最新代码
|
||||
|
||||
**无论你昨天工作到哪里,每天开始新一天的工作时,请务必执行以下步骤。这是保证团队高效协作、避免合并冲突的基石。**
|
||||
|
||||
1. **更新主开发分支 `lyf-dev`**
|
||||
```bash
|
||||
# 切换到主开发分支
|
||||
git checkout lyf-dev
|
||||
|
||||
# 从远程拉取最新代码,--prune 会清理远程已删除的分支引用
|
||||
git pull origin lyf-dev --prune
|
||||
```
|
||||
|
||||
2. **同步你的功能分支 (团队选择一种方案)**
|
||||
将主分支的最新代码同步到你的功能分支,有两种主流方案,请团队根据偏好选择其一。
|
||||
|
||||
---
|
||||
### 方案一 (推荐): 使用 `rebase` 保持历史清爽
|
||||
|
||||
此方案会让你的分支提交历史保持为一条直线,非常清晰。
|
||||
|
||||
```bash
|
||||
# 切换回你正在开发的功能分支(例如 lyf-dev-req0001)
|
||||
git checkout lyf-dev-req0001
|
||||
|
||||
# 使用 rebase 将 lyf-dev 的最新更新同步到你的分支
|
||||
git rebase lyf-dev
|
||||
```
|
||||
- **优点**: 最终的提交历史非常干净、线性,便于代码审查和问题追溯。
|
||||
- **缺点**: 重写了提交历史,需要使用 `git push --force-with-lease` 强制推送。
|
||||
- **冲突解决**:
|
||||
1. 手动修改冲突文件。
|
||||
2. 执行 `git add <冲突文件>`。
|
||||
3. 执行 `git rebase --continue`。
|
||||
4. 若想中止,执行 `git rebase --abort`。
|
||||
|
||||
---
|
||||
### 方案二: 使用 `merge` 保留完整历史
|
||||
|
||||
此方案会忠实记录每一次合并操作,不修改历史提交。
|
||||
|
||||
```bash
|
||||
# 切换回你正在开发的功能分支(例如 lyf-dev-req0001)
|
||||
git checkout lyf-dev-req0001
|
||||
|
||||
# 将最新的 lyf-dev 合并到你当前的分支
|
||||
git merge lyf-dev
|
||||
```
|
||||
- **优点**: 操作安全,不修改历史,推送时无需强制。
|
||||
- **缺点**: 会在功能分支中产生额外的合并提交记录 (e.g., "Merge branch 'lyf-dev' into ..."),使历史记录变得复杂。
|
||||
- **冲突解决**:
|
||||
1. 手动修改冲突文件。
|
||||
2. 执行 `git add <冲突文件>`。
|
||||
3. 执行 `git commit` 完成合并。
|
||||
|
||||
---
|
||||
|
||||
## 完整开发流程
|
||||
|
||||
### 1. 开始新需求:创建功能分支
|
||||
|
||||
**当你需要开启一个全新的功能开发时:**
|
||||
|
||||
1. **确保 `lyf-dev` 已是最新**
|
||||
(此步骤已在“每日工作第一步”中完成,此处作为提醒)
|
||||
|
||||
2. **从 `lyf-dev` 创建并切换到新分支**
|
||||
假设新需求编号是 `req0002`:
|
||||
```bash
|
||||
# 这会从最新的 lyf-dev 创建 lyf-dev-req0002 分支并切换过去
|
||||
git checkout -b lyf-dev-req0002
|
||||
```
|
||||
|
||||
### 2. 日常开发:提交与推送
|
||||
|
||||
**在你的功能分支上(如 `lyf-dev-req0002`)进行开发:**
|
||||
|
||||
1. **编码与本地提交**
|
||||
完成一个小的、完整的功能点后,就进行一次提交。
|
||||
```bash
|
||||
# 查看修改状态
|
||||
git status
|
||||
# 添加所有相关文件到暂存区
|
||||
git add .
|
||||
# 提交并撰写清晰的说明(feat: 功能, fix: 修复, docs: 文档等)
|
||||
git commit -m "feat: 实现用户认证模块"
|
||||
```
|
||||
|
||||
2. **推送改动到远程备份**
|
||||
为了代码安全和方便团队协作,应频繁将本地提交推送到远程。
|
||||
```bash
|
||||
# -u 参数会设置本地分支跟踪远程分支,后续只需 git push 即可
|
||||
git push -u origin lyf-dev-req0002
|
||||
```
|
||||
|
||||
### 3. 功能完成:合并回主线
|
||||
|
||||
**当功能开发完成并通过测试后,将其合并回 `lyf-dev`:**
|
||||
|
||||
1. **最后一次同步**
|
||||
在正式合并前,做最后一次同步,确保分支包含了 `lyf-dev` 的所有最新内容。
|
||||
(重复“每日工作第一步”中的同步流程)
|
||||
|
||||
2. **切换到主分支并拉取最新代码**
|
||||
```bash
|
||||
git checkout lyf-dev
|
||||
git pull origin lyf-dev
|
||||
```
|
||||
|
||||
3. **合并功能分支**
|
||||
我们使用 `--no-ff` (No Fast-forward) 参数来创建合并提交,这样可以清晰地记录“合并了一个功能”这个行为。
|
||||
```bash
|
||||
# --no-ff 会创建一个新的合并提交,保留分支历史
|
||||
git merge --no-ff lyf-dev-req0002
|
||||
```
|
||||
如果同步工作做得好,这一步通常不会有冲突。
|
||||
|
||||
4. **推送合并后的主分支**
|
||||
```bash
|
||||
git push origin lyf-dev
|
||||
```
|
||||
|
||||
### 4. 清理工作
|
||||
|
||||
**合并完成后,功能分支的历史使命就完成了:**
|
||||
|
||||
1. **删除远程分支**
|
||||
```bash
|
||||
git push origin --delete lyf-dev-req0002
|
||||
```
|
||||
|
||||
2. **删除本地分支**
|
||||
```bash
|
||||
git branch -d lyf-dev-req0002
|
||||
```
|
||||
|
||||
遵循以上流程,可以确保团队的开发工作流清晰、安全且高效。
|
466
系统调用逻辑与核心代码分析.md
Normal file
466
系统调用逻辑与核心代码分析.md
Normal file
@ -0,0 +1,466 @@
|
||||
# 系统调用逻辑与核心代码分析
|
||||
|
||||
本文档旨在详细阐述本销售预测系统的端到端调用链路,从系统启动、前端交互、后端处理,到最终的模型训练、预测和图表展示。
|
||||
|
||||
## 1. 系统启动
|
||||
|
||||
系统由两部分组成:Vue.js前端和Flask后端。
|
||||
|
||||
### 1.1. 启动后端API服务
|
||||
|
||||
在项目根目录下,通过以下命令启动后端服务:
|
||||
|
||||
```bash
|
||||
python server/api.py
|
||||
```
|
||||
|
||||
该命令会启动一个Flask应用,监听在 `http://localhost:5000`,并提供所有API和WebSocket服务。
|
||||
|
||||
### 1.2. 启动前端开发服务器
|
||||
|
||||
进入 `UI` 目录,执行以下命令:
|
||||
|
||||
```bash
|
||||
cd UI
|
||||
npm install
|
||||
npm run dev
|
||||
```
|
||||
|
||||
这将启动Vite开发服务器,通常在 `http://localhost:5173`,并自动打开浏览器访问前端页面。
|
||||
|
||||
## 2. 核心调用链路概览
|
||||
|
||||
以最核心的 **“按药品训练 -> 按药品预测”** 流程为例,其高层调用链路如下:
|
||||
|
||||
**训练流程:**
|
||||
`前端UI` -> `POST /api/training` -> `api.py: start_training()` -> `TrainingManager` -> `后台进程` -> `predictor.py: train_model()` -> `[model]_trainer.py: train_product_model_with_*()` -> `保存模型.pth`
|
||||
|
||||
**预测流程:**
|
||||
`前端UI` -> `POST /api/prediction` -> `api.py: predict()` -> `predictor.py: predict()` -> `model_predictor.py: load_model_and_predict()` -> `加载模型.pth` -> `返回预测JSON` -> `前端图表渲染`
|
||||
|
||||
## 3. 详细流程:按药品训练
|
||||
|
||||
此流程的目标是为特定药品训练一个专用的预测模型。
|
||||
|
||||
### 3.1. 前端交互与API请求
|
||||
|
||||
1. **用户操作**: 用户在 **“按药品训练”** 页面 ([`UI/src/views/training/ProductTrainingView.vue`](UI/src/views/training/ProductTrainingView.vue:1)) 选择一个药品、一个模型类型(如Transformer)、设置训练轮次(Epochs),然后点击 **“启动药品训练”** 按钮。
|
||||
|
||||
2. **触发函数**: 点击事件调用 [`startTraining`](UI/src/views/training/ProductTrainingView.vue:521) 方法。
|
||||
|
||||
3. **构建Payload**: `startTraining` 方法构建一个包含训练参数的 `payload` 对象。关键字段是 `training_mode: 'product'`,用于告知后端这是针对特定产品的训练。
|
||||
|
||||
*核心代码 ([`UI/src/views/training/ProductTrainingView.vue`](UI/src/views/training/ProductTrainingView.vue:521))*
|
||||
```javascript
|
||||
const startTraining = async () => {
|
||||
// ... 表单验证 ...
|
||||
trainingLoading.value = true;
|
||||
try {
|
||||
const endpoint = "/api/training";
|
||||
|
||||
const payload = {
|
||||
product_id: form.product_id,
|
||||
store_id: form.data_scope === 'global' ? null : form.store_id,
|
||||
model_type: form.model_type,
|
||||
epochs: form.epochs,
|
||||
training_mode: 'product' // 标识这是药品训练模式
|
||||
};
|
||||
|
||||
const response = await axios.post(endpoint, payload);
|
||||
// ... 处理响应,启动WebSocket监听 ...
|
||||
}
|
||||
// ... 错误处理 ...
|
||||
};
|
||||
```
|
||||
|
||||
4. **API请求**: 使用 `axios` 向后端 `POST /api/training` 发送请求。
|
||||
|
||||
### 3.2. 后端API接收与任务分发
|
||||
|
||||
1. **路由处理**: 后端 [`server/api.py`](server/api.py:1) 中的 [`@app.route('/api/training', methods=['POST'])`](server/api.py:933) 装饰器捕获该请求,并由 [`start_training()`](server/api.py:971) 函数处理。
|
||||
|
||||
2. **任务提交**: `start_training()` 函数解析请求中的JSON数据,然后调用 `training_manager.submit_task()` 将训练任务提交到一个后台进程池中执行,以避免阻塞API主线程。这使得API可以立即返回一个任务ID,而训练在后台异步进行。
|
||||
|
||||
*核心代码 ([`server/api.py`](server/api.py:971))*
|
||||
```python
|
||||
@app.route('/api/training', methods=['POST'])
|
||||
def start_training():
|
||||
data = request.get_json()
|
||||
|
||||
training_mode = data.get('training_mode', 'product')
|
||||
model_type = data.get('model_type')
|
||||
epochs = data.get('epochs', 50)
|
||||
product_id = data.get('product_id')
|
||||
store_id = data.get('store_id')
|
||||
|
||||
if not model_type or (training_mode == 'product' and not product_id):
|
||||
return jsonify({'error': '缺少必要参数'}), 400
|
||||
|
||||
try:
|
||||
# 使用训练进程管理器提交任务
|
||||
task_id = training_manager.submit_task(
|
||||
product_id=product_id or "unknown",
|
||||
model_type=model_type,
|
||||
training_mode=training_mode,
|
||||
store_id=store_id,
|
||||
epochs=epochs
|
||||
)
|
||||
|
||||
logger.info(f"🚀 训练任务已提交到进程管理器: {task_id[:8]}")
|
||||
|
||||
return jsonify({
|
||||
'message': '模型训练已开始(使用独立进程)',
|
||||
'task_id': task_id,
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 提交训练任务失败: {str(e)}")
|
||||
return jsonify({'error': f'启动训练任务失败: {str(e)}'}), 500
|
||||
```
|
||||
|
||||
### 3.3. 核心训练逻辑
|
||||
|
||||
1. **调用核心预测器**: 后台进程最终会调用 [`server/core/predictor.py`](server/core/predictor.py:1) 中的 [`PharmacyPredictor.train_model()`](server/core/predictor.py:63) 方法。
|
||||
|
||||
2. **数据准备**: `train_model` 方法首先根据 `training_mode` (`'product'`) 和 `product_id` 从数据源加载并聚合所有店铺关于该药品的销售数据。
|
||||
|
||||
3. **分发到具体训练器**: 接着,它根据 `model_type` 调用相应的训练函数。例如,如果 `model_type` 是 `transformer`,它会调用 `train_product_model_with_transformer`。
|
||||
|
||||
*核心代码 ([`server/core/predictor.py`](server/core/predictor.py:63))*
|
||||
```python
|
||||
class PharmacyPredictor:
|
||||
def train_model(self, product_id, model_type='transformer', ..., training_mode='product', ...):
|
||||
# ...
|
||||
if training_mode == 'product':
|
||||
product_data = self.data[self.data['product_id'] == product_id].copy()
|
||||
# ...
|
||||
|
||||
# 根据训练模式构建模型标识符
|
||||
model_identifier = product_id
|
||||
|
||||
try:
|
||||
if model_type == 'transformer':
|
||||
model_result, metrics, actual_version = train_product_model_with_transformer(
|
||||
product_id=product_id,
|
||||
model_identifier=model_identifier,
|
||||
product_df=product_data,
|
||||
# ... 其他参数 ...
|
||||
)
|
||||
# ... 其他模型的elif分支 ...
|
||||
|
||||
return metrics
|
||||
except Exception as e:
|
||||
# ... 错误处理 ...
|
||||
return None
|
||||
```
|
||||
|
||||
### 3.4. 模型训练与保存
|
||||
|
||||
1. **具体训练器**: 以 [`server/trainers/transformer_trainer.py`](server/trainers/transformer_trainer.py:1) 为例,`train_product_model_with_transformer` 函数执行以下步骤:
|
||||
* **数据预处理**: 调用 `prepare_data` 和 `prepare_sequences` 将原始销售数据转换为模型可以理解的、带有时间序列特征的监督学习格式(输入序列和目标序列)。
|
||||
* **模型实例化**: 创建 `TimeSeriesTransformer` 模型实例。
|
||||
* **训练循环**: 执行指定的 `epochs` 次训练,计算损失并使用优化器更新模型权重。
|
||||
* **进度更新**: 在训练过程中,通过 `socketio.emit` 向前端发送 `training_progress` 事件,实时更新进度条和日志。
|
||||
* **模型保存**: 训练完成后,将模型权重 (`model.state_dict()`)、完整的模型配置 (`config`) 以及数据缩放器 (`scaler_X`, `scaler_y`) 打包成一个字典(checkpoint),并使用 `torch.save()` 保存到 `.pth` 文件中。文件名由 `get_model_file_path` 根据 `model_identifier`、`model_type` 和 `version` 统一生成。
|
||||
|
||||
*核心代码 ([`server/trainers/transformer_trainer.py`](server/trainers/transformer_trainer.py:33))*
|
||||
```python
|
||||
def train_product_model_with_transformer(...):
|
||||
# ... 数据准备 ...
|
||||
|
||||
# 定义模型配置
|
||||
config = {
|
||||
'input_dim': input_dim,
|
||||
'output_dim': forecast_horizon,
|
||||
'hidden_size': hidden_size,
|
||||
# ... 所有必要的超参数 ...
|
||||
'model_type': 'transformer'
|
||||
}
|
||||
|
||||
model = TimeSeriesTransformer(...)
|
||||
|
||||
# ... 训练循环 ...
|
||||
|
||||
# 保存模型
|
||||
checkpoint = {
|
||||
'model_state_dict': model.state_dict(),
|
||||
'config': config,
|
||||
'scaler_X': scaler_X,
|
||||
'scaler_y': scaler_y,
|
||||
'metrics': test_metrics
|
||||
}
|
||||
|
||||
model_path = get_model_file_path(model_identifier, 'transformer', version)
|
||||
torch.save(checkpoint, model_path)
|
||||
|
||||
return model, test_metrics, version
|
||||
```
|
||||
|
||||
## 4. 详细流程:按药品预测
|
||||
|
||||
训练完成后,用户可以使用已保存的模型进行预测。
|
||||
|
||||
### 4.1. 前端交互与API请求
|
||||
|
||||
1. **用户操作**: 用户在 **“按药品预测”** 页面 ([`UI/src/views/prediction/ProductPredictionView.vue`](UI/src/views/prediction/ProductPredictionView.vue:1)) 选择同一个药品、对应的模型和版本,然后点击 **“开始预测”**。
|
||||
|
||||
2. **触发函数**: 点击事件调用 [`startPrediction`](UI/src/views/prediction/ProductPredictionView.vue:202) 方法。
|
||||
|
||||
3. **构建Payload**: 该方法构建一个包含预测参数的 `payload`。
|
||||
|
||||
*核心代码 ([`UI/src/views/prediction/ProductPredictionView.vue`](UI/src/views/prediction/ProductPredictionView.vue:202))*
|
||||
```javascript
|
||||
const startPrediction = async () => {
|
||||
try {
|
||||
predicting.value = true
|
||||
const payload = {
|
||||
product_id: form.product_id,
|
||||
model_type: form.model_type,
|
||||
version: form.version,
|
||||
future_days: form.future_days,
|
||||
// training_mode is implicitly 'product' here
|
||||
}
|
||||
const response = await axios.post('/api/prediction', payload)
|
||||
if (response.data.status === 'success') {
|
||||
predictionResult.value = response.data
|
||||
await nextTick()
|
||||
renderChart()
|
||||
}
|
||||
// ... 错误处理 ...
|
||||
}
|
||||
// ...
|
||||
}
|
||||
```
|
||||
|
||||
4. **API请求**: 使用 `axios` 向后端 `POST /api/prediction` 发送请求。
|
||||
|
||||
### 4.2. 后端API接收与预测执行
|
||||
|
||||
1. **路由处理**: [`server/api.py`](server/api.py:1) 中的 [`@app.route('/api/prediction', methods=['POST'])`](server/api.py:1413) 捕获请求,由 [`predict()`](server/api.py:1469) 函数处理。
|
||||
|
||||
2. **调用核心预测器**: `predict()` 函数解析参数,然后调用 `run_prediction` 辅助函数,该函数内部再调用 [`server/core/predictor.py`](server/core/predictor.py:1) 中的 [`PharmacyPredictor.predict()`](server/core/predictor.py:295) 方法。
|
||||
|
||||
*核心代码 ([`server/api.py`](server/api.py:1469))*
|
||||
```python
|
||||
@app.route('/api/prediction', methods=['POST'])
|
||||
def predict():
|
||||
try:
|
||||
data = request.json
|
||||
# ... 解析参数 ...
|
||||
training_mode = data.get('training_mode', 'product')
|
||||
product_id = data.get('product_id')
|
||||
# ...
|
||||
|
||||
# 根据模式确定模型标识符
|
||||
if training_mode == 'product':
|
||||
model_identifier = product_id
|
||||
# ...
|
||||
|
||||
# 执行预测
|
||||
prediction_result = run_prediction(model_type, product_id, model_id, ...)
|
||||
|
||||
# ... 格式化响应 ...
|
||||
return jsonify(response_data)
|
||||
except Exception as e:
|
||||
# ... 错误处理 ...
|
||||
```
|
||||
|
||||
3. **分发到模型加载器**: [`PharmacyPredictor.predict()`](server/core/predictor.py:295) 方法的主要作用是再次根据 `training_mode` 和 `product_id` 确定 `model_identifier`,然后将所有参数传递给 [`server/predictors/model_predictor.py`](server/predictors/model_predictor.py:1) 中的 [`load_model_and_predict()`](server/predictors/model_predictor.py:26) 函数。
|
||||
|
||||
*核心代码 ([`server/core/predictor.py`](server/core/predictor.py:295))*
|
||||
```python
|
||||
class PharmacyPredictor:
|
||||
def predict(self, product_id, model_type, ..., training_mode='product', ...):
|
||||
if training_mode == 'product':
|
||||
model_identifier = product_id
|
||||
# ...
|
||||
|
||||
return load_model_and_predict(
|
||||
model_identifier,
|
||||
model_type,
|
||||
# ... 其他参数 ...
|
||||
)
|
||||
```
|
||||
|
||||
### 4.3. 模型加载与执行预测
|
||||
|
||||
[`load_model_and_predict()`](server/predictors/model_predictor.py:26) 是预测流程的核心,它执行以下步骤:
|
||||
|
||||
1. **定位模型文件**: 使用 `get_model_file_path` 根据 `product_id` (即 `model_identifier`), `model_type`, 和 `version` 找到之前保存的 `.pth` 模型文件。
|
||||
|
||||
2. **加载Checkpoint**: 使用 `torch.load()` 加载模型文件,得到包含 `model_state_dict`, `config`, 和 `scalers` 的字典。
|
||||
|
||||
3. **重建模型**: 根据加载的 `config` 中的超参数(如 `hidden_size`, `num_layers` 等),重新创建一个与训练时结构完全相同的模型实例。**这是我们之前修复的关键点,确保所有必要参数都被保存和加载。**
|
||||
|
||||
4. **加载权重**: 将加载的 `model_state_dict` 应用到新创建的模型实例上。
|
||||
|
||||
5. **准备输入数据**: 从数据源获取最新的 `sequence_length` 天的历史数据作为预测的输入。
|
||||
|
||||
6. **数据归一化**: 使用加载的 `scaler_X` 对输入数据进行归一化。
|
||||
|
||||
7. **执行预测**: 将归一化的数据输入模型 (`model(X_input)`),得到预测结果。
|
||||
|
||||
8. **反归一化**: 使用加载的 `scaler_y` 将模型的输出(预测值)反归一化,转换回原始的销售量尺度。
|
||||
|
||||
9. **构建结果**: 将预测值和对应的未来日期组合成一个DataFrame,并连同历史数据一起返回。
|
||||
|
||||
*核心代码 ([`server/predictors/model_predictor.py`](server/predictors/model_predictor.py:26))*
|
||||
```python
|
||||
def load_model_and_predict(...):
|
||||
# ... 找到模型文件路径 model_path ...
|
||||
|
||||
# 加载模型和配置
|
||||
checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False)
|
||||
config = checkpoint['config']
|
||||
scaler_X = checkpoint['scaler_X']
|
||||
scaler_y = checkpoint['scaler_y']
|
||||
|
||||
# 创建模型实例 (以Transformer为例)
|
||||
model = TimeSeriesTransformer(
|
||||
num_features=config['input_dim'],
|
||||
d_model=config['hidden_size'],
|
||||
# ... 使用config中的所有参数 ...
|
||||
).to(DEVICE)
|
||||
|
||||
# 加载模型参数
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
model.eval()
|
||||
|
||||
# ... 准备输入数据 ...
|
||||
|
||||
# 归一化输入数据
|
||||
X_scaled = scaler_X.transform(X)
|
||||
X_input = torch.tensor(X_scaled.reshape(1, sequence_length, -1), ...).to(DEVICE)
|
||||
|
||||
# 预测
|
||||
with torch.no_grad():
|
||||
y_pred_scaled = model(X_input).cpu().numpy()
|
||||
|
||||
# 反归一化
|
||||
y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).flatten()
|
||||
|
||||
# ... 构建返回结果 ...
|
||||
return {
|
||||
'predictions': predictions_df,
|
||||
'history_data': recent_history,
|
||||
# ...
|
||||
}
|
||||
```
|
||||
|
||||
### 4.4. 响应格式化与前端图表渲染
|
||||
|
||||
1. **API层格式化**: 在 [`server/api.py`](server/api.py:1) 的 [`predict()`](server/api.py:1469) 函数中,从 `load_model_and_predict` 返回的结果被精心格式化成前端期望的JSON结构,该结构在顶层同时包含 `history_data` 和 `prediction_data` 两个数组。
|
||||
|
||||
2. **前端接收数据**: 前端 [`ProductPredictionView.vue`](UI/src/views/prediction/ProductPredictionView.vue:1) 在 `startPrediction` 方法中接收到这个JSON响应,并将其存入 `predictionResult` ref。
|
||||
|
||||
3. **图表渲染**: [`renderChart()`](UI/src/views/prediction/ProductPredictionView.vue:232) 方法被调用。它从 `predictionResult.value` 中提取 `history_data` 和 `prediction_data`,然后使用Chart.js库将这两部分数据绘制在同一个 `<canvas>` 上,历史数据为实线,预测数据为虚线,从而形成一个连续的趋势图。
|
||||
|
||||
*核心代码 ([`UI/src/views/prediction/ProductPredictionView.vue`](UI/src/views/prediction/ProductPredictionView.vue:232))*
|
||||
```javascript
|
||||
const renderChart = () => {
|
||||
if (!chartCanvas.value || !predictionResult.value) return
|
||||
// ...
|
||||
|
||||
// 后端直接提供 history_data 和 prediction_data
|
||||
const historyData = predictionResult.value.history_data || []
|
||||
const predictionData = predictionResult.value.prediction_data || []
|
||||
|
||||
const historyLabels = historyData.map(p => p.date)
|
||||
const historySales = historyData.map(p => p.sales)
|
||||
|
||||
const predictionLabels = predictionData.map(p => p.date)
|
||||
const predictionSales = predictionData.map(p => p.predicted_sales)
|
||||
|
||||
// ... 组合标签和数据,对齐数据点 ...
|
||||
|
||||
chart = new Chart(chartCanvas.value, {
|
||||
type: 'line',
|
||||
data: {
|
||||
labels: allLabels,
|
||||
datasets: [
|
||||
{
|
||||
label: '历史销量',
|
||||
data: alignedHistorySales,
|
||||
// ... 样式 ...
|
||||
},
|
||||
{
|
||||
label: '预测销量',
|
||||
data: alignedPredictionSales,
|
||||
// ... 样式 ...
|
||||
}
|
||||
]
|
||||
},
|
||||
// ... Chart.js 配置 ...
|
||||
})
|
||||
}
|
||||
```
|
||||
|
||||
至此,一个完整的“训练->预测->展示”的调用链路就完成了。
|
||||
|
||||
## 5. 模型保存与版本管理核心逻辑 (重构后)
|
||||
|
||||
为了根治版本混乱和模型加载失败的问题,系统进行了一项重要的重构。现在,所有与模型保存、命名和版本管理相关的逻辑都已**统一集中**到 [`server/utils/model_manager.py`](server/utils/model_manager.py:1) 的 `ModelManager` 类中。
|
||||
|
||||
### 5.1. 统一管理者:`ModelManager`
|
||||
|
||||
- **单一职责**: `ModelManager` 是系统中唯一负责处理模型文件IO的组件。所有训练器 (`trainer`) 在需要保存模型时,都必须通过它来进行。
|
||||
- **核心功能**:
|
||||
1. **自动版本控制**: 自动生成和递增符合规范的版本号。
|
||||
2. **统一命名**: 根据模型的元数据(算法类型、训练模式、ID等)生成标准化的文件名。
|
||||
3. **安全保存**: 将模型数据和元数据一起打包保存到 `.pth` 文件中。
|
||||
4. **可靠检索**: 提供统一的接口来列出和查找模型。
|
||||
|
||||
### 5.2. 统一版本规范
|
||||
|
||||
所有模型版本现在都遵循一个严格的、可预测的格式:
|
||||
|
||||
- **数字版本**: `v{数字}`,例如 `v1`, `v2`, `v3`...
|
||||
- **生成**: 当一次训练**正常完成**时,`ModelManager` 会自动计算出当前模型的下一个可用版本号(例如,如果已存在 `v1` 和 `v2`,则新版本为 `v3`),并以此命名最终的模型文件。
|
||||
- **用途**: 代表一次完整的、稳定的训练产出。
|
||||
- **特殊版本**: `best`
|
||||
- **生成**: 在训练过程中,如果某个 `epoch` 产生的模型在验证集上的性能超过了之前所有 `epoch`,训练器会调用 `ModelManager` 将这个模型保存为 `best` 版本,覆盖掉旧的 `best` 模型。
|
||||
- **用途**: 始终指向该模型迄今为止性能最佳的一个版本,便于快速进行高质量的预测。
|
||||
|
||||
### 5.3. 统一命名约定 (v2版)
|
||||
|
||||
随着系统增加了“按店铺”和“全局”训练模式,`ModelManager` 的 `generate_model_filename` 方法也已升级,以支持更丰富的、无歧义的命名格式:
|
||||
|
||||
- **药品模型**: `{model_type}_product_{product_id}_{version}.pth`
|
||||
- *示例*: `transformer_product_17002608_best.pth`
|
||||
- **店铺模型**: `{model_type}_store_{store_id}_{version}.pth`
|
||||
- *示例*: `mlstm_store_01010023_v2.pth`
|
||||
- **全局模型**: `{model_type}_global_{aggregation_method}_{version}.pth`
|
||||
- *示例*: `tcn_global_sum_v1.pth`
|
||||
|
||||
这个新的命名系统确保了不同训练模式产出的模型可以清晰地被识别和管理。
|
||||
|
||||
### 5.4. Checkpoint文件内容 (结构不变)
|
||||
|
||||
每个 `.pth` 文件依然是一个包含模型权重、完整配置和数据缩放器的PyTorch Checkpoint。重构加强了**所有训练器都必须将完整的配置信息存入 `config` 字典**这一规则,确保了模型的完全可复现性。
|
||||
|
||||
### 5.5. 核心优势 (重构后)
|
||||
|
||||
- **逻辑集中**: 所有版本管理的复杂性都被封装在 `ModelManager` 内部,训练器只需调用 `save_model` 即可,无需关心版本号如何生成。
|
||||
- **数据一致性**: 由于版本的生成、保存和检索都由同一个组件以同一种逻辑处理,从根本上杜绝了因命名或版本格式不匹配导致“模型未找到”的问题。
|
||||
- **易于维护**: 未来如果需要修改版本策略或命名规则,只需修改 `ModelManager` 一个文件即可,无需改动所有训练器。
|
||||
|
||||
## 6. 核心流程的演进:支持店铺与全局模式
|
||||
|
||||
在最初的“按药品”流程基础上,系统已重构以支持“按店铺”和“全局”的完整AI闭环。这引入了一些关键的逻辑变化:
|
||||
|
||||
### 6.1. 训练流程的变化
|
||||
|
||||
- **统一入口**: 所有训练请求(药品、店铺、全局)都通过 `POST /api/training` 接口,由 `training_mode` 参数区分。
|
||||
- **数据聚合**: 在 [`predictor.py`](server/core/predictor.py:1) 的 `train_model` 方法中,会根据 `training_mode` 调用 `aggregate_multi_store_data` 函数,为店铺或全局模式准备正确的聚合时间序列数据。
|
||||
- **模型标识符**: `train_model` 方法现在会生成一个唯一的 `model_identifier`(例如 `product_17002608`, `store_01010023`, `global_sum`),并将其传递给所有下游训练器。这是确保模型被正确命名的关键。
|
||||
|
||||
### 6.2. 预测流程的重大修复
|
||||
|
||||
预测流程经过了重大修复,以解决之前因逻辑不统一导致的 `404` 错误。
|
||||
|
||||
- **废弃旧函数**: `core/config.py` 中的 `get_model_file_path` 和 `get_model_versions` 等旧的、有缺陷的辅助函数已被**完全废弃**。
|
||||
- **统一查找逻辑**: 现在,[`api.py`](server/api.py:1) 的 `predict` 函数**必须**使用 `model_manager.list_models()` 方法来查找模型。
|
||||
- **可靠的路径传递**: `predict` 函数找到正确的模型文件路径后,会将其作为一个参数,一路传递给 `run_prediction` 和最终的 `load_model_and_predict` 函数。
|
||||
- **根除缺陷**: `load_model_and_predict` 函数内部所有手动的、过时的文件查找逻辑已被**完全移除**。它现在只负责接收一个明确的路径并加载模型。
|
||||
|
||||
这个修复确保了整个预测链路都依赖于 `ModelManager` 这一个“单一事实来源”,从根本上解决了因路径不匹配导致的预测失败问题。
|
43
配置UV环境.bat
43
配置UV环境.bat
@ -1,43 +0,0 @@
|
||||
@echo off
|
||||
chcp 65001 >nul 2>&1
|
||||
echo 🔧 配置药店销售预测系统UV环境...
|
||||
echo.
|
||||
|
||||
:: 设置编码环境变量
|
||||
set PYTHONIOENCODING=utf-8
|
||||
set PYTHONLEGACYWINDOWSSTDIO=0
|
||||
|
||||
:: 设置缓存目录
|
||||
echo 📁 设置UV缓存目录...
|
||||
uv config set cache-dir "H:\_Workings\_OneTree\_ShopTRAINING\.uv_cache"
|
||||
|
||||
:: 设置镜像源
|
||||
echo 🌐 配置国内镜像源...
|
||||
uv config set global.index-url "https://pypi.tuna.tsinghua.edu.cn/simple"
|
||||
|
||||
:: 设置信任主机
|
||||
echo 🔒 配置信任主机...
|
||||
uv config set global.trusted-host "pypi.tuna.tsinghua.edu.cn"
|
||||
|
||||
echo.
|
||||
echo ✅ UV环境配置完成
|
||||
echo 📋 当前配置:
|
||||
uv config list
|
||||
|
||||
echo.
|
||||
echo 🚀 初始化项目并同步依赖...
|
||||
uv sync
|
||||
|
||||
echo.
|
||||
echo 📦 安装完成,检查依赖状态...
|
||||
uv tree
|
||||
|
||||
echo.
|
||||
echo 🎉 环境配置和依赖同步完成!
|
||||
echo.
|
||||
echo 💡 使用方法:
|
||||
echo 启动API服务器: uv run server/api.py
|
||||
echo 运行测试: uv run pytest
|
||||
echo 格式化代码: uv run black server/
|
||||
echo.
|
||||
pause
|
127
项目快速上手指南.md
Normal file
127
项目快速上手指南.md
Normal file
@ -0,0 +1,127 @@
|
||||
# 项目快速上手指南 (面向新开发者)
|
||||
|
||||
欢迎加入项目!本指南旨在帮助你快速理解项目的核心功能、技术架构和开发流程,特别是为你(一位Java背景的开发者)提供清晰的切入点。
|
||||
|
||||
## 1. 项目是做什么的?(实现了什么功能)
|
||||
|
||||
这是一个基于历史销售数据的 **智能销售预测系统**。
|
||||
|
||||
核心功能有三个,全部通过Web界面操作:
|
||||
1. **模型训练**: 用户可以选择某个**药品**、某个**店铺**或**全局**数据,然后选择一种机器学习算法(如Transformer、mLSTM等)进行训练,最终生成一个预测模型。
|
||||
2. **销售预测**: 使用已经训练好的模型,对未来的销量进行预测。
|
||||
3. **结果可视化**: 将历史销量和预测销量在同一个图表中展示出来,方便用户直观地看到趋势。
|
||||
|
||||
简单来说,它就是一个 **"数据 -> 训练 -> 模型 -> 预测 -> 可视化"** 的完整闭环应用。
|
||||
|
||||
## 2. 用了什么技术?(技术栈)
|
||||
|
||||
你可以将这个项目的技术栈与Java世界进行类比:
|
||||
|
||||
| 层面 | 本项目技术 | Java世界类比 | 说明 |
|
||||
| :--- | :--- | :--- | :--- |
|
||||
| **后端框架** | **Flask** | Spring Boot | 一个轻量级的Web框架,用于提供API接口。 |
|
||||
| **前端框架** | **Vue.js** | React / Angular | 用于构建用户交互界面的现代化JavaScript框架。 |
|
||||
| **核心算法库** | **PyTorch** | (无直接对应) | 类似于Java的Deeplearning4j,是实现深度学习算法的核心。 |
|
||||
| **数据处理** | **Pandas** | (无直接对应) | Python中用于数据分析和处理的“瑞士军刀”,可以看作是内存中的强大数据表格。 |
|
||||
| **构建/打包** | **Vite** (前端) | Maven / Gradle | 前端项目的构建和依赖管理工具。 |
|
||||
| **数据库** | **SQLite** | H2 / MySQL | 一个轻量级的本地文件数据库,用于记录预测历史等。 |
|
||||
| **实时通信** | **Socket.IO** | WebSocket / STOMP | 用于后端在训练时向前端实时推送进度。 |
|
||||
|
||||
## 3. 系统架构是怎样的?(架构层级和设计)
|
||||
|
||||
本项目是经典的前后端分离架构,可以分为四个主要层次:
|
||||
|
||||
```
|
||||
+------------------------------------------------------+
|
||||
| 用户 (Browser) |
|
||||
+------------------------------------------------------+
|
||||
|
|
||||
+------------------------------------------------------+
|
||||
| 1. 前端层 (Frontend - Vue.js) |
|
||||
| - Views (页面组件, e.g., ProductPredictionView.vue) |
|
||||
| - API Calls (使用axios与后端通信) |
|
||||
| - Charting (使用Chart.js进行图表渲染) |
|
||||
+------------------------------------------------------+
|
||||
| (HTTP/S, WebSocket)
|
||||
+------------------------------------------------------+
|
||||
| 2. 后端API层 (Backend API - Flask) |
|
||||
| - api.py (类似Controller, 定义RESTful接口) |
|
||||
| - 接收请求, 验证参数, 调用业务逻辑层 |
|
||||
+------------------------------------------------------+
|
||||
|
|
||||
+------------------------------------------------------+
|
||||
| 3. 业务逻辑层 (Business Logic - Python) |
|
||||
| - core/predictor.py (类似Service层) |
|
||||
| - 封装核心业务, 如“根据参数选择合适的训练器” |
|
||||
+------------------------------------------------------+
|
||||
|
|
||||
+------------------------------------------------------+
|
||||
| 4. 数据与模型层 (Data & Model - PyTorch/Pandas) |
|
||||
| - trainers/*.py (具体的算法实现和训练逻辑) |
|
||||
| - predictors/model_predictor.py (模型加载与预测逻辑) |
|
||||
| - saved_models/ (存放训练好的.pth模型文件) |
|
||||
| - data/ (存放原始数据.parquet文件) |
|
||||
+------------------------------------------------------+
|
||||
```
|
||||
|
||||
## 4. 关键执行流程
|
||||
|
||||
以最常见的“按药品预测”为例:
|
||||
|
||||
1. **前端**: 用户在页面上选择药品和模型,点击“预测”按钮。Vue组件通过`axios`向后端发送一个POST请求到 `/api/prediction`。
|
||||
2. **API层**: `api.py` 接收到请求,像一个Controller一样,解析出药品ID、模型类型等参数。
|
||||
3. **业务逻辑层**: `api.py` 调用 `core/predictor.py` 中的 `predict` 方法,将参数传递下去。这一层是业务的“调度中心”。
|
||||
4. **模型层**: `core/predictor.py` 最终调用 `predictors/model_predictor.py` 中的 `load_model_and_predict` 函数。
|
||||
5. **模型加载与执行**:
|
||||
* 根据参数在 `saved_models/` 目录下找到对应的模型文件(例如 `transformer_store_01010023_best.pth` 或 `mlstm_product_17002608_v3.pth`)。
|
||||
* 加载文件,从中恢复出 **模型结构**、**模型权重** 和 **数据缩放器**。
|
||||
* 准备最新的历史数据作为输入,执行预测。
|
||||
* 将预测结果返回。
|
||||
6. **返回与渲染**: 结果逐层返回到`api.py`,在这里被格式化为JSON,然后发送给前端。前端接收到JSON后,使用`Chart.js`将历史和预测数据画在图表上。
|
||||
|
||||
## 5. 如何添加一个新的算法?(开发者指南)
|
||||
|
||||
这是你最可能接触到的新功能开发。假设你要添加一个名为 `NewNet` 的新算法,你需要按以下步骤操作:
|
||||
|
||||
**目标**: 让 `NewNet` 出现在前端的“模型类型”下拉框中,并能成功训练和预测。
|
||||
|
||||
1. **创建训练器文件**:
|
||||
* 在 `server/trainers/` 目录下,复制一份现有的训练器文件(例如 `tcn_trainer.py`)并重命名为 `newnet_trainer.py`。
|
||||
* 在 `newnet_trainer.py` 中:
|
||||
* 定义你的 `NewNet` 模型类(继承自 `torch.nn.Module`)。
|
||||
* 修改 `train_..._with_tcn` 函数,将其重命名为 `train_..._with_newnet`。
|
||||
* 在这个新函数里,确保实例化的是你的 `NewNet` 模型。
|
||||
* **最关键的一步**: 在保存checkpoint时,确保 `config` 字典里包含了重建 `NewNet` 所需的所有超参数(比如层数、节点数等)。
|
||||
|
||||
* **重要开发规范:参数命名规则**
|
||||
为了防止在模型加载时出现参数不匹配的错误(例如 `KeyError: 'num_layers'`),我们制定了以下命名规范:
|
||||
> **规则:** 对于特定于某个算法的超参数,其在 `config` 字典中的键名(key)必须以该算法的名称作为前缀或唯一标识。
|
||||
|
||||
**示例:**
|
||||
* 对于 `mLSTM` 模型的层数,键名应为 `mlstm_layers`。
|
||||
* 对于 `TCN` 模型的通道数,键名可以是 `tcn_channels`。
|
||||
* 对于 `Transformer` 模型的编码器层数,键名可以是 `num_encoder_layers` (因为这在Transformer语境下是明确的)。
|
||||
|
||||
在 **加载模型时** ([`server/predictors/model_predictor.py`](server/predictors/model_predictor.py:1)),必须使用与保存时完全一致的键名来读取这些参数。遵循此规则可以从根本上杜绝因参数名不一致导致的模型加载失败问题。
|
||||
|
||||
2. **注册新模型**:
|
||||
* 打开 `server/core/config.py` 文件。
|
||||
* 找到 `SUPPORTED_MODELS` 列表。
|
||||
* 在列表中添加你的新模型标识符 `'newnet'`。
|
||||
|
||||
3. **接入业务逻辑层 (训练)**:
|
||||
* 打开 `server/core/predictor.py` 文件。
|
||||
* 在 `train_model` 方法中,找到 `if/elif` 模型选择逻辑。
|
||||
* 添加一个新的 `elif model_type == 'newnet':` 分支,让它调用你在第一步中创建的 `train_..._with_newnet` 函数。
|
||||
|
||||
4. **接入模型层 (预测)**:
|
||||
* 打开 `server/predictors/model_predictor.py` 文件。
|
||||
* 在 `load_model_and_predict` 函数中,找到 `if/elif` 模型实例化逻辑。
|
||||
* 添加一个新的 `elif model_type == 'newnet':` 分支,确保它能根据 `config` 正确地创建 `NewNet` 模型实例。
|
||||
|
||||
5. **更新前端界面**:
|
||||
* 打开 `UI/src/views/training/` 和 `UI/src/views/prediction/` 目录下的相关Vue文件(如 `ProductTrainingView.vue`)。
|
||||
* 找到定义模型选项的地方(通常是一个数组或对象)。
|
||||
* 添加 `{ label: '新网络模型 (NewNet)', value: 'newnet' }` 这样的新选项。
|
||||
|
||||
完成以上步骤后,重启服务,你就可以在界面上选择并使用你的新算法了。
|
Loading…
x
Reference in New Issue
Block a user