ShopTRAINING/server/models/model_registry.py
2025-07-22 15:41:05 +08:00

64 lines
2.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
模型注册表
用于解耦模型的调用和实现,支持插件式扩展新模型。
"""
# 训练器注册表
TRAINER_REGISTRY = {}
def register_trainer(name, func):
"""
注册一个模型训练器。
参数:
name (str): 模型类型名称 (e.g., 'xgboost')
func (function): 对应的训练函数
"""
if name in TRAINER_REGISTRY:
print(f"警告: 模型训练器 '{name}' 已被覆盖注册。")
TRAINER_REGISTRY[name] = func
print(f"✅ 已注册训练器: {name}")
def get_trainer(name):
"""
根据模型类型名称获取一个已注册的训练器。
"""
if name not in TRAINER_REGISTRY:
# 在打印可用训练器之前,确保它们已经被加载
from trainers import discover_trainers
discover_trainers()
if name not in TRAINER_REGISTRY:
raise ValueError(f"未注册的模型训练器: '{name}'. 可用: {list(TRAINER_REGISTRY.keys())}")
return TRAINER_REGISTRY[name]
# --- 预测器注册表 ---
# 预测器函数需要一个统一的接口,例如:
# def predictor_function(model, checkpoint, **kwargs): -> predictions
PREDICTOR_REGISTRY = {}
def register_predictor(name, func):
"""
注册一个模型预测器。
"""
if name in PREDICTOR_REGISTRY:
print(f"警告: 模型预测器 '{name}' 已被覆盖注册。")
PREDICTOR_REGISTRY[name] = func
def get_predictor(name):
"""
根据模型类型名称获取一个已注册的预测器。
如果找不到特定预测器,可以返回一个默认的。
"""
return PREDICTOR_REGISTRY.get(name, PREDICTOR_REGISTRY.get('default'))
# 默认的PyTorch预测逻辑可以被注册为 'default'
def register_default_predictors():
from predictors.model_predictor import default_pytorch_predictor
register_predictor('default', default_pytorch_predictor)
# 如果其他PyTorch模型有特殊预测逻辑也可以在这里注册
# register_predictor('kan', kan_predictor_func)
# 注意:这个函数的调用时机很重要,需要在应用启动时执行一次。
# 我们可以暂时在 model_predictor.py 导入注册表后调用它。