138 lines
4.2 KiB
Markdown
138 lines
4.2 KiB
Markdown
# API.py 文件调整总结
|
||
|
||
## 主要修改内容
|
||
|
||
为了使 api.py 适应新的模块化架构,我们进行了以下主要修改:
|
||
|
||
### 1. 更新导入语句
|
||
|
||
```python
|
||
# 导入核心预测器类
|
||
from core.predictor import PharmacyPredictor
|
||
|
||
# 导入训练函数
|
||
from trainers.mlstm_trainer import train_product_model_with_mlstm
|
||
from trainers.kan_trainer import train_product_model_with_kan
|
||
from trainers.tcn_trainer import train_product_model_with_tcn
|
||
from trainers.transformer_trainer import train_product_model_with_transformer
|
||
|
||
# 导入预测函数
|
||
from predictors.model_predictor import load_model_and_predict
|
||
|
||
# 导入分析函数
|
||
from analysis.trend_analysis import analyze_prediction_result
|
||
from analysis.metrics import evaluate_model, compare_models
|
||
```
|
||
|
||
### 2. 使用 PharmacyPredictor 类进行模型训练
|
||
|
||
```python
|
||
# 创建预测器实例
|
||
predictor = PharmacyPredictor()
|
||
|
||
# 使用预测器进行训练
|
||
metrics = predictor.train_model(
|
||
product_id=product_id,
|
||
model_type=model_type,
|
||
epochs=epochs
|
||
)
|
||
```
|
||
|
||
### 3. 更新预测功能
|
||
|
||
```python
|
||
# 创建预测器实例
|
||
predictor = PharmacyPredictor()
|
||
|
||
# 使用预测器进行预测
|
||
result = predictor.predict(
|
||
product_id=product_id,
|
||
model_type=model_type,
|
||
future_days=future_days,
|
||
start_date=start_date,
|
||
analyze_result=True
|
||
)
|
||
```
|
||
|
||
### 4. 更新模型比较功能
|
||
|
||
```python
|
||
# 执行每个模型的预测
|
||
predictions = {}
|
||
metrics = {}
|
||
|
||
for model_type in model_types:
|
||
result = predictor.predict(
|
||
product_id=product_id,
|
||
model_type=model_type,
|
||
future_days=future_days,
|
||
start_date=start_date,
|
||
analyze_result=True
|
||
)
|
||
|
||
if result and 'predictions' in result:
|
||
predictions[model_type] = result['predictions']
|
||
|
||
# 如果有分析结果,提取评估指标
|
||
if 'analysis' in result and result['analysis']:
|
||
metrics[model_type] = result['analysis'].get('metrics', {})
|
||
|
||
# 比较模型性能
|
||
comparison_result = {}
|
||
if len(metrics) >= 2:
|
||
comparison_result = compare_models(metrics)
|
||
```
|
||
|
||
### 5. 更新预测分析功能
|
||
|
||
```python
|
||
# 使用分析函数
|
||
from analysis.trend_analysis import analyze_prediction_result
|
||
analysis = analyze_prediction_result(product_id, model_type, predictions_array, features)
|
||
```
|
||
|
||
### 6. 更新服务器启动配置
|
||
|
||
```python
|
||
# 确保目录存在
|
||
os.makedirs('static/plots', exist_ok=True)
|
||
os.makedirs('static/csv', exist_ok=True)
|
||
os.makedirs('static/predictions/compare', exist_ok=True)
|
||
os.makedirs('saved_models', exist_ok=True)
|
||
|
||
if __name__ == '__main__':
|
||
parser = argparse.ArgumentParser(description='药店销售预测系统API服务')
|
||
parser.add_argument('--host', default='0.0.0.0', help='服务器主机地址')
|
||
parser.add_argument('--port', type=int, default=5000, help='服务器端口')
|
||
parser.add_argument('--debug', action='store_true', help='是否启用调试模式')
|
||
|
||
args = parser.parse_args()
|
||
|
||
print(f"启动API服务,地址: {args.host}:{args.port}")
|
||
print(f"API文档地址: http://{args.host}:{args.port}/swagger/")
|
||
print(f"UI界面地址: http://{args.host}:{args.port}/ui/")
|
||
|
||
app.run(host=args.host, port=args.port, debug=args.debug)
|
||
```
|
||
|
||
## 调整优势
|
||
|
||
1. **更好的代码组织**:API服务现在使用模块化架构中的类和函数,使代码结构更清晰。
|
||
|
||
2. **统一的接口**:通过 PharmacyPredictor 类提供统一的接口进行训练和预测,简化了API实现。
|
||
|
||
3. **功能分离**:训练、预测和分析功能现在分别由专门的模块提供,提高了代码的可维护性。
|
||
|
||
4. **保持API兼容性**:尽管内部实现发生了变化,但API接口保持不变,确保现有前端应用可以继续使用。
|
||
|
||
5. **更灵活的配置**:添加了命令行参数解析,使服务器配置更加灵活。
|
||
|
||
## 注意事项
|
||
|
||
1. 确保所有必要的目录存在,特别是 `saved_models` 目录,用于存储训练好的模型。
|
||
|
||
2. API服务现在依赖于新的模块化结构,如果模块名称或路径发生变化,需要相应地更新导入语句。
|
||
|
||
3. 如果模型的保存格式或路径约定发生变化,可能需要更新 `get_latest_model_id` 函数。
|
||
|
||
4. 在部署前,建议进行全面测试,确保所有API端点都能正常工作。 |