2025-06-11 10:18:18 +08:00
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
2025-07-14 19:26:57 +08:00
|
|
|
|
from typing import Tuple
|
2025-06-11 10:18:18 +08:00
|
|
|
|
|
|
|
|
|
# 定义sLSTM单元
|
|
|
|
|
class sLSTMCell(nn.Module):
|
|
|
|
|
def __init__(self, input_size: int, hidden_size: int, bias: bool = True) -> None:
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
# 存储输入和隐藏层大小
|
|
|
|
|
self.input_size = input_size
|
|
|
|
|
self.hidden_size = hidden_size
|
|
|
|
|
self.bias = bias
|
|
|
|
|
|
|
|
|
|
# 组合权重和循环权重到单个矩阵
|
|
|
|
|
self.W = nn.Parameter(
|
|
|
|
|
nn.init.xavier_uniform_(
|
|
|
|
|
torch.randn(self.input_size + self.hidden_size, 4 * self.hidden_size)
|
|
|
|
|
),
|
|
|
|
|
requires_grad=True,
|
|
|
|
|
)
|
|
|
|
|
# 组合偏置到单个矩阵
|
|
|
|
|
if self.bias:
|
|
|
|
|
self.B = nn.Parameter(
|
|
|
|
|
(torch.zeros(4 * self.hidden_size)), requires_grad=True
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
x: torch.Tensor,
|
2025-07-14 19:26:57 +08:00
|
|
|
|
internal_state: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
|
|
|
|
|
) -> Tuple[
|
|
|
|
|
torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
|
2025-06-11 10:18:18 +08:00
|
|
|
|
]:
|
|
|
|
|
# 解包内部状态
|
|
|
|
|
h, c, n, m = internal_state # (batch_size, hidden_size)
|
|
|
|
|
|
|
|
|
|
# 组合权重和输入
|
|
|
|
|
combined = torch.cat((x, h), dim=1) # (batch_size, input_size + hidden_size)
|
|
|
|
|
# 计算线性变换
|
|
|
|
|
gates = torch.matmul(combined, self.W) # (batch_size, 4 * hidden_size)
|
|
|
|
|
|
|
|
|
|
# 如果包括偏置,则添加
|
|
|
|
|
if self.bias:
|
|
|
|
|
gates += self.B
|
|
|
|
|
|
|
|
|
|
# 将门分为输入、遗忘、输出和稳定化门
|
|
|
|
|
z_tilda, i_tilda, f_tilda, o_tilda = torch.split(gates, self.hidden_size, dim=1)
|
|
|
|
|
|
|
|
|
|
# 计算状态的激活
|
|
|
|
|
z_t = torch.tanh(z_tilda) # (batch_size, hidden_size)
|
|
|
|
|
# 输入门的指数激活
|
|
|
|
|
i_t = torch.exp(i_tilda) # (batch_size, hidden_size)
|
|
|
|
|
# 遗忘门的sigmoid激活
|
|
|
|
|
f_t = torch.sigmoid(f_tilda) # (batch_size, hidden_size)
|
|
|
|
|
|
|
|
|
|
# 输出门的sigmoid激活
|
|
|
|
|
o_t = torch.sigmoid(o_tilda) # (batch_size, input_size)
|
|
|
|
|
# 计算稳定化状态
|
|
|
|
|
m_t = torch.max(torch.log(f_t) + m, torch.log(i_t)) # (batch_size, hidden_size)
|
|
|
|
|
# 计算输入稳定化状态
|
|
|
|
|
i_prime = torch.exp(i_tilda - m_t) # (batch_size, hidden_size)
|
|
|
|
|
|
|
|
|
|
# 计算新的内部状态
|
|
|
|
|
c_t = f_t * c + i_prime * z_t # (batch_size, hidden_size)
|
|
|
|
|
n_t = f_t * n + i_prime # (batch_size, hidden_size)
|
|
|
|
|
|
|
|
|
|
# 计算稳定化隐藏状态
|
|
|
|
|
h_tilda = c_t / n_t # (batch_size, hidden_size)
|
|
|
|
|
|
|
|
|
|
# 计算新的隐藏状态
|
|
|
|
|
h_t = o_t * h_tilda # (batch_size, hidden_size)
|
|
|
|
|
return h_t, (
|
|
|
|
|
h_t,
|
|
|
|
|
c_t,
|
|
|
|
|
n_t,
|
|
|
|
|
m_t,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def init_hidden(
|
|
|
|
|
self, batch_size: int, **kwargs
|
2025-07-14 19:26:57 +08:00
|
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
2025-06-11 10:18:18 +08:00
|
|
|
|
return (
|
|
|
|
|
torch.zeros(batch_size, self.hidden_size, **kwargs),
|
|
|
|
|
torch.zeros(batch_size, self.hidden_size, **kwargs),
|
|
|
|
|
torch.zeros(batch_size, self.hidden_size, **kwargs),
|
|
|
|
|
torch.zeros(batch_size, self.hidden_size, **kwargs),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 定义sLSTM层
|
|
|
|
|
class sLSTM(nn.Module):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
input_size: int,
|
|
|
|
|
hidden_size: int,
|
|
|
|
|
num_layers: int,
|
|
|
|
|
bias: bool = True,
|
|
|
|
|
batch_first: bool = False,
|
|
|
|
|
) -> None:
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.input_size = input_size
|
|
|
|
|
self.hidden_size = hidden_size
|
|
|
|
|
self.num_layers = num_layers
|
|
|
|
|
self.bias = bias
|
|
|
|
|
self.batch_first = batch_first
|
|
|
|
|
|
|
|
|
|
self.cells = nn.ModuleList(
|
|
|
|
|
[
|
|
|
|
|
sLSTMCell(input_size if layer == 0 else hidden_size, hidden_size, bias)
|
|
|
|
|
for layer in range(num_layers)
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
|
|
|
self,
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
hidden_states=None,
|
|
|
|
|
):
|
|
|
|
|
# 如果batch_first为True,置换输入张量
|
|
|
|
|
if self.batch_first:
|
|
|
|
|
x = x.permute(1, 0, 2)
|
|
|
|
|
|
|
|
|
|
# 如果未提供隐藏状态,则初始化
|
|
|
|
|
if hidden_states is None:
|
|
|
|
|
hidden_states = self.init_hidden(x.size(1), device=x.device, dtype=x.dtype)
|
|
|
|
|
else:
|
|
|
|
|
# 检查隐藏状态是否正确长度
|
|
|
|
|
if len(hidden_states) != self.num_layers:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Expected hidden states of length {self.num_layers}, but got {len(hidden_states)}"
|
|
|
|
|
)
|
|
|
|
|
if any(state[0].size(0) != x.size(1) for state in hidden_states):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Expected hidden states of batch size {x.size(1)}, but got {hidden_states[0][0].size(0)}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
H, C, N, M = [], [], [], []
|
|
|
|
|
|
|
|
|
|
# 通过每一层处理序列
|
|
|
|
|
outputs = []
|
|
|
|
|
for t in range(x.size(0)):
|
|
|
|
|
x_t = x[t]
|
|
|
|
|
for layer in range(self.num_layers):
|
|
|
|
|
if layer == 0:
|
|
|
|
|
h_t, state = self.cells[layer](x_t, hidden_states[layer])
|
|
|
|
|
else:
|
|
|
|
|
h_t, state = self.cells[layer](h_t, hidden_states[layer])
|
|
|
|
|
hidden_states[layer] = state
|
|
|
|
|
outputs.append(h_t)
|
|
|
|
|
|
|
|
|
|
# 将输出堆叠成序列
|
|
|
|
|
outputs = torch.stack(outputs)
|
|
|
|
|
|
|
|
|
|
# 如果batch_first为True,置换输出张量
|
|
|
|
|
if self.batch_first:
|
|
|
|
|
outputs = outputs.permute(1, 0, 2)
|
|
|
|
|
|
|
|
|
|
return outputs, hidden_states
|
|
|
|
|
|
|
|
|
|
def init_hidden(
|
|
|
|
|
self, batch_size: int, **kwargs
|
|
|
|
|
):
|
|
|
|
|
return [cell.init_hidden(batch_size, **kwargs) for cell in self.cells]
|