Merge branch 'lyf-dev-req0004' into lyf-dev

This commit is contained in:
LYFxiaoan 2025-07-26 14:43:51 +08:00
commit 3b261feb30
80 changed files with 827 additions and 1198 deletions

View File

@ -410,3 +410,77 @@
2. **调试与修复**: 脚本的初版存在时区比较逻辑错误未能正确识别本地时区的记录。在经过一轮调试将时间比较基准从UTC修正为本地时间后脚本得以正常工作。 2. **调试与修复**: 脚本的初版存在时区比较逻辑错误未能正确识别本地时区的记录。在经过一轮调试将时间比较基准从UTC修正为本地时间后脚本得以正常工作。
3. **执行清理**: 最终成功执行脚本,从数据库中清除了所有在指定时间点之前创建的、带有错误文件路径的历史记录。 3. **执行清理**: 最终成功执行脚本,从数据库中清除了所有在指定时间点之前创建的、带有错误文件路径的历史记录。
- **最终结论**: 通过对数据库中的存量数据进行清理,彻底解决了因新旧数据路径不一致而导致的 `404` 错误,确保了系统数据的一致性和功能的稳定性。 - **最终结论**: 通过对数据库中的存量数据进行清理,彻底解决了因新旧数据路径不一致而导致的 `404` 错误,确保了系统数据的一致性和功能的稳定性。
---
## 2025-07-26核心数据源替换与系统级重构
**开发者**: Roo (AI Assistant) & lyf
### 第一阶段:核心数据源替换与适配器重构
- **任务目标**: 将项目的数据源从旧的、简单的 `timeseries_training_data_sample_10s50p.parquet` 彻底更换为新的、特征更丰富、结构更复杂的核心数据集。
- **核心挑战**: 新旧数据集在表结构(如 `subbh` vs `store_id`)、数据类型、特征集(如新数据缺少 `is_promotion` 列)和数据粒度上存在巨大差异,直接替换会导致系统全面崩溃。
- **解决方案 (适配器模式)**:
1. **创建中央数据适配器**: 新建了 `server/utils/new_data_loader.py`,作为整个系统唯一的数据入口。
2. **实现动态转换**: 该适配器负责加载新数据集,并动态地将其转换为旧代码库能够理解的格式,包括重命名列、用 `0` 填充缺失的 `is_promotion` 特征、以及处理数据类型。
3. **全面重构**: 系统性地重构了所有数据消费端——包括 `server/api.py`, `server/core/predictor.py` 以及 `server/trainers/` 目录下的**所有**训练器脚本——使其不再直接加载数据,而是统一通过新的数据适配器获取,从而实现了新数据源与现有业务逻辑的完全解耦。
### 第二阶段:端到端迭代调试与连锁问题修复
在完成重构后,我们进行了一系列端到端的测试,并修复了因此次重大变更而引发的一系列连锁问题。
- **修复 `KeyError`**: 解决了因新旧列名 (`subbh` vs `store_id`) 不匹配导致的键查找错误。
- **修复 `NaN` 值问题**: 在实现“按店铺训练-所有商品”的聚合逻辑时,因部分商品在某些日期无销售记录,导致聚合操作引入了`NaN`值。通过在所有训练器的数据准备阶段增加 `.fillna(0)` 清理步骤,彻底解决了该问题。
- **修复 `TypeError` (JSON序列化)**: 解决了在API响应中因返回了未经处理的 `numpy.ndarray``datetime.date` 对象而导致的JSON序列化失败问题。
- **修复 XGBoost 预测逻辑的根本性错误**:
- **问题现象**: 使用新数据训练的 `XGBoost` 模型,后端预测返回 `200 OK`,但前端图表渲染失败。
- **根本原因**: `server/predictors/model_predictor.py` 中存在一个严重的逻辑缺陷。它错误地对 `XGBoost` 模型使用了为 `PyTorch` 模型设计的“自回归”循环预测逻辑(即预测一天,再用该预测值去预测下一天)。而 `XGBoost` 模型本身是“直接多步输出”模型,一次性就能返回所有未来日期的预测值。错误的循环逻辑导致系统只取用了 `XGBoost` 完整输出结果中的第一个值,并将其错误地复制了多天,生成了无用的预测结果。
- **最终修复**: 在 `model_predictor.py` 中为 `XGBoost` 模型创建了一个独立的、非循环的逻辑分支。该分支能够正确地接收并处理 `XGBoost` 的完整输出数组,从而生成了正确的、可供前端渲染的预测结果。
### 第三阶段:系统环境清理
- **任务目标**: 在最终测试前,确保一个完全纯净的、不受任何旧数据或旧模型干扰的系统环境。
- **实施过程**: 在开发者的精确指导下,我们完成了以下清理工作:
1. **文件系统清理**: 手动删除了 `saved_models``saved_predictions` 文件夹中的所有历史产物。
2. **数据库清理**: 成功执行了 `server/tools/delete_old_predictions.py` 脚本,清空了 `prediction_history.db` 数据库中所有过时的预测记录。
- **最终结论**: 至此,数据源替换、系统重构、连锁问题修复和环境清理工作已全部完成。项目现在处于一个代码逻辑更健壮、数据源更可靠的全新状态,并已准备好进行最终的完整性验证。
---
## 2025-07-26 (续): 端到端测试与系统级修复
**开发者**: Roo (AI Assistant) & lyf
### 第三阶段 (续): 迭代修复与架构澄清
在对重构后的系统进行全面的端到端测试时我们发现并修复了一系列深层次的、与特定模型架构和训练模式相关的Bug。
- **修复 `XGBoost` 自定义天数预测崩溃的Bug**:
- **问题现象**: 使用 `XGBoost` 模型进行预测时,如果自定义的预测天数与模型训练时固定的预测范围不符,程序会因数组长度不匹配而崩溃。
- **架构澄清**: 我们确认了 `XGBoost` 是“直接多步输出”模型,其预测长度在训练时已固定。
- **最终修复**: 修改了 `server/predictors/model_predictor.py`,使其不再依赖用户输入的预测天数,而是根据模型**实际**的输出长度来动态生成日期序列,从而保证了程序的健壮性。
- **修复 `CnnBiLstmAttention` 模型预测逻辑错误**:
- **问题现象**: 该模型预测成功,但前端图表渲染失败。
- **根本原因**: 与 `XGBoost` 类似,该模型也是“直接多步输出”架构,但被错误地注册给了为“自回归”模型设计的预测函数,导致生成了无意义的预测结果。
- **最终修复**: 在 `server/predictors/model_predictor.py` 中,为 `CnnBiLstmAttention` 模型创建了一个专属的、正确的、非回归式的预测逻辑分支,确保其能被正确处理。
- **修复“全局训练”模式数据筛选逻辑错误**:
- **问题现象**: 在“全局训练”模式下,选择“所有店铺所有药品”时,训练因找不到任何数据而失败。
- **根本原因**: `server/core/predictor.py` 在处理该模式时,错误地使用了一个特殊的信号值(如 `unknown`)去筛选 `product_id`,导致数据集为空。
- **最终修复**: 重构了该部分的逻辑,确保在选择“所有药品”时,程序能正确地跳过按 `product_id` 筛选的步骤,直接对整个数据集进行聚合。
- **修复“全局预测”模式数据加载逻辑错误**:
- **问题现象**: “全局预测”模式因找不到数据而失败。
- **根本原因**: `server/predictors/model_predictor.py` 中存在与训练逻辑不一致的镜像Bug。
- **最终修复**: 再次修改了 `model_predictor.py`,使其“全局预测”的数据加载逻辑与 `predictor.py` 中的训练逻辑完全同步。
- **澄清“负销量”是数据特性而非Bug**:
- **问题现象**: 在“全局预测”的图表中,历史销量出现了负数。
- **分析**: 经过讨论,我们确认这是由于原始数据中包含了“退货”等业务场景,导致按天聚合求和后出现负值。
- **最终决策**: 为了尊重元数据的真实性,我们决定**不**在代码中将负数强制修正为0并接受这可能会对模型的训练和预测结果产生影响。这是一个重要的、关于数据处理哲学的决策。
- **全链路增强以支持“自定义全局训练”功能**:
- **问题现象**: 在“全局训练”中选择“自定义范围”(指定店铺和药品列表)时,训练失败。
- **根本原因**: 这是一个**全链路的参数传递中断**问题。从API接口到进程管理器再到核心训练逻辑整个系统都没有为处理这个新增的、复杂的自定义范围参数做好准备。
- **最终修复 (三位一体的全链路改造)**:
1. **升级API层 (`api.py`)**: 修改了 `/api/training` 接口,使其能够正确接收前端传递的 `selected_stores``selected_products` 列表。
2. **升级进程管理层 (`training_process_manager.py`)**: 对 `TrainingTask` 数据类和核心函数进行了全面改造,使其能够完整地接收、存储和向下传递这些新的自定义范围参数。
3. **升级核心逻辑层 (`predictor.py`)**: 对 `train_model` 函数进行了重大功能增强,为其增加了处理 `selected_stores``selected_products` 列表的全新逻辑分支。
- **最终结论**: 通过这次彻底的全链路改造我们不仅修复了Bug还成功地为系统增加了一项强大的新功能。至此所有在端到端测试中发现的已知问题均已得到解决。

Binary file not shown.

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -1,525 +0,0 @@
{
"product_id": "17021449",
"product_name": "布洛芬混悬液(美林)",
"model_type": "cnn_bilstm_attention",
"predictions": [
{
"date": "2025-07-25",
"predicted_sales": 0.8147072196006775
},
{
"date": "2025-07-26",
"predicted_sales": 0.8167740106582642
},
{
"date": "2025-07-27",
"predicted_sales": 0.8197348117828369
},
{
"date": "2025-07-28",
"predicted_sales": 0.8219858407974243
},
{
"date": "2025-07-29",
"predicted_sales": 0.8112776875495911
},
{
"date": "2025-07-30",
"predicted_sales": 0.8004958629608154
},
{
"date": "2025-07-31",
"predicted_sales": 0.8058184385299683
}
],
"prediction_data": [
{
"date": "2025-07-25",
"predicted_sales": 0.8147072196006775
},
{
"date": "2025-07-26",
"predicted_sales": 0.8167740106582642
},
{
"date": "2025-07-27",
"predicted_sales": 0.8197348117828369
},
{
"date": "2025-07-28",
"predicted_sales": 0.8219858407974243
},
{
"date": "2025-07-29",
"predicted_sales": 0.8112776875495911
},
{
"date": "2025-07-30",
"predicted_sales": 0.8004958629608154
},
{
"date": "2025-07-31",
"predicted_sales": 0.8058184385299683
}
],
"history_data": [
{
"date": "2025-06-25",
"sales": 0.0,
"product_id": "17021449",
"product_name": "布洛芬混悬液(美林)",
"store_id": "GLOBAL",
"store_name": "全部店铺-SUM",
"weekday": 2,
"month": 6,
"is_holiday": false,
"is_weekend": false,
"is_promotion": false,
"temperature": 20.0
},
{
"date": "2025-06-26",
"sales": 0.0,
"product_id": "17021449",
"product_name": "布洛芬混悬液(美林)",
"store_id": "GLOBAL",
"store_name": "全部店铺-SUM",
"weekday": 3,
"month": 6,
"is_holiday": false,
"is_weekend": false,
"is_promotion": false,
"temperature": 20.0
},
{
"date": "2025-06-27",
"sales": 0.0,
"product_id": "17021449",
"product_name": "布洛芬混悬液(美林)",
"store_id": "GLOBAL",
"store_name": "全部店铺-SUM",
"weekday": 4,
"month": 6,
"is_holiday": false,
"is_weekend": false,
"is_promotion": false,
"temperature": 20.0
},
{
"date": "2025-06-28",
"sales": 0.0,
"product_id": "17021449",
"product_name": "布洛芬混悬液(美林)",
"store_id": "GLOBAL",
"store_name": "全部店铺-SUM",
"weekday": 5,
"month": 6,
"is_holiday": false,
"is_weekend": true,
"is_promotion": false,
"temperature": 20.0
},
{
"date": "2025-06-29",
"sales": 0.0,
"product_id": "17021449",
"product_name": "布洛芬混悬液(美林)",
"store_id": "GLOBAL",
"store_name": "全部店铺-SUM",
"weekday": 6,
"month": 6,
"is_holiday": false,
"is_weekend": true,
"is_promotion": false,
"temperature": 20.0
},
{
"date": "2025-06-30",
"sales": 0.0,
"product_id": "17021449",
"product_name": "布洛芬混悬液(美林)",
"store_id": "GLOBAL",
"store_name": "全部店铺-SUM",
"weekday": 0,
"month": 6,
"is_holiday": false,
"is_weekend": false,
"is_promotion": false,
"temperature": 20.0
},
{
"date": "2025-07-01",
"sales": 0.0,
"product_id": "17021449",
"product_name": "布洛芬混悬液(美林)",
"store_id": "GLOBAL",
"store_name": "全部店铺-SUM",
"weekday": 1,
"month": 7,
"is_holiday": false,
"is_weekend": false,
"is_promotion": false,
"temperature": 20.0
},
{
"date": "2025-07-02",
"sales": 0.0,
"product_id": "17021449",
"product_name": "布洛芬混悬液(美林)",
"store_id": "GLOBAL",
"store_name": "全部店铺-SUM",
"weekday": 2,
"month": 7,
"is_holiday": false,
"is_weekend": false,
"is_promotion": false,
"temperature": 20.0
},
{
"date": "2025-07-03",
"sales": 0.0,
"product_id": "17021449",
"product_name": "布洛芬混悬液(美林)",
"store_id": "GLOBAL",
"store_name": "全部店铺-SUM",
"weekday": 3,
"month": 7,
"is_holiday": false,
"is_weekend": false,
"is_promotion": false,
"temperature": 20.0
},
{
"date": "2025-07-04",
"sales": 0.0,
"product_id": "17021449",
"product_name": "布洛芬混悬液(美林)",
"store_id": "GLOBAL",
"store_name": "全部店铺-SUM",
"weekday": 4,
"month": 7,
"is_holiday": false,
"is_weekend": false,
"is_promotion": false,
"temperature": 20.0
},
{
"date": "2025-07-05",
"sales": 0.0,
"product_id": "17021449",
"product_name": "布洛芬混悬液(美林)",
"store_id": "GLOBAL",
"store_name": "全部店铺-SUM",
"weekday": 5,
"month": 7,
"is_holiday": false,
"is_weekend": true,
"is_promotion": false,
"temperature": 20.0
},
{
"date": "2025-07-06",
"sales": 0.0,
"product_id": "17021449",
"product_name": "布洛芬混悬液(美林)",
"store_id": "GLOBAL",
"store_name": "全部店铺-SUM",
"weekday": 6,
"month": 7,
"is_holiday": false,
"is_weekend": true,
"is_promotion": false,
"temperature": 20.0
},
{
"date": "2025-07-07",
"sales": 0.0,
"product_id": "17021449",
"product_name": "布洛芬混悬液(美林)",
"store_id": "GLOBAL",
"store_name": "全部店铺-SUM",
"weekday": 0,
"month": 7,
"is_holiday": false,
"is_weekend": false,
"is_promotion": false,
"temperature": 20.0
},
{
"date": "2025-07-08",
"sales": 0.0,
"product_id": "17021449",
"product_name": "布洛芬混悬液(美林)",
"store_id": "GLOBAL",
"store_name": "全部店铺-SUM",
"weekday": 1,
"month": 7,
"is_holiday": false,
"is_weekend": false,
"is_promotion": false,
"temperature": 20.0
},
{
"date": "2025-07-09",
"sales": 0.0,
"product_id": "17021449",
"product_name": "布洛芬混悬液(美林)",
"store_id": "GLOBAL",
"store_name": "全部店铺-SUM",
"weekday": 2,
"month": 7,
"is_holiday": false,
"is_weekend": false,
"is_promotion": false,
"temperature": 20.0
},
{
"date": "2025-07-10",
"sales": 0.0,
"product_id": "17021449",
"product_name": "布洛芬混悬液(美林)",
"store_id": "GLOBAL",
"store_name": "全部店铺-SUM",
"weekday": 3,
"month": 7,
"is_holiday": false,
"is_weekend": false,
"is_promotion": false,
"temperature": 20.0
},
{
"date": "2025-07-11",
"sales": 0.0,
"product_id": "17021449",
"product_name": "布洛芬混悬液(美林)",
"store_id": "GLOBAL",
"store_name": "全部店铺-SUM",
"weekday": 4,
"month": 7,
"is_holiday": false,
"is_weekend": false,
"is_promotion": false,
"temperature": 20.0
},
{
"date": "2025-07-12",
"sales": 0.0,
"product_id": "17021449",
"product_name": "布洛芬混悬液(美林)",
"store_id": "GLOBAL",
"store_name": "全部店铺-SUM",
"weekday": 5,
"month": 7,
"is_holiday": false,
"is_weekend": true,
"is_promotion": false,
"temperature": 20.0
},
{
"date": "2025-07-13",
"sales": 0.0,
"product_id": "17021449",
"product_name": "布洛芬混悬液(美林)",
"store_id": "GLOBAL",
"store_name": "全部店铺-SUM",
"weekday": 6,
"month": 7,
"is_holiday": false,
"is_weekend": true,
"is_promotion": false,
"temperature": 20.0
},
{
"date": "2025-07-14",
"sales": 0.0,
"product_id": "17021449",
"product_name": "布洛芬混悬液(美林)",
"store_id": "GLOBAL",
"store_name": "全部店铺-SUM",
"weekday": 0,
"month": 7,
"is_holiday": false,
"is_weekend": false,
"is_promotion": false,
"temperature": 20.0
},
{
"date": "2025-07-15",
"sales": 0.0,
"product_id": "17021449",
"product_name": "布洛芬混悬液(美林)",
"store_id": "GLOBAL",
"store_name": "全部店铺-SUM",
"weekday": 1,
"month": 7,
"is_holiday": false,
"is_weekend": false,
"is_promotion": false,
"temperature": 20.0
},
{
"date": "2025-07-16",
"sales": 0.0,
"product_id": "17021449",
"product_name": "布洛芬混悬液(美林)",
"store_id": "GLOBAL",
"store_name": "全部店铺-SUM",
"weekday": 2,
"month": 7,
"is_holiday": false,
"is_weekend": false,
"is_promotion": false,
"temperature": 20.0
},
{
"date": "2025-07-17",
"sales": 0.0,
"product_id": "17021449",
"product_name": "布洛芬混悬液(美林)",
"store_id": "GLOBAL",
"store_name": "全部店铺-SUM",
"weekday": 3,
"month": 7,
"is_holiday": false,
"is_weekend": false,
"is_promotion": false,
"temperature": 20.0
},
{
"date": "2025-07-18",
"sales": 0.0,
"product_id": "17021449",
"product_name": "布洛芬混悬液(美林)",
"store_id": "GLOBAL",
"store_name": "全部店铺-SUM",
"weekday": 4,
"month": 7,
"is_holiday": false,
"is_weekend": false,
"is_promotion": false,
"temperature": 20.0
},
{
"date": "2025-07-19",
"sales": 0.0,
"product_id": "17021449",
"product_name": "布洛芬混悬液(美林)",
"store_id": "GLOBAL",
"store_name": "全部店铺-SUM",
"weekday": 5,
"month": 7,
"is_holiday": false,
"is_weekend": true,
"is_promotion": false,
"temperature": 20.0
},
{
"date": "2025-07-20",
"sales": 0.0,
"product_id": "17021449",
"product_name": "布洛芬混悬液(美林)",
"store_id": "GLOBAL",
"store_name": "全部店铺-SUM",
"weekday": 6,
"month": 7,
"is_holiday": false,
"is_weekend": true,
"is_promotion": false,
"temperature": 20.0
},
{
"date": "2025-07-21",
"sales": 0.0,
"product_id": "17021449",
"product_name": "布洛芬混悬液(美林)",
"store_id": "GLOBAL",
"store_name": "全部店铺-SUM",
"weekday": 0,
"month": 7,
"is_holiday": false,
"is_weekend": false,
"is_promotion": false,
"temperature": 20.0
},
{
"date": "2025-07-22",
"sales": 0.0,
"product_id": "17021449",
"product_name": "布洛芬混悬液(美林)",
"store_id": "GLOBAL",
"store_name": "全部店铺-SUM",
"weekday": 1,
"month": 7,
"is_holiday": false,
"is_weekend": false,
"is_promotion": false,
"temperature": 20.0
},
{
"date": "2025-07-23",
"sales": 0.0,
"product_id": "17021449",
"product_name": "布洛芬混悬液(美林)",
"store_id": "GLOBAL",
"store_name": "全部店铺-SUM",
"weekday": 2,
"month": 7,
"is_holiday": false,
"is_weekend": false,
"is_promotion": false,
"temperature": 20.0
},
{
"date": "2025-07-24",
"sales": 0.0,
"product_id": "17021449",
"product_name": "布洛芬混悬液(美林)",
"store_id": "GLOBAL",
"store_name": "全部店铺-SUM",
"weekday": 3,
"month": 7,
"is_holiday": false,
"is_weekend": false,
"is_promotion": false,
"temperature": 20.0
}
],
"analysis": {
"trend": {
"slope": -0.0024171343871525353,
"trend_type": "平稳",
"r_squared": 0.4619268323887481,
"p_value": 0.0930247330579927,
"volatility": 0.008749220910445412,
"volatility_level": "低"
},
"statistics": {
"mean": 0.8129705531256539,
"median": 0.8147072196006775,
"min": 0.8004958629608154,
"max": 0.8219858407974243,
"std": 0.007112858962983344,
"q1": 0.8085480630397797,
"q3": 0.8182544112205505
},
"day_over_day": [
0.25368512857903625,
0.36249942896524706,
0.2746045406674448,
-1.3027174820243943,
-1.328993112252526,
0.6649098159565847
],
"influencing_factors": {
"product_id": "17021449",
"model_type": "cnn_bilstm_attention",
"feature_count": 7,
"important_features": [
"价格",
"周末",
"节假日"
]
},
"explanation": "cnn_bilstm_attention模型对产品17021449的预测分析\n预测显示销量整体呈平稳趋势销量基本保持稳定。\n预测期内销量波动性低表明销量相对稳定预测可信度较高。\n预测期内平均日销量为0.81个单位最高日销量为0.82个单位最低日销量为0.80个单位。\n\n主要影响因素包括价格, 周末, 节假日。"
}
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -63,9 +63,11 @@ from core.config import (
) )
# 导入多店铺数据工具 # 导入多店铺数据工具
from utils.multi_store_data_utils import ( # from utils.multi_store_data_utils import (
get_available_stores, get_available_products, get_sales_statistics # get_available_stores, get_available_products, get_sales_statistics
) # )
# 以上旧模块将被新的统一数据加载器替代
from utils.new_data_loader import load_new_data
# 导入数据库初始化工具 # 导入数据库初始化工具
from init_multi_store_db import get_db_connection from init_multi_store_db import get_db_connection
@ -202,6 +204,9 @@ class CustomJSONEncoder(json.JSONEncoder):
# 处理日期时间类型 # 处理日期时间类型
elif isinstance(obj, datetime): elif isinstance(obj, datetime):
return obj.isoformat() return obj.isoformat()
# 新增处理date对象
elif isinstance(obj, pd.Timestamp) or hasattr(obj, 'isoformat'):
return obj.isoformat()
return super(CustomJSONEncoder, self).default(obj) return super(CustomJSONEncoder, self).default(obj)
# Helper function to convert numpy types to native python types for JSON serialization # Helper function to convert numpy types to native python types for JSON serialization
@ -515,11 +520,21 @@ def swagger_ui():
} }
}) })
def get_products(): def get_products():
"""获取所有产品列表 (已重构为使用新数据源)"""
try: try:
from utils.multi_store_data_utils import get_available_products df = load_new_data()
products = get_available_products() # 从新数据中提取唯一的产品ID
products_df = df[['product_id']].drop_duplicates().sort_values('product_id')
# 由于新数据没有product_name我们创建一个兼容的格式
products = [
{'product_id': pid, 'product_name': f'产品 {pid}'}
for pid in products_df['product_id']
]
return jsonify({"status": "success", "data": products}) return jsonify({"status": "success", "data": products})
except Exception as e: except Exception as e:
logger.error(f"获取产品列表失败: {traceback.format_exc()}")
return jsonify({"status": "error", "message": str(e)}), 500 return jsonify({"status": "error", "message": str(e)}), 500
@app.route('/api/products/<product_id>', methods=['GET']) @app.route('/api/products/<product_id>', methods=['GET'])
@ -570,25 +585,32 @@ def get_products():
} }
}) })
def get_product(product_id): def get_product(product_id):
"""获取单个产品详情 (已重构为使用新数据源)"""
try: try:
from utils.multi_store_data_utils import load_multi_store_data df = load_new_data()
df = load_multi_store_data(product_id=product_id) product_df = df[df['product_id'] == product_id]
if df.empty: if product_df.empty:
return jsonify({"status": "error", "message": "产品不存在"}), 404 return jsonify({"status": "error", "message": "产品不存在"}), 404
# 从新数据中提取信息
product_name = f"产品 {product_id}" # 备用名称
if 'product_name' in product_df.columns and not product_df['product_name'].empty:
product_name = product_df['product_name'].iloc[0]
product_info = { product_info = {
"product_id": product_id, "product_id": product_id,
"product_name": df['product_name'].iloc[0], "product_name": product_name,
"data_points": len(df), "data_points": len(product_df),
"date_range": { "date_range": {
"start": df['date'].min().strftime('%Y-%m-%d'), "start": product_df['date'].min().strftime('%Y-%m-%d'),
"end": df['date'].max().strftime('%Y-%m-%d') "end": product_df['date'].max().strftime('%Y-%m-%d')
} }
} }
return jsonify({"status": "success", "data": product_info}) return jsonify({"status": "success", "data": product_info})
except Exception as e: except Exception as e:
logger.error(f"获取产品详情失败: {traceback.format_exc()}")
return jsonify({"status": "error", "message": str(e)}), 500 return jsonify({"status": "error", "message": str(e)}), 500
@app.route('/api/products/<product_id>/sales', methods=['GET']) @app.route('/api/products/<product_id>/sales', methods=['GET'])
@ -644,29 +666,29 @@ def get_product(product_id):
} }
}) })
def get_product_sales(product_id): def get_product_sales(product_id):
"""获取产品销售数据 (已重构为使用新数据源)"""
try: try:
start_date = request.args.get('start_date') start_date = request.args.get('start_date')
end_date = request.args.get('end_date') end_date = request.args.get('end_date')
from utils.multi_store_data_utils import load_multi_store_data df = load_new_data()
df = load_multi_store_data( df_product = df[df['product_id'] == product_id]
product_id=product_id,
start_date=start_date,
end_date=end_date
)
if df.empty: if start_date:
return jsonify({"status": "error", "message": "产品不存在或无数据"}), 404 df_product = df_product[df_product['date'] >= pd.to_datetime(start_date)]
if end_date:
df_product = df_product[df_product['date'] <= pd.to_datetime(end_date)]
# 确保数据按日期排序 if df_product.empty:
df = df.sort_values('date') return jsonify({"status": "error", "message": "产品不存在或在指定日期范围内无数据"}), 404
# 转换日期为字符串以便JSON序列化 df_product = df_product.sort_values('date')
df['date'] = df['date'].dt.strftime('%Y-%m-%d') df_product['date'] = df_product['date'].dt.strftime('%Y-%m-%d')
sales_data = df.to_dict('records') sales_data = df_product.to_dict('records')
return jsonify({"status": "success", "data": sales_data}) return jsonify({"status": "success", "data": sales_data})
except Exception as e: except Exception as e:
logger.error(f"获取产品销售数据失败: {traceback.format_exc()}")
return jsonify({"status": "error", "message": str(e)}), 500 return jsonify({"status": "error", "message": str(e)}), 500
@app.route('/api/data/upload', methods=['POST']) @app.route('/api/data/upload', methods=['POST'])
@ -897,12 +919,16 @@ def start_training():
# 使用新的训练进程管理器提交任务 # 使用新的训练进程管理器提交任务
try: try:
# 升级:将自定义范围参数也传递给训练任务
task_id = training_manager.submit_task( task_id = training_manager.submit_task(
product_id=product_id or "unknown", product_id=product_id or "unknown",
model_type=model_type, model_type=model_type,
training_mode=training_mode, training_mode=training_mode,
store_id=store_id, store_id=store_id,
epochs=epochs epochs=epochs,
selected_stores=data.get('selected_stores'),
selected_products=data.get('selected_products'),
aggregation_method=aggregation_method
) )
logger.info(f"🚀 训练任务已提交到进程管理器: {task_id[:8]}") logger.info(f"🚀 训练任务已提交到进程管理器: {task_id[:8]}")
@ -1500,7 +1526,9 @@ def predict():
traceback.print_exc() traceback.print_exc()
# 不应阻止向用户返回结果,因此只打印警告 # 不应阻止向用户返回结果,因此只打印警告
return jsonify(response_data) # 在返回前,使用我们的辅助函数对整个响应进行一次深度清洗
cleaned_response_data = convert_numpy_types_for_json(response_data)
return jsonify(cleaned_response_data)
except Exception as e: except Exception as e:
print(f"预测失败: {str(e)}") print(f"预测失败: {str(e)}")
import traceback import traceback
@ -2550,32 +2578,24 @@ def get_latest_model_id(model_type, product_id):
# 获取产品名称的辅助函数 # 获取产品名称的辅助函数
def get_product_name(product_id): def get_product_name(product_id):
"""根据产品ID获取产品名称""" """根据产品ID获取产品名称 (已重构)"""
try: try:
# 从Excel文件中查找产品名称 # 注意:新数据源中没有 'product_name'。此函数现在返回一个占位符。
from utils.multi_store_data_utils import load_multi_store_data # 在未来的迭代中,可能需要关联一个产品信息表。
df = load_multi_store_data() return f"产品 {product_id}"
product_df = df[df['product_id'] == product_id] except Exception as e:
if not product_df.empty: logger.warning(f"获取产品名称时出现问题: {e}")
return product_df['product_name'].iloc[0] return product_id
return None
except Exception as e:
print(f"获取产品名称失败: {str(e)}")
return None
# 获取店铺名称的辅助函数
def get_store_name(store_id): def get_store_name(store_id):
"""根据店铺ID获取店铺名称""" """根据店铺ID获取店铺名称 (已重构)"""
try: try:
from utils.multi_store_data_utils import get_available_stores # 注意:新数据源中没有 'store_name'。此函数现在返回一个占位符。
stores = get_available_stores() # 在未来的迭代中,可能需要关联一个店铺信息表。
for store in stores: return f"店铺 {store_id}"
if store['store_id'] == store_id:
return store['store_name']
return None
except Exception as e: except Exception as e:
print(f"获取店铺名称失败: {str(e)}") logger.warning(f"获取店铺名称时出现问题: {e}")
return None return store_id
# run_prediction 函数已被移除,因为其逻辑已完全整合到 /api/prediction 路由处理函数中 # run_prediction 函数已被移除,因为其逻辑已完全整合到 /api/prediction 路由处理函数中
@ -3837,23 +3857,38 @@ def update_train_task_with_websocket():
@app.route('/api/stores', methods=['GET']) @app.route('/api/stores', methods=['GET'])
def get_stores(): def get_stores():
""" """获取所有店铺列表 (已重构为使用新数据源并填充信息)"""
获取所有店铺列表
"""
try: try:
from utils.multi_store_data_utils import get_available_stores df = load_new_data()
stores = get_available_stores()
# 从新数据中提取唯一的店铺信息
# 修正:只选择数据文件中实际存在的列
# 根据之前的分析,新数据有 'district' 列,但没有 'city' 和 'province'
stores_df = df[['store_id', 'district']].drop_duplicates('store_id')
stores_data = []
for _, row in stores_df.iterrows():
# 构建位置信息
location = row['district'] if pd.notna(row['district']) else "Unknown Location"
stores_data.append({
"store_id": row['store_id'],
"store_name": f"店铺 {row['store_id']}", # 使用ID作为临时名称
"location": location,
"type": "标准药店", # 填充默认值
"size": "120㎡", # 填充默认值
"opening_date": "2023-01-01", # 填充默认值
"status": "营业中" # 填充默认值
})
return jsonify({ return jsonify({
"status": "success", "status": "success",
"data": stores, "data": stores_data,
"count": len(stores) "count": len(stores_data)
}) })
except Exception as e: except Exception as e:
return jsonify({ logger.error(f"获取店铺列表失败: {traceback.format_exc()}")
"status": "error", return jsonify({"status": "error", "message": f"获取店铺列表失败: {str(e)}"}), 500
"message": f"获取店铺列表失败: {str(e)}"
}), 500
@app.route('/api/stores/<store_id>', methods=['GET']) @app.route('/api/stores/<store_id>', methods=['GET'])
def get_store(store_id): def get_store(store_id):
@ -4043,11 +4078,20 @@ def delete_store(store_id):
@app.route('/api/stores/<store_id>/products', methods=['GET']) @app.route('/api/stores/<store_id>/products', methods=['GET'])
def get_store_products(store_id): def get_store_products(store_id):
""" """获取店铺的产品列表 (已重构为使用新数据源)"""
获取店铺的产品列表
"""
try: try:
products = get_available_products(store_id=store_id) df = load_new_data()
store_df = df[df['store_id'] == store_id]
if store_df.empty:
return jsonify({"status": "success", "data": [], "count": 0})
products_df = store_df[['product_id']].drop_duplicates().sort_values('product_id')
products = [
{'product_id': pid, 'product_name': f'产品 {pid}'}
for pid in products_df['product_id']
]
return jsonify({ return jsonify({
"status": "success", "status": "success",
@ -4055,10 +4099,8 @@ def get_store_products(store_id):
"count": len(products) "count": len(products)
}) })
except Exception as e: except Exception as e:
return jsonify({ logger.error(f"获取店铺产品列表失败: {traceback.format_exc()}")
"status": "error", return jsonify({"status": "error", "message": f"获取店铺产品列表失败: {str(e)}"}), 500
"message": f"获取店铺产品列表失败: {str(e)}"
}), 500
@app.route('/api/stores/<store_id>/statistics', methods=['GET']) @app.route('/api/stores/<store_id>/statistics', methods=['GET'])
def get_store_statistics(store_id): def get_store_statistics(store_id):

View File

@ -20,11 +20,13 @@ from datetime import datetime
# 上述导入已不再需要,因为我们现在通过模型注册表动态获取训练器 # 上述导入已不再需要,因为我们现在通过模型注册表动态获取训练器
from predictors.model_predictor import load_model_and_predict from predictors.model_predictor import load_model_and_predict
from utils.data_utils import prepare_data, prepare_sequences from utils.data_utils import prepare_data, prepare_sequences
from utils.multi_store_data_utils import ( # from utils.multi_store_data_utils import (
load_multi_store_data, # load_multi_store_data,
get_store_product_sales_data, # get_store_product_sales_data,
aggregate_multi_store_data # aggregate_multi_store_data
) # )
# 以上旧模块已被新的统一数据加载器替代
from utils.new_data_loader import load_new_data
from analysis.metrics import evaluate_model from analysis.metrics import evaluate_model
from core.config import DEVICE, DEFAULT_MODEL_DIR, DEFAULT_DATA_PATH from core.config import DEVICE, DEFAULT_MODEL_DIR, DEFAULT_DATA_PATH
@ -53,18 +55,16 @@ class PharmacyPredictor:
print(f"使用设备: {self.device}") print(f"使用设备: {self.device}")
# 尝试加载多店铺数据 # 重构:不再预加载整个数据集到内存
try: # self.data 将在需要时动态加载
self.data = load_multi_store_data(data_path)
print(f"已加载多店铺数据,来源: {data_path}")
except Exception as e:
print(f"加载数据失败: {e}")
self.data = None self.data = None
print("PharmacyPredictor 已初始化,将在需要时动态加载数据。")
def train_model(self, product_id, model_type='transformer', epochs=100, batch_size=32, def train_model(self, product_id, model_type='transformer', epochs=100, batch_size=32,
learning_rate=0.001, sequence_length=30, forecast_horizon=7, learning_rate=0.001, sequence_length=30, forecast_horizon=7,
hidden_size=64, num_layers=2, dropout=0.1, use_optimized=False, hidden_size=64, num_layers=2, dropout=0.1, use_optimized=False,
store_id=None, training_mode='product', aggregation_method='sum', store_id=None, training_mode='product', aggregation_method='sum',
selected_stores: list = None, selected_products: list = None, # 新增参数以支持自定义范围
socketio=None, task_id=None, version=None, continue_training=False, socketio=None, task_id=None, version=None, continue_training=False,
progress_callback=None): progress_callback=None):
""" """
@ -104,76 +104,78 @@ class PharmacyPredictor:
except Exception as e: except Exception as e:
print(f"进度回调失败: {e}", flush=True) print(f"进度回调失败: {e}", flush=True)
if self.data is None: # --- 数据加载与筛选重构 ---
log_message("没有可用的数据,请先加载或生成数据", 'error') # 统一使用新的数据加载器,替换掉所有旧的、分散的加载逻辑
log_message("正在使用新的统一数据加载器...", 'info')
try:
full_df = load_new_data()
except Exception as e:
log_message(f"使用新数据加载器失败: {e}", 'error')
return None return None
# 根据训练模式准备数据
if training_mode == 'product': if training_mode == 'product':
# 按产品训练:使用所有店铺的该产品数据 product_data = full_df[full_df['product_id'] == product_id].copy()
product_data = self.data[self.data['product_id'] == product_id].copy()
if product_data.empty: if product_data.empty:
log_message(f"找不到产品 {product_id} 的数据", 'error') log_message(f"找不到产品 {product_id} 的数据", 'error')
return None return None
log_message(f"按产品训练模式: 产品 {product_id}, 数据量: {len(product_data)}") log_message(f"按产品训练模式: 产品 {product_id}, 数据量: {len(product_data)}", 'info')
elif training_mode == 'store': elif training_mode == 'store':
# 按店铺训练
if not store_id: if not store_id:
log_message("店铺训练模式需要指定 store_id", 'error') log_message("店铺训练模式需要指定 store_id", 'error')
return None return None
# 如果product_id是'unknown',则表示为店铺所有商品训练一个聚合模型 # 筛选出该店铺的所有数据
if product_id == 'unknown': store_df = full_df[full_df['store_id'] == store_id].copy()
try:
# 使用新的聚合函数,按店铺聚合 # 判断是为单个产品训练还是为整个店铺聚合训练
product_data = aggregate_multi_store_data( if product_id and product_id != 'unknown' and product_id != 'all_products':
store_id=store_id,
aggregation_method=aggregation_method,
file_path=self.data_path
)
log_message(f"按店铺聚合训练: 店铺 {store_id}, 聚合方法 {aggregation_method}, 数据量: {len(product_data)}")
# 将product_id设置为'store_{store_id}'与API查找逻辑保持一致
product_id = f"store_{store_id}"
except Exception as e:
log_message(f"聚合店铺 {store_id} 数据失败: {e}", 'error')
return None
else:
# 为店铺的单个特定产品训练 # 为店铺的单个特定产品训练
try: product_data = store_df[store_df['product_id'] == product_id].copy()
product_data = get_store_product_sales_data( log_message(f"按店铺-产品训练: 店铺 {store_id}, 产品 {product_id}, 数据量: {len(product_data)}", 'info')
store_id=store_id, else:
product_id=product_id, # 为整个店铺聚合训练
file_path=self.data_path log_message(f"按店铺聚合训练: 店铺 {store_id} (所有药品)", 'info')
) product_data = store_df.groupby('date').agg({
log_message(f"按店铺-产品训练: 店铺 {store_id}, 产品 {product_id}, 数据量: {len(product_data)}") 'sales': 'sum',
except Exception as e: 'weekday': 'first', 'month': 'first', 'is_holiday': 'first',
log_message(f"获取店铺产品数据失败: {e}", 'error') 'is_weekend': 'first', 'is_promotion': 'first', 'temperature': 'mean'
return None }).reset_index()
log_message(f"聚合后数据量: {len(product_data)}", 'info')
# 数据清洗使用0填充聚合后可能产生的NaN值
product_data.fillna(0, inplace=True)
log_message("已对聚合数据进行NaN值填充处理使用0填充", 'info')
elif training_mode == 'global': elif training_mode == 'global':
# 全局训练:聚合所有店铺的产品数据 # 升级:支持“自定义范围”的全局训练
try: if selected_stores and selected_products:
# 如果product_id是'unknown',则表示为全局所有商品训练一个聚合模型 log_message(f"全局训练模式: 自定义范围 -> 店铺 {selected_stores}, 产品 {selected_products}", 'info')
if product_id == 'unknown': # 根据选择的店铺和产品列表进行精确筛选
product_data = aggregate_multi_store_data( product_data = full_df[
product_id=None, # 传递None以触发真正的全局聚合 full_df['store_id'].isin(selected_stores) &
aggregation_method=aggregation_method, full_df['product_id'].isin(selected_products)
file_path=self.data_path ].copy()
) # 回退到旧的全局训练逻辑
log_message(f"全局训练模式: 所有产品, 聚合方法 {aggregation_method}, 数据量: {len(product_data)}") elif product_id and product_id not in ['unknown', 'all_products']:
# 将product_id设置为一个有意义的标识符 log_message(f"全局训练模式: 聚合单个产品 '{product_id}' 的跨店数据...", 'info')
product_id = 'all_products' product_data = full_df[full_df['product_id'] == product_id].copy()
else: else:
product_data = aggregate_multi_store_data( log_message("全局训练模式: 聚合所有店铺、所有药品的数据...", 'info')
product_id=product_id, product_data = full_df.copy()
aggregation_method=aggregation_method,
file_path=self.data_path if product_data.empty:
) log_message(f"在指定的全局范围或自定义范围内找不到任何数据。", 'error')
log_message(f"全局训练模式: 产品 {product_id}, 聚合方法 {aggregation_method}, 数据量: {len(product_data)}")
except Exception as e:
log_message(f"聚合全局数据失败: {e}", 'error')
return None return None
# 对筛选出的数据按天进行聚合
product_data = product_data.groupby('date').agg({
'sales': aggregation_method,
'weekday': 'first', 'month': 'first', 'is_holiday': 'first',
'is_weekend': 'first', 'is_promotion': 'first', 'temperature': 'mean'
}).reset_index()
log_message(f"全局训练聚合完成,聚合方法: {aggregation_method}, 数据量: {len(product_data)}", 'info')
else: else:
log_message(f"不支持的训练模式: {training_mode}", 'error') log_message(f"不支持的训练模式: {training_mode}", 'error')
return None return None
@ -183,7 +185,15 @@ class PharmacyPredictor:
# 店铺模型的标识符只应基于店铺ID # 店铺模型的标识符只应基于店铺ID
model_identifier = f"store_{store_id}" model_identifier = f"store_{store_id}"
elif training_mode == 'global': elif training_mode == 'global':
# 全局模型的标识符不应依赖于单个product_id # 升级:为自定义范围的全局模型创建唯一标识符
if selected_stores and selected_products:
# 创建一个基于内容哈希的稳定标识符
import hashlib
s_stores = "_".join(sorted(selected_stores))
s_products = "_".join(sorted(selected_products))
hash_input = f"{s_stores}_{s_products}_{aggregation_method}"
model_identifier = f"global_custom_{hashlib.md5(hash_input.encode()).hexdigest()[:8]}"
else:
model_identifier = f"global_{aggregation_method}" model_identifier = f"global_{aggregation_method}"
else: # product mode else: # product mode
model_identifier = product_id model_identifier = product_id

View File

@ -23,7 +23,7 @@ import xgboost as xgb
from analysis.trend_analysis import analyze_prediction_result from analysis.trend_analysis import analyze_prediction_result
from utils.visualization import plot_prediction_results from utils.visualization import plot_prediction_results
from utils.multi_store_data_utils import get_store_product_sales_data, aggregate_multi_store_data from utils.new_data_loader import load_new_data
from core.config import DEVICE, get_model_file_path, DEFAULT_DATA_PATH from core.config import DEVICE, get_model_file_path, DEFAULT_DATA_PATH
from models.model_registry import get_predictor, register_predictor from models.model_registry import get_predictor, register_predictor
@ -49,24 +49,64 @@ def default_pytorch_predictor(model, checkpoint, product_df, future_days, start_
history_for_chart_df = product_df[product_df['date'] < start_date_dt].tail(history_lookback_days) history_for_chart_df = product_df[product_df['date'] < start_date_dt].tail(history_lookback_days)
# --- 重构预测逻辑以正确处理XGBoost ---
if isinstance(model, xgb.Booster):
# --- XGBoost 预测路径 (非回归式,一次性预测) ---
X_current_scaled = scaler_X.transform(prediction_input_df[features].values)
X_input_reshaped = X_current_scaled.reshape(1, -1)
d_input = xgb.DMatrix(X_input_reshaped)
# 一次性获取所有未来天数的预测
y_pred_scaled = model.predict(d_input, iteration_range=(0, model.best_iteration))
# 反归一化整个序列
y_pred_unscaled = scaler_y.inverse_transform(y_pred_scaled.reshape(1, -1)).flatten()
y_pred_unscaled = np.maximum(0, y_pred_unscaled) # 确保销量不为负
# 生成未来日期序列
# 修正: 未来日期的数量必须与模型实际输出的预测点数量一致
# 而不是遵循用户输入的 future_days因为XGBoost模型输出的长度是固定的。
future_dates = pd.date_range(start=start_date_dt, periods=len(y_pred_unscaled))
# 直接构建结果DataFrame
predictions_df = pd.DataFrame({
'date': future_dates,
'predicted_sales': y_pred_unscaled
})
elif isinstance(model, CnnBiLstmAttention):
# --- CnnBiLstmAttention 预测路径 (非回归式,一次性预测) ---
X_current_scaled = scaler_X.transform(prediction_input_df[features].values)
X_input = torch.tensor(X_current_scaled.reshape(1, sequence_length, -1), dtype=torch.float32).to(DEVICE)
with torch.no_grad():
# 一次性获取所有未来天数的预测
y_pred_scaled = model(X_input).cpu().numpy()
# 反归一化整个序列
y_pred_unscaled = scaler_y.inverse_transform(y_pred_scaled).flatten()
y_pred_unscaled = np.maximum(0, y_pred_unscaled) # 确保销量不为负
# 生成未来日期序列,其长度与模型实际输出的预测点数量一致
future_dates = pd.date_range(start=start_date_dt, periods=len(y_pred_unscaled))
# 直接构建结果DataFrame
predictions_df = pd.DataFrame({
'date': future_dates,
'predicted_sales': y_pred_unscaled
})
else:
# --- 默认 PyTorch 模型预测路径 (自回归式) ---
all_predictions = [] all_predictions = []
current_sequence_df = prediction_input_df.copy() current_sequence_df = prediction_input_df.copy()
for _ in range(future_days): for _ in range(future_days):
X_current_scaled = scaler_X.transform(current_sequence_df[features].values) X_current_scaled = scaler_X.transform(current_sequence_df[features].values)
# **核心改进**: 智能判断模型类型并调用相应的预测方法
if isinstance(model, xgb.Booster):
# XGBoost 模型预测路径
X_input_reshaped = X_current_scaled.reshape(1, -1)
d_input = xgb.DMatrix(X_input_reshaped)
# **关键修复**: 使用 best_iteration 进行预测,以匹配早停策略
y_pred_scaled = model.predict(d_input, iteration_range=(0, model.best_iteration))
next_step_pred_scaled = y_pred_scaled.reshape(1, -1)
else:
# 默认 PyTorch 模型预测路径
X_input = torch.tensor(X_current_scaled.reshape(1, sequence_length, -1), dtype=torch.float32).to(DEVICE) X_input = torch.tensor(X_current_scaled.reshape(1, sequence_length, -1), dtype=torch.float32).to(DEVICE)
with torch.no_grad(): with torch.no_grad():
y_pred_scaled = model(X_input).cpu().numpy() y_pred_scaled = model(X_input).cpu().numpy()
next_step_pred_scaled = y_pred_scaled[0, 0].reshape(1, -1) next_step_pred_scaled = y_pred_scaled[0, 0].reshape(1, -1)
next_step_pred_unscaled = float(max(0, scaler_y.inverse_transform(next_step_pred_scaled)[0][0])) next_step_pred_unscaled = float(max(0, scaler_y.inverse_transform(next_step_pred_scaled)[0][0]))
@ -96,19 +136,48 @@ def load_model_and_predict(model_path: str, product_id: str, model_type: str, st
if not os.path.exists(model_path): if not os.path.exists(model_path):
raise FileNotFoundError(f"模型文件 {model_path} 不存在") raise FileNotFoundError(f"模型文件 {model_path} 不存在")
# --- 数据加载部分保持不变 --- # --- 数据加载重构 ---
from utils.multi_store_data_utils import aggregate_multi_store_data # 统一使用新的数据加载器,确保与训练时的数据源和处理逻辑完全一致
print("正在使用新的统一数据加载器进行预测...")
full_df = load_new_data()
if training_mode == 'store' and store_id: if training_mode == 'store' and store_id:
from utils.multi_store_data_utils import load_multi_store_data store_df = full_df[full_df['store_id'] == store_id].copy()
store_df_for_name = load_multi_store_data(store_id=store_id) # 判断是为单个产品预测还是为整个店铺聚合预测
product_name = store_df_for_name['store_name'].iloc[0] if not store_df_for_name.empty else f"店铺 {store_id}" if product_id and product_id != 'unknown' and product_id != 'all_products':
product_df = aggregate_multi_store_data(store_id=store_id, aggregation_method='sum', file_path=DEFAULT_DATA_PATH) product_df = store_df[store_df['product_id'] == product_id].copy()
elif training_mode == 'global': product_name = f"店铺 {store_id} - 产品 {product_id}"
product_df = aggregate_multi_store_data(aggregation_method='sum', file_path=DEFAULT_DATA_PATH)
product_name = "全局销售数据"
else: else:
product_df = aggregate_multi_store_data(product_id=product_id, aggregation_method='sum', file_path=DEFAULT_DATA_PATH) # 为整个店铺的聚合销售额进行预测
product_name = product_df['product_name'].iloc[0] if not product_df.empty else product_id product_df = store_df.groupby('date').agg({
'sales': 'sum',
'weekday': 'first', 'month': 'first', 'is_holiday': 'first',
'is_weekend': 'first', 'is_promotion': 'first', 'temperature': 'mean'
}).reset_index()
product_name = f"店铺 {store_id} (所有药品聚合)"
elif training_mode == 'global':
# 修正全局预测的数据加载逻辑
if product_id and product_id not in ['unknown', 'all_products']:
# 如果提供了具体产品ID虽然全局模式下不常见但应兼容则聚合该产品的跨店数据
product_df = full_df[full_df['product_id'] == product_id].copy()
product_name = f"全局聚合 - 产品 {product_id}"
else:
# 如果是“所有药品”的全局预测,则聚合所有数据
product_df = full_df.copy()
product_name = "全局聚合 (所有药品)"
product_df = product_df.groupby('date').agg({
'sales': 'sum', # 默认使用sum未来可配置
'weekday': 'first', 'month': 'first', 'is_holiday': 'first',
'is_weekend': 'first', 'is_promotion': 'first', 'temperature': 'mean'
}).reset_index()
else: # 默认 'product' 模式
product_df = full_df[full_df['product_id'] == product_id].copy()
# 兼容性处理:新数据可能没有 product_name 列
if 'product_name' in product_df.columns and not product_df['product_name'].empty:
product_name = product_df['product_name'].iloc[0]
else:
product_name = f"Product {product_id}"
if product_df.empty: if product_df.empty:
raise ValueError(f"产品 {product_id} 或店铺 {store_id} 没有销售数据") raise ValueError(f"产品 {product_id} 或店铺 {store_id} 没有销售数据")

View File

@ -13,6 +13,7 @@ from models.model_registry import register_trainer
from utils.model_manager import model_manager from utils.model_manager import model_manager
from analysis.metrics import evaluate_model from analysis.metrics import evaluate_model
from utils.data_utils import create_dataset from utils.data_utils import create_dataset
from utils.new_data_loader import load_new_data
from sklearn.preprocessing import MinMaxScaler from sklearn.preprocessing import MinMaxScaler
from utils.visualization import plot_loss_curve # 导入绘图函数 from utils.visualization import plot_loss_curve # 导入绘图函数
@ -26,9 +27,38 @@ def train_with_cnn_bilstm_attention(product_id, model_identifier, product_df, st
print(f"🚀 CNN-BiLSTM-Attention 训练器启动: model_identifier='{model_identifier}'") print(f"🚀 CNN-BiLSTM-Attention 训练器启动: model_identifier='{model_identifier}'")
start_time = time.time() start_time = time.time()
# --- 1. 数据准备 --- # --- 1. 数据加载与筛选重构 ---
if product_df is None:
print("正在使用新的统一数据加载器...")
full_df = load_new_data()
if training_mode == 'store' and store_id:
store_df = full_df[full_df['store_id'] == store_id].copy()
if product_id and product_id != 'unknown' and product_id != 'all_products':
product_df = store_df[store_df['product_id'] == product_id].copy()
else:
product_df = store_df.groupby('date').agg({
'sales': 'sum', 'weekday': 'first', 'month': 'first',
'is_holiday': 'first', 'is_weekend': 'first',
'is_promotion': 'first', 'temperature': 'mean'
}).reset_index()
product_df.fillna(0, inplace=True)
elif training_mode == 'global':
product_df = full_df[full_df['product_id'] == product_id].copy()
product_df = product_df.groupby('date').agg({
'sales': 'sum', 'weekday': 'first', 'month': 'first',
'is_holiday': 'first', 'is_weekend': 'first',
'is_promotion': 'first', 'temperature': 'mean'
}).reset_index()
product_df.fillna(0, inplace=True)
else: # 默认 'product' 模式
product_df = full_df[full_df['product_id'] == product_id].copy()
# --- 2. 数据准备 ---
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'] features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
product_name = product_df['product_name'].iloc[0] if 'product_name' in product_df.columns else model_identifier product_name = f"产品 {product_id}"
if 'product_name' in product_df.columns and not product_df['product_name'].empty:
product_name = product_df['product_name'].iloc[0]
X = product_df[features].values X = product_df[features].values
y = product_df[['sales']].values y = product_df[['sales']].values

View File

@ -17,6 +17,7 @@ from tqdm import tqdm
from models.kan_model import KANForecaster from models.kan_model import KANForecaster
from models.optimized_kan_forecaster import OptimizedKANForecaster from models.optimized_kan_forecaster import OptimizedKANForecaster
from utils.data_utils import create_dataset, PharmacyDataset from utils.data_utils import create_dataset, PharmacyDataset
from utils.new_data_loader import load_new_data
from utils.visualization import plot_loss_curve from utils.visualization import plot_loss_curve
from analysis.metrics import evaluate_model from analysis.metrics import evaluate_model
from core.config import DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON from core.config import DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON
@ -35,45 +36,47 @@ def train_product_model_with_kan(product_id, model_identifier, product_df=None,
model: 训练好的模型 model: 训练好的模型
metrics: 模型评估指标 metrics: 模型评估指标
""" """
# 如果没有传入product_df则根据训练模式加载数据 # --- 数据加载与筛选重构 ---
if product_df is None: # 统一使用新的数据加载器,替换掉所有旧的、分散的加载逻辑
from utils.multi_store_data_utils import load_multi_store_data, get_store_product_sales_data, aggregate_multi_store_data print("正在使用新的统一数据加载器...")
full_df = load_new_data() # 加载完整的、适配后的新数据
try:
if training_mode == 'store' and store_id: if training_mode == 'store' and store_id:
# 加载特定店铺的数据 store_df = full_df[full_df['store_id'] == store_id].copy()
product_df = get_store_product_sales_data( if product_id and product_id != 'unknown' and product_id != 'all_products':
store_id, product_df = store_df[store_df['product_id'] == product_id].copy()
product_id, training_scope = f"店铺 {store_id} - 产品 {product_id}"
'pharmacy_sales_multi_store.csv' else:
) product_df = store_df.groupby('date').agg({
training_scope = f"店铺 {store_id}" 'sales': 'sum', 'weekday': 'first', 'month': 'first',
'is_holiday': 'first', 'is_weekend': 'first',
'is_promotion': 'first', 'temperature': 'mean'
}).reset_index()
training_scope = f"店铺 {store_id} (所有药品聚合)"
# 数据清洗使用0填充聚合后可能产生的NaN值
product_df.fillna(0, inplace=True)
elif training_mode == 'global': elif training_mode == 'global':
# 聚合所有店铺的数据 # 筛选特定产品在所有店铺的聚合数据
product_df = aggregate_multi_store_data( # 注意:新数据已经是按 (store_id, product_id, date) 展开的,聚合逻辑可能需要重新审视
product_id, # 此处暂时只筛选产品ID
aggregation_method=aggregation_method, product_df = full_df[full_df['product_id'] == product_id].copy()
file_path='pharmacy_sales_multi_store.csv' # 按日期对同一产品在不同店铺的销售额求和
) product_df = product_df.groupby('date').agg({
'sales': 'sum',
# 保留其他需要的特征,例如取第一个非空值或平均值
'weekday': 'first',
'month': 'first',
'is_holiday': 'first',
'is_weekend': 'first',
'is_promotion': 'first',
'temperature': 'mean'
}).reset_index()
training_scope = f"全局聚合({aggregation_method})" training_scope = f"全局聚合({aggregation_method})"
else: else: # 默认 'product' 模式
# 默认:加载所有店铺的产品数据 # 筛选特定产品的数据(可能跨越多个店铺,但此处不聚合)
product_df = load_multi_store_data('pharmacy_sales_multi_store.csv', product_id=product_id) product_df = full_df[full_df['product_id'] == product_id].copy()
training_scope = "所有店铺" training_scope = f"所有店铺中的产品 {product_id}"
except Exception as e:
print(f"多店铺数据加载失败: {e}")
# 后备方案:尝试原始数据
df = pd.read_excel('pharmacy_sales.xlsx')
product_df = df[df['product_id'] == product_id].sort_values('date')
training_scope = "原始数据"
else:
# 如果传入了product_df直接使用
if training_mode == 'store' and store_id:
training_scope = f"店铺 {store_id}"
elif training_mode == 'global':
training_scope = f"全局聚合({aggregation_method})"
else:
training_scope = "所有店铺"
if product_df.empty: if product_df.empty:
raise ValueError(f"产品 {product_id} 没有可用的销售数据") raise ValueError(f"产品 {product_id} 没有可用的销售数据")
@ -95,7 +98,11 @@ def train_product_model_with_kan(product_id, model_identifier, product_df=None,
raise ValueError(error_msg) raise ValueError(error_msg)
product_df = product_df.sort_values('date') product_df = product_df.sort_values('date')
# 兼容性处理:新数据可能没有 product_name 列
if 'product_name' in product_df.columns:
product_name = product_df['product_name'].iloc[0] product_name = product_df['product_name'].iloc[0]
else:
product_name = f"Product {product_id}" # 使用 product_id 作为备用名称
model_type = "优化版KAN" if use_optimized else "KAN" model_type = "优化版KAN" if use_optimized else "KAN"
print(f"使用{model_type}模型训练产品 '{product_name}' (ID: {product_id}) 的销售预测模型") print(f"使用{model_type}模型训练产品 '{product_name}' (ID: {product_id}) 的销售预测模型")

View File

@ -16,7 +16,8 @@ from tqdm import tqdm
from models.mlstm_model import MLSTMTransformer as MatrixLSTM from models.mlstm_model import MLSTMTransformer as MatrixLSTM
from utils.data_utils import create_dataset, PharmacyDataset from utils.data_utils import create_dataset, PharmacyDataset
from utils.multi_store_data_utils import get_store_product_sales_data, aggregate_multi_store_data # from utils.multi_store_data_utils import get_store_product_sales_data, aggregate_multi_store_data
from utils.new_data_loader import load_new_data
from utils.visualization import plot_loss_curve from utils.visualization import plot_loss_curve
from analysis.metrics import evaluate_model from analysis.metrics import evaluate_model
from core.config import ( from core.config import (
@ -124,7 +125,40 @@ def train_product_model_with_mlstm(
except Exception as e: except Exception as e:
print(f"[mLSTM] 任务 {task_id}: 进度管理器初始化失败: {e}", flush=True) print(f"[mLSTM] 任务 {task_id}: 进度管理器初始化失败: {e}", flush=True)
# 数据现在由调用方传入,不再在此处加载 # --- 数据加载与筛选重构 ---
# 统一使用新的数据加载器
if product_df is None:
print("正在使用新的统一数据加载器...")
full_df = load_new_data()
if training_mode == 'store' and store_id:
store_df = full_df[full_df['store_id'] == store_id].copy()
if product_id and product_id != 'unknown' and product_id != 'all_products':
product_df = store_df[store_df['product_id'] == product_id].copy()
training_scope = f"店铺 {store_id} - 产品 {product_id}"
else:
product_df = store_df.groupby('date').agg({
'sales': 'sum', 'weekday': 'first', 'month': 'first',
'is_holiday': 'first', 'is_weekend': 'first',
'is_promotion': 'first', 'temperature': 'mean'
}).reset_index()
training_scope = f"店铺 {store_id} (所有药品聚合)"
# 数据清洗使用0填充聚合后可能产生的NaN值
product_df.fillna(0, inplace=True)
elif training_mode == 'global':
product_df = full_df[full_df['product_id'] == product_id].copy()
product_df = product_df.groupby('date').agg({
'sales': 'sum', 'weekday': 'first', 'month': 'first',
'is_holiday': 'first', 'is_weekend': 'first',
'is_promotion': 'first', 'temperature': 'mean'
}).reset_index()
training_scope = f"全局聚合({aggregation_method})"
else: # 默认 'product' 模式
product_df = full_df[full_df['product_id'] == product_id].copy()
training_scope = f"所有店铺中的产品 {product_id}"
else:
# 如果传入了product_df直接使用
if training_mode == 'store' and store_id: if training_mode == 'store' and store_id:
training_scope = f"店铺 {store_id}" training_scope = f"店铺 {store_id}"
elif training_mode == 'global': elif training_mode == 'global':
@ -149,7 +183,11 @@ def train_product_model_with_mlstm(
emit_progress(f"训练失败:数据不足 ({len(product_df)}/{min_required_samples} 天)") emit_progress(f"训练失败:数据不足 ({len(product_df)}/{min_required_samples} 天)")
raise ValueError(error_msg) raise ValueError(error_msg)
# 兼容性处理:新数据可能没有 product_name 列
if 'product_name' in product_df.columns and not product_df['product_name'].empty:
product_name = product_df['product_name'].iloc[0] product_name = product_df['product_name'].iloc[0]
else:
product_name = f"产品 {product_id}"
print(f"[mLSTM] 使用mLSTM模型训练产品 '{product_name}' (ID: {product_id}) 的销售预测模型", flush=True) print(f"[mLSTM] 使用mLSTM模型训练产品 '{product_name}' (ID: {product_id}) 的销售预测模型", flush=True)
print(f"[mLSTM] 训练范围: {training_scope}", flush=True) print(f"[mLSTM] 训练范围: {training_scope}", flush=True)

View File

@ -16,6 +16,7 @@ from tqdm import tqdm
from models.tcn_model import TCNForecaster from models.tcn_model import TCNForecaster
from utils.data_utils import create_dataset, PharmacyDataset from utils.data_utils import create_dataset, PharmacyDataset
from utils.new_data_loader import load_new_data
from utils.visualization import plot_loss_curve from utils.visualization import plot_loss_curve
from analysis.metrics import evaluate_model from analysis.metrics import evaluate_model
from core.config import DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON from core.config import DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON
@ -58,12 +59,41 @@ def train_product_model_with_tcn(
emit_progress(f"开始训练 TCN 模型") emit_progress(f"开始训练 TCN 模型")
# --- 数据加载与筛选重构 ---
if product_df is None: if product_df is None:
from utils.multi_store_data_utils import aggregate_multi_store_data print("正在使用新的统一数据加载器...")
product_df = aggregate_multi_store_data( full_df = load_new_data()
product_id=product_id,
aggregation_method=aggregation_method if training_mode == 'store' and store_id:
) store_df = full_df[full_df['store_id'] == store_id].copy()
if product_id and product_id != 'unknown' and product_id != 'all_products':
product_df = store_df[store_df['product_id'] == product_id].copy()
training_scope = f"店铺 {store_id} - 产品 {product_id}"
else:
product_df = store_df.groupby('date').agg({
'sales': 'sum', 'weekday': 'first', 'month': 'first',
'is_holiday': 'first', 'is_weekend': 'first',
'is_promotion': 'first', 'temperature': 'mean'
}).reset_index()
training_scope = f"店铺 {store_id} (所有药品聚合)"
product_df.fillna(0, inplace=True)
elif training_mode == 'global':
product_df = full_df[full_df['product_id'] == product_id].copy()
product_df = product_df.groupby('date').agg({
'sales': 'sum', 'weekday': 'first', 'month': 'first',
'is_holiday': 'first', 'is_weekend': 'first',
'is_promotion': 'first', 'temperature': 'mean'
}).reset_index()
training_scope = f"全局聚合({aggregation_method})"
product_df.fillna(0, inplace=True)
else: # 默认 'product' 模式
product_df = full_df[full_df['product_id'] == product_id].copy()
training_scope = f"所有店铺中的产品 {product_id}"
else:
# 如果传入了product_df直接使用
if training_mode == 'store' and store_id:
training_scope = f"店铺 {store_id}"
elif training_mode == 'global':
training_scope = f"全局聚合({aggregation_method})" training_scope = f"全局聚合({aggregation_method})"
else: else:
training_scope = "所有店铺" training_scope = "所有店铺"
@ -84,7 +114,11 @@ def train_product_model_with_tcn(
raise ValueError(error_msg) raise ValueError(error_msg)
product_df = product_df.sort_values('date') product_df = product_df.sort_values('date')
# 兼容性处理:新数据可能没有 product_name 列
if 'product_name' in product_df.columns and not product_df['product_name'].empty:
product_name = product_df['product_name'].iloc[0] product_name = product_df['product_name'].iloc[0]
else:
product_name = f"产品 {product_id}"
print(f"使用TCN模型训练产品 '{product_name}' (ID: {product_id}) 的销售预测模型") print(f"使用TCN模型训练产品 '{product_name}' (ID: {product_id}) 的销售预测模型")
print(f"训练范围: {training_scope}") print(f"训练范围: {training_scope}")

View File

@ -17,7 +17,8 @@ from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from models.transformer_model import TimeSeriesTransformer from models.transformer_model import TimeSeriesTransformer
from utils.data_utils import create_dataset, PharmacyDataset from utils.data_utils import create_dataset, PharmacyDataset
from utils.multi_store_data_utils import get_store_product_sales_data, aggregate_multi_store_data # from utils.multi_store_data_utils import get_store_product_sales_data, aggregate_multi_store_data
from utils.new_data_loader import load_new_data
from utils.visualization import plot_loss_curve from utils.visualization import plot_loss_curve
from analysis.metrics import evaluate_model from analysis.metrics import evaluate_model
from core.config import ( from core.config import (
@ -81,12 +82,42 @@ def train_product_model_with_transformer(
def finish_training(self, *args, **kwargs): pass def finish_training(self, *args, **kwargs): pass
progress_manager = DummyProgressManager() progress_manager = DummyProgressManager()
# --- 数据加载与筛选重构 ---
if product_df is None: if product_df is None:
from utils.multi_store_data_utils import aggregate_multi_store_data print("正在使用新的统一数据加载器...")
product_df = aggregate_multi_store_data( full_df = load_new_data()
product_id=product_id,
aggregation_method=aggregation_method if training_mode == 'store' and store_id:
) store_df = full_df[full_df['store_id'] == store_id].copy()
if product_id and product_id != 'unknown' and product_id != 'all_products':
product_df = store_df[store_df['product_id'] == product_id].copy()
training_scope = f"店铺 {store_id} - 产品 {product_id}"
else:
product_df = store_df.groupby('date').agg({
'sales': 'sum', 'weekday': 'first', 'month': 'first',
'is_holiday': 'first', 'is_weekend': 'first',
'is_promotion': 'first', 'temperature': 'mean'
}).reset_index()
training_scope = f"店铺 {store_id} (所有药品聚合)"
# 数据清洗使用0填充聚合后可能产生的NaN值
product_df.fillna(0, inplace=True)
elif training_mode == 'global':
product_df = full_df[full_df['product_id'] == product_id].copy()
product_df = product_df.groupby('date').agg({
'sales': 'sum', 'weekday': 'first', 'month': 'first',
'is_holiday': 'first', 'is_weekend': 'first',
'is_promotion': 'first', 'temperature': 'mean'
}).reset_index()
training_scope = f"全局聚合({aggregation_method})"
else: # 默认 'product' 模式
product_df = full_df[full_df['product_id'] == product_id].copy()
training_scope = f"所有店铺中的产品 {product_id}"
else:
# 如果传入了product_df直接使用
if training_mode == 'store' and store_id:
training_scope = f"店铺 {store_id}"
elif training_mode == 'global':
training_scope = f"全局聚合({aggregation_method})" training_scope = f"全局聚合({aggregation_method})"
else: else:
training_scope = "所有店铺" training_scope = "所有店铺"
@ -106,7 +137,11 @@ def train_product_model_with_transformer(
raise ValueError(error_msg) raise ValueError(error_msg)
product_df = product_df.sort_values('date') product_df = product_df.sort_values('date')
# 兼容性处理:新数据可能没有 product_name 列
if 'product_name' in product_df.columns and not product_df['product_name'].empty:
product_name = product_df['product_name'].iloc[0] product_name = product_df['product_name'].iloc[0]
else:
product_name = f"产品 {product_id}"
print(f"[Transformer] 训练产品 '{product_name}' (ID: {product_id}) 的销售预测模型", flush=True) print(f"[Transformer] 训练产品 '{product_name}' (ID: {product_id}) 的销售预测模型", flush=True)
print(f"[Device] 使用设备: {DEVICE}", flush=True) print(f"[Device] 使用设备: {DEVICE}", flush=True)

View File

@ -11,6 +11,7 @@ from xgboost.callback import EarlyStopping
# 导入核心工具 # 导入核心工具
from utils.data_utils import create_dataset from utils.data_utils import create_dataset
from utils.new_data_loader import load_new_data
from analysis.metrics import evaluate_model from analysis.metrics import evaluate_model
from utils.model_manager import model_manager from utils.model_manager import model_manager
from models.model_registry import register_trainer from models.model_registry import register_trainer
@ -23,9 +24,32 @@ def train_product_model_with_xgboost(product_id, model_identifier, product_df, s
""" """
print(f"🚀 XGBoost训练器启动: model_identifier='{model_identifier}'") print(f"🚀 XGBoost训练器启动: model_identifier='{model_identifier}'")
# --- 1. 数据准备和验证 --- # --- 1. 数据加载与筛选重构 ---
if product_df.empty: if product_df is None:
raise ValueError(f"产品 {product_id} 没有可用的销售数据") print("正在使用新的统一数据加载器...")
full_df = load_new_data()
if training_mode == 'store' and store_id:
store_df = full_df[full_df['store_id'] == store_id].copy()
if product_id and product_id != 'unknown' and product_id != 'all_products':
product_df = store_df[store_df['product_id'] == product_id].copy()
else:
product_df = store_df.groupby('date').agg({
'sales': 'sum', 'weekday': 'first', 'month': 'first',
'is_holiday': 'first', 'is_weekend': 'first',
'is_promotion': 'first', 'temperature': 'mean'
}).reset_index()
product_df.fillna(0, inplace=True)
elif training_mode == 'global':
product_df = full_df[full_df['product_id'] == product_id].copy()
product_df = product_df.groupby('date').agg({
'sales': 'sum', 'weekday': 'first', 'month': 'first',
'is_holiday': 'first', 'is_weekend': 'first',
'is_promotion': 'first', 'temperature': 'mean'
}).reset_index()
product_df.fillna(0, inplace=True)
else: # 默认 'product' 模式
product_df = full_df[full_df['product_id'] == product_id].copy()
min_required_samples = sequence_length + forecast_horizon min_required_samples = sequence_length + forecast_horizon
if len(product_df) < min_required_samples: if len(product_df) < min_required_samples:
@ -33,7 +57,9 @@ def train_product_model_with_xgboost(product_id, model_identifier, product_df, s
raise ValueError(error_msg) raise ValueError(error_msg)
product_df = product_df.sort_values('date') product_df = product_df.sort_values('date')
product_name = product_df['product_name'].iloc[0] if 'product_name' in product_df.columns else model_identifier product_name = f"产品 {product_id}"
if 'product_name' in product_df.columns and not product_df['product_name'].empty:
product_name = product_df['product_name'].iloc[0]
# --- 2. 数据预处理和适配 --- # --- 2. 数据预处理和适配 ---
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'] features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']

View File

@ -1,424 +0,0 @@
"""
多店铺销售预测系统 - 数据处理工具函数
支持多店铺数据的加载过滤和处理
"""
import pandas as pd
import numpy as np
import os
from datetime import datetime, timedelta
from typing import Optional, List, Tuple, Dict, Any
from core.config import DEFAULT_DATA_PATH
def load_multi_store_data(file_path: str = None,
store_id: Optional[str] = None,
product_id: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None) -> pd.DataFrame:
"""
加载多店铺销售数据支持按店铺产品时间范围过滤
参数:
file_path: 数据文件路径 (支持 .csv, .xlsx, .parquet)如果为None则使用config中定义的默认路径
store_id: 店铺ID为None时返回所有店铺数据
product_id: 产品ID为None时返回所有产品数据
start_date: 开始日期 (YYYY-MM-DD)
end_date: 结束日期 (YYYY-MM-DD)
返回:
DataFrame: 过滤后的销售数据
"""
# 如果未提供文件路径,则使用配置文件中的默认路径
if file_path is None:
file_path = DEFAULT_DATA_PATH
if not os.path.exists(file_path):
raise FileNotFoundError(f"数据文件不存在: {file_path}")
try:
if file_path.endswith('.csv'):
df = pd.read_csv(file_path)
elif file_path.endswith('.xlsx'):
df = pd.read_excel(file_path)
elif file_path.endswith('.parquet'):
df = pd.read_parquet(file_path)
else:
raise ValueError(f"不支持的文件格式: {file_path}")
print(f"成功加载数据文件: {file_path}")
except Exception as e:
print(f"加载文件 {file_path} 失败: {e}")
raise
# 按店铺过滤
if store_id:
df = df[df['store_id'] == store_id].copy()
print(f"按店铺过滤: {store_id}, 剩余记录数: {len(df)}")
# 按产品过滤
if product_id:
df = df[df['product_id'] == product_id].copy()
print(f"按产品过滤: {product_id}, 剩余记录数: {len(df)}")
# 标准化列名和数据类型
df = standardize_column_names(df)
# 在标准化之后进行时间范围过滤
if start_date:
try:
start_date_dt = pd.to_datetime(start_date)
# 确保比较是在datetime对象之间
if 'date' in df.columns:
df = df[df['date'] >= start_date_dt].copy()
print(f"开始日期过滤: {start_date_dt}, 剩余记录数: {len(df)}")
except (ValueError, TypeError):
print(f"警告: 无效的开始日期格式 '{start_date}',已忽略。")
if end_date:
try:
end_date_dt = pd.to_datetime(end_date)
# 确保比较是在datetime对象之间
if 'date' in df.columns:
df = df[df['date'] <= end_date_dt].copy()
print(f"结束日期过滤: {end_date_dt}, 剩余记录数: {len(df)}")
except (ValueError, TypeError):
print(f"警告: 无效的结束日期格式 '{end_date}',已忽略。")
if len(df) == 0:
print("警告: 过滤后没有数据")
return df
def standardize_column_names(df: pd.DataFrame) -> pd.DataFrame:
"""
标准化列名以匹配训练代码和API期望的格式
参数:
df: 原始DataFrame
返回:
DataFrame: 标准化列名后的DataFrame
"""
df = df.copy()
# 定义列名映射并强制重命名
rename_map = {
'sales_quantity': 'sales', # 修复:匹配原始列名
'temperature_2m_mean': 'temperature', # 新增:处理温度列
'dayofweek': 'weekday' # 修复:匹配原始列名
}
df.rename(columns={k: v for k, v in rename_map.items() if k in df.columns}, inplace=True)
# 确保date列是datetime类型
if 'date' in df.columns:
df['date'] = pd.to_datetime(df['date'], errors='coerce')
df.dropna(subset=['date'], inplace=True) # 移除无法解析的日期行
else:
# 如果没有date列无法继续返回空DataFrame
return pd.DataFrame()
# 计算 sales_amount
# 由于没有price列sales_amount的计算逻辑需要调整或移除
# 这里我们注释掉它因为原始数据中已有sales_amount
# if 'sales_amount' not in df.columns and 'sales' in df.columns and 'price' in df.columns:
# # 先确保sales和price是数字
# df['sales'] = pd.to_numeric(df['sales'], errors='coerce')
# df['price'] = pd.to_numeric(df['price'], errors='coerce')
# df['sales_amount'] = df['sales'] * df['price']
# 创建缺失的特征列
if 'weekday' not in df.columns:
df['weekday'] = df['date'].dt.dayofweek
if 'month' not in df.columns:
df['month'] = df['date'].dt.month
# 添加缺失的元数据列
meta_columns = {
'store_name': 'Unknown Store',
'store_location': 'Unknown Location',
'store_type': 'Unknown',
'product_name': 'Unknown Product',
'product_category': 'Unknown Category'
}
for col, default in meta_columns.items():
if col not in df.columns:
df[col] = default
# 添加缺失的布尔特征列
default_features = {
'is_holiday': False,
'is_weekend': None,
'is_promotion': False,
'temperature': 20.0
}
for feature, default_value in default_features.items():
if feature not in df.columns:
if feature == 'is_weekend':
df['is_weekend'] = df['weekday'].isin([5, 6])
else:
df[feature] = default_value
# 确保数值类型正确
numeric_columns = ['sales', 'sales_amount', 'weekday', 'month', 'temperature']
for col in numeric_columns:
if col in df.columns:
df[col] = pd.to_numeric(df[col], errors='coerce')
# 确保布尔类型正确
boolean_columns = ['is_holiday', 'is_weekend', 'is_promotion']
for col in boolean_columns:
if col in df.columns:
df[col] = df[col].astype(bool)
print(f"数据标准化完成,可用特征列: {[col for col in ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'] if col in df.columns]}")
return df
def get_available_stores(file_path: str = None) -> List[Dict[str, Any]]:
"""
获取可用的店铺列表
参数:
file_path: 数据文件路径
返回:
List[Dict]: 店铺信息列表
"""
try:
df = load_multi_store_data(file_path)
if 'store_id' not in df.columns:
print("数据文件中缺少 'store_id'")
return []
# 智能地获取店铺信息,即使某些列缺失
store_info = []
# 使用drop_duplicates获取唯一的店铺组合
stores_df = df.drop_duplicates(subset=['store_id'])
for _, row in stores_df.iterrows():
store_info.append({
'store_id': row['store_id'],
'store_name': row.get('store_name', f"店铺 {row['store_id']}"),
'location': row.get('store_location', '未知位置'),
'type': row.get('store_type', '标准'),
'opening_date': row.get('opening_date', '未知'),
})
return store_info
except Exception as e:
print(f"获取店铺列表失败: {e}")
return []
def get_available_products(file_path: str = None,
store_id: Optional[str] = None) -> List[Dict[str, Any]]:
"""
获取可用的产品列表
参数:
file_path: 数据文件路径
store_id: 店铺ID为None时返回所有产品
返回:
List[Dict]: 产品信息列表
"""
try:
df = load_multi_store_data(file_path, store_id=store_id)
# 获取唯一产品信息
product_columns = ['product_id', 'product_name']
if 'product_category' in df.columns:
product_columns.append('product_category')
if 'unit_price' in df.columns:
product_columns.append('unit_price')
products = df[product_columns].drop_duplicates()
return products.to_dict('records')
except Exception as e:
print(f"获取产品列表失败: {e}")
return []
def get_store_product_sales_data(store_id: str,
product_id: str,
file_path: str = None) -> pd.DataFrame:
"""
获取特定店铺和产品的销售数据用于模型训练
参数:
file_path: 数据文件路径
store_id: 店铺ID
product_id: 产品ID
返回:
DataFrame: 处理后的销售数据包含模型需要的特征
"""
# 加载数据
df = load_multi_store_data(file_path, store_id=store_id, product_id=product_id)
if len(df) == 0:
raise ValueError(f"没有找到店铺 {store_id} 产品 {product_id} 的销售数据")
# 确保数据按日期排序
df = df.sort_values('date').copy()
# 数据标准化已在load_multi_store_data中完成
# 验证必要的列是否存在
required_columns = ['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
missing_columns = [col for col in required_columns if col not in df.columns]
if missing_columns:
print(f"警告: 数据标准化后仍缺少列 {missing_columns}")
raise ValueError(f"无法获取完整的特征数据,缺少列: {missing_columns}")
# 定义模型训练所需的所有列(特征 + 目标)
final_columns = [
'date', 'sales', 'product_id', 'product_name', 'store_id', 'store_name',
'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'
]
# 筛选出DataFrame中实际存在的列
existing_columns = [col for col in final_columns if col in df.columns]
# 返回只包含这些必需列的DataFrame
return df[existing_columns]
def aggregate_multi_store_data(product_id: Optional[str] = None,
store_id: Optional[str] = None,
aggregation_method: str = 'sum',
file_path: str = None) -> pd.DataFrame:
"""
聚合销售数据可按产品全局或按店铺所有产品
参数:
file_path: 数据文件路径
product_id: 产品ID (用于全局模型)
store_id: 店铺ID (用于店铺聚合模型)
aggregation_method: 聚合方法 ('sum', 'mean', 'median')
返回:
DataFrame: 聚合后的销售数据
"""
# 根据是全局聚合、店铺聚合还是真正全局聚合来加载数据
if store_id:
# 店铺聚合:加载该店铺的所有数据
df = load_multi_store_data(file_path, store_id=store_id)
if len(df) == 0:
raise ValueError(f"没有找到店铺 {store_id} 的销售数据")
grouping_entity = f"店铺 {store_id}"
elif product_id:
# 按产品聚合:加载该产品在所有店铺的数据
df = load_multi_store_data(file_path, product_id=product_id)
if len(df) == 0:
raise ValueError(f"没有找到产品 {product_id} 的销售数据")
grouping_entity = f"产品 {product_id}"
else:
# 真正全局聚合:加载所有数据
df = load_multi_store_data(file_path)
if len(df) == 0:
raise ValueError("数据文件为空,无法进行全局聚合")
grouping_entity = "所有产品"
# 按日期聚合(使用标准化后的列名)
agg_dict = {}
if aggregation_method == 'sum':
agg_dict = {
'sales': 'sum', # 标准化后的销量列
'sales_amount': 'sum',
'price': 'mean' # 标准化后的价格列,取平均值
}
elif aggregation_method == 'mean':
agg_dict = {
'sales': 'mean',
'sales_amount': 'mean',
'price': 'mean'
}
elif aggregation_method == 'median':
agg_dict = {
'sales': 'median',
'sales_amount': 'median',
'price': 'median'
}
# 确保列名存在
available_cols = df.columns.tolist()
agg_dict = {k: v for k, v in agg_dict.items() if k in available_cols}
# 聚合数据
aggregated_df = df.groupby('date').agg(agg_dict).reset_index()
# 获取产品信息(取第一个店铺的信息)
product_info = df[['product_id', 'product_name', 'product_category']].iloc[0]
for col, val in product_info.items():
aggregated_df[col] = val
# 添加店铺信息标识为全局
aggregated_df['store_id'] = 'GLOBAL'
aggregated_df['store_name'] = f'全部店铺-{aggregation_method.upper()}'
aggregated_df['store_location'] = '全局聚合'
aggregated_df['store_type'] = 'global'
# 对聚合后的数据进行标准化(添加缺失的特征列)
aggregated_df = aggregated_df.sort_values('date').copy()
aggregated_df = standardize_column_names(aggregated_df)
# 定义模型训练所需的所有列(特征 + 目标)
final_columns = [
'date', 'sales', 'product_id', 'product_name', 'store_id', 'store_name',
'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'
]
# 筛选出DataFrame中实际存在的列
existing_columns = [col for col in final_columns if col in aggregated_df.columns]
# 返回只包含这些必需列的DataFrame
return aggregated_df[existing_columns]
def get_sales_statistics(file_path: str = None,
store_id: Optional[str] = None,
product_id: Optional[str] = None) -> Dict[str, Any]:
"""
获取销售数据统计信息
参数:
file_path: 数据文件路径
store_id: 店铺ID
product_id: 产品ID
返回:
Dict: 统计信息
"""
try:
df = load_multi_store_data(file_path, store_id=store_id, product_id=product_id)
if len(df) == 0:
return {'error': '没有数据'}
stats = {
'total_records': len(df),
'date_range': {
'start': df['date'].min().strftime('%Y-%m-%d'),
'end': df['date'].max().strftime('%Y-%m-%d')
},
'stores': df['store_id'].nunique(),
'products': df['product_id'].nunique(),
'total_sales_amount': float(df['sales_amount'].sum()) if 'sales_amount' in df.columns else 0,
'total_quantity': int(df['quantity_sold'].sum()) if 'quantity_sold' in df.columns else 0,
'avg_daily_sales': float(df.groupby('date')['quantity_sold'].sum().mean()) if 'quantity_sold' in df.columns else 0
}
return stats
except Exception as e:
return {'error': str(e)}
# 向后兼容的函数
def load_data(file_path=None, store_id=None):
"""
向后兼容的数据加载函数
"""
return load_multi_store_data(file_path, store_id=store_id)

View File

@ -0,0 +1,86 @@
import pandas as pd
import os
def load_new_data(file_path='data/old_5shops_50skus.parquet'):
"""
加载并适配新的Parquet数据文件为现有系统提供兼容的数据格式
核心原则:
1. 保证新数据的完整性不丢弃任何原始特征
2. 优先适配新数据通过重命名和创建代理列来兼容旧代码
参数:
file_path (str): 新数据文件的路径
返回:
pandas.DataFrame: 经过适配处理的包含所有原始特征的DataFrame
"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"数据文件不存在: {file_path}")
print(f"正在从 {file_path} 加载新数据...")
df = pd.read_parquet(file_path)
print("数据加载完成,开始进行适配处理...")
# 创建一个副本以进行修改保留原始df的纯净
df_adapted = df.copy()
# --- 1. 列名映射 (适配旧代码的命名习惯) ---
# 步骤1.1: 安全地删除新数据中多余的 'date' 列,以 'kdrq' 为准
if 'date' in df_adapted.columns and 'kdrq' in df.columns:
df_adapted.drop(columns=['date'], inplace=True)
print("已删除新数据中多余的 'date' 列,以 'kdrq' 为准。")
rename_map = {
'subbh': 'store_id',
'hh': 'product_id',
'kdrq': 'date', # 现在可以安全地将 kdrq 重命名为 date
'net_sales_quantity': 'sales', # 将目标变量映射为 'sales'
'temperature_2m_mean': 'temperature',
'day_of_week': 'weekday'
}
df_adapted.rename(columns=rename_map, inplace=True)
print(f"已完成列名映射: {list(rename_map.keys())} -> {list(rename_map.values())}")
# --- 2. 数据类型转换 ---
# 将 'date' 列转换为标准的datetime对象
df_adapted['date'] = pd.to_datetime(df_adapted['date'])
print("已将 'date' 列转换为 datetime 类型。")
# --- 3. 关键特征工程 (创建代理列) ---
# 现有模型依赖 'is_promotion' 和 'is_weekend' 特征。
# 'is_weekend' 在新数据中已存在,无需处理。
# 'is_promotion' 在新数据中不存在创建一个默认值为0的代理列。
if 'is_promotion' not in df_adapted.columns:
df_adapted['is_promotion'] = 0
print("创建了代理列 'is_promotion' 并填充默认值 0。")
# 确保 'month' 列存在,如果不存在则从日期中提取
if 'month' not in df_adapted.columns and 'date' in df_adapted.columns:
df_adapted['month'] = df_adapted['date'].dt.month
print("'date' 列中提取并创建了 'month' 列。")
print("数据适配处理完成。")
# 返回包含所有列的适配后DataFrame
return df_adapted
if __name__ == '__main__':
# 用于直接运行此脚本进行测试
print("--- 测试数据加载器 ---")
try:
adapted_df = load_new_data()
print("\n--- 适配后数据信息 ---")
adapted_df.info()
print("\n--- 检查关键列 ---")
key_cols = [
'store_id', 'product_id', 'date', 'sales',
'temperature', 'weekday', 'is_promotion', 'month'
]
print(adapted_df[key_cols].head())
print(f"\n测试成功适配后的DataFrame包含 {len(adapted_df.columns)} 列。")
except Exception as e:
print(f"\n测试失败: {e}")

View File

@ -45,6 +45,10 @@ class TrainingTask:
training_mode: str training_mode: str
store_id: Optional[str] = None store_id: Optional[str] = None
epochs: int = 100 epochs: int = 100
# 新增字段以支持自定义范围的全局训练
selected_stores: Optional[list] = None
selected_products: Optional[list] = None
aggregation_method: str = 'sum'
status: str = "pending" # pending, running, completed, failed status: str = "pending" # pending, running, completed, failed
start_time: Optional[str] = None start_time: Optional[str] = None
end_time: Optional[str] = None end_time: Optional[str] = None
@ -138,12 +142,16 @@ class TrainingWorker:
training_logger.error(f"进度回调失败: {e}") training_logger.error(f"进度回调失败: {e}")
# 执行真正的训练,传递进度回调 # 执行真正的训练,传递进度回调
# 升级:传递新增的自定义范围参数
metrics = predictor.train_model( metrics = predictor.train_model(
product_id=task.product_id, product_id=task.product_id,
model_type=task.model_type, model_type=task.model_type,
epochs=task.epochs, epochs=task.epochs,
store_id=task.store_id, store_id=task.store_id,
training_mode=task.training_mode, training_mode=task.training_mode,
selected_stores=task.selected_stores,
selected_products=task.selected_products,
aggregation_method=task.aggregation_method,
socketio=None, # 子进程中不能直接使用socketio socketio=None, # 子进程中不能直接使用socketio
task_id=task.task_id, task_id=task.task_id,
progress_callback=progress_callback # 传递进度回调函数 progress_callback=progress_callback # 传递进度回调函数
@ -282,7 +290,11 @@ class TrainingProcessManager:
self.logger.info("✅ 训练进程管理器已停止") self.logger.info("✅ 训练进程管理器已停止")
def submit_task(self, product_id: str, model_type: str, training_mode: str = "product", def submit_task(self, product_id: str, model_type: str, training_mode: str = "product",
store_id: str = None, epochs: int = 100, **kwargs) -> str: store_id: str = None, epochs: int = 100,
selected_stores: Optional[list] = None,
selected_products: Optional[list] = None,
aggregation_method: str = 'sum',
**kwargs) -> str:
"""提交训练任务""" """提交训练任务"""
task_id = str(uuid.uuid4()) task_id = str(uuid.uuid4())
@ -292,7 +304,10 @@ class TrainingProcessManager:
model_type=model_type, model_type=model_type,
training_mode=training_mode, training_mode=training_mode,
store_id=store_id, store_id=store_id,
epochs=epochs epochs=epochs,
selected_stores=selected_stores,
selected_products=selected_products,
aggregation_method=aggregation_method
) )
with self.lock: with self.lock:

55
temp_data_analysis.py Normal file
View File

@ -0,0 +1,55 @@
import pandas as pd
import os
def analyze_parquet_files():
"""
分析两个Parquet数据文件的结构差异
"""
data_path = 'data'
current_data_file = os.path.join(data_path, 'timeseries_training_data_sample_10s50p.parquet')
new_data_file = os.path.join(data_path, 'old_5shops_50skus.parquet')
print("="*50)
print("数据文件差异分析报告")
print("="*50)
try:
# --- 分析当前数据文件 ---
print(f"\n--- 1. 分析当前数据: {current_data_file} ---\n")
if os.path.exists(current_data_file):
df_current = pd.read_parquet(current_data_file)
print("【列名和数据类型】:")
df_current.info(verbose=False)
print("\n【前5行样本数据】:")
print(df_current.head())
print(f"\n【总行数】: {len(df_current)}")
print(f"【唯一店铺数】: {df_current['store_id'].nunique()}")
print(f"【唯一商品数】: {df_current['product_id'].nunique()}")
else:
print(f"错误: 文件不存在 {current_data_file}")
print("\n" + "-"*40 + "\n")
# --- 分析新数据文件 ---
print(f"\n--- 2. 分析新数据: {new_data_file} ---\n")
if os.path.exists(new_data_file):
df_new = pd.read_parquet(new_data_file)
print("【列名和数据类型 (仅显示部分)】:")
df_new.info(verbose=True, max_cols=10, show_counts=True) # 显示更详细的信息
print("\n【所有列名列表】:")
print(df_new.columns.tolist())
print("\n【前5行样本数据 (部分列)】:")
# 选择一些关键列进行展示
display_cols = ['subbh', 'hh', 'kdrq', 'net_sales_quantity', 'is_weekend', 'sales_quantity_rolling_mean_7d', 'province', 'temperature_2m_mean', 'brand_encoded']
print(df_new[display_cols].head())
print(f"\n【总行数】: {len(df_new)}")
print(f"【唯一店铺数 (subbh)】: {df_new['subbh'].nunique()}")
print(f"【唯一商品数 (hh)】: {df_new['hh'].nunique()}")
else:
print(f"错误: 文件不存在 {new_data_file}")
except Exception as e:
print(f"\n分析过程中出现错误: {e}")
if __name__ == '__main__':
analyze_parquet_files()