import os import torch import glob import pandas as pd import matplotlib.pyplot as plt from datetime import datetime import json import shutil from .utils import get_device, to_device from .mlstm_model import MLSTMTransformer from .transformer_model import TimeSeriesTransformer from .kan_model import KANForecaster class ModelManager: """ 模型管理类:此类现在主要负责提供模型类的映射。 注意:所有与文件系统交互的逻辑(保存、加载、删除等)已被移除, 并由 server.utils.file_save.ModelPathManager 统一处理, 以遵循新的扁平化文件存储规范。 """ def __init__(self): """ 初始化模型管理器 """ # 模型类型到其对应类的映射 self.model_types = { 'mlstm': MLSTMTransformer, 'transformer': TimeSeriesTransformer, 'kan': KANForecaster } def get_model_class(self, model_type: str): """ 根据模型类型字符串获取模型类。 Args: model_type (str): 模型类型 (e.g., 'mlstm', 'kan')。 Returns: 模型类,如果不存在则返回 None。 """ return self.model_types.get(model_type)