import os
import torch
import glob
import pandas as pd
import matplotlib.pyplot as plt
from datetime import datetime
import json
import shutil
from .utils import get_device, to_device
from .mlstm_model import MLSTMTransformer
from .transformer_model import TimeSeriesTransformer
from .kan_model import KANForecaster
class ModelManager:
"""
模型管理类:此类现在主要负责提供模型类的映射。
注意:所有与文件系统交互的逻辑(保存、加载、删除等)已被移除,
并由 server.utils.file_save.ModelPathManager 统一处理,
以遵循新的扁平化文件存储规范。
def __init__(self):
初始化模型管理器
# 模型类型到其对应类的映射
self.model_types = {
'mlstm': MLSTMTransformer,
'transformer': TimeSeriesTransformer,
'kan': KANForecaster
}
def get_model_class(self, model_type: str):
根据模型类型字符串获取模型类。
Args:
model_type (str): 模型类型 (e.g., 'mlstm', 'kan')。
Returns:
模型类,如果不存在则返回 None。
return self.model_types.get(model_type)