import torch import torch.nn as nn # 定义sLSTM单元 class sLSTMCell(nn.Module): def __init__(self, input_size: int, hidden_size: int, bias: bool = True) -> None: super().__init__() # 存储输入和隐藏层大小 self.input_size = input_size self.hidden_size = hidden_size self.bias = bias # 组合权重和循环权重到单个矩阵 self.W = nn.Parameter( nn.init.xavier_uniform_( torch.randn(self.input_size + self.hidden_size, 4 * self.hidden_size) ), requires_grad=True, ) # 组合偏置到单个矩阵 if self.bias: self.B = nn.Parameter( (torch.zeros(4 * self.hidden_size)), requires_grad=True ) def forward( self, x: torch.Tensor, internal_state: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], ) -> tuple[ torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] ]: # 解包内部状态 h, c, n, m = internal_state # (batch_size, hidden_size) # 组合权重和输入 combined = torch.cat((x, h), dim=1) # (batch_size, input_size + hidden_size) # 计算线性变换 gates = torch.matmul(combined, self.W) # (batch_size, 4 * hidden_size) # 如果包括偏置,则添加 if self.bias: gates += self.B # 将门分为输入、遗忘、输出和稳定化门 z_tilda, i_tilda, f_tilda, o_tilda = torch.split(gates, self.hidden_size, dim=1) # 计算状态的激活 z_t = torch.tanh(z_tilda) # (batch_size, hidden_size) # 输入门的指数激活 i_t = torch.exp(i_tilda) # (batch_size, hidden_size) # 遗忘门的sigmoid激活 f_t = torch.sigmoid(f_tilda) # (batch_size, hidden_size) # 输出门的sigmoid激活 o_t = torch.sigmoid(o_tilda) # (batch_size, input_size) # 计算稳定化状态 m_t = torch.max(torch.log(f_t) + m, torch.log(i_t)) # (batch_size, hidden_size) # 计算输入稳定化状态 i_prime = torch.exp(i_tilda - m_t) # (batch_size, hidden_size) # 计算新的内部状态 c_t = f_t * c + i_prime * z_t # (batch_size, hidden_size) n_t = f_t * n + i_prime # (batch_size, hidden_size) # 计算稳定化隐藏状态 h_tilda = c_t / n_t # (batch_size, hidden_size) # 计算新的隐藏状态 h_t = o_t * h_tilda # (batch_size, hidden_size) return h_t, ( h_t, c_t, n_t, m_t, ) def init_hidden( self, batch_size: int, **kwargs ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: return ( torch.zeros(batch_size, self.hidden_size, **kwargs), torch.zeros(batch_size, self.hidden_size, **kwargs), torch.zeros(batch_size, self.hidden_size, **kwargs), torch.zeros(batch_size, self.hidden_size, **kwargs), ) # 定义sLSTM层 class sLSTM(nn.Module): def __init__( self, input_size: int, hidden_size: int, num_layers: int, bias: bool = True, batch_first: bool = False, ) -> None: super().__init__() self.input_size = input_size self.hidden_size = hidden_size self.num_layers = num_layers self.bias = bias self.batch_first = batch_first self.cells = nn.ModuleList( [ sLSTMCell(input_size if layer == 0 else hidden_size, hidden_size, bias) for layer in range(num_layers) ] ) def forward( self, x: torch.Tensor, hidden_states=None, ): # 如果batch_first为True,置换输入张量 if self.batch_first: x = x.permute(1, 0, 2) # 如果未提供隐藏状态,则初始化 if hidden_states is None: hidden_states = self.init_hidden(x.size(1), device=x.device, dtype=x.dtype) else: # 检查隐藏状态是否正确长度 if len(hidden_states) != self.num_layers: raise ValueError( f"Expected hidden states of length {self.num_layers}, but got {len(hidden_states)}" ) if any(state[0].size(0) != x.size(1) for state in hidden_states): raise ValueError( f"Expected hidden states of batch size {x.size(1)}, but got {hidden_states[0][0].size(0)}" ) H, C, N, M = [], [], [], [] # 通过每一层处理序列 outputs = [] for t in range(x.size(0)): x_t = x[t] for layer in range(self.num_layers): if layer == 0: h_t, state = self.cells[layer](x_t, hidden_states[layer]) else: h_t, state = self.cells[layer](h_t, hidden_states[layer]) hidden_states[layer] = state outputs.append(h_t) # 将输出堆叠成序列 outputs = torch.stack(outputs) # 如果batch_first为True,置换输出张量 if self.batch_first: outputs = outputs.permute(1, 0, 2) return outputs, hidden_states def init_hidden( self, batch_size: int, **kwargs ): return [cell.init_hidden(batch_size, **kwargs) for cell in self.cells]