912 lines
22 KiB
Markdown
912 lines
22 KiB
Markdown
![]() |
# 药店单品销售预测系统 - 开发者文档
|
|||
|
|
|||
|
## 1. 项目结构
|
|||
|
|
|||
|
### 1.1 目录结构
|
|||
|
|
|||
|
```
|
|||
|
xLSTM+transformer/
|
|||
|
├── models/ # 模型定义和实现
|
|||
|
│ ├── __init__.py # 模型初始化文件
|
|||
|
│ ├── data_utils.py # 数据处理工具
|
|||
|
│ ├── mlstm/ # mLSTM模型实现
|
|||
|
│ ├── transformer_model.py # Transformer模型实现
|
|||
|
│ ├── mlstm_model.py # mLSTM模型实现
|
|||
|
│ ├── kan_model.py # KAN模型实现
|
|||
|
│ └── utils.py # 通用工具函数
|
|||
|
│
|
|||
|
├── predictions/ # 预测结果和保存的模型
|
|||
|
│ ├── mlstm/ # mLSTM模型预测结果
|
|||
|
│ ├── transformer/ # Transformer模型预测结果
|
|||
|
│ └── kan/ # KAN模型预测结果
|
|||
|
│
|
|||
|
├── Server/ # API服务器实现
|
|||
|
│
|
|||
|
├── UI/ # 用户界面组件
|
|||
|
│ └── src/ # UI源代码
|
|||
|
│
|
|||
|
├── docs/ # 项目文档
|
|||
|
│
|
|||
|
├── api.py # API服务入口
|
|||
|
├── api_test.py # API测试脚本
|
|||
|
├── pharmacy_predictor.py # 核心预测功能实现
|
|||
|
├── run_pharmacy_prediction.py # 主程序入口
|
|||
|
├── model_management.py # 模型管理工具
|
|||
|
├── generate_pharmacy_data.py # 生成模拟数据
|
|||
|
├── check_gpu.py # 检查GPU支持状态
|
|||
|
├── pharmacy_sales.xlsx # 示例销售数据
|
|||
|
├── requirements.txt # 依赖项列表
|
|||
|
└── README.md # 项目说明
|
|||
|
```
|
|||
|
|
|||
|
### 1.2 主要模块
|
|||
|
|
|||
|
- **pharmacy_predictor.py**: 核心预测功能实现,包含各种模型的训练和预测函数
|
|||
|
- **model_management.py**: 模型管理工具,用于管理和操作已训练的模型
|
|||
|
- **run_pharmacy_prediction.py**: 主程序入口,提供交互式界面
|
|||
|
- **api.py**: RESTful API服务
|
|||
|
- **models/**: 模型定义和实现
|
|||
|
- **transformer_model.py**: Transformer模型实现
|
|||
|
- **mlstm_model.py**: mLSTM模型实现
|
|||
|
- **kan_model.py**: KAN模型实现
|
|||
|
- **data_utils.py**: 数据处理工具
|
|||
|
|
|||
|
## 2. 核心类与接口
|
|||
|
|
|||
|
### 2.1 数据处理接口
|
|||
|
|
|||
|
```python
|
|||
|
# models/data_utils.py
|
|||
|
|
|||
|
# 创建时间序列数据集
|
|||
|
def create_dataset(X, y, look_back, future_days):
|
|||
|
"""
|
|||
|
创建时间序列数据集
|
|||
|
|
|||
|
参数:
|
|||
|
X: 输入特征
|
|||
|
y: 目标变量
|
|||
|
look_back: 使用过去多少天的数据进行预测
|
|||
|
future_days: 预测未来多少天
|
|||
|
|
|||
|
返回:
|
|||
|
X_out: 形状为 (samples, look_back, features)
|
|||
|
y_out: 形状为 (samples, future_days)
|
|||
|
"""
|
|||
|
pass
|
|||
|
|
|||
|
# 数据集类
|
|||
|
class PharmacyDataset(Dataset):
|
|||
|
"""
|
|||
|
药店销售数据集
|
|||
|
|
|||
|
参数:
|
|||
|
X: 输入特征
|
|||
|
y: 目标变量
|
|||
|
"""
|
|||
|
def __init__(self, X, y):
|
|||
|
self.X = X
|
|||
|
self.y = y
|
|||
|
|
|||
|
def __len__(self):
|
|||
|
return len(self.X)
|
|||
|
|
|||
|
def __getitem__(self, idx):
|
|||
|
return self.X[idx], self.y[idx]
|
|||
|
|
|||
|
# 评估模型
|
|||
|
def evaluate_model(y_true, y_pred):
|
|||
|
"""
|
|||
|
评估模型性能
|
|||
|
|
|||
|
参数:
|
|||
|
y_true: 真实值
|
|||
|
y_pred: 预测值
|
|||
|
|
|||
|
返回:
|
|||
|
包含各评估指标的字典
|
|||
|
"""
|
|||
|
pass
|
|||
|
```
|
|||
|
|
|||
|
### 2.2 模型定义接口
|
|||
|
|
|||
|
```python
|
|||
|
# models/transformer_model.py
|
|||
|
|
|||
|
class TimeSeriesTransformer(nn.Module):
|
|||
|
"""
|
|||
|
基于Transformer的时间序列预测模型
|
|||
|
|
|||
|
参数:
|
|||
|
num_features: 输入特征数
|
|||
|
embed_dim: 嵌入维度
|
|||
|
num_heads: 注意力头数
|
|||
|
dense_dim: 前馈网络隐藏层维度
|
|||
|
dropout_rate: Dropout比率
|
|||
|
num_blocks: Transformer块数量
|
|||
|
output_sequence_length: 输出序列长度
|
|||
|
"""
|
|||
|
def __init__(self, num_features, embed_dim, num_heads, dense_dim, dropout_rate, num_blocks, output_sequence_length):
|
|||
|
pass
|
|||
|
|
|||
|
def forward(self, x):
|
|||
|
"""
|
|||
|
前向传播
|
|||
|
|
|||
|
参数:
|
|||
|
x: 输入张量,形状为 [batch_size, seq_len, num_features]
|
|||
|
|
|||
|
返回:
|
|||
|
输出张量,形状为 [batch_size, output_sequence_length]
|
|||
|
"""
|
|||
|
pass
|
|||
|
```
|
|||
|
|
|||
|
```python
|
|||
|
# models/mlstm_model.py
|
|||
|
|
|||
|
class MLSTMTransformer(nn.Module):
|
|||
|
"""
|
|||
|
结合mLSTM和Transformer的混合模型
|
|||
|
|
|||
|
参数:
|
|||
|
num_features: 输入特征数
|
|||
|
hidden_size: LSTM隐藏层大小
|
|||
|
mlstm_layers: mLSTM层数
|
|||
|
embed_dim: Transformer嵌入维度
|
|||
|
dense_dim: 前馈网络隐藏层维度
|
|||
|
num_heads: 注意力头数
|
|||
|
dropout_rate: Dropout比率
|
|||
|
num_blocks: Transformer块数量
|
|||
|
output_sequence_length: 输出序列长度
|
|||
|
"""
|
|||
|
def __init__(self, num_features, hidden_size, mlstm_layers, embed_dim, dense_dim, num_heads, dropout_rate, num_blocks, output_sequence_length):
|
|||
|
pass
|
|||
|
|
|||
|
def forward(self, x):
|
|||
|
"""
|
|||
|
前向传播
|
|||
|
|
|||
|
参数:
|
|||
|
x: 输入张量,形状为 [batch_size, seq_len, num_features]
|
|||
|
|
|||
|
返回:
|
|||
|
输出张量,形状为 [batch_size, output_sequence_length]
|
|||
|
"""
|
|||
|
pass
|
|||
|
```
|
|||
|
|
|||
|
```python
|
|||
|
# models/kan_model.py
|
|||
|
|
|||
|
class KANForecaster(nn.Module):
|
|||
|
"""
|
|||
|
基于Kolmogorov-Arnold网络的预测模型
|
|||
|
|
|||
|
参数:
|
|||
|
num_features: 输入特征数
|
|||
|
hidden_sizes: 隐藏层大小列表
|
|||
|
grid_size: 网格大小
|
|||
|
spline_order: 样条阶数
|
|||
|
output_sequence_length: 输出序列长度
|
|||
|
"""
|
|||
|
def __init__(self, num_features, hidden_sizes, grid_size, spline_order, output_sequence_length):
|
|||
|
pass
|
|||
|
|
|||
|
def forward(self, x):
|
|||
|
"""
|
|||
|
前向传播
|
|||
|
|
|||
|
参数:
|
|||
|
x: 输入张量,形状为 [batch_size, seq_len, num_features]
|
|||
|
|
|||
|
返回:
|
|||
|
输出张量,形状为 [batch_size, output_sequence_length]
|
|||
|
"""
|
|||
|
pass
|
|||
|
```
|
|||
|
|
|||
|
### 2.3 工具类接口
|
|||
|
|
|||
|
```python
|
|||
|
# models/utils.py
|
|||
|
|
|||
|
def get_device():
|
|||
|
"""
|
|||
|
获取可用设备(GPU或CPU)
|
|||
|
|
|||
|
返回:
|
|||
|
torch.device
|
|||
|
"""
|
|||
|
pass
|
|||
|
|
|||
|
def to_device(data, device):
|
|||
|
"""
|
|||
|
将数据移动到指定设备上
|
|||
|
|
|||
|
参数:
|
|||
|
data: 要移动的数据
|
|||
|
device: 目标设备
|
|||
|
|
|||
|
返回:
|
|||
|
移动后的数据
|
|||
|
"""
|
|||
|
pass
|
|||
|
|
|||
|
class DeviceDataLoader():
|
|||
|
"""
|
|||
|
设备数据加载器
|
|||
|
|
|||
|
参数:
|
|||
|
dl: 原始数据加载器
|
|||
|
device: 目标设备
|
|||
|
"""
|
|||
|
def __init__(self, dl, device):
|
|||
|
pass
|
|||
|
|
|||
|
def __iter__(self):
|
|||
|
"""迭代器"""
|
|||
|
pass
|
|||
|
|
|||
|
def __len__(self):
|
|||
|
"""长度"""
|
|||
|
pass
|
|||
|
```
|
|||
|
|
|||
|
## 3. 模型训练接口
|
|||
|
|
|||
|
### 3.1 模型训练函数
|
|||
|
|
|||
|
```python
|
|||
|
# pharmacy_predictor.py
|
|||
|
|
|||
|
def train_product_model_with_mlstm(product_id, epochs=50):
|
|||
|
"""
|
|||
|
使用mLSTM模型训练产品销售预测模型
|
|||
|
|
|||
|
参数:
|
|||
|
product_id: 产品ID
|
|||
|
epochs: 训练轮次
|
|||
|
|
|||
|
返回:
|
|||
|
训练好的模型和评估指标
|
|||
|
"""
|
|||
|
pass
|
|||
|
|
|||
|
def train_product_model_with_transformer(product_id, epochs=50):
|
|||
|
"""
|
|||
|
使用Transformer模型训练产品销售预测模型
|
|||
|
|
|||
|
参数:
|
|||
|
product_id: 产品ID
|
|||
|
epochs: 训练轮次
|
|||
|
|
|||
|
返回:
|
|||
|
训练好的模型和评估指标
|
|||
|
"""
|
|||
|
pass
|
|||
|
|
|||
|
def train_product_model_with_kan(product_id, epochs=50):
|
|||
|
"""
|
|||
|
使用KAN模型训练产品销售预测模型
|
|||
|
|
|||
|
参数:
|
|||
|
product_id: 产品ID
|
|||
|
epochs: 训练轮次
|
|||
|
|
|||
|
返回:
|
|||
|
训练好的模型和评估指标
|
|||
|
"""
|
|||
|
pass
|
|||
|
```
|
|||
|
|
|||
|
### 3.2 预测函数
|
|||
|
|
|||
|
```python
|
|||
|
# pharmacy_predictor.py
|
|||
|
|
|||
|
def load_model_and_predict(product_id, model_type, future_days=7):
|
|||
|
"""
|
|||
|
加载模型并进行预测
|
|||
|
|
|||
|
参数:
|
|||
|
product_id: 产品ID
|
|||
|
model_type: 模型类型
|
|||
|
future_days: 预测未来天数
|
|||
|
|
|||
|
返回:
|
|||
|
预测结果
|
|||
|
"""
|
|||
|
pass
|
|||
|
```
|
|||
|
|
|||
|
## 4. 模型管理接口
|
|||
|
|
|||
|
```python
|
|||
|
# model_management.py
|
|||
|
|
|||
|
class ModelManager:
|
|||
|
"""模型管理器"""
|
|||
|
|
|||
|
def list_models(self, product_id=None, model_type=None):
|
|||
|
"""
|
|||
|
列出模型
|
|||
|
|
|||
|
参数:
|
|||
|
product_id: 按产品ID筛选
|
|||
|
model_type: 按模型类型筛选
|
|||
|
|
|||
|
返回:
|
|||
|
模型列表
|
|||
|
"""
|
|||
|
pass
|
|||
|
|
|||
|
def get_model_details(self, product_id, model_type, version=None):
|
|||
|
"""
|
|||
|
获取模型详情
|
|||
|
|
|||
|
参数:
|
|||
|
product_id: 产品ID
|
|||
|
model_type: 模型类型
|
|||
|
version: 版本
|
|||
|
|
|||
|
返回:
|
|||
|
模型详情
|
|||
|
"""
|
|||
|
pass
|
|||
|
|
|||
|
def delete_model(self, product_id, model_type, version=None):
|
|||
|
"""
|
|||
|
删除模型
|
|||
|
|
|||
|
参数:
|
|||
|
product_id: 产品ID
|
|||
|
model_type: 模型类型
|
|||
|
version: 版本
|
|||
|
|
|||
|
返回:
|
|||
|
是否成功
|
|||
|
"""
|
|||
|
pass
|
|||
|
|
|||
|
def export_model(self, product_id, model_type, version=None, export_dir="exported_models"):
|
|||
|
"""
|
|||
|
导出模型
|
|||
|
|
|||
|
参数:
|
|||
|
product_id: 产品ID
|
|||
|
model_type: 模型类型
|
|||
|
version: 版本
|
|||
|
export_dir: 导出目录
|
|||
|
|
|||
|
返回:
|
|||
|
导出路径
|
|||
|
"""
|
|||
|
pass
|
|||
|
|
|||
|
def import_model(self, import_file, overwrite=False):
|
|||
|
"""
|
|||
|
导入模型
|
|||
|
|
|||
|
参数:
|
|||
|
import_file: 导入文件路径
|
|||
|
overwrite: 是否覆盖同名模型
|
|||
|
|
|||
|
返回:
|
|||
|
导入路径
|
|||
|
"""
|
|||
|
pass
|
|||
|
|
|||
|
def predict_with_model(self, product_id, model_type, version=None):
|
|||
|
"""
|
|||
|
使用模型预测
|
|||
|
|
|||
|
参数:
|
|||
|
product_id: 产品ID
|
|||
|
model_type: 模型类型
|
|||
|
version: 版本
|
|||
|
|
|||
|
返回:
|
|||
|
预测结果
|
|||
|
"""
|
|||
|
pass
|
|||
|
|
|||
|
def compare_models(self, product_id, model_types=None):
|
|||
|
"""
|
|||
|
比较模型
|
|||
|
|
|||
|
参数:
|
|||
|
product_id: 产品ID
|
|||
|
model_types: 模型类型列表
|
|||
|
|
|||
|
返回:
|
|||
|
比较结果
|
|||
|
"""
|
|||
|
pass
|
|||
|
```
|
|||
|
|
|||
|
## 5. API接口
|
|||
|
|
|||
|
```python
|
|||
|
# api.py
|
|||
|
|
|||
|
# 数据管理API
|
|||
|
@app.route('/api/products', methods=['GET'])
|
|||
|
def get_products():
|
|||
|
"""获取所有产品列表"""
|
|||
|
pass
|
|||
|
|
|||
|
@app.route('/api/products/<product_id>', methods=['GET'])
|
|||
|
def get_product(product_id):
|
|||
|
"""获取单个产品详情"""
|
|||
|
pass
|
|||
|
|
|||
|
@app.route('/api/products/<product_id>/sales', methods=['GET'])
|
|||
|
def get_product_sales(product_id):
|
|||
|
"""获取产品销售数据"""
|
|||
|
pass
|
|||
|
|
|||
|
@app.route('/api/data/upload', methods=['POST'])
|
|||
|
def upload_data():
|
|||
|
"""上传销售数据"""
|
|||
|
pass
|
|||
|
|
|||
|
# 模型训练API
|
|||
|
@app.route('/api/training', methods=['POST'])
|
|||
|
def start_training():
|
|||
|
"""启动模型训练任务"""
|
|||
|
pass
|
|||
|
|
|||
|
@app.route('/api/training/<task_id>', methods=['GET'])
|
|||
|
def get_training_status(task_id):
|
|||
|
"""查询训练任务状态"""
|
|||
|
pass
|
|||
|
|
|||
|
# 模型预测API
|
|||
|
@app.route('/api/prediction', methods=['POST'])
|
|||
|
def predict():
|
|||
|
"""使用模型预测"""
|
|||
|
pass
|
|||
|
|
|||
|
@app.route('/api/prediction/compare', methods=['POST'])
|
|||
|
def compare_predictions():
|
|||
|
"""比较不同模型预测结果"""
|
|||
|
pass
|
|||
|
|
|||
|
# 模型管理API
|
|||
|
@app.route('/api/models', methods=['GET'])
|
|||
|
def get_models():
|
|||
|
"""获取模型列表"""
|
|||
|
pass
|
|||
|
|
|||
|
@app.route('/api/models/<product_id>/<model_type>', methods=['GET'])
|
|||
|
def get_model_details(product_id, model_type):
|
|||
|
"""获取模型详情"""
|
|||
|
pass
|
|||
|
|
|||
|
@app.route('/api/models/<product_id>/<model_type>', methods=['DELETE'])
|
|||
|
def delete_model(product_id, model_type):
|
|||
|
"""删除模型"""
|
|||
|
pass
|
|||
|
|
|||
|
@app.route('/api/models/<product_id>/<model_type>/export', methods=['GET'])
|
|||
|
def export_model(product_id, model_type):
|
|||
|
"""导出模型"""
|
|||
|
pass
|
|||
|
|
|||
|
@app.route('/api/models/import', methods=['POST'])
|
|||
|
def import_model():
|
|||
|
"""导入模型"""
|
|||
|
pass
|
|||
|
```
|
|||
|
|
|||
|
## 6. 扩展指南
|
|||
|
|
|||
|
### 6.1 添加新模型
|
|||
|
|
|||
|
要添加新的预测模型,请按以下步骤操作:
|
|||
|
|
|||
|
1. 在`models/`目录下创建新的模型文件,例如`new_model.py`
|
|||
|
2. 实现模型类,遵循现有模型的接口约定
|
|||
|
3. 在`pharmacy_predictor.py`中添加训练函数`train_product_model_with_new_model()`
|
|||
|
4. 更新`load_model_and_predict()`函数以支持新模型
|
|||
|
5. 在`run_pharmacy_prediction.py`中添加新模型的训练选项
|
|||
|
6. 更新模型管理工具以支持新模型
|
|||
|
|
|||
|
示例:
|
|||
|
|
|||
|
```python
|
|||
|
# models/new_model.py
|
|||
|
import torch
|
|||
|
import torch.nn as nn
|
|||
|
|
|||
|
class NewModel(nn.Module):
|
|||
|
def __init__(self, num_features, hidden_size, output_sequence_length):
|
|||
|
super().__init__()
|
|||
|
self.hidden_size = hidden_size
|
|||
|
self.output_sequence_length = output_sequence_length
|
|||
|
|
|||
|
# 定义模型架构
|
|||
|
self.input_layer = nn.Linear(num_features, hidden_size)
|
|||
|
self.hidden_layer = nn.Linear(hidden_size, hidden_size)
|
|||
|
self.output_layer = nn.Linear(hidden_size, output_sequence_length)
|
|||
|
|
|||
|
def forward(self, x):
|
|||
|
# 实现前向传播
|
|||
|
batch_size, seq_len, _ = x.shape
|
|||
|
|
|||
|
# 处理输入序列
|
|||
|
x = x.view(batch_size, -1)
|
|||
|
|
|||
|
# 通过网络层
|
|||
|
x = torch.relu(self.input_layer(x))
|
|||
|
x = torch.relu(self.hidden_layer(x))
|
|||
|
output = self.output_layer(x)
|
|||
|
|
|||
|
return output
|
|||
|
```
|
|||
|
|
|||
|
```python
|
|||
|
# pharmacy_predictor.py 中添加训练函数
|
|||
|
def train_product_model_with_new_model(product_id, epochs=50):
|
|||
|
# 读取数据
|
|||
|
df = pd.read_excel('pharmacy_sales.xlsx')
|
|||
|
product_df = df[df['product_id'] == product_id].sort_values('date')
|
|||
|
product_name = product_df['product_name'].iloc[0]
|
|||
|
|
|||
|
print(f"使用NewModel训练产品 '{product_name}' (ID: {product_id}) 的销售预测模型")
|
|||
|
|
|||
|
# 数据预处理
|
|||
|
# ...
|
|||
|
|
|||
|
# 创建模型
|
|||
|
model = NewModel(
|
|||
|
num_features=num_features,
|
|||
|
hidden_size=64,
|
|||
|
output_sequence_length=T
|
|||
|
)
|
|||
|
|
|||
|
# 训练模型
|
|||
|
# ...
|
|||
|
|
|||
|
return model, metrics
|
|||
|
```
|
|||
|
|
|||
|
### 6.2 添加新特征
|
|||
|
|
|||
|
要添加新的预测特征,请按以下步骤操作:
|
|||
|
|
|||
|
1. 修改`generate_pharmacy_data.py`,在生成的数据中添加新特征
|
|||
|
2. 修改`pharmacy_predictor.py`中的特征列表
|
|||
|
3. 更新数据预处理代码以处理新特征
|
|||
|
4. 如果需要,调整模型架构以适应新特征
|
|||
|
|
|||
|
示例,添加"促销折扣率"特征:
|
|||
|
|
|||
|
```python
|
|||
|
# pharmacy_predictor.py 更新特征列表
|
|||
|
features = ['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature', 'discount_rate']
|
|||
|
```
|
|||
|
|
|||
|
### 6.3 添加新API端点
|
|||
|
|
|||
|
要添加新的API端点,请按以下步骤操作:
|
|||
|
|
|||
|
1. 在`api.py`中定义新的路由和处理函数
|
|||
|
2. 更新API文档
|
|||
|
3. 在`api_test.py`中添加相应的测试
|
|||
|
|
|||
|
示例:
|
|||
|
|
|||
|
```python
|
|||
|
# api.py 添加新端点
|
|||
|
@app.route('/api/products/<product_id>/forecast', methods=['GET'])
|
|||
|
def get_product_forecast(product_id):
|
|||
|
"""获取产品销售预测"""
|
|||
|
try:
|
|||
|
# 获取查询参数
|
|||
|
days = request.args.get('days', default=7, type=int)
|
|||
|
model_type = request.args.get('model_type', default='mlstm', type=str)
|
|||
|
|
|||
|
# 调用预测函数
|
|||
|
predictions = load_model_and_predict(product_id, model_type, future_days=days)
|
|||
|
|
|||
|
if predictions is None:
|
|||
|
return jsonify({
|
|||
|
'status': 'error',
|
|||
|
'error': {
|
|||
|
'code': 'MODEL_NOT_FOUND',
|
|||
|
'message': '找不到指定的模型'
|
|||
|
}
|
|||
|
}), 404
|
|||
|
|
|||
|
# 处理预测结果
|
|||
|
response = {
|
|||
|
'status': 'success',
|
|||
|
'data': predictions
|
|||
|
}
|
|||
|
|
|||
|
return jsonify(response)
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
return jsonify({
|
|||
|
'status': 'error',
|
|||
|
'error': {
|
|||
|
'code': 'PREDICTION_ERROR',
|
|||
|
'message': str(e)
|
|||
|
}
|
|||
|
}), 500
|
|||
|
```
|
|||
|
|
|||
|
## 7. 测试指南
|
|||
|
|
|||
|
### 7.1 单元测试
|
|||
|
|
|||
|
项目使用`unittest`框架进行单元测试。测试文件存放在`tests/`目录下(需要创建)。
|
|||
|
|
|||
|
添加新测试:
|
|||
|
|
|||
|
1. 在`tests/`目录下创建测试文件,例如`test_new_model.py`
|
|||
|
2. 编写测试用例
|
|||
|
3. 运行测试:`python -m unittest discover tests`
|
|||
|
|
|||
|
示例:
|
|||
|
|
|||
|
```python
|
|||
|
# tests/test_new_model.py
|
|||
|
import unittest
|
|||
|
import torch
|
|||
|
import numpy as np
|
|||
|
from models.new_model import NewModel
|
|||
|
|
|||
|
class TestNewModel(unittest.TestCase):
|
|||
|
|
|||
|
def setUp(self):
|
|||
|
self.num_features = 8
|
|||
|
self.hidden_size = 64
|
|||
|
self.output_sequence_length = 7
|
|||
|
self.batch_size = 16
|
|||
|
self.seq_len = 14
|
|||
|
|
|||
|
self.model = NewModel(
|
|||
|
num_features=self.num_features,
|
|||
|
hidden_size=self.hidden_size,
|
|||
|
output_sequence_length=self.output_sequence_length
|
|||
|
)
|
|||
|
|
|||
|
def test_model_output_shape(self):
|
|||
|
# 创建随机输入
|
|||
|
x = torch.rand(self.batch_size, self.seq_len, self.num_features)
|
|||
|
|
|||
|
# 前向传播
|
|||
|
output = self.model(x)
|
|||
|
|
|||
|
# 检查输出形状
|
|||
|
self.assertEqual(output.shape, (self.batch_size, self.output_sequence_length))
|
|||
|
|
|||
|
def test_model_forward_pass(self):
|
|||
|
# 创建随机输入
|
|||
|
x = torch.rand(self.batch_size, self.seq_len, self.num_features)
|
|||
|
|
|||
|
# 检查前向传播不会抛出异常
|
|||
|
try:
|
|||
|
output = self.model(x)
|
|||
|
success = True
|
|||
|
except:
|
|||
|
success = False
|
|||
|
|
|||
|
self.assertTrue(success)
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
unittest.main()
|
|||
|
```
|
|||
|
|
|||
|
### 7.2 API测试
|
|||
|
|
|||
|
API测试使用`api_test.py`脚本。要添加新的API测试,请修改该文件并添加测试函数。
|
|||
|
|
|||
|
示例:
|
|||
|
|
|||
|
```python
|
|||
|
# api_test.py 添加新测试
|
|||
|
def test_get_product_forecast():
|
|||
|
"""测试获取产品销售预测"""
|
|||
|
print("测试获取产品销售预测...")
|
|||
|
|
|||
|
# 请求API
|
|||
|
response = requests.get(f"{BASE_URL}/products/P001/forecast?days=10&model_type=mlstm")
|
|||
|
|
|||
|
# 检查响应
|
|||
|
if response.status_code == 200:
|
|||
|
data = response.json()
|
|||
|
if data['status'] == 'success':
|
|||
|
print("✓ 成功获取产品销售预测")
|
|||
|
print(f"预测销量: {data['data']['predicted_sales']}")
|
|||
|
else:
|
|||
|
print(f"✗ 获取产品销售预测失败: {data['error']['message']}")
|
|||
|
else:
|
|||
|
print(f"✗ 请求失败,状态码: {response.status_code}")
|
|||
|
|
|||
|
print()
|
|||
|
```
|
|||
|
|
|||
|
## 8. 部署指南
|
|||
|
|
|||
|
### 8.1 生产环境部署
|
|||
|
|
|||
|
#### 8.1.1 Docker部署
|
|||
|
|
|||
|
您可以使用Docker容器化应用程序进行部署。创建一个`Dockerfile`:
|
|||
|
|
|||
|
```dockerfile
|
|||
|
FROM python:3.8-slim
|
|||
|
|
|||
|
WORKDIR /app
|
|||
|
|
|||
|
# 复制项目文件
|
|||
|
COPY . .
|
|||
|
|
|||
|
# 安装依赖
|
|||
|
RUN pip install --no-cache-dir -r requirements.txt
|
|||
|
|
|||
|
# 暴露端口
|
|||
|
EXPOSE 5000
|
|||
|
|
|||
|
# 启动API服务
|
|||
|
CMD ["python", "api.py", "--host", "0.0.0.0"]
|
|||
|
```
|
|||
|
|
|||
|
构建和运行Docker容器:
|
|||
|
|
|||
|
```bash
|
|||
|
docker build -t pharmacy-prediction-system .
|
|||
|
docker run -p 5000:5000 pharmacy-prediction-system
|
|||
|
```
|
|||
|
|
|||
|
#### 8.1.2 传统服务器部署
|
|||
|
|
|||
|
1. 安装Python 3.8+
|
|||
|
2. 克隆项目代码
|
|||
|
3. 安装依赖:`pip install -r requirements.txt`
|
|||
|
4. 设置环境变量(可选)
|
|||
|
5. 使用Gunicorn运行API服务:`gunicorn -w 4 -b 0.0.0.0:5000 api:app`
|
|||
|
|
|||
|
### 8.2 模型部署
|
|||
|
|
|||
|
模型训练和预测可以分离部署:
|
|||
|
|
|||
|
1. 在训练服务器上训练模型
|
|||
|
2. 导出模型:`python model_management.py --action export --product_id P001 --model_type mlstm`
|
|||
|
3. 将模型文件传输到预测服务器
|
|||
|
4. 在预测服务器上导入模型:`python model_management.py --action import --file_path exported_models/P001_mlstm_20230615123456.pt`
|
|||
|
|
|||
|
## 9. 性能优化
|
|||
|
|
|||
|
### 9.1 模型优化
|
|||
|
|
|||
|
- **量化**: 使用PyTorch的量化功能减小模型大小并加速推理
|
|||
|
- **剪枝**: 移除对预测贡献小的神经元,减小模型复杂度
|
|||
|
- **知识蒸馏**: 训练小模型学习大模型的行为
|
|||
|
|
|||
|
### 9.2 推理优化
|
|||
|
|
|||
|
- **批处理**: 尽可能使用批处理进行预测
|
|||
|
- **TorchScript**: 将模型转换为TorchScript以优化推理
|
|||
|
- **ONNX导出**: 导出为ONNX格式,使用专用推理引擎
|
|||
|
|
|||
|
### 9.3 数据加载优化
|
|||
|
|
|||
|
- **预处理缓存**: 缓存预处理后的数据
|
|||
|
- **异步数据加载**: 使用多线程加载和预处理数据
|
|||
|
- **内存映射**: 对大型数据集使用内存映射
|
|||
|
|
|||
|
## 10. 贡献指南
|
|||
|
|
|||
|
### 10.1 代码规范
|
|||
|
|
|||
|
- 遵循PEP 8 Python代码风格指南
|
|||
|
- 使用类型注解提高代码可读性
|
|||
|
- 为所有公共函数和类编写文档字符串
|
|||
|
- 使用英文编写注释和文档字符串
|
|||
|
|
|||
|
### 10.2 提交流程
|
|||
|
|
|||
|
1. Fork项目仓库
|
|||
|
2. 创建功能分支
|
|||
|
3. 编写代码并遵循代码规范
|
|||
|
4. 编写测试
|
|||
|
5. 提交拉取请求
|
|||
|
6. 等待代码审查
|
|||
|
|
|||
|
### 10.3 版本控制
|
|||
|
|
|||
|
项目使用语义化版本控制:
|
|||
|
|
|||
|
- 主版本号:不兼容的API更改
|
|||
|
- 次版本号:向后兼容的功能性新增
|
|||
|
- 修订号:向后兼容的问题修正
|
|||
|
|
|||
|
## 11. 常见问题解答
|
|||
|
|
|||
|
### 11.1 模型训练相关
|
|||
|
|
|||
|
**Q: 如何加速模型训练?**
|
|||
|
|
|||
|
A: 您可以尝试以下方法:
|
|||
|
- 使用GPU加速(系统会自动检测并使用GPU)
|
|||
|
- 减小batch_size
|
|||
|
- 减少训练轮次(epochs)
|
|||
|
- 减小模型复杂度(例如减少Transformer层数或隐藏层大小)
|
|||
|
|
|||
|
**Q: 模型保存失败怎么办?**
|
|||
|
|
|||
|
A: 请检查:
|
|||
|
- 是否有足够的磁盘空间
|
|||
|
- 是否有写入权限
|
|||
|
- predictions目录是否存在(如果不存在,手动创建)
|
|||
|
|
|||
|
### 11.2 API相关
|
|||
|
|
|||
|
**Q: API服务启动失败怎么办?**
|
|||
|
|
|||
|
A: 常见原因包括:
|
|||
|
- 端口被占用(更改端口)
|
|||
|
- 缺少依赖(安装API依赖)
|
|||
|
- 权限问题(以管理员/root身份运行)
|
|||
|
|
|||
|
**Q: 如何限制API访问速率?**
|
|||
|
|
|||
|
A: 您可以使用Flask-Limiter扩展:
|
|||
|
```python
|
|||
|
from flask_limiter import Limiter
|
|||
|
from flask_limiter.util import get_remote_address
|
|||
|
|
|||
|
limiter = Limiter(
|
|||
|
app,
|
|||
|
key_func=get_remote_address,
|
|||
|
default_limits=["200 per day", "50 per hour"]
|
|||
|
)
|
|||
|
|
|||
|
# 然后在路由上应用限制
|
|||
|
@app.route('/api/products')
|
|||
|
@limiter.limit("10 per minute")
|
|||
|
def get_products():
|
|||
|
# ...
|
|||
|
```
|
|||
|
|
|||
|
### 11.3 模型管理相关
|
|||
|
|
|||
|
**Q: 如何清理旧的模型文件以节省空间?**
|
|||
|
|
|||
|
A: 您可以使用模型管理工具删除旧模型:
|
|||
|
```bash
|
|||
|
python model_management.py --action delete --product_id P001 --model_type mlstm
|
|||
|
```
|
|||
|
|
|||
|
或者编写一个脚本自动删除某个日期之前的模型:
|
|||
|
```python
|
|||
|
import os
|
|||
|
import glob
|
|||
|
import datetime
|
|||
|
|
|||
|
def clean_old_models(days=30):
|
|||
|
"""删除指定天数之前的模型"""
|
|||
|
cutoff_date = datetime.datetime.now() - datetime.timedelta(days=days)
|
|||
|
|
|||
|
for model_dir in glob.glob('predictions/*/*'):
|
|||
|
for model_file in glob.glob(f'{model_dir}/*.pt'):
|
|||
|
# 从文件名中提取时间戳
|
|||
|
try:
|
|||
|
timestamp = os.path.basename(model_file).split('.')[0]
|
|||
|
model_date = datetime.datetime.strptime(timestamp, '%Y%m%d%H%M%S')
|
|||
|
|
|||
|
if model_date < cutoff_date:
|
|||
|
os.remove(model_file)
|
|||
|
print(f"已删除旧模型: {model_file}")
|
|||
|
except:
|
|||
|
continue
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
clean_old_models(30) # 删除30天前的模型
|
|||
|
```
|