ShopTRAINING/docs/输出文档/开发者文档.md

912 lines
22 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 药店单品销售预测系统 - 开发者文档
## 1. 项目结构
### 1.1 目录结构
```
xLSTM+transformer/
├── models/ # 模型定义和实现
│ ├── __init__.py # 模型初始化文件
│ ├── data_utils.py # 数据处理工具
│ ├── mlstm/ # mLSTM模型实现
│ ├── transformer_model.py # Transformer模型实现
│ ├── mlstm_model.py # mLSTM模型实现
│ ├── kan_model.py # KAN模型实现
│ └── utils.py # 通用工具函数
├── predictions/ # 预测结果和保存的模型
│ ├── mlstm/ # mLSTM模型预测结果
│ ├── transformer/ # Transformer模型预测结果
│ └── kan/ # KAN模型预测结果
├── Server/ # API服务器实现
├── UI/ # 用户界面组件
│ └── src/ # UI源代码
├── docs/ # 项目文档
├── api.py # API服务入口
├── api_test.py # API测试脚本
├── pharmacy_predictor.py # 核心预测功能实现
├── run_pharmacy_prediction.py # 主程序入口
├── model_management.py # 模型管理工具
├── generate_pharmacy_data.py # 生成模拟数据
├── check_gpu.py # 检查GPU支持状态
├── pharmacy_sales.xlsx # 示例销售数据
├── requirements.txt # 依赖项列表
└── README.md # 项目说明
```
### 1.2 主要模块
- **pharmacy_predictor.py**: 核心预测功能实现,包含各种模型的训练和预测函数
- **model_management.py**: 模型管理工具,用于管理和操作已训练的模型
- **run_pharmacy_prediction.py**: 主程序入口,提供交互式界面
- **api.py**: RESTful API服务
- **models/**: 模型定义和实现
- **transformer_model.py**: Transformer模型实现
- **mlstm_model.py**: mLSTM模型实现
- **kan_model.py**: KAN模型实现
- **data_utils.py**: 数据处理工具
## 2. 核心类与接口
### 2.1 数据处理接口
```python
# models/data_utils.py
# 创建时间序列数据集
def create_dataset(X, y, look_back, future_days):
"""
创建时间序列数据集
参数:
X: 输入特征
y: 目标变量
look_back: 使用过去多少天的数据进行预测
future_days: 预测未来多少天
返回:
X_out: 形状为 (samples, look_back, features)
y_out: 形状为 (samples, future_days)
"""
pass
# 数据集类
class PharmacyDataset(Dataset):
"""
药店销售数据集
参数:
X: 输入特征
y: 目标变量
"""
def __init__(self, X, y):
self.X = X
self.y = y
def __len__(self):
return len(self.X)
def __getitem__(self, idx):
return self.X[idx], self.y[idx]
# 评估模型
def evaluate_model(y_true, y_pred):
"""
评估模型性能
参数:
y_true: 真实值
y_pred: 预测值
返回:
包含各评估指标的字典
"""
pass
```
### 2.2 模型定义接口
```python
# models/transformer_model.py
class TimeSeriesTransformer(nn.Module):
"""
基于Transformer的时间序列预测模型
参数:
num_features: 输入特征数
embed_dim: 嵌入维度
num_heads: 注意力头数
dense_dim: 前馈网络隐藏层维度
dropout_rate: Dropout比率
num_blocks: Transformer块数量
output_sequence_length: 输出序列长度
"""
def __init__(self, num_features, embed_dim, num_heads, dense_dim, dropout_rate, num_blocks, output_sequence_length):
pass
def forward(self, x):
"""
前向传播
参数:
x: 输入张量,形状为 [batch_size, seq_len, num_features]
返回:
输出张量,形状为 [batch_size, output_sequence_length]
"""
pass
```
```python
# models/mlstm_model.py
class MLSTMTransformer(nn.Module):
"""
结合mLSTM和Transformer的混合模型
参数:
num_features: 输入特征数
hidden_size: LSTM隐藏层大小
mlstm_layers: mLSTM层数
embed_dim: Transformer嵌入维度
dense_dim: 前馈网络隐藏层维度
num_heads: 注意力头数
dropout_rate: Dropout比率
num_blocks: Transformer块数量
output_sequence_length: 输出序列长度
"""
def __init__(self, num_features, hidden_size, mlstm_layers, embed_dim, dense_dim, num_heads, dropout_rate, num_blocks, output_sequence_length):
pass
def forward(self, x):
"""
前向传播
参数:
x: 输入张量,形状为 [batch_size, seq_len, num_features]
返回:
输出张量,形状为 [batch_size, output_sequence_length]
"""
pass
```
```python
# models/kan_model.py
class KANForecaster(nn.Module):
"""
基于Kolmogorov-Arnold网络的预测模型
参数:
num_features: 输入特征数
hidden_sizes: 隐藏层大小列表
grid_size: 网格大小
spline_order: 样条阶数
output_sequence_length: 输出序列长度
"""
def __init__(self, num_features, hidden_sizes, grid_size, spline_order, output_sequence_length):
pass
def forward(self, x):
"""
前向传播
参数:
x: 输入张量,形状为 [batch_size, seq_len, num_features]
返回:
输出张量,形状为 [batch_size, output_sequence_length]
"""
pass
```
### 2.3 工具类接口
```python
# models/utils.py
def get_device():
"""
获取可用设备GPU或CPU
返回:
torch.device
"""
pass
def to_device(data, device):
"""
将数据移动到指定设备上
参数:
data: 要移动的数据
device: 目标设备
返回:
移动后的数据
"""
pass
class DeviceDataLoader():
"""
设备数据加载器
参数:
dl: 原始数据加载器
device: 目标设备
"""
def __init__(self, dl, device):
pass
def __iter__(self):
"""迭代器"""
pass
def __len__(self):
"""长度"""
pass
```
## 3. 模型训练接口
### 3.1 模型训练函数
```python
# pharmacy_predictor.py
def train_product_model_with_mlstm(product_id, epochs=50):
"""
使用mLSTM模型训练产品销售预测模型
参数:
product_id: 产品ID
epochs: 训练轮次
返回:
训练好的模型和评估指标
"""
pass
def train_product_model_with_transformer(product_id, epochs=50):
"""
使用Transformer模型训练产品销售预测模型
参数:
product_id: 产品ID
epochs: 训练轮次
返回:
训练好的模型和评估指标
"""
pass
def train_product_model_with_kan(product_id, epochs=50):
"""
使用KAN模型训练产品销售预测模型
参数:
product_id: 产品ID
epochs: 训练轮次
返回:
训练好的模型和评估指标
"""
pass
```
### 3.2 预测函数
```python
# pharmacy_predictor.py
def load_model_and_predict(product_id, model_type, future_days=7):
"""
加载模型并进行预测
参数:
product_id: 产品ID
model_type: 模型类型
future_days: 预测未来天数
返回:
预测结果
"""
pass
```
## 4. 模型管理接口
```python
# model_management.py
class ModelManager:
"""模型管理器"""
def list_models(self, product_id=None, model_type=None):
"""
列出模型
参数:
product_id: 按产品ID筛选
model_type: 按模型类型筛选
返回:
模型列表
"""
pass
def get_model_details(self, product_id, model_type, version=None):
"""
获取模型详情
参数:
product_id: 产品ID
model_type: 模型类型
version: 版本
返回:
模型详情
"""
pass
def delete_model(self, product_id, model_type, version=None):
"""
删除模型
参数:
product_id: 产品ID
model_type: 模型类型
version: 版本
返回:
是否成功
"""
pass
def export_model(self, product_id, model_type, version=None, export_dir="exported_models"):
"""
导出模型
参数:
product_id: 产品ID
model_type: 模型类型
version: 版本
export_dir: 导出目录
返回:
导出路径
"""
pass
def import_model(self, import_file, overwrite=False):
"""
导入模型
参数:
import_file: 导入文件路径
overwrite: 是否覆盖同名模型
返回:
导入路径
"""
pass
def predict_with_model(self, product_id, model_type, version=None):
"""
使用模型预测
参数:
product_id: 产品ID
model_type: 模型类型
version: 版本
返回:
预测结果
"""
pass
def compare_models(self, product_id, model_types=None):
"""
比较模型
参数:
product_id: 产品ID
model_types: 模型类型列表
返回:
比较结果
"""
pass
```
## 5. API接口
```python
# api.py
# 数据管理API
@app.route('/api/products', methods=['GET'])
def get_products():
"""获取所有产品列表"""
pass
@app.route('/api/products/<product_id>', methods=['GET'])
def get_product(product_id):
"""获取单个产品详情"""
pass
@app.route('/api/products/<product_id>/sales', methods=['GET'])
def get_product_sales(product_id):
"""获取产品销售数据"""
pass
@app.route('/api/data/upload', methods=['POST'])
def upload_data():
"""上传销售数据"""
pass
# 模型训练API
@app.route('/api/training', methods=['POST'])
def start_training():
"""启动模型训练任务"""
pass
@app.route('/api/training/<task_id>', methods=['GET'])
def get_training_status(task_id):
"""查询训练任务状态"""
pass
# 模型预测API
@app.route('/api/prediction', methods=['POST'])
def predict():
"""使用模型预测"""
pass
@app.route('/api/prediction/compare', methods=['POST'])
def compare_predictions():
"""比较不同模型预测结果"""
pass
# 模型管理API
@app.route('/api/models', methods=['GET'])
def get_models():
"""获取模型列表"""
pass
@app.route('/api/models/<product_id>/<model_type>', methods=['GET'])
def get_model_details(product_id, model_type):
"""获取模型详情"""
pass
@app.route('/api/models/<product_id>/<model_type>', methods=['DELETE'])
def delete_model(product_id, model_type):
"""删除模型"""
pass
@app.route('/api/models/<product_id>/<model_type>/export', methods=['GET'])
def export_model(product_id, model_type):
"""导出模型"""
pass
@app.route('/api/models/import', methods=['POST'])
def import_model():
"""导入模型"""
pass
```
## 6. 扩展指南
### 6.1 添加新模型
要添加新的预测模型,请按以下步骤操作:
1.`models/`目录下创建新的模型文件,例如`new_model.py`
2. 实现模型类,遵循现有模型的接口约定
3.`pharmacy_predictor.py`中添加训练函数`train_product_model_with_new_model()`
4. 更新`load_model_and_predict()`函数以支持新模型
5.`run_pharmacy_prediction.py`中添加新模型的训练选项
6. 更新模型管理工具以支持新模型
示例:
```python
# models/new_model.py
import torch
import torch.nn as nn
class NewModel(nn.Module):
def __init__(self, num_features, hidden_size, output_sequence_length):
super().__init__()
self.hidden_size = hidden_size
self.output_sequence_length = output_sequence_length
# 定义模型架构
self.input_layer = nn.Linear(num_features, hidden_size)
self.hidden_layer = nn.Linear(hidden_size, hidden_size)
self.output_layer = nn.Linear(hidden_size, output_sequence_length)
def forward(self, x):
# 实现前向传播
batch_size, seq_len, _ = x.shape
# 处理输入序列
x = x.view(batch_size, -1)
# 通过网络层
x = torch.relu(self.input_layer(x))
x = torch.relu(self.hidden_layer(x))
output = self.output_layer(x)
return output
```
```python
# pharmacy_predictor.py 中添加训练函数
def train_product_model_with_new_model(product_id, epochs=50):
# 读取数据
df = pd.read_excel('pharmacy_sales.xlsx')
product_df = df[df['product_id'] == product_id].sort_values('date')
product_name = product_df['product_name'].iloc[0]
print(f"使用NewModel训练产品 '{product_name}' (ID: {product_id}) 的销售预测模型")
# 数据预处理
# ...
# 创建模型
model = NewModel(
num_features=num_features,
hidden_size=64,
output_sequence_length=T
)
# 训练模型
# ...
return model, metrics
```
### 6.2 添加新特征
要添加新的预测特征,请按以下步骤操作:
1. 修改`generate_pharmacy_data.py`,在生成的数据中添加新特征
2. 修改`pharmacy_predictor.py`中的特征列表
3. 更新数据预处理代码以处理新特征
4. 如果需要,调整模型架构以适应新特征
示例,添加"促销折扣率"特征:
```python
# pharmacy_predictor.py 更新特征列表
features = ['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature', 'discount_rate']
```
### 6.3 添加新API端点
要添加新的API端点请按以下步骤操作
1.`api.py`中定义新的路由和处理函数
2. 更新API文档
3.`api_test.py`中添加相应的测试
示例:
```python
# api.py 添加新端点
@app.route('/api/products/<product_id>/forecast', methods=['GET'])
def get_product_forecast(product_id):
"""获取产品销售预测"""
try:
# 获取查询参数
days = request.args.get('days', default=7, type=int)
model_type = request.args.get('model_type', default='mlstm', type=str)
# 调用预测函数
predictions = load_model_and_predict(product_id, model_type, future_days=days)
if predictions is None:
return jsonify({
'status': 'error',
'error': {
'code': 'MODEL_NOT_FOUND',
'message': '找不到指定的模型'
}
}), 404
# 处理预测结果
response = {
'status': 'success',
'data': predictions
}
return jsonify(response)
except Exception as e:
return jsonify({
'status': 'error',
'error': {
'code': 'PREDICTION_ERROR',
'message': str(e)
}
}), 500
```
## 7. 测试指南
### 7.1 单元测试
项目使用`unittest`框架进行单元测试。测试文件存放在`tests/`目录下(需要创建)。
添加新测试:
1.`tests/`目录下创建测试文件,例如`test_new_model.py`
2. 编写测试用例
3. 运行测试:`python -m unittest discover tests`
示例:
```python
# tests/test_new_model.py
import unittest
import torch
import numpy as np
from models.new_model import NewModel
class TestNewModel(unittest.TestCase):
def setUp(self):
self.num_features = 8
self.hidden_size = 64
self.output_sequence_length = 7
self.batch_size = 16
self.seq_len = 14
self.model = NewModel(
num_features=self.num_features,
hidden_size=self.hidden_size,
output_sequence_length=self.output_sequence_length
)
def test_model_output_shape(self):
# 创建随机输入
x = torch.rand(self.batch_size, self.seq_len, self.num_features)
# 前向传播
output = self.model(x)
# 检查输出形状
self.assertEqual(output.shape, (self.batch_size, self.output_sequence_length))
def test_model_forward_pass(self):
# 创建随机输入
x = torch.rand(self.batch_size, self.seq_len, self.num_features)
# 检查前向传播不会抛出异常
try:
output = self.model(x)
success = True
except:
success = False
self.assertTrue(success)
if __name__ == '__main__':
unittest.main()
```
### 7.2 API测试
API测试使用`api_test.py`脚本。要添加新的API测试请修改该文件并添加测试函数。
示例:
```python
# api_test.py 添加新测试
def test_get_product_forecast():
"""测试获取产品销售预测"""
print("测试获取产品销售预测...")
# 请求API
response = requests.get(f"{BASE_URL}/products/P001/forecast?days=10&model_type=mlstm")
# 检查响应
if response.status_code == 200:
data = response.json()
if data['status'] == 'success':
print("✓ 成功获取产品销售预测")
print(f"预测销量: {data['data']['predicted_sales']}")
else:
print(f"✗ 获取产品销售预测失败: {data['error']['message']}")
else:
print(f"✗ 请求失败,状态码: {response.status_code}")
print()
```
## 8. 部署指南
### 8.1 生产环境部署
#### 8.1.1 Docker部署
您可以使用Docker容器化应用程序进行部署。创建一个`Dockerfile`
```dockerfile
FROM python:3.8-slim
WORKDIR /app
# 复制项目文件
COPY . .
# 安装依赖
RUN pip install --no-cache-dir -r requirements.txt
# 暴露端口
EXPOSE 5000
# 启动API服务
CMD ["python", "api.py", "--host", "0.0.0.0"]
```
构建和运行Docker容器
```bash
docker build -t pharmacy-prediction-system .
docker run -p 5000:5000 pharmacy-prediction-system
```
#### 8.1.2 传统服务器部署
1. 安装Python 3.8+
2. 克隆项目代码
3. 安装依赖:`pip install -r requirements.txt`
4. 设置环境变量(可选)
5. 使用Gunicorn运行API服务`gunicorn -w 4 -b 0.0.0.0:5000 api:app`
### 8.2 模型部署
模型训练和预测可以分离部署:
1. 在训练服务器上训练模型
2. 导出模型:`python model_management.py --action export --product_id P001 --model_type mlstm`
3. 将模型文件传输到预测服务器
4. 在预测服务器上导入模型:`python model_management.py --action import --file_path exported_models/P001_mlstm_20230615123456.pt`
## 9. 性能优化
### 9.1 模型优化
- **量化**: 使用PyTorch的量化功能减小模型大小并加速推理
- **剪枝**: 移除对预测贡献小的神经元,减小模型复杂度
- **知识蒸馏**: 训练小模型学习大模型的行为
### 9.2 推理优化
- **批处理**: 尽可能使用批处理进行预测
- **TorchScript**: 将模型转换为TorchScript以优化推理
- **ONNX导出**: 导出为ONNX格式使用专用推理引擎
### 9.3 数据加载优化
- **预处理缓存**: 缓存预处理后的数据
- **异步数据加载**: 使用多线程加载和预处理数据
- **内存映射**: 对大型数据集使用内存映射
## 10. 贡献指南
### 10.1 代码规范
- 遵循PEP 8 Python代码风格指南
- 使用类型注解提高代码可读性
- 为所有公共函数和类编写文档字符串
- 使用英文编写注释和文档字符串
### 10.2 提交流程
1. Fork项目仓库
2. 创建功能分支
3. 编写代码并遵循代码规范
4. 编写测试
5. 提交拉取请求
6. 等待代码审查
### 10.3 版本控制
项目使用语义化版本控制:
- 主版本号不兼容的API更改
- 次版本号:向后兼容的功能性新增
- 修订号:向后兼容的问题修正
## 11. 常见问题解答
### 11.1 模型训练相关
**Q: 如何加速模型训练?**
A: 您可以尝试以下方法:
- 使用GPU加速系统会自动检测并使用GPU
- 减小batch_size
- 减少训练轮次epochs
- 减小模型复杂度例如减少Transformer层数或隐藏层大小
**Q: 模型保存失败怎么办?**
A: 请检查:
- 是否有足够的磁盘空间
- 是否有写入权限
- predictions目录是否存在如果不存在手动创建
### 11.2 API相关
**Q: API服务启动失败怎么办**
A: 常见原因包括:
- 端口被占用(更改端口)
- 缺少依赖安装API依赖
- 权限问题(以管理员/root身份运行
**Q: 如何限制API访问速率**
A: 您可以使用Flask-Limiter扩展
```python
from flask_limiter import Limiter
from flask_limiter.util import get_remote_address
limiter = Limiter(
app,
key_func=get_remote_address,
default_limits=["200 per day", "50 per hour"]
)
# 然后在路由上应用限制
@app.route('/api/products')
@limiter.limit("10 per minute")
def get_products():
# ...
```
### 11.3 模型管理相关
**Q: 如何清理旧的模型文件以节省空间?**
A: 您可以使用模型管理工具删除旧模型:
```bash
python model_management.py --action delete --product_id P001 --model_type mlstm
```
或者编写一个脚本自动删除某个日期之前的模型:
```python
import os
import glob
import datetime
def clean_old_models(days=30):
"""删除指定天数之前的模型"""
cutoff_date = datetime.datetime.now() - datetime.timedelta(days=days)
for model_dir in glob.glob('predictions/*/*'):
for model_file in glob.glob(f'{model_dir}/*.pt'):
# 从文件名中提取时间戳
try:
timestamp = os.path.basename(model_file).split('.')[0]
model_date = datetime.datetime.strptime(timestamp, '%Y%m%d%H%M%S')
if model_date < cutoff_date:
os.remove(model_file)
print(f"已删除旧模型: {model_file}")
except:
continue
if __name__ == '__main__':
clean_old_models(30) # 删除30天前的模型
```