4.2 KiB
4.2 KiB
API.py 文件调整总结
主要修改内容
为了使 api.py 适应新的模块化架构,我们进行了以下主要修改:
1. 更新导入语句
# 导入核心预测器类
from core.predictor import PharmacyPredictor
# 导入训练函数
from trainers.mlstm_trainer import train_product_model_with_mlstm
from trainers.kan_trainer import train_product_model_with_kan
from trainers.tcn_trainer import train_product_model_with_tcn
from trainers.transformer_trainer import train_product_model_with_transformer
# 导入预测函数
from predictors.model_predictor import load_model_and_predict
# 导入分析函数
from analysis.trend_analysis import analyze_prediction_result
from analysis.metrics import evaluate_model, compare_models
2. 使用 PharmacyPredictor 类进行模型训练
# 创建预测器实例
predictor = PharmacyPredictor()
# 使用预测器进行训练
metrics = predictor.train_model(
product_id=product_id,
model_type=model_type,
epochs=epochs
)
3. 更新预测功能
# 创建预测器实例
predictor = PharmacyPredictor()
# 使用预测器进行预测
result = predictor.predict(
product_id=product_id,
model_type=model_type,
future_days=future_days,
start_date=start_date,
analyze_result=True
)
4. 更新模型比较功能
# 执行每个模型的预测
predictions = {}
metrics = {}
for model_type in model_types:
result = predictor.predict(
product_id=product_id,
model_type=model_type,
future_days=future_days,
start_date=start_date,
analyze_result=True
)
if result and 'predictions' in result:
predictions[model_type] = result['predictions']
# 如果有分析结果,提取评估指标
if 'analysis' in result and result['analysis']:
metrics[model_type] = result['analysis'].get('metrics', {})
# 比较模型性能
comparison_result = {}
if len(metrics) >= 2:
comparison_result = compare_models(metrics)
5. 更新预测分析功能
# 使用分析函数
from analysis.trend_analysis import analyze_prediction_result
analysis = analyze_prediction_result(product_id, model_type, predictions_array, features)
6. 更新服务器启动配置
# 确保目录存在
os.makedirs('static/plots', exist_ok=True)
os.makedirs('static/csv', exist_ok=True)
os.makedirs('static/predictions/compare', exist_ok=True)
os.makedirs('saved_models', exist_ok=True)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='药店销售预测系统API服务')
parser.add_argument('--host', default='0.0.0.0', help='服务器主机地址')
parser.add_argument('--port', type=int, default=5000, help='服务器端口')
parser.add_argument('--debug', action='store_true', help='是否启用调试模式')
args = parser.parse_args()
print(f"启动API服务,地址: {args.host}:{args.port}")
print(f"API文档地址: http://{args.host}:{args.port}/swagger/")
print(f"UI界面地址: http://{args.host}:{args.port}/ui/")
app.run(host=args.host, port=args.port, debug=args.debug)
调整优势
-
更好的代码组织:API服务现在使用模块化架构中的类和函数,使代码结构更清晰。
-
统一的接口:通过 PharmacyPredictor 类提供统一的接口进行训练和预测,简化了API实现。
-
功能分离:训练、预测和分析功能现在分别由专门的模块提供,提高了代码的可维护性。
-
保持API兼容性:尽管内部实现发生了变化,但API接口保持不变,确保现有前端应用可以继续使用。
-
更灵活的配置:添加了命令行参数解析,使服务器配置更加灵活。
注意事项
-
确保所有必要的目录存在,特别是
saved_models
目录,用于存储训练好的模型。 -
API服务现在依赖于新的模块化结构,如果模块名称或路径发生变化,需要相应地更新导入语句。
-
如果模型的保存格式或路径约定发生变化,可能需要更新
get_latest_model_id
函数。 -
在部署前,建议进行全面测试,确保所有API端点都能正常工作。