162 lines
5.4 KiB
Python
162 lines
5.4 KiB
Python
![]() |
import torch
|
|||
|
import torch.nn as nn
|
|||
|
|
|||
|
# 定义sLSTM单元
|
|||
|
class sLSTMCell(nn.Module):
|
|||
|
def __init__(self, input_size: int, hidden_size: int, bias: bool = True) -> None:
|
|||
|
super().__init__()
|
|||
|
|
|||
|
# 存储输入和隐藏层大小
|
|||
|
self.input_size = input_size
|
|||
|
self.hidden_size = hidden_size
|
|||
|
self.bias = bias
|
|||
|
|
|||
|
# 组合权重和循环权重到单个矩阵
|
|||
|
self.W = nn.Parameter(
|
|||
|
nn.init.xavier_uniform_(
|
|||
|
torch.randn(self.input_size + self.hidden_size, 4 * self.hidden_size)
|
|||
|
),
|
|||
|
requires_grad=True,
|
|||
|
)
|
|||
|
# 组合偏置到单个矩阵
|
|||
|
if self.bias:
|
|||
|
self.B = nn.Parameter(
|
|||
|
(torch.zeros(4 * self.hidden_size)), requires_grad=True
|
|||
|
)
|
|||
|
|
|||
|
def forward(
|
|||
|
self,
|
|||
|
x: torch.Tensor,
|
|||
|
internal_state: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
|
|||
|
) -> tuple[
|
|||
|
torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
|
|||
|
]:
|
|||
|
# 解包内部状态
|
|||
|
h, c, n, m = internal_state # (batch_size, hidden_size)
|
|||
|
|
|||
|
# 组合权重和输入
|
|||
|
combined = torch.cat((x, h), dim=1) # (batch_size, input_size + hidden_size)
|
|||
|
# 计算线性变换
|
|||
|
gates = torch.matmul(combined, self.W) # (batch_size, 4 * hidden_size)
|
|||
|
|
|||
|
# 如果包括偏置,则添加
|
|||
|
if self.bias:
|
|||
|
gates += self.B
|
|||
|
|
|||
|
# 将门分为输入、遗忘、输出和稳定化门
|
|||
|
z_tilda, i_tilda, f_tilda, o_tilda = torch.split(gates, self.hidden_size, dim=1)
|
|||
|
|
|||
|
# 计算状态的激活
|
|||
|
z_t = torch.tanh(z_tilda) # (batch_size, hidden_size)
|
|||
|
# 输入门的指数激活
|
|||
|
i_t = torch.exp(i_tilda) # (batch_size, hidden_size)
|
|||
|
# 遗忘门的sigmoid激活
|
|||
|
f_t = torch.sigmoid(f_tilda) # (batch_size, hidden_size)
|
|||
|
|
|||
|
# 输出门的sigmoid激活
|
|||
|
o_t = torch.sigmoid(o_tilda) # (batch_size, input_size)
|
|||
|
# 计算稳定化状态
|
|||
|
m_t = torch.max(torch.log(f_t) + m, torch.log(i_t)) # (batch_size, hidden_size)
|
|||
|
# 计算输入稳定化状态
|
|||
|
i_prime = torch.exp(i_tilda - m_t) # (batch_size, hidden_size)
|
|||
|
|
|||
|
# 计算新的内部状态
|
|||
|
c_t = f_t * c + i_prime * z_t # (batch_size, hidden_size)
|
|||
|
n_t = f_t * n + i_prime # (batch_size, hidden_size)
|
|||
|
|
|||
|
# 计算稳定化隐藏状态
|
|||
|
h_tilda = c_t / n_t # (batch_size, hidden_size)
|
|||
|
|
|||
|
# 计算新的隐藏状态
|
|||
|
h_t = o_t * h_tilda # (batch_size, hidden_size)
|
|||
|
return h_t, (
|
|||
|
h_t,
|
|||
|
c_t,
|
|||
|
n_t,
|
|||
|
m_t,
|
|||
|
)
|
|||
|
|
|||
|
def init_hidden(
|
|||
|
self, batch_size: int, **kwargs
|
|||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|||
|
return (
|
|||
|
torch.zeros(batch_size, self.hidden_size, **kwargs),
|
|||
|
torch.zeros(batch_size, self.hidden_size, **kwargs),
|
|||
|
torch.zeros(batch_size, self.hidden_size, **kwargs),
|
|||
|
torch.zeros(batch_size, self.hidden_size, **kwargs),
|
|||
|
)
|
|||
|
|
|||
|
# 定义sLSTM层
|
|||
|
class sLSTM(nn.Module):
|
|||
|
def __init__(
|
|||
|
self,
|
|||
|
input_size: int,
|
|||
|
hidden_size: int,
|
|||
|
num_layers: int,
|
|||
|
bias: bool = True,
|
|||
|
batch_first: bool = False,
|
|||
|
) -> None:
|
|||
|
super().__init__()
|
|||
|
self.input_size = input_size
|
|||
|
self.hidden_size = hidden_size
|
|||
|
self.num_layers = num_layers
|
|||
|
self.bias = bias
|
|||
|
self.batch_first = batch_first
|
|||
|
|
|||
|
self.cells = nn.ModuleList(
|
|||
|
[
|
|||
|
sLSTMCell(input_size if layer == 0 else hidden_size, hidden_size, bias)
|
|||
|
for layer in range(num_layers)
|
|||
|
]
|
|||
|
)
|
|||
|
|
|||
|
def forward(
|
|||
|
self,
|
|||
|
x: torch.Tensor,
|
|||
|
hidden_states=None,
|
|||
|
):
|
|||
|
# 如果batch_first为True,置换输入张量
|
|||
|
if self.batch_first:
|
|||
|
x = x.permute(1, 0, 2)
|
|||
|
|
|||
|
# 如果未提供隐藏状态,则初始化
|
|||
|
if hidden_states is None:
|
|||
|
hidden_states = self.init_hidden(x.size(1), device=x.device, dtype=x.dtype)
|
|||
|
else:
|
|||
|
# 检查隐藏状态是否正确长度
|
|||
|
if len(hidden_states) != self.num_layers:
|
|||
|
raise ValueError(
|
|||
|
f"Expected hidden states of length {self.num_layers}, but got {len(hidden_states)}"
|
|||
|
)
|
|||
|
if any(state[0].size(0) != x.size(1) for state in hidden_states):
|
|||
|
raise ValueError(
|
|||
|
f"Expected hidden states of batch size {x.size(1)}, but got {hidden_states[0][0].size(0)}"
|
|||
|
)
|
|||
|
|
|||
|
H, C, N, M = [], [], [], []
|
|||
|
|
|||
|
# 通过每一层处理序列
|
|||
|
outputs = []
|
|||
|
for t in range(x.size(0)):
|
|||
|
x_t = x[t]
|
|||
|
for layer in range(self.num_layers):
|
|||
|
if layer == 0:
|
|||
|
h_t, state = self.cells[layer](x_t, hidden_states[layer])
|
|||
|
else:
|
|||
|
h_t, state = self.cells[layer](h_t, hidden_states[layer])
|
|||
|
hidden_states[layer] = state
|
|||
|
outputs.append(h_t)
|
|||
|
|
|||
|
# 将输出堆叠成序列
|
|||
|
outputs = torch.stack(outputs)
|
|||
|
|
|||
|
# 如果batch_first为True,置换输出张量
|
|||
|
if self.batch_first:
|
|||
|
outputs = outputs.permute(1, 0, 2)
|
|||
|
|
|||
|
return outputs, hidden_states
|
|||
|
|
|||
|
def init_hidden(
|
|||
|
self, batch_size: int, **kwargs
|
|||
|
):
|
|||
|
return [cell.init_hidden(batch_size, **kwargs) for cell in self.cells]
|