# API.py 文件调整总结 ## 主要修改内容 为了使 api.py 适应新的模块化架构,我们进行了以下主要修改: ### 1. 更新导入语句 ```python # 导入核心预测器类 from core.predictor import PharmacyPredictor # 导入训练函数 from trainers.mlstm_trainer import train_product_model_with_mlstm from trainers.kan_trainer import train_product_model_with_kan from trainers.tcn_trainer import train_product_model_with_tcn from trainers.transformer_trainer import train_product_model_with_transformer # 导入预测函数 from predictors.model_predictor import load_model_and_predict # 导入分析函数 from analysis.trend_analysis import analyze_prediction_result from analysis.metrics import evaluate_model, compare_models ``` ### 2. 使用 PharmacyPredictor 类进行模型训练 ```python # 创建预测器实例 predictor = PharmacyPredictor() # 使用预测器进行训练 metrics = predictor.train_model( product_id=product_id, model_type=model_type, epochs=epochs ) ``` ### 3. 更新预测功能 ```python # 创建预测器实例 predictor = PharmacyPredictor() # 使用预测器进行预测 result = predictor.predict( product_id=product_id, model_type=model_type, future_days=future_days, start_date=start_date, analyze_result=True ) ``` ### 4. 更新模型比较功能 ```python # 执行每个模型的预测 predictions = {} metrics = {} for model_type in model_types: result = predictor.predict( product_id=product_id, model_type=model_type, future_days=future_days, start_date=start_date, analyze_result=True ) if result and 'predictions' in result: predictions[model_type] = result['predictions'] # 如果有分析结果,提取评估指标 if 'analysis' in result and result['analysis']: metrics[model_type] = result['analysis'].get('metrics', {}) # 比较模型性能 comparison_result = {} if len(metrics) >= 2: comparison_result = compare_models(metrics) ``` ### 5. 更新预测分析功能 ```python # 使用分析函数 from analysis.trend_analysis import analyze_prediction_result analysis = analyze_prediction_result(product_id, model_type, predictions_array, features) ``` ### 6. 更新服务器启动配置 ```python # 确保目录存在 os.makedirs('static/plots', exist_ok=True) os.makedirs('static/csv', exist_ok=True) os.makedirs('static/predictions/compare', exist_ok=True) os.makedirs('saved_models', exist_ok=True) if __name__ == '__main__': parser = argparse.ArgumentParser(description='药店销售预测系统API服务') parser.add_argument('--host', default='0.0.0.0', help='服务器主机地址') parser.add_argument('--port', type=int, default=5000, help='服务器端口') parser.add_argument('--debug', action='store_true', help='是否启用调试模式') args = parser.parse_args() print(f"启动API服务,地址: {args.host}:{args.port}") print(f"API文档地址: http://{args.host}:{args.port}/swagger/") print(f"UI界面地址: http://{args.host}:{args.port}/ui/") app.run(host=args.host, port=args.port, debug=args.debug) ``` ## 调整优势 1. **更好的代码组织**:API服务现在使用模块化架构中的类和函数,使代码结构更清晰。 2. **统一的接口**:通过 PharmacyPredictor 类提供统一的接口进行训练和预测,简化了API实现。 3. **功能分离**:训练、预测和分析功能现在分别由专门的模块提供,提高了代码的可维护性。 4. **保持API兼容性**:尽管内部实现发生了变化,但API接口保持不变,确保现有前端应用可以继续使用。 5. **更灵活的配置**:添加了命令行参数解析,使服务器配置更加灵活。 ## 注意事项 1. 确保所有必要的目录存在,特别是 `saved_models` 目录,用于存储训练好的模型。 2. API服务现在依赖于新的模块化结构,如果模块名称或路径发生变化,需要相应地更新导入语句。 3. 如果模型的保存格式或路径约定发生变化,可能需要更新 `get_latest_model_id` 函数。 4. 在部署前,建议进行全面测试,确保所有API端点都能正常工作。