**核心目标**: 将新的 `ModelManager` 统一应用到项目中所有剩余的模型训练器,并重构核心调用逻辑,确保整个训练链路的架构一致性。 **1. 修改 `server/trainers/kan_trainer.py`** * **内容**: 完全重写了 `kan_trainer.py`。 * **适配接口**: 函数签名与 `mlstm_trainer` 对齐,增加了 `socketio`, `task_id`, `patience` 等参数。 * **集成 `ModelManager`**: 移除了所有旧的、手动的保存逻辑,改为在训练开始时调用 `model_manager` 获取版本号和路径。 * **标准化产物保存**: 所有产物(模型、元数据、检查点、损失曲线)均通过 `model_manager.save_model_artifact()` 保存。 * **增加健壮性**: 引入了早停(Early Stopping)和保存最佳检查点(Best Checkpoint)的逻辑。 **2. 修改 `server/trainers/tcn_trainer.py`** * **内容**: 完全重写了 `tcn_trainer.py`,应用了与 `kan_trainer` 完全相同的重构模式。 * 移除了旧的 `save_checkpoint` 辅助函数和基于 `core.config` 的版本管理。 * 全面转向使用 `model_manager` 进行版本控制和文件保存。 * 统一了函数签名和进度反馈逻辑。 **3. 修改 `server/trainers/transformer_trainer.py`** * **内容**: 完全重写了 `transformer_trainer.py`,完成了对所有训练器的统一重构。 * 移除了所有遗留的、基于文件名的路径拼接和保存逻辑。 * 实现了与其它训练器一致的、基于 `ModelManager` 的标准化训练流程。 **4. 修改 `server/core/predictor.py`** * **内容**: 对核心预测器类 `PharmacyPredictor` 进行了彻底重构。 * **统一调用接口**: `train_model` 方法现在以完全一致的方式调用所有(`mlstm`, `kan`, `tcn`, `transformer`)训练器。 * **移除旧逻辑**: 删除了 `_parse_model_filename` 等所有基于文件名解析的旧方法。 * **适配 `ModelManager`**: `list_models` 和 `delete_model` 等方法现在直接调用 `model_manager` 的相应功能,不再自己实现逻辑。 * **简化 `predict`**: 预测方法现在直接接收标准化的模型版本路径 (`model_version_path`) 作为输入,逻辑更清晰。
药店销售预测系统
这是一个基于多种深度学习模型的药店销售预测系统,支持多种时序预测模型,包括 Transformer、mLSTM、KAN 和 TCN。
功能特点
- 支持多种深度学习模型进行销量预测
- 提供命令行界面和API服务两种使用方式
- 支持模型训练、预测和评估
- 提供预测结果可视化和分析
- 支持模型比较和管理
项目结构
├── core/ # 核心模块
│ ├── __init__.py
│ ├── config.py # 全局配置参数
│ └── predictor.py # 核心预测器类
├── trainers/ # 模型训练器
│ ├── __init__.py
│ ├── mlstm_trainer.py # mLSTM模型训练函数
│ ├── kan_trainer.py # KAN模型训练函数
│ ├── tcn_trainer.py # TCN模型训练函数
│ └── transformer_trainer.py # Transformer模型训练函数
├── predictors/ # 预测模块
│ ├── __init__.py
│ └── model_predictor.py # 模型预测函数
├── analysis/ # 分析模块
│ ├── __init__.py
│ ├── metrics.py # 评估指标计算函数
│ ├── trend_analysis.py # 趋势分析函数
│ └── explanation.py # 预测解释函数
├── utils/ # 工具模块
│ ├── __init__.py
│ ├── data_utils.py # 数据处理工具函数
│ └── visualization.py # 可视化工具函数
├── models/ # 模型定义
│ ├── transformer_model.py
│ ├── mlstm_model.py
│ ├── kan_model.py
│ ├── tcn_model.py
│ └── optimized_kan_forecaster.py
├── pharmacy_predictor.py # 主接口文件
├── run_pharmacy_prediction.py # 命令行运行入口
├── api.py # API服务入口
└── pharmacy_sales.xlsx # 示例数据文件
支持的模型
- Transformer: 基于自注意力机制的时序预测模型
- mLSTM: 矩阵LSTM模型,结合了LSTM和Transformer的优点
- KAN: Kolmogorov-Arnold Network,一种基于柯尔莫哥洛夫-阿诺德定理的神经网络
- TCN: 时间卷积网络,使用因果卷积进行时序建模
- 优化版KAN: 经过优化的KAN模型,提高了预测精度和训练效率
使用方法
命令行界面
运行命令行界面:
python run_pharmacy_prediction.py
API服务
启动API服务:
python api.py
代码中使用
from pharmacy_predictor import PharmacyPredictor
# 创建预测器实例
predictor = PharmacyPredictor(data_path='pharmacy_sales.xlsx')
# 训练模型
metrics = predictor.train_model(product_id='P001', model_type='tcn', epochs=50)
# 使用模型预测
result = predictor.predict(product_id='P001', model_type='tcn', future_days=7, analyze_result=True)
依赖库
- PyTorch
- pandas
- numpy
- matplotlib
- scikit-learn
- Flask (用于API服务)
- pytorch-tcn (用于TCN模型)
Description
Languages
Python
73.3%
Vue
22%
HTML
1.8%
CSS
1.1%
Batchfile
0.8%
Other
1%