138 lines
4.2 KiB
Markdown
138 lines
4.2 KiB
Markdown
![]() |
# API.py 文件调整总结
|
|||
|
|
|||
|
## 主要修改内容
|
|||
|
|
|||
|
为了使 api.py 适应新的模块化架构,我们进行了以下主要修改:
|
|||
|
|
|||
|
### 1. 更新导入语句
|
|||
|
|
|||
|
```python
|
|||
|
# 导入核心预测器类
|
|||
|
from core.predictor import PharmacyPredictor
|
|||
|
|
|||
|
# 导入训练函数
|
|||
|
from trainers.mlstm_trainer import train_product_model_with_mlstm
|
|||
|
from trainers.kan_trainer import train_product_model_with_kan
|
|||
|
from trainers.tcn_trainer import train_product_model_with_tcn
|
|||
|
from trainers.transformer_trainer import train_product_model_with_transformer
|
|||
|
|
|||
|
# 导入预测函数
|
|||
|
from predictors.model_predictor import load_model_and_predict
|
|||
|
|
|||
|
# 导入分析函数
|
|||
|
from analysis.trend_analysis import analyze_prediction_result
|
|||
|
from analysis.metrics import evaluate_model, compare_models
|
|||
|
```
|
|||
|
|
|||
|
### 2. 使用 PharmacyPredictor 类进行模型训练
|
|||
|
|
|||
|
```python
|
|||
|
# 创建预测器实例
|
|||
|
predictor = PharmacyPredictor()
|
|||
|
|
|||
|
# 使用预测器进行训练
|
|||
|
metrics = predictor.train_model(
|
|||
|
product_id=product_id,
|
|||
|
model_type=model_type,
|
|||
|
epochs=epochs
|
|||
|
)
|
|||
|
```
|
|||
|
|
|||
|
### 3. 更新预测功能
|
|||
|
|
|||
|
```python
|
|||
|
# 创建预测器实例
|
|||
|
predictor = PharmacyPredictor()
|
|||
|
|
|||
|
# 使用预测器进行预测
|
|||
|
result = predictor.predict(
|
|||
|
product_id=product_id,
|
|||
|
model_type=model_type,
|
|||
|
future_days=future_days,
|
|||
|
start_date=start_date,
|
|||
|
analyze_result=True
|
|||
|
)
|
|||
|
```
|
|||
|
|
|||
|
### 4. 更新模型比较功能
|
|||
|
|
|||
|
```python
|
|||
|
# 执行每个模型的预测
|
|||
|
predictions = {}
|
|||
|
metrics = {}
|
|||
|
|
|||
|
for model_type in model_types:
|
|||
|
result = predictor.predict(
|
|||
|
product_id=product_id,
|
|||
|
model_type=model_type,
|
|||
|
future_days=future_days,
|
|||
|
start_date=start_date,
|
|||
|
analyze_result=True
|
|||
|
)
|
|||
|
|
|||
|
if result and 'predictions' in result:
|
|||
|
predictions[model_type] = result['predictions']
|
|||
|
|
|||
|
# 如果有分析结果,提取评估指标
|
|||
|
if 'analysis' in result and result['analysis']:
|
|||
|
metrics[model_type] = result['analysis'].get('metrics', {})
|
|||
|
|
|||
|
# 比较模型性能
|
|||
|
comparison_result = {}
|
|||
|
if len(metrics) >= 2:
|
|||
|
comparison_result = compare_models(metrics)
|
|||
|
```
|
|||
|
|
|||
|
### 5. 更新预测分析功能
|
|||
|
|
|||
|
```python
|
|||
|
# 使用分析函数
|
|||
|
from analysis.trend_analysis import analyze_prediction_result
|
|||
|
analysis = analyze_prediction_result(product_id, model_type, predictions_array, features)
|
|||
|
```
|
|||
|
|
|||
|
### 6. 更新服务器启动配置
|
|||
|
|
|||
|
```python
|
|||
|
# 确保目录存在
|
|||
|
os.makedirs('static/plots', exist_ok=True)
|
|||
|
os.makedirs('static/csv', exist_ok=True)
|
|||
|
os.makedirs('static/predictions/compare', exist_ok=True)
|
|||
|
os.makedirs('saved_models', exist_ok=True)
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
parser = argparse.ArgumentParser(description='药店销售预测系统API服务')
|
|||
|
parser.add_argument('--host', default='0.0.0.0', help='服务器主机地址')
|
|||
|
parser.add_argument('--port', type=int, default=5000, help='服务器端口')
|
|||
|
parser.add_argument('--debug', action='store_true', help='是否启用调试模式')
|
|||
|
|
|||
|
args = parser.parse_args()
|
|||
|
|
|||
|
print(f"启动API服务,地址: {args.host}:{args.port}")
|
|||
|
print(f"API文档地址: http://{args.host}:{args.port}/swagger/")
|
|||
|
print(f"UI界面地址: http://{args.host}:{args.port}/ui/")
|
|||
|
|
|||
|
app.run(host=args.host, port=args.port, debug=args.debug)
|
|||
|
```
|
|||
|
|
|||
|
## 调整优势
|
|||
|
|
|||
|
1. **更好的代码组织**:API服务现在使用模块化架构中的类和函数,使代码结构更清晰。
|
|||
|
|
|||
|
2. **统一的接口**:通过 PharmacyPredictor 类提供统一的接口进行训练和预测,简化了API实现。
|
|||
|
|
|||
|
3. **功能分离**:训练、预测和分析功能现在分别由专门的模块提供,提高了代码的可维护性。
|
|||
|
|
|||
|
4. **保持API兼容性**:尽管内部实现发生了变化,但API接口保持不变,确保现有前端应用可以继续使用。
|
|||
|
|
|||
|
5. **更灵活的配置**:添加了命令行参数解析,使服务器配置更加灵活。
|
|||
|
|
|||
|
## 注意事项
|
|||
|
|
|||
|
1. 确保所有必要的目录存在,特别是 `saved_models` 目录,用于存储训练好的模型。
|
|||
|
|
|||
|
2. API服务现在依赖于新的模块化结构,如果模块名称或路径发生变化,需要相应地更新导入语句。
|
|||
|
|
|||
|
3. 如果模型的保存格式或路径约定发生变化,可能需要更新 `get_latest_model_id` 函数。
|
|||
|
|
|||
|
4. 在部署前,建议进行全面测试,确保所有API端点都能正常工作。
|