99 lines
3.0 KiB
Markdown
99 lines
3.0 KiB
Markdown
# 药店销售预测系统
|
||
|
||
这是一个基于多种深度学习模型的药店销售预测系统,支持多种时序预测模型,包括 Transformer、mLSTM、KAN 和 TCN。
|
||
|
||
## 功能特点
|
||
|
||
- 支持多种深度学习模型进行销量预测
|
||
- 提供命令行界面和API服务两种使用方式
|
||
- 支持模型训练、预测和评估
|
||
- 提供预测结果可视化和分析
|
||
- 支持模型比较和管理
|
||
|
||
## 项目结构
|
||
|
||
```
|
||
├── core/ # 核心模块
|
||
│ ├── __init__.py
|
||
│ ├── config.py # 全局配置参数
|
||
│ └── predictor.py # 核心预测器类
|
||
├── trainers/ # 模型训练器
|
||
│ ├── __init__.py
|
||
│ ├── mlstm_trainer.py # mLSTM模型训练函数
|
||
│ ├── kan_trainer.py # KAN模型训练函数
|
||
│ ├── tcn_trainer.py # TCN模型训练函数
|
||
│ └── transformer_trainer.py # Transformer模型训练函数
|
||
├── predictors/ # 预测模块
|
||
│ ├── __init__.py
|
||
│ └── model_predictor.py # 模型预测函数
|
||
├── analysis/ # 分析模块
|
||
│ ├── __init__.py
|
||
│ ├── metrics.py # 评估指标计算函数
|
||
│ ├── trend_analysis.py # 趋势分析函数
|
||
│ └── explanation.py # 预测解释函数
|
||
├── utils/ # 工具模块
|
||
│ ├── __init__.py
|
||
│ ├── data_utils.py # 数据处理工具函数
|
||
│ └── visualization.py # 可视化工具函数
|
||
├── models/ # 模型定义
|
||
│ ├── transformer_model.py
|
||
│ ├── mlstm_model.py
|
||
│ ├── kan_model.py
|
||
│ ├── tcn_model.py
|
||
│ └── optimized_kan_forecaster.py
|
||
├── pharmacy_predictor.py # 主接口文件
|
||
├── run_pharmacy_prediction.py # 命令行运行入口
|
||
├── api.py # API服务入口
|
||
└── pharmacy_sales.xlsx # 示例数据文件
|
||
```
|
||
|
||
## 支持的模型
|
||
|
||
1. **Transformer**: 基于自注意力机制的时序预测模型
|
||
2. **mLSTM**: 矩阵LSTM模型,结合了LSTM和Transformer的优点
|
||
3. **KAN**: Kolmogorov-Arnold Network,一种基于柯尔莫哥洛夫-阿诺德定理的神经网络
|
||
4. **TCN**: 时间卷积网络,使用因果卷积进行时序建模
|
||
5. **优化版KAN**: 经过优化的KAN模型,提高了预测精度和训练效率
|
||
|
||
## 使用方法
|
||
|
||
### 命令行界面
|
||
|
||
运行命令行界面:
|
||
|
||
```bash
|
||
python run_pharmacy_prediction.py
|
||
```
|
||
|
||
### API服务
|
||
|
||
启动API服务:
|
||
|
||
```bash
|
||
python api.py
|
||
```
|
||
|
||
### 代码中使用
|
||
|
||
```python
|
||
from pharmacy_predictor import PharmacyPredictor
|
||
|
||
# 创建预测器实例
|
||
predictor = PharmacyPredictor(data_path='pharmacy_sales.xlsx')
|
||
|
||
# 训练模型
|
||
metrics = predictor.train_model(product_id='P001', model_type='tcn', epochs=50)
|
||
|
||
# 使用模型预测
|
||
result = predictor.predict(product_id='P001', model_type='tcn', future_days=7, analyze_result=True)
|
||
```
|
||
|
||
## 依赖库
|
||
|
||
- PyTorch
|
||
- pandas
|
||
- numpy
|
||
- matplotlib
|
||
- scikit-learn
|
||
- Flask (用于API服务)
|
||
- pytorch-tcn (用于TCN模型) |