import torch import torch.nn as nn from .transformer_model import TransformerEncoder, TransformerDecoder # 定义mLSTM单元 class mLSTMCell(nn.Module): def __init__(self, input_size: int, hidden_size: int, bias: bool = True) -> None: super().__init__() self.input_size = input_size self.hidden_size = hidden_size self.bias = bias # 初始化权重和偏置 self.W_i = nn.Parameter( nn.init.xavier_uniform_(torch.zeros(input_size, hidden_size)), requires_grad=True, ) self.W_f = nn.Parameter( nn.init.xavier_uniform_(torch.zeros(input_size, hidden_size)), requires_grad=True, ) self.W_o = nn.Parameter( nn.init.xavier_uniform_(torch.zeros(input_size, hidden_size)), requires_grad=True, ) self.W_q = nn.Parameter( nn.init.xavier_uniform_(torch.zeros(input_size, hidden_size)), requires_grad=True, ) self.W_k = nn.Parameter( nn.init.xavier_uniform_(torch.zeros(input_size, hidden_size)), requires_grad=True, ) self.W_v = nn.Parameter( nn.init.xavier_uniform_(torch.zeros(input_size, hidden_size)), requires_grad=True, ) if self.bias: self.B_i = nn.Parameter(torch.zeros(hidden_size), requires_grad=True) self.B_f = nn.Parameter(torch.zeros(hidden_size), requires_grad=True) self.B_o = nn.Parameter(torch.zeros(hidden_size), requires_grad=True) self.B_q = nn.Parameter(torch.zeros(hidden_size), requires_grad=True) self.B_k = nn.Parameter(torch.zeros(hidden_size), requires_grad=True) self.B_v = nn.Parameter(torch.zeros(hidden_size), requires_grad=True) def forward( self, x: torch.Tensor, internal_state: tuple[torch.Tensor, torch.Tensor, torch.Tensor], ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: # 获取内部状态 C, n, m = internal_state # 计算输入、遗忘、输出、查询、键和值门 i_tilda = ( torch.matmul(x, self.W_i) + self.B_i if self.bias else torch.matmul(x, self.W_i) ) f_tilda = ( torch.matmul(x, self.W_f) + self.B_f if self.bias else torch.matmul(x, self.W_f) ) o_tilda = ( torch.matmul(x, self.W_o) + self.B_o if self.bias else torch.matmul(x, self.W_o) ) q_t = ( torch.matmul(x, self.W_q) + self.B_q if self.bias else torch.matmul(x, self.W_q) ) k_t = ( torch.matmul(x, self.W_k) / torch.sqrt(torch.tensor(self.hidden_size)) + self.B_k if self.bias else torch.matmul(x, self.W_k) / torch.sqrt(torch.tensor(self.hidden_size)) ) v_t = ( torch.matmul(x, self.W_v) + self.B_v if self.bias else torch.matmul(x, self.W_v) ) # 输入门的指数激活 i_t = torch.exp(i_tilda) f_t = torch.sigmoid(f_tilda) o_t = torch.sigmoid(o_tilda) # 稳定化状态 m_t = torch.max(torch.log(f_t) + m, torch.log(i_t)) i_prime = torch.exp(i_tilda - m_t) # 协方差矩阵和归一化状态 C_t = f_t.unsqueeze(-1) * C + i_prime.unsqueeze(-1) * torch.einsum( "bi, bk -> bik", v_t, k_t ) n_t = f_t * n + i_prime * k_t normalize_inner = torch.diagonal(torch.matmul(n_t, q_t.T)) divisor = torch.max( torch.abs(normalize_inner), torch.ones_like(normalize_inner) ) h_tilda = torch.einsum("bkj,bj -> bk", C_t, q_t) / divisor.view(-1, 1) h_t = o_t * h_tilda return h_t, (C_t, n_t, m_t) def init_hidden( self, batch_size: int, **kwargs ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: return ( torch.zeros(batch_size, self.hidden_size, self.hidden_size, **kwargs), torch.zeros(batch_size, self.hidden_size, **kwargs), torch.zeros(batch_size, self.hidden_size, **kwargs), ) # 定义mLSTM层 class mLSTM(nn.Module): def __init__( self, input_size: int, hidden_size: int, num_layers: int, bias: bool = True, batch_first: bool = False, ) -> None: super().__init__() self.input_size = input_size self.hidden_size = hidden_size self.num_layers = num_layers self.bias = bias self.batch_first = batch_first self.cells = nn.ModuleList( [ mLSTMCell(input_size if layer == 0 else hidden_size, hidden_size, bias) for layer in range(num_layers) ] ) def forward( self, x: torch.Tensor, hidden_states=None, ): # 如果batch_first为True,置换输入张量 if self.batch_first: x = x.permute(1, 0, 2) if hidden_states is None: hidden_states = self.init_hidden(x.size(1), device=x.device, dtype=x.dtype) else: # 检查隐藏状态是否正确长度 if len(hidden_states) != self.num_layers: raise ValueError( f"Expected hidden states of length {self.num_layers}, but got {len(hidden_states)}" ) if any(state[0].size(0) != x.size(1) for state in hidden_states): raise ValueError( f"Expected hidden states of batch size {x.size(1)}, but got {hidden_states[0][0].size(0)}" ) outputs = [] for t in range(x.size(0)): x_t = x[t] for layer in range(self.num_layers): if layer == 0: h_t, state = self.cells[layer](x_t, hidden_states[layer]) else: h_t, state = self.cells[layer](h_t, hidden_states[layer]) hidden_states[layer] = state outputs.append(h_t) # 将输出堆叠成序列 outputs = torch.stack(outputs) # 如果batch_first为True,置换输出张量 if self.batch_first: outputs = outputs.permute(1, 0, 2) return outputs, hidden_states def init_hidden( self, batch_size: int, **kwargs ): return [cell.init_hidden(batch_size, **kwargs) for cell in self.cells] # 结合mLSTM和Transformer的模型 class MLSTMTransformer(nn.Module): def __init__(self, num_features, hidden_size=128, mlstm_layers=1, embed_dim=32, dense_dim=32, num_heads=4, dropout_rate=0.1, num_blocks=3, output_sequence_length=1): super().__init__() # mLSTM部分 self.mlstm = mLSTM( input_size=num_features, hidden_size=hidden_size, num_layers=mlstm_layers, batch_first=True ) # Transformer部分 self.input_embedding = nn.Linear(hidden_size, embed_dim) self.positional_encoding = nn.Parameter(torch.randn(1, 1000, embed_dim) * 0.1) self.encoders = nn.ModuleList([TransformerEncoder(embed_dim, dense_dim, num_heads, dropout_rate) for _ in range(num_blocks)]) self.decoders = nn.ModuleList([TransformerDecoder(embed_dim, dense_dim, num_heads, dropout_rate) for _ in range(num_blocks)]) self.output_layer = nn.Linear(embed_dim, 1) self.dropout = nn.Dropout(dropout_rate) self.output_sequence_length = output_sequence_length def forward(self, inputs): # mLSTM处理 mlstm_output, _ = self.mlstm(inputs) # 将mLSTM输出嵌入到Transformer空间 x = self.input_embedding(mlstm_output) x = x + self.positional_encoding[:, :x.size(1), :] x = self.dropout(x) # 编码器部分 encoder_outputs = x for encoder in self.encoders: encoder_outputs = encoder(encoder_outputs) # 解码器部分 decoder_inputs = encoder_outputs[:, -1:, :].expand(-1, self.output_sequence_length, -1) decoder_outputs = decoder_inputs for decoder in self.decoders: decoder_outputs = decoder(decoder_outputs, encoder_outputs) return self.output_layer(decoder_outputs)