修复训练预测链路
This commit is contained in:
parent
aaf2672b7f
commit
fab0c254be
@ -410,3 +410,77 @@
|
|||||||
2. **调试与修复**: 脚本的初版存在时区比较逻辑错误,未能正确识别本地时区的记录。在经过一轮调试,将时间比较基准从UTC修正为本地时间后,脚本得以正常工作。
|
2. **调试与修复**: 脚本的初版存在时区比较逻辑错误,未能正确识别本地时区的记录。在经过一轮调试,将时间比较基准从UTC修正为本地时间后,脚本得以正常工作。
|
||||||
3. **执行清理**: 最终成功执行脚本,从数据库中清除了所有在指定时间点之前创建的、带有错误文件路径的历史记录。
|
3. **执行清理**: 最终成功执行脚本,从数据库中清除了所有在指定时间点之前创建的、带有错误文件路径的历史记录。
|
||||||
- **最终结论**: 通过对数据库中的存量数据进行清理,彻底解决了因新旧数据路径不一致而导致的 `404` 错误,确保了系统数据的一致性和功能的稳定性。
|
- **最终结论**: 通过对数据库中的存量数据进行清理,彻底解决了因新旧数据路径不一致而导致的 `404` 错误,确保了系统数据的一致性和功能的稳定性。
|
||||||
|
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2025-07-26:核心数据源替换与系统级重构
|
||||||
|
**开发者**: Roo (AI Assistant) & lyf
|
||||||
|
|
||||||
|
### 第一阶段:核心数据源替换与适配器重构
|
||||||
|
- **任务目标**: 将项目的数据源从旧的、简单的 `timeseries_training_data_sample_10s50p.parquet` 彻底更换为新的、特征更丰富、结构更复杂的核心数据集。
|
||||||
|
- **核心挑战**: 新旧数据集在表结构(如 `subbh` vs `store_id`)、数据类型、特征集(如新数据缺少 `is_promotion` 列)和数据粒度上存在巨大差异,直接替换会导致系统全面崩溃。
|
||||||
|
- **解决方案 (适配器模式)**:
|
||||||
|
1. **创建中央数据适配器**: 新建了 `server/utils/new_data_loader.py`,作为整个系统唯一的数据入口。
|
||||||
|
2. **实现动态转换**: 该适配器负责加载新数据集,并动态地将其转换为旧代码库能够理解的格式,包括重命名列、用 `0` 填充缺失的 `is_promotion` 特征、以及处理数据类型。
|
||||||
|
3. **全面重构**: 系统性地重构了所有数据消费端——包括 `server/api.py`, `server/core/predictor.py` 以及 `server/trainers/` 目录下的**所有**训练器脚本——使其不再直接加载数据,而是统一通过新的数据适配器获取,从而实现了新数据源与现有业务逻辑的完全解耦。
|
||||||
|
|
||||||
|
### 第二阶段:端到端迭代调试与连锁问题修复
|
||||||
|
在完成重构后,我们进行了一系列端到端的测试,并修复了因此次重大变更而引发的一系列连锁问题。
|
||||||
|
- **修复 `KeyError`**: 解决了因新旧列名 (`subbh` vs `store_id`) 不匹配导致的键查找错误。
|
||||||
|
- **修复 `NaN` 值问题**: 在实现“按店铺训练-所有商品”的聚合逻辑时,因部分商品在某些日期无销售记录,导致聚合操作引入了`NaN`值。通过在所有训练器的数据准备阶段增加 `.fillna(0)` 清理步骤,彻底解决了该问题。
|
||||||
|
- **修复 `TypeError` (JSON序列化)**: 解决了在API响应中,因返回了未经处理的 `numpy.ndarray` 和 `datetime.date` 对象而导致的JSON序列化失败问题。
|
||||||
|
- **修复 XGBoost 预测逻辑的根本性错误**:
|
||||||
|
- **问题现象**: 使用新数据训练的 `XGBoost` 模型,后端预测返回 `200 OK`,但前端图表渲染失败。
|
||||||
|
- **根本原因**: `server/predictors/model_predictor.py` 中存在一个严重的逻辑缺陷。它错误地对 `XGBoost` 模型使用了为 `PyTorch` 模型设计的“自回归”循环预测逻辑(即预测一天,再用该预测值去预测下一天)。而 `XGBoost` 模型本身是“直接多步输出”模型,一次性就能返回所有未来日期的预测值。错误的循环逻辑导致系统只取用了 `XGBoost` 完整输出结果中的第一个值,并将其错误地复制了多天,生成了无用的预测结果。
|
||||||
|
- **最终修复**: 在 `model_predictor.py` 中为 `XGBoost` 模型创建了一个独立的、非循环的逻辑分支。该分支能够正确地接收并处理 `XGBoost` 的完整输出数组,从而生成了正确的、可供前端渲染的预测结果。
|
||||||
|
|
||||||
|
### 第三阶段:系统环境清理
|
||||||
|
- **任务目标**: 在最终测试前,确保一个完全纯净的、不受任何旧数据或旧模型干扰的系统环境。
|
||||||
|
- **实施过程**: 在开发者的精确指导下,我们完成了以下清理工作:
|
||||||
|
1. **文件系统清理**: 手动删除了 `saved_models` 和 `saved_predictions` 文件夹中的所有历史产物。
|
||||||
|
2. **数据库清理**: 成功执行了 `server/tools/delete_old_predictions.py` 脚本,清空了 `prediction_history.db` 数据库中所有过时的预测记录。
|
||||||
|
- **最终结论**: 至此,数据源替换、系统重构、连锁问题修复和环境清理工作已全部完成。项目现在处于一个代码逻辑更健壮、数据源更可靠的全新状态,并已准备好进行最终的完整性验证。
|
||||||
|
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2025-07-26 (续): 端到端测试与系统级修复
|
||||||
|
**开发者**: Roo (AI Assistant) & lyf
|
||||||
|
|
||||||
|
### 第三阶段 (续): 迭代修复与架构澄清
|
||||||
|
在对重构后的系统进行全面的端到端测试时,我们发现并修复了一系列深层次的、与特定模型架构和训练模式相关的Bug。
|
||||||
|
|
||||||
|
- **修复 `XGBoost` 自定义天数预测崩溃的Bug**:
|
||||||
|
- **问题现象**: 使用 `XGBoost` 模型进行预测时,如果自定义的预测天数与模型训练时固定的预测范围不符,程序会因数组长度不匹配而崩溃。
|
||||||
|
- **架构澄清**: 我们确认了 `XGBoost` 是“直接多步输出”模型,其预测长度在训练时已固定。
|
||||||
|
- **最终修复**: 修改了 `server/predictors/model_predictor.py`,使其不再依赖用户输入的预测天数,而是根据模型**实际**的输出长度来动态生成日期序列,从而保证了程序的健壮性。
|
||||||
|
|
||||||
|
- **修复 `CnnBiLstmAttention` 模型预测逻辑错误**:
|
||||||
|
- **问题现象**: 该模型预测成功,但前端图表渲染失败。
|
||||||
|
- **根本原因**: 与 `XGBoost` 类似,该模型也是“直接多步输出”架构,但被错误地注册给了为“自回归”模型设计的预测函数,导致生成了无意义的预测结果。
|
||||||
|
- **最终修复**: 在 `server/predictors/model_predictor.py` 中,为 `CnnBiLstmAttention` 模型创建了一个专属的、正确的、非回归式的预测逻辑分支,确保其能被正确处理。
|
||||||
|
|
||||||
|
- **修复“全局训练”模式数据筛选逻辑错误**:
|
||||||
|
- **问题现象**: 在“全局训练”模式下,选择“所有店铺所有药品”时,训练因找不到任何数据而失败。
|
||||||
|
- **根本原因**: `server/core/predictor.py` 在处理该模式时,错误地使用了一个特殊的信号值(如 `unknown`)去筛选 `product_id`,导致数据集为空。
|
||||||
|
- **最终修复**: 重构了该部分的逻辑,确保在选择“所有药品”时,程序能正确地跳过按 `product_id` 筛选的步骤,直接对整个数据集进行聚合。
|
||||||
|
|
||||||
|
- **修复“全局预测”模式数据加载逻辑错误**:
|
||||||
|
- **问题现象**: “全局预测”模式因找不到数据而失败。
|
||||||
|
- **根本原因**: `server/predictors/model_predictor.py` 中存在与训练逻辑不一致的镜像Bug。
|
||||||
|
- **最终修复**: 再次修改了 `model_predictor.py`,使其“全局预测”的数据加载逻辑与 `predictor.py` 中的训练逻辑完全同步。
|
||||||
|
|
||||||
|
- **澄清“负销量”是数据特性而非Bug**:
|
||||||
|
- **问题现象**: 在“全局预测”的图表中,历史销量出现了负数。
|
||||||
|
- **分析**: 经过讨论,我们确认这是由于原始数据中包含了“退货”等业务场景,导致按天聚合求和后出现负值。
|
||||||
|
- **最终决策**: 为了尊重元数据的真实性,我们决定**不**在代码中将负数强制修正为0,并接受这可能会对模型的训练和预测结果产生影响。这是一个重要的、关于数据处理哲学的决策。
|
||||||
|
|
||||||
|
- **全链路增强以支持“自定义全局训练”功能**:
|
||||||
|
- **问题现象**: 在“全局训练”中选择“自定义范围”(指定店铺和药品列表)时,训练失败。
|
||||||
|
- **根本原因**: 这是一个**全链路的参数传递中断**问题。从API接口到进程管理器,再到核心训练逻辑,整个系统都没有为处理这个新增的、复杂的自定义范围参数做好准备。
|
||||||
|
- **最终修复 (三位一体的全链路改造)**:
|
||||||
|
1. **升级API层 (`api.py`)**: 修改了 `/api/training` 接口,使其能够正确接收前端传递的 `selected_stores` 和 `selected_products` 列表。
|
||||||
|
2. **升级进程管理层 (`training_process_manager.py`)**: 对 `TrainingTask` 数据类和核心函数进行了全面改造,使其能够完整地接收、存储和向下传递这些新的自定义范围参数。
|
||||||
|
3. **升级核心逻辑层 (`predictor.py`)**: 对 `train_model` 函数进行了重大功能增强,为其增加了处理 `selected_stores` 和 `selected_products` 列表的全新逻辑分支。
|
||||||
|
- **最终结论**: 通过这次彻底的全链路改造,我们不仅修复了Bug,还成功地为系统增加了一项强大的新功能。至此,所有在端到端测试中发现的已知问题均已得到解决。
|
||||||
|
Binary file not shown.
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -1 +0,0 @@
|
|||||||
{"product_id": "11020059", "product_name": "Product 11020059", "model_type": "kan", "predictions": [{"date": "2022-01-01", "predicted_sales": 0.12464457750320435}, {"date": "2022-01-02", "predicted_sales": 0.11735431849956512}, {"date": "2022-01-03", "predicted_sales": 0.11733274161815643}, {"date": "2022-01-04", "predicted_sales": 0.11753901839256287}, {"date": "2022-01-05", "predicted_sales": 0.11754211783409119}, {"date": "2022-01-06", "predicted_sales": 0.11753572523593903}, {"date": "2022-01-07", "predicted_sales": 0.11738376319408417}], "prediction_data": [{"date": "2022-01-01", "predicted_sales": 0.12464457750320435}, {"date": "2022-01-02", "predicted_sales": 0.11735431849956512}, {"date": "2022-01-03", "predicted_sales": 0.11733274161815643}, {"date": "2022-01-04", "predicted_sales": 0.11753901839256287}, {"date": "2022-01-05", "predicted_sales": 0.11754211783409119}, {"date": "2022-01-06", "predicted_sales": 0.11753572523593903}, {"date": "2022-01-07", "predicted_sales": 0.11738376319408417}], "history_data": [{"store_id": "01010108", "product_id": "11020059", "date": "2021-12-02", "sales_quantity": 0.0, "return_quantity": 0.0, "sales": 0.0, "gross_profit_total": 0.0, "transaction_count": 0, "sales_quantity_rolling_mean_7d": 0.5, "return_quantity_rolling_mean_7d": 0.0, "net_sales_quantity_rolling_mean_7d": 0.5, "sales_quantity_rolling_sum_7d": 1.0, "return_quantity_rolling_sum_7d": 0.0, "net_sales_quantity_rolling_sum_7d": 1.0, "sales_quantity_rolling_mean_15d": 0.5, "return_quantity_rolling_mean_15d": 0.0, "net_sales_quantity_rolling_mean_15d": 0.5, "sales_quantity_rolling_sum_15d": 1.0, "return_quantity_rolling_sum_15d": 0.0, "net_sales_quantity_rolling_sum_15d": 1.0, "sales_quantity_rolling_mean_30d": 0.71, "return_quantity_rolling_mean_30d": -0.14, "net_sales_quantity_rolling_mean_30d": 0.57, "sales_quantity_rolling_sum_30d": 5.0, "return_quantity_rolling_sum_30d": -1.0, "net_sales_quantity_rolling_sum_30d": 4.0, "sales_quantity_rolling_mean_90d": 0.77, "return_quantity_rolling_mean_90d": -0.09, "net_sales_quantity_rolling_mean_90d": 0.68, "sales_quantity_rolling_sum_90d": 17.0, "return_quantity_rolling_sum_90d": -2.0, "net_sales_quantity_rolling_sum_90d": 15.0, "is_weekend": false, "weekday": 3, "day_of_month": 2, "day_of_year": 336, "week_of_month": 1, "month": 12, "quarter": 4, "is_holiday": false, "first_sale_date":
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -1 +0,0 @@
|
|||||||
{"product_id": "11020059", "product_name": "Product 11020059", "model_type": "kan", "predictions": [{"date": "2022-01-01", "predicted_sales": 0.12464457750320435}, {"date": "2022-01-02", "predicted_sales": 0.11735431849956512}, {"date": "2022-01-03", "predicted_sales": 0.11733274161815643}, {"date": "2022-01-04", "predicted_sales": 0.11753901839256287}, {"date": "2022-01-05", "predicted_sales": 0.11754211783409119}, {"date": "2022-01-06", "predicted_sales": 0.11753572523593903}, {"date": "2022-01-07", "predicted_sales": 0.11738376319408417}], "prediction_data": [{"date": "2022-01-01", "predicted_sales": 0.12464457750320435}, {"date": "2022-01-02", "predicted_sales": 0.11735431849956512}, {"date": "2022-01-03", "predicted_sales": 0.11733274161815643}, {"date": "2022-01-04", "predicted_sales": 0.11753901839256287}, {"date": "2022-01-05", "predicted_sales": 0.11754211783409119}, {"date": "2022-01-06", "predicted_sales": 0.11753572523593903}, {"date": "2022-01-07", "predicted_sales": 0.11738376319408417}], "history_data": [{"store_id": "01010108", "product_id": "11020059", "date": "2021-12-02", "sales_quantity": 0.0, "return_quantity": 0.0, "sales": 0.0, "gross_profit_total": 0.0, "transaction_count": 0, "sales_quantity_rolling_mean_7d": 0.5, "return_quantity_rolling_mean_7d": 0.0, "net_sales_quantity_rolling_mean_7d": 0.5, "sales_quantity_rolling_sum_7d": 1.0, "return_quantity_rolling_sum_7d": 0.0, "net_sales_quantity_rolling_sum_7d": 1.0, "sales_quantity_rolling_mean_15d": 0.5, "return_quantity_rolling_mean_15d": 0.0, "net_sales_quantity_rolling_mean_15d": 0.5, "sales_quantity_rolling_sum_15d": 1.0, "return_quantity_rolling_sum_15d": 0.0, "net_sales_quantity_rolling_sum_15d": 1.0, "sales_quantity_rolling_mean_30d": 0.71, "return_quantity_rolling_mean_30d": -0.14, "net_sales_quantity_rolling_mean_30d": 0.57, "sales_quantity_rolling_sum_30d": 5.0, "return_quantity_rolling_sum_30d": -1.0, "net_sales_quantity_rolling_sum_30d": 4.0, "sales_quantity_rolling_mean_90d": 0.77, "return_quantity_rolling_mean_90d": -0.09, "net_sales_quantity_rolling_mean_90d": 0.68, "sales_quantity_rolling_sum_90d": 17.0, "return_quantity_rolling_sum_90d": -2.0, "net_sales_quantity_rolling_sum_90d": 15.0, "is_weekend": false, "weekday": 3, "day_of_month": 2, "day_of_year": 336, "week_of_month": 1, "month": 12, "quarter": 4, "is_holiday": false, "first_sale_date":
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -919,12 +919,16 @@ def start_training():
|
|||||||
|
|
||||||
# 使用新的训练进程管理器提交任务
|
# 使用新的训练进程管理器提交任务
|
||||||
try:
|
try:
|
||||||
|
# 升级:将自定义范围参数也传递给训练任务
|
||||||
task_id = training_manager.submit_task(
|
task_id = training_manager.submit_task(
|
||||||
product_id=product_id or "unknown",
|
product_id=product_id or "unknown",
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
training_mode=training_mode,
|
training_mode=training_mode,
|
||||||
store_id=store_id,
|
store_id=store_id,
|
||||||
epochs=epochs
|
epochs=epochs,
|
||||||
|
selected_stores=data.get('selected_stores'),
|
||||||
|
selected_products=data.get('selected_products'),
|
||||||
|
aggregation_method=aggregation_method
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"🚀 训练任务已提交到进程管理器: {task_id[:8]}")
|
logger.info(f"🚀 训练任务已提交到进程管理器: {task_id[:8]}")
|
||||||
|
@ -64,6 +64,7 @@ class PharmacyPredictor:
|
|||||||
learning_rate=0.001, sequence_length=30, forecast_horizon=7,
|
learning_rate=0.001, sequence_length=30, forecast_horizon=7,
|
||||||
hidden_size=64, num_layers=2, dropout=0.1, use_optimized=False,
|
hidden_size=64, num_layers=2, dropout=0.1, use_optimized=False,
|
||||||
store_id=None, training_mode='product', aggregation_method='sum',
|
store_id=None, training_mode='product', aggregation_method='sum',
|
||||||
|
selected_stores: list = None, selected_products: list = None, # 新增参数以支持自定义范围
|
||||||
socketio=None, task_id=None, version=None, continue_training=False,
|
socketio=None, task_id=None, version=None, continue_training=False,
|
||||||
progress_callback=None):
|
progress_callback=None):
|
||||||
"""
|
"""
|
||||||
@ -142,14 +143,38 @@ class PharmacyPredictor:
|
|||||||
}).reset_index()
|
}).reset_index()
|
||||||
log_message(f"聚合后数据量: {len(product_data)}", 'info')
|
log_message(f"聚合后数据量: {len(product_data)}", 'info')
|
||||||
|
|
||||||
|
# 数据清洗:使用0填充聚合后可能产生的NaN值
|
||||||
|
product_data.fillna(0, inplace=True)
|
||||||
|
log_message("已对聚合数据进行NaN值填充处理(使用0填充)", 'info')
|
||||||
|
|
||||||
elif training_mode == 'global':
|
elif training_mode == 'global':
|
||||||
|
# 升级:支持“自定义范围”的全局训练
|
||||||
|
if selected_stores and selected_products:
|
||||||
|
log_message(f"全局训练模式: 自定义范围 -> 店铺 {selected_stores}, 产品 {selected_products}", 'info')
|
||||||
|
# 根据选择的店铺和产品列表进行精确筛选
|
||||||
|
product_data = full_df[
|
||||||
|
full_df['store_id'].isin(selected_stores) &
|
||||||
|
full_df['product_id'].isin(selected_products)
|
||||||
|
].copy()
|
||||||
|
# 回退到旧的全局训练逻辑
|
||||||
|
elif product_id and product_id not in ['unknown', 'all_products']:
|
||||||
|
log_message(f"全局训练模式: 聚合单个产品 '{product_id}' 的跨店数据...", 'info')
|
||||||
product_data = full_df[full_df['product_id'] == product_id].copy()
|
product_data = full_df[full_df['product_id'] == product_id].copy()
|
||||||
|
else:
|
||||||
|
log_message("全局训练模式: 聚合所有店铺、所有药品的数据...", 'info')
|
||||||
|
product_data = full_df.copy()
|
||||||
|
|
||||||
|
if product_data.empty:
|
||||||
|
log_message(f"在指定的全局范围或自定义范围内找不到任何数据。", 'error')
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 对筛选出的数据按天进行聚合
|
||||||
product_data = product_data.groupby('date').agg({
|
product_data = product_data.groupby('date').agg({
|
||||||
'sales': aggregation_method,
|
'sales': aggregation_method,
|
||||||
'weekday': 'first', 'month': 'first', 'is_holiday': 'first',
|
'weekday': 'first', 'month': 'first', 'is_holiday': 'first',
|
||||||
'is_weekend': 'first', 'is_promotion': 'first', 'temperature': 'mean'
|
'is_weekend': 'first', 'is_promotion': 'first', 'temperature': 'mean'
|
||||||
}).reset_index()
|
}).reset_index()
|
||||||
log_message(f"全局训练模式: 产品 {product_id}, 聚合方法 {aggregation_method}, 数据量: {len(product_data)}", 'info')
|
log_message(f"全局训练聚合完成,聚合方法: {aggregation_method}, 数据量: {len(product_data)}", 'info')
|
||||||
|
|
||||||
else:
|
else:
|
||||||
log_message(f"不支持的训练模式: {training_mode}", 'error')
|
log_message(f"不支持的训练模式: {training_mode}", 'error')
|
||||||
@ -160,7 +185,15 @@ class PharmacyPredictor:
|
|||||||
# 店铺模型的标识符只应基于店铺ID
|
# 店铺模型的标识符只应基于店铺ID
|
||||||
model_identifier = f"store_{store_id}"
|
model_identifier = f"store_{store_id}"
|
||||||
elif training_mode == 'global':
|
elif training_mode == 'global':
|
||||||
# 全局模型的标识符不应依赖于单个product_id
|
# 升级:为自定义范围的全局模型创建唯一标识符
|
||||||
|
if selected_stores and selected_products:
|
||||||
|
# 创建一个基于内容哈希的稳定标识符
|
||||||
|
import hashlib
|
||||||
|
s_stores = "_".join(sorted(selected_stores))
|
||||||
|
s_products = "_".join(sorted(selected_products))
|
||||||
|
hash_input = f"{s_stores}_{s_products}_{aggregation_method}"
|
||||||
|
model_identifier = f"global_custom_{hashlib.md5(hash_input.encode()).hexdigest()[:8]}"
|
||||||
|
else:
|
||||||
model_identifier = f"global_{aggregation_method}"
|
model_identifier = f"global_{aggregation_method}"
|
||||||
else: # product mode
|
else: # product mode
|
||||||
model_identifier = product_id
|
model_identifier = product_id
|
||||||
|
@ -49,24 +49,64 @@ def default_pytorch_predictor(model, checkpoint, product_df, future_days, start_
|
|||||||
|
|
||||||
history_for_chart_df = product_df[product_df['date'] < start_date_dt].tail(history_lookback_days)
|
history_for_chart_df = product_df[product_df['date'] < start_date_dt].tail(history_lookback_days)
|
||||||
|
|
||||||
|
# --- 重构预测逻辑以正确处理XGBoost ---
|
||||||
|
if isinstance(model, xgb.Booster):
|
||||||
|
# --- XGBoost 预测路径 (非回归式,一次性预测) ---
|
||||||
|
X_current_scaled = scaler_X.transform(prediction_input_df[features].values)
|
||||||
|
X_input_reshaped = X_current_scaled.reshape(1, -1)
|
||||||
|
d_input = xgb.DMatrix(X_input_reshaped)
|
||||||
|
|
||||||
|
# 一次性获取所有未来天数的预测
|
||||||
|
y_pred_scaled = model.predict(d_input, iteration_range=(0, model.best_iteration))
|
||||||
|
|
||||||
|
# 反归一化整个序列
|
||||||
|
y_pred_unscaled = scaler_y.inverse_transform(y_pred_scaled.reshape(1, -1)).flatten()
|
||||||
|
y_pred_unscaled = np.maximum(0, y_pred_unscaled) # 确保销量不为负
|
||||||
|
|
||||||
|
# 生成未来日期序列
|
||||||
|
# 修正: 未来日期的数量必须与模型实际输出的预测点数量一致
|
||||||
|
# 而不是遵循用户输入的 future_days,因为XGBoost模型输出的长度是固定的。
|
||||||
|
future_dates = pd.date_range(start=start_date_dt, periods=len(y_pred_unscaled))
|
||||||
|
|
||||||
|
# 直接构建结果DataFrame
|
||||||
|
predictions_df = pd.DataFrame({
|
||||||
|
'date': future_dates,
|
||||||
|
'predicted_sales': y_pred_unscaled
|
||||||
|
})
|
||||||
|
|
||||||
|
elif isinstance(model, CnnBiLstmAttention):
|
||||||
|
# --- CnnBiLstmAttention 预测路径 (非回归式,一次性预测) ---
|
||||||
|
X_current_scaled = scaler_X.transform(prediction_input_df[features].values)
|
||||||
|
X_input = torch.tensor(X_current_scaled.reshape(1, sequence_length, -1), dtype=torch.float32).to(DEVICE)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# 一次性获取所有未来天数的预测
|
||||||
|
y_pred_scaled = model(X_input).cpu().numpy()
|
||||||
|
|
||||||
|
# 反归一化整个序列
|
||||||
|
y_pred_unscaled = scaler_y.inverse_transform(y_pred_scaled).flatten()
|
||||||
|
y_pred_unscaled = np.maximum(0, y_pred_unscaled) # 确保销量不为负
|
||||||
|
|
||||||
|
# 生成未来日期序列,其长度与模型实际输出的预测点数量一致
|
||||||
|
future_dates = pd.date_range(start=start_date_dt, periods=len(y_pred_unscaled))
|
||||||
|
|
||||||
|
# 直接构建结果DataFrame
|
||||||
|
predictions_df = pd.DataFrame({
|
||||||
|
'date': future_dates,
|
||||||
|
'predicted_sales': y_pred_unscaled
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
# --- 默认 PyTorch 模型预测路径 (自回归式) ---
|
||||||
all_predictions = []
|
all_predictions = []
|
||||||
current_sequence_df = prediction_input_df.copy()
|
current_sequence_df = prediction_input_df.copy()
|
||||||
|
|
||||||
for _ in range(future_days):
|
for _ in range(future_days):
|
||||||
X_current_scaled = scaler_X.transform(current_sequence_df[features].values)
|
X_current_scaled = scaler_X.transform(current_sequence_df[features].values)
|
||||||
# **核心改进**: 智能判断模型类型并调用相应的预测方法
|
|
||||||
if isinstance(model, xgb.Booster):
|
|
||||||
# XGBoost 模型预测路径
|
|
||||||
X_input_reshaped = X_current_scaled.reshape(1, -1)
|
|
||||||
d_input = xgb.DMatrix(X_input_reshaped)
|
|
||||||
# **关键修复**: 使用 best_iteration 进行预测,以匹配早停策略
|
|
||||||
y_pred_scaled = model.predict(d_input, iteration_range=(0, model.best_iteration))
|
|
||||||
next_step_pred_scaled = y_pred_scaled.reshape(1, -1)
|
|
||||||
else:
|
|
||||||
# 默认 PyTorch 模型预测路径
|
|
||||||
X_input = torch.tensor(X_current_scaled.reshape(1, sequence_length, -1), dtype=torch.float32).to(DEVICE)
|
X_input = torch.tensor(X_current_scaled.reshape(1, sequence_length, -1), dtype=torch.float32).to(DEVICE)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
y_pred_scaled = model(X_input).cpu().numpy()
|
y_pred_scaled = model(X_input).cpu().numpy()
|
||||||
|
|
||||||
next_step_pred_scaled = y_pred_scaled[0, 0].reshape(1, -1)
|
next_step_pred_scaled = y_pred_scaled[0, 0].reshape(1, -1)
|
||||||
next_step_pred_unscaled = float(max(0, scaler_y.inverse_transform(next_step_pred_scaled)[0][0]))
|
next_step_pred_unscaled = float(max(0, scaler_y.inverse_transform(next_step_pred_scaled)[0][0]))
|
||||||
|
|
||||||
@ -116,13 +156,21 @@ def load_model_and_predict(model_path: str, product_id: str, model_type: str, st
|
|||||||
}).reset_index()
|
}).reset_index()
|
||||||
product_name = f"店铺 {store_id} (所有药品聚合)"
|
product_name = f"店铺 {store_id} (所有药品聚合)"
|
||||||
elif training_mode == 'global':
|
elif training_mode == 'global':
|
||||||
|
# 修正全局预测的数据加载逻辑
|
||||||
|
if product_id and product_id not in ['unknown', 'all_products']:
|
||||||
|
# 如果提供了具体产品ID(虽然全局模式下不常见,但应兼容),则聚合该产品的跨店数据
|
||||||
product_df = full_df[full_df['product_id'] == product_id].copy()
|
product_df = full_df[full_df['product_id'] == product_id].copy()
|
||||||
|
product_name = f"全局聚合 - 产品 {product_id}"
|
||||||
|
else:
|
||||||
|
# 如果是“所有药品”的全局预测,则聚合所有数据
|
||||||
|
product_df = full_df.copy()
|
||||||
|
product_name = "全局聚合 (所有药品)"
|
||||||
|
|
||||||
product_df = product_df.groupby('date').agg({
|
product_df = product_df.groupby('date').agg({
|
||||||
'sales': 'sum',
|
'sales': 'sum', # 默认使用sum,未来可配置
|
||||||
'weekday': 'first', 'month': 'first', 'is_holiday': 'first',
|
'weekday': 'first', 'month': 'first', 'is_holiday': 'first',
|
||||||
'is_weekend': 'first', 'is_promotion': 'first', 'temperature': 'mean'
|
'is_weekend': 'first', 'is_promotion': 'first', 'temperature': 'mean'
|
||||||
}).reset_index()
|
}).reset_index()
|
||||||
product_name = f"全局聚合 - 产品 {product_id}"
|
|
||||||
else: # 默认 'product' 模式
|
else: # 默认 'product' 模式
|
||||||
product_df = full_df[full_df['product_id'] == product_id].copy()
|
product_df = full_df[full_df['product_id'] == product_id].copy()
|
||||||
# 兼容性处理:新数据可能没有 product_name 列
|
# 兼容性处理:新数据可能没有 product_name 列
|
||||||
|
@ -13,6 +13,7 @@ from models.model_registry import register_trainer
|
|||||||
from utils.model_manager import model_manager
|
from utils.model_manager import model_manager
|
||||||
from analysis.metrics import evaluate_model
|
from analysis.metrics import evaluate_model
|
||||||
from utils.data_utils import create_dataset
|
from utils.data_utils import create_dataset
|
||||||
|
from utils.new_data_loader import load_new_data
|
||||||
from sklearn.preprocessing import MinMaxScaler
|
from sklearn.preprocessing import MinMaxScaler
|
||||||
from utils.visualization import plot_loss_curve # 导入绘图函数
|
from utils.visualization import plot_loss_curve # 导入绘图函数
|
||||||
|
|
||||||
@ -26,9 +27,38 @@ def train_with_cnn_bilstm_attention(product_id, model_identifier, product_df, st
|
|||||||
print(f"🚀 CNN-BiLSTM-Attention 训练器启动: model_identifier='{model_identifier}'")
|
print(f"🚀 CNN-BiLSTM-Attention 训练器启动: model_identifier='{model_identifier}'")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# --- 1. 数据准备 ---
|
# --- 1. 数据加载与筛选重构 ---
|
||||||
|
if product_df is None:
|
||||||
|
print("正在使用新的统一数据加载器...")
|
||||||
|
full_df = load_new_data()
|
||||||
|
|
||||||
|
if training_mode == 'store' and store_id:
|
||||||
|
store_df = full_df[full_df['store_id'] == store_id].copy()
|
||||||
|
if product_id and product_id != 'unknown' and product_id != 'all_products':
|
||||||
|
product_df = store_df[store_df['product_id'] == product_id].copy()
|
||||||
|
else:
|
||||||
|
product_df = store_df.groupby('date').agg({
|
||||||
|
'sales': 'sum', 'weekday': 'first', 'month': 'first',
|
||||||
|
'is_holiday': 'first', 'is_weekend': 'first',
|
||||||
|
'is_promotion': 'first', 'temperature': 'mean'
|
||||||
|
}).reset_index()
|
||||||
|
product_df.fillna(0, inplace=True)
|
||||||
|
elif training_mode == 'global':
|
||||||
|
product_df = full_df[full_df['product_id'] == product_id].copy()
|
||||||
|
product_df = product_df.groupby('date').agg({
|
||||||
|
'sales': 'sum', 'weekday': 'first', 'month': 'first',
|
||||||
|
'is_holiday': 'first', 'is_weekend': 'first',
|
||||||
|
'is_promotion': 'first', 'temperature': 'mean'
|
||||||
|
}).reset_index()
|
||||||
|
product_df.fillna(0, inplace=True)
|
||||||
|
else: # 默认 'product' 模式
|
||||||
|
product_df = full_df[full_df['product_id'] == product_id].copy()
|
||||||
|
|
||||||
|
# --- 2. 数据准备 ---
|
||||||
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||||
product_name = product_df['product_name'].iloc[0] if 'product_name' in product_df.columns else model_identifier
|
product_name = f"产品 {product_id}"
|
||||||
|
if 'product_name' in product_df.columns and not product_df['product_name'].empty:
|
||||||
|
product_name = product_df['product_name'].iloc[0]
|
||||||
|
|
||||||
X = product_df[features].values
|
X = product_df[features].values
|
||||||
y = product_df[['sales']].values
|
y = product_df[['sales']].values
|
||||||
|
@ -53,6 +53,9 @@ def train_product_model_with_kan(product_id, model_identifier, product_df=None,
|
|||||||
'is_promotion': 'first', 'temperature': 'mean'
|
'is_promotion': 'first', 'temperature': 'mean'
|
||||||
}).reset_index()
|
}).reset_index()
|
||||||
training_scope = f"店铺 {store_id} (所有药品聚合)"
|
training_scope = f"店铺 {store_id} (所有药品聚合)"
|
||||||
|
|
||||||
|
# 数据清洗:使用0填充聚合后可能产生的NaN值
|
||||||
|
product_df.fillna(0, inplace=True)
|
||||||
elif training_mode == 'global':
|
elif training_mode == 'global':
|
||||||
# 筛选特定产品在所有店铺的聚合数据
|
# 筛选特定产品在所有店铺的聚合数据
|
||||||
# 注意:新数据已经是按 (store_id, product_id, date) 展开的,聚合逻辑可能需要重新审视
|
# 注意:新数据已经是按 (store_id, product_id, date) 展开的,聚合逻辑可能需要重新审视
|
||||||
|
@ -143,6 +143,9 @@ def train_product_model_with_mlstm(
|
|||||||
'is_promotion': 'first', 'temperature': 'mean'
|
'is_promotion': 'first', 'temperature': 'mean'
|
||||||
}).reset_index()
|
}).reset_index()
|
||||||
training_scope = f"店铺 {store_id} (所有药品聚合)"
|
training_scope = f"店铺 {store_id} (所有药品聚合)"
|
||||||
|
|
||||||
|
# 数据清洗:使用0填充聚合后可能产生的NaN值
|
||||||
|
product_df.fillna(0, inplace=True)
|
||||||
elif training_mode == 'global':
|
elif training_mode == 'global':
|
||||||
product_df = full_df[full_df['product_id'] == product_id].copy()
|
product_df = full_df[full_df['product_id'] == product_id].copy()
|
||||||
product_df = product_df.groupby('date').agg({
|
product_df = product_df.groupby('date').agg({
|
||||||
|
@ -16,6 +16,7 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
from models.tcn_model import TCNForecaster
|
from models.tcn_model import TCNForecaster
|
||||||
from utils.data_utils import create_dataset, PharmacyDataset
|
from utils.data_utils import create_dataset, PharmacyDataset
|
||||||
|
from utils.new_data_loader import load_new_data
|
||||||
from utils.visualization import plot_loss_curve
|
from utils.visualization import plot_loss_curve
|
||||||
from analysis.metrics import evaluate_model
|
from analysis.metrics import evaluate_model
|
||||||
from core.config import DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON
|
from core.config import DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON
|
||||||
@ -58,12 +59,41 @@ def train_product_model_with_tcn(
|
|||||||
|
|
||||||
emit_progress(f"开始训练 TCN 模型")
|
emit_progress(f"开始训练 TCN 模型")
|
||||||
|
|
||||||
|
# --- 数据加载与筛选重构 ---
|
||||||
if product_df is None:
|
if product_df is None:
|
||||||
from utils.multi_store_data_utils import aggregate_multi_store_data
|
print("正在使用新的统一数据加载器...")
|
||||||
product_df = aggregate_multi_store_data(
|
full_df = load_new_data()
|
||||||
product_id=product_id,
|
|
||||||
aggregation_method=aggregation_method
|
if training_mode == 'store' and store_id:
|
||||||
)
|
store_df = full_df[full_df['store_id'] == store_id].copy()
|
||||||
|
if product_id and product_id != 'unknown' and product_id != 'all_products':
|
||||||
|
product_df = store_df[store_df['product_id'] == product_id].copy()
|
||||||
|
training_scope = f"店铺 {store_id} - 产品 {product_id}"
|
||||||
|
else:
|
||||||
|
product_df = store_df.groupby('date').agg({
|
||||||
|
'sales': 'sum', 'weekday': 'first', 'month': 'first',
|
||||||
|
'is_holiday': 'first', 'is_weekend': 'first',
|
||||||
|
'is_promotion': 'first', 'temperature': 'mean'
|
||||||
|
}).reset_index()
|
||||||
|
training_scope = f"店铺 {store_id} (所有药品聚合)"
|
||||||
|
product_df.fillna(0, inplace=True)
|
||||||
|
elif training_mode == 'global':
|
||||||
|
product_df = full_df[full_df['product_id'] == product_id].copy()
|
||||||
|
product_df = product_df.groupby('date').agg({
|
||||||
|
'sales': 'sum', 'weekday': 'first', 'month': 'first',
|
||||||
|
'is_holiday': 'first', 'is_weekend': 'first',
|
||||||
|
'is_promotion': 'first', 'temperature': 'mean'
|
||||||
|
}).reset_index()
|
||||||
|
training_scope = f"全局聚合({aggregation_method})"
|
||||||
|
product_df.fillna(0, inplace=True)
|
||||||
|
else: # 默认 'product' 模式
|
||||||
|
product_df = full_df[full_df['product_id'] == product_id].copy()
|
||||||
|
training_scope = f"所有店铺中的产品 {product_id}"
|
||||||
|
else:
|
||||||
|
# 如果传入了product_df,直接使用
|
||||||
|
if training_mode == 'store' and store_id:
|
||||||
|
training_scope = f"店铺 {store_id}"
|
||||||
|
elif training_mode == 'global':
|
||||||
training_scope = f"全局聚合({aggregation_method})"
|
training_scope = f"全局聚合({aggregation_method})"
|
||||||
else:
|
else:
|
||||||
training_scope = "所有店铺"
|
training_scope = "所有店铺"
|
||||||
@ -84,7 +114,11 @@ def train_product_model_with_tcn(
|
|||||||
raise ValueError(error_msg)
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
product_df = product_df.sort_values('date')
|
product_df = product_df.sort_values('date')
|
||||||
|
# 兼容性处理:新数据可能没有 product_name 列
|
||||||
|
if 'product_name' in product_df.columns and not product_df['product_name'].empty:
|
||||||
product_name = product_df['product_name'].iloc[0]
|
product_name = product_df['product_name'].iloc[0]
|
||||||
|
else:
|
||||||
|
product_name = f"产品 {product_id}"
|
||||||
|
|
||||||
print(f"使用TCN模型训练产品 '{product_name}' (ID: {product_id}) 的销售预测模型")
|
print(f"使用TCN模型训练产品 '{product_name}' (ID: {product_id}) 的销售预测模型")
|
||||||
print(f"训练范围: {training_scope}")
|
print(f"训练范围: {training_scope}")
|
||||||
|
@ -99,6 +99,9 @@ def train_product_model_with_transformer(
|
|||||||
'is_promotion': 'first', 'temperature': 'mean'
|
'is_promotion': 'first', 'temperature': 'mean'
|
||||||
}).reset_index()
|
}).reset_index()
|
||||||
training_scope = f"店铺 {store_id} (所有药品聚合)"
|
training_scope = f"店铺 {store_id} (所有药品聚合)"
|
||||||
|
|
||||||
|
# 数据清洗:使用0填充聚合后可能产生的NaN值
|
||||||
|
product_df.fillna(0, inplace=True)
|
||||||
elif training_mode == 'global':
|
elif training_mode == 'global':
|
||||||
product_df = full_df[full_df['product_id'] == product_id].copy()
|
product_df = full_df[full_df['product_id'] == product_id].copy()
|
||||||
product_df = product_df.groupby('date').agg({
|
product_df = product_df.groupby('date').agg({
|
||||||
|
@ -11,6 +11,7 @@ from xgboost.callback import EarlyStopping
|
|||||||
|
|
||||||
# 导入核心工具
|
# 导入核心工具
|
||||||
from utils.data_utils import create_dataset
|
from utils.data_utils import create_dataset
|
||||||
|
from utils.new_data_loader import load_new_data
|
||||||
from analysis.metrics import evaluate_model
|
from analysis.metrics import evaluate_model
|
||||||
from utils.model_manager import model_manager
|
from utils.model_manager import model_manager
|
||||||
from models.model_registry import register_trainer
|
from models.model_registry import register_trainer
|
||||||
@ -23,9 +24,32 @@ def train_product_model_with_xgboost(product_id, model_identifier, product_df, s
|
|||||||
"""
|
"""
|
||||||
print(f"🚀 XGBoost训练器启动: model_identifier='{model_identifier}'")
|
print(f"🚀 XGBoost训练器启动: model_identifier='{model_identifier}'")
|
||||||
|
|
||||||
# --- 1. 数据准备和验证 ---
|
# --- 1. 数据加载与筛选重构 ---
|
||||||
if product_df.empty:
|
if product_df is None:
|
||||||
raise ValueError(f"产品 {product_id} 没有可用的销售数据")
|
print("正在使用新的统一数据加载器...")
|
||||||
|
full_df = load_new_data()
|
||||||
|
|
||||||
|
if training_mode == 'store' and store_id:
|
||||||
|
store_df = full_df[full_df['store_id'] == store_id].copy()
|
||||||
|
if product_id and product_id != 'unknown' and product_id != 'all_products':
|
||||||
|
product_df = store_df[store_df['product_id'] == product_id].copy()
|
||||||
|
else:
|
||||||
|
product_df = store_df.groupby('date').agg({
|
||||||
|
'sales': 'sum', 'weekday': 'first', 'month': 'first',
|
||||||
|
'is_holiday': 'first', 'is_weekend': 'first',
|
||||||
|
'is_promotion': 'first', 'temperature': 'mean'
|
||||||
|
}).reset_index()
|
||||||
|
product_df.fillna(0, inplace=True)
|
||||||
|
elif training_mode == 'global':
|
||||||
|
product_df = full_df[full_df['product_id'] == product_id].copy()
|
||||||
|
product_df = product_df.groupby('date').agg({
|
||||||
|
'sales': 'sum', 'weekday': 'first', 'month': 'first',
|
||||||
|
'is_holiday': 'first', 'is_weekend': 'first',
|
||||||
|
'is_promotion': 'first', 'temperature': 'mean'
|
||||||
|
}).reset_index()
|
||||||
|
product_df.fillna(0, inplace=True)
|
||||||
|
else: # 默认 'product' 模式
|
||||||
|
product_df = full_df[full_df['product_id'] == product_id].copy()
|
||||||
|
|
||||||
min_required_samples = sequence_length + forecast_horizon
|
min_required_samples = sequence_length + forecast_horizon
|
||||||
if len(product_df) < min_required_samples:
|
if len(product_df) < min_required_samples:
|
||||||
@ -33,7 +57,9 @@ def train_product_model_with_xgboost(product_id, model_identifier, product_df, s
|
|||||||
raise ValueError(error_msg)
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
product_df = product_df.sort_values('date')
|
product_df = product_df.sort_values('date')
|
||||||
product_name = product_df['product_name'].iloc[0] if 'product_name' in product_df.columns else model_identifier
|
product_name = f"产品 {product_id}"
|
||||||
|
if 'product_name' in product_df.columns and not product_df['product_name'].empty:
|
||||||
|
product_name = product_df['product_name'].iloc[0]
|
||||||
|
|
||||||
# --- 2. 数据预处理和适配 ---
|
# --- 2. 数据预处理和适配 ---
|
||||||
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||||
|
@ -45,6 +45,10 @@ class TrainingTask:
|
|||||||
training_mode: str
|
training_mode: str
|
||||||
store_id: Optional[str] = None
|
store_id: Optional[str] = None
|
||||||
epochs: int = 100
|
epochs: int = 100
|
||||||
|
# 新增字段以支持自定义范围的全局训练
|
||||||
|
selected_stores: Optional[list] = None
|
||||||
|
selected_products: Optional[list] = None
|
||||||
|
aggregation_method: str = 'sum'
|
||||||
status: str = "pending" # pending, running, completed, failed
|
status: str = "pending" # pending, running, completed, failed
|
||||||
start_time: Optional[str] = None
|
start_time: Optional[str] = None
|
||||||
end_time: Optional[str] = None
|
end_time: Optional[str] = None
|
||||||
@ -138,12 +142,16 @@ class TrainingWorker:
|
|||||||
training_logger.error(f"进度回调失败: {e}")
|
training_logger.error(f"进度回调失败: {e}")
|
||||||
|
|
||||||
# 执行真正的训练,传递进度回调
|
# 执行真正的训练,传递进度回调
|
||||||
|
# 升级:传递新增的自定义范围参数
|
||||||
metrics = predictor.train_model(
|
metrics = predictor.train_model(
|
||||||
product_id=task.product_id,
|
product_id=task.product_id,
|
||||||
model_type=task.model_type,
|
model_type=task.model_type,
|
||||||
epochs=task.epochs,
|
epochs=task.epochs,
|
||||||
store_id=task.store_id,
|
store_id=task.store_id,
|
||||||
training_mode=task.training_mode,
|
training_mode=task.training_mode,
|
||||||
|
selected_stores=task.selected_stores,
|
||||||
|
selected_products=task.selected_products,
|
||||||
|
aggregation_method=task.aggregation_method,
|
||||||
socketio=None, # 子进程中不能直接使用socketio
|
socketio=None, # 子进程中不能直接使用socketio
|
||||||
task_id=task.task_id,
|
task_id=task.task_id,
|
||||||
progress_callback=progress_callback # 传递进度回调函数
|
progress_callback=progress_callback # 传递进度回调函数
|
||||||
@ -282,7 +290,11 @@ class TrainingProcessManager:
|
|||||||
self.logger.info("✅ 训练进程管理器已停止")
|
self.logger.info("✅ 训练进程管理器已停止")
|
||||||
|
|
||||||
def submit_task(self, product_id: str, model_type: str, training_mode: str = "product",
|
def submit_task(self, product_id: str, model_type: str, training_mode: str = "product",
|
||||||
store_id: str = None, epochs: int = 100, **kwargs) -> str:
|
store_id: str = None, epochs: int = 100,
|
||||||
|
selected_stores: Optional[list] = None,
|
||||||
|
selected_products: Optional[list] = None,
|
||||||
|
aggregation_method: str = 'sum',
|
||||||
|
**kwargs) -> str:
|
||||||
"""提交训练任务"""
|
"""提交训练任务"""
|
||||||
task_id = str(uuid.uuid4())
|
task_id = str(uuid.uuid4())
|
||||||
|
|
||||||
@ -292,7 +304,10 @@ class TrainingProcessManager:
|
|||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
training_mode=training_mode,
|
training_mode=training_mode,
|
||||||
store_id=store_id,
|
store_id=store_id,
|
||||||
epochs=epochs
|
epochs=epochs,
|
||||||
|
selected_stores=selected_stores,
|
||||||
|
selected_products=selected_products,
|
||||||
|
aggregation_method=aggregation_method
|
||||||
)
|
)
|
||||||
|
|
||||||
with self.lock:
|
with self.lock:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user