ShopTRAINING/docs/api_update_summary.md

4.2 KiB
Raw Permalink Blame History

API.py 文件调整总结

主要修改内容

为了使 api.py 适应新的模块化架构,我们进行了以下主要修改:

1. 更新导入语句

# 导入核心预测器类
from core.predictor import PharmacyPredictor

# 导入训练函数
from trainers.mlstm_trainer import train_product_model_with_mlstm
from trainers.kan_trainer import train_product_model_with_kan
from trainers.tcn_trainer import train_product_model_with_tcn
from trainers.transformer_trainer import train_product_model_with_transformer

# 导入预测函数
from predictors.model_predictor import load_model_and_predict

# 导入分析函数
from analysis.trend_analysis import analyze_prediction_result
from analysis.metrics import evaluate_model, compare_models

2. 使用 PharmacyPredictor 类进行模型训练

# 创建预测器实例
predictor = PharmacyPredictor()

# 使用预测器进行训练
metrics = predictor.train_model(
    product_id=product_id,
    model_type=model_type,
    epochs=epochs
)

3. 更新预测功能

# 创建预测器实例
predictor = PharmacyPredictor()

# 使用预测器进行预测
result = predictor.predict(
    product_id=product_id,
    model_type=model_type,
    future_days=future_days,
    start_date=start_date,
    analyze_result=True
)

4. 更新模型比较功能

# 执行每个模型的预测
predictions = {}
metrics = {}

for model_type in model_types:
    result = predictor.predict(
        product_id=product_id,
        model_type=model_type,
        future_days=future_days,
        start_date=start_date,
        analyze_result=True
    )
    
    if result and 'predictions' in result:
        predictions[model_type] = result['predictions']
        
        # 如果有分析结果,提取评估指标
        if 'analysis' in result and result['analysis']:
            metrics[model_type] = result['analysis'].get('metrics', {})

# 比较模型性能
comparison_result = {}
if len(metrics) >= 2:
    comparison_result = compare_models(metrics)

5. 更新预测分析功能

# 使用分析函数
from analysis.trend_analysis import analyze_prediction_result
analysis = analyze_prediction_result(product_id, model_type, predictions_array, features)

6. 更新服务器启动配置

# 确保目录存在
os.makedirs('static/plots', exist_ok=True)
os.makedirs('static/csv', exist_ok=True)
os.makedirs('static/predictions/compare', exist_ok=True)
os.makedirs('saved_models', exist_ok=True)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='药店销售预测系统API服务')
    parser.add_argument('--host', default='0.0.0.0', help='服务器主机地址')
    parser.add_argument('--port', type=int, default=5000, help='服务器端口')
    parser.add_argument('--debug', action='store_true', help='是否启用调试模式')
    
    args = parser.parse_args()
    
    print(f"启动API服务地址: {args.host}:{args.port}")
    print(f"API文档地址: http://{args.host}:{args.port}/swagger/")
    print(f"UI界面地址: http://{args.host}:{args.port}/ui/")
    
    app.run(host=args.host, port=args.port, debug=args.debug)

调整优势

  1. 更好的代码组织API服务现在使用模块化架构中的类和函数使代码结构更清晰。

  2. 统一的接口:通过 PharmacyPredictor 类提供统一的接口进行训练和预测简化了API实现。

  3. 功能分离:训练、预测和分析功能现在分别由专门的模块提供,提高了代码的可维护性。

  4. 保持API兼容性尽管内部实现发生了变化但API接口保持不变确保现有前端应用可以继续使用。

  5. 更灵活的配置:添加了命令行参数解析,使服务器配置更加灵活。

注意事项

  1. 确保所有必要的目录存在,特别是 saved_models 目录,用于存储训练好的模型。

  2. API服务现在依赖于新的模块化结构如果模块名称或路径发生变化需要相应地更新导入语句。

  3. 如果模型的保存格式或路径约定发生变化,可能需要更新 get_latest_model_id 函数。

  4. 在部署前建议进行全面测试确保所有API端点都能正常工作。