64 lines
2.1 KiB
Python
64 lines
2.1 KiB
Python
"""
|
||
模型注册表
|
||
用于解耦模型的调用和实现,支持插件式扩展新模型。
|
||
"""
|
||
|
||
# 训练器注册表
|
||
TRAINER_REGISTRY = {}
|
||
|
||
def register_trainer(name, func):
|
||
"""
|
||
注册一个模型训练器。
|
||
|
||
参数:
|
||
name (str): 模型类型名称 (e.g., 'xgboost')
|
||
func (function): 对应的训练函数
|
||
"""
|
||
if name in TRAINER_REGISTRY:
|
||
print(f"警告: 模型训练器 '{name}' 已被覆盖注册。")
|
||
TRAINER_REGISTRY[name] = func
|
||
print(f"✅ 已注册训练器: {name}")
|
||
|
||
def get_trainer(name):
|
||
"""
|
||
根据模型类型名称获取一个已注册的训练器。
|
||
"""
|
||
if name not in TRAINER_REGISTRY:
|
||
# 在打印可用训练器之前,确保它们已经被加载
|
||
from trainers import discover_trainers
|
||
discover_trainers()
|
||
if name not in TRAINER_REGISTRY:
|
||
raise ValueError(f"未注册的模型训练器: '{name}'. 可用: {list(TRAINER_REGISTRY.keys())}")
|
||
return TRAINER_REGISTRY[name]
|
||
|
||
# --- 预测器注册表 ---
|
||
|
||
# 预测器函数需要一个统一的接口,例如:
|
||
# def predictor_function(model, checkpoint, **kwargs): -> predictions
|
||
|
||
PREDICTOR_REGISTRY = {}
|
||
|
||
def register_predictor(name, func):
|
||
"""
|
||
注册一个模型预测器。
|
||
"""
|
||
if name in PREDICTOR_REGISTRY:
|
||
print(f"警告: 模型预测器 '{name}' 已被覆盖注册。")
|
||
PREDICTOR_REGISTRY[name] = func
|
||
|
||
def get_predictor(name):
|
||
"""
|
||
根据模型类型名称获取一个已注册的预测器。
|
||
如果找不到特定预测器,可以返回一个默认的。
|
||
"""
|
||
return PREDICTOR_REGISTRY.get(name, PREDICTOR_REGISTRY.get('default'))
|
||
|
||
# 默认的PyTorch预测逻辑可以被注册为 'default'
|
||
def register_default_predictors():
|
||
from predictors.model_predictor import default_pytorch_predictor
|
||
register_predictor('default', default_pytorch_predictor)
|
||
# 如果其他PyTorch模型有特殊预测逻辑,也可以在这里注册
|
||
# register_predictor('kan', kan_predictor_func)
|
||
|
||
# 注意:这个函数的调用时机很重要,需要在应用启动时执行一次。
|
||
# 我们可以暂时在 model_predictor.py 导入注册表后调用它。 |