ShopTRAINING/server/models/model_registry.py

64 lines
2.1 KiB
Python
Raw Permalink Normal View History

2025-07-22 15:40:37 +08:00
"""
模型注册表
用于解耦模型的调用和实现支持插件式扩展新模型
"""
# 训练器注册表
TRAINER_REGISTRY = {}
def register_trainer(name, func):
"""
注册一个模型训练器
参数:
name (str): 模型类型名称 (e.g., 'xgboost')
func (function): 对应的训练函数
"""
if name in TRAINER_REGISTRY:
print(f"警告: 模型训练器 '{name}' 已被覆盖注册。")
TRAINER_REGISTRY[name] = func
print(f"✅ 已注册训练器: {name}")
def get_trainer(name):
"""
根据模型类型名称获取一个已注册的训练器
"""
if name not in TRAINER_REGISTRY:
# 在打印可用训练器之前,确保它们已经被加载
from trainers import discover_trainers
discover_trainers()
if name not in TRAINER_REGISTRY:
raise ValueError(f"未注册的模型训练器: '{name}'. 可用: {list(TRAINER_REGISTRY.keys())}")
return TRAINER_REGISTRY[name]
# --- 预测器注册表 ---
# 预测器函数需要一个统一的接口,例如:
# def predictor_function(model, checkpoint, **kwargs): -> predictions
PREDICTOR_REGISTRY = {}
def register_predictor(name, func):
"""
注册一个模型预测器
"""
if name in PREDICTOR_REGISTRY:
print(f"警告: 模型预测器 '{name}' 已被覆盖注册。")
PREDICTOR_REGISTRY[name] = func
def get_predictor(name):
"""
根据模型类型名称获取一个已注册的预测器
如果找不到特定预测器可以返回一个默认的
"""
return PREDICTOR_REGISTRY.get(name, PREDICTOR_REGISTRY.get('default'))
# 默认的PyTorch预测逻辑可以被注册为 'default'
def register_default_predictors():
from predictors.model_predictor import default_pytorch_predictor
register_predictor('default', default_pytorch_predictor)
# 如果其他PyTorch模型有特殊预测逻辑也可以在这里注册
# register_predictor('kan', kan_predictor_func)
# 注意:这个函数的调用时机很重要,需要在应用启动时执行一次。
# 我们可以暂时在 model_predictor.py 导入注册表后调用它。