""" 模型注册表 用于解耦模型的调用和实现,支持插件式扩展新模型。 """ # 训练器注册表 TRAINER_REGISTRY = {} def register_trainer(name, func): """ 注册一个模型训练器。 参数: name (str): 模型类型名称 (e.g., 'xgboost') func (function): 对应的训练函数 """ if name in TRAINER_REGISTRY: print(f"警告: 模型训练器 '{name}' 已被覆盖注册。") TRAINER_REGISTRY[name] = func print(f"✅ 已注册训练器: {name}") def get_trainer(name): """ 根据模型类型名称获取一个已注册的训练器。 """ if name not in TRAINER_REGISTRY: # 在打印可用训练器之前,确保它们已经被加载 from trainers import discover_trainers discover_trainers() if name not in TRAINER_REGISTRY: raise ValueError(f"未注册的模型训练器: '{name}'. 可用: {list(TRAINER_REGISTRY.keys())}") return TRAINER_REGISTRY[name] # --- 预测器注册表 --- # 预测器函数需要一个统一的接口,例如: # def predictor_function(model, checkpoint, **kwargs): -> predictions PREDICTOR_REGISTRY = {} def register_predictor(name, func): """ 注册一个模型预测器。 """ if name in PREDICTOR_REGISTRY: print(f"警告: 模型预测器 '{name}' 已被覆盖注册。") PREDICTOR_REGISTRY[name] = func def get_predictor(name): """ 根据模型类型名称获取一个已注册的预测器。 如果找不到特定预测器,可以返回一个默认的。 """ return PREDICTOR_REGISTRY.get(name, PREDICTOR_REGISTRY.get('default')) # 默认的PyTorch预测逻辑可以被注册为 'default' def register_default_predictors(): from predictors.model_predictor import default_pytorch_predictor register_predictor('default', default_pytorch_predictor) # 如果其他PyTorch模型有特殊预测逻辑,也可以在这里注册 # register_predictor('kan', kan_predictor_func) # 注意:这个函数的调用时机很重要,需要在应用启动时执行一次。 # 我们可以暂时在 model_predictor.py 导入注册表后调用它。