162 lines
5.4 KiB
Python
162 lines
5.4 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
|
||
# 定义sLSTM单元
|
||
class sLSTMCell(nn.Module):
|
||
def __init__(self, input_size: int, hidden_size: int, bias: bool = True) -> None:
|
||
super().__init__()
|
||
|
||
# 存储输入和隐藏层大小
|
||
self.input_size = input_size
|
||
self.hidden_size = hidden_size
|
||
self.bias = bias
|
||
|
||
# 组合权重和循环权重到单个矩阵
|
||
self.W = nn.Parameter(
|
||
nn.init.xavier_uniform_(
|
||
torch.randn(self.input_size + self.hidden_size, 4 * self.hidden_size)
|
||
),
|
||
requires_grad=True,
|
||
)
|
||
# 组合偏置到单个矩阵
|
||
if self.bias:
|
||
self.B = nn.Parameter(
|
||
(torch.zeros(4 * self.hidden_size)), requires_grad=True
|
||
)
|
||
|
||
def forward(
|
||
self,
|
||
x: torch.Tensor,
|
||
internal_state: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
|
||
) -> tuple[
|
||
torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
|
||
]:
|
||
# 解包内部状态
|
||
h, c, n, m = internal_state # (batch_size, hidden_size)
|
||
|
||
# 组合权重和输入
|
||
combined = torch.cat((x, h), dim=1) # (batch_size, input_size + hidden_size)
|
||
# 计算线性变换
|
||
gates = torch.matmul(combined, self.W) # (batch_size, 4 * hidden_size)
|
||
|
||
# 如果包括偏置,则添加
|
||
if self.bias:
|
||
gates += self.B
|
||
|
||
# 将门分为输入、遗忘、输出和稳定化门
|
||
z_tilda, i_tilda, f_tilda, o_tilda = torch.split(gates, self.hidden_size, dim=1)
|
||
|
||
# 计算状态的激活
|
||
z_t = torch.tanh(z_tilda) # (batch_size, hidden_size)
|
||
# 输入门的指数激活
|
||
i_t = torch.exp(i_tilda) # (batch_size, hidden_size)
|
||
# 遗忘门的sigmoid激活
|
||
f_t = torch.sigmoid(f_tilda) # (batch_size, hidden_size)
|
||
|
||
# 输出门的sigmoid激活
|
||
o_t = torch.sigmoid(o_tilda) # (batch_size, input_size)
|
||
# 计算稳定化状态
|
||
m_t = torch.max(torch.log(f_t) + m, torch.log(i_t)) # (batch_size, hidden_size)
|
||
# 计算输入稳定化状态
|
||
i_prime = torch.exp(i_tilda - m_t) # (batch_size, hidden_size)
|
||
|
||
# 计算新的内部状态
|
||
c_t = f_t * c + i_prime * z_t # (batch_size, hidden_size)
|
||
n_t = f_t * n + i_prime # (batch_size, hidden_size)
|
||
|
||
# 计算稳定化隐藏状态
|
||
h_tilda = c_t / n_t # (batch_size, hidden_size)
|
||
|
||
# 计算新的隐藏状态
|
||
h_t = o_t * h_tilda # (batch_size, hidden_size)
|
||
return h_t, (
|
||
h_t,
|
||
c_t,
|
||
n_t,
|
||
m_t,
|
||
)
|
||
|
||
def init_hidden(
|
||
self, batch_size: int, **kwargs
|
||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||
return (
|
||
torch.zeros(batch_size, self.hidden_size, **kwargs),
|
||
torch.zeros(batch_size, self.hidden_size, **kwargs),
|
||
torch.zeros(batch_size, self.hidden_size, **kwargs),
|
||
torch.zeros(batch_size, self.hidden_size, **kwargs),
|
||
)
|
||
|
||
# 定义sLSTM层
|
||
class sLSTM(nn.Module):
|
||
def __init__(
|
||
self,
|
||
input_size: int,
|
||
hidden_size: int,
|
||
num_layers: int,
|
||
bias: bool = True,
|
||
batch_first: bool = False,
|
||
) -> None:
|
||
super().__init__()
|
||
self.input_size = input_size
|
||
self.hidden_size = hidden_size
|
||
self.num_layers = num_layers
|
||
self.bias = bias
|
||
self.batch_first = batch_first
|
||
|
||
self.cells = nn.ModuleList(
|
||
[
|
||
sLSTMCell(input_size if layer == 0 else hidden_size, hidden_size, bias)
|
||
for layer in range(num_layers)
|
||
]
|
||
)
|
||
|
||
def forward(
|
||
self,
|
||
x: torch.Tensor,
|
||
hidden_states=None,
|
||
):
|
||
# 如果batch_first为True,置换输入张量
|
||
if self.batch_first:
|
||
x = x.permute(1, 0, 2)
|
||
|
||
# 如果未提供隐藏状态,则初始化
|
||
if hidden_states is None:
|
||
hidden_states = self.init_hidden(x.size(1), device=x.device, dtype=x.dtype)
|
||
else:
|
||
# 检查隐藏状态是否正确长度
|
||
if len(hidden_states) != self.num_layers:
|
||
raise ValueError(
|
||
f"Expected hidden states of length {self.num_layers}, but got {len(hidden_states)}"
|
||
)
|
||
if any(state[0].size(0) != x.size(1) for state in hidden_states):
|
||
raise ValueError(
|
||
f"Expected hidden states of batch size {x.size(1)}, but got {hidden_states[0][0].size(0)}"
|
||
)
|
||
|
||
H, C, N, M = [], [], [], []
|
||
|
||
# 通过每一层处理序列
|
||
outputs = []
|
||
for t in range(x.size(0)):
|
||
x_t = x[t]
|
||
for layer in range(self.num_layers):
|
||
if layer == 0:
|
||
h_t, state = self.cells[layer](x_t, hidden_states[layer])
|
||
else:
|
||
h_t, state = self.cells[layer](h_t, hidden_states[layer])
|
||
hidden_states[layer] = state
|
||
outputs.append(h_t)
|
||
|
||
# 将输出堆叠成序列
|
||
outputs = torch.stack(outputs)
|
||
|
||
# 如果batch_first为True,置换输出张量
|
||
if self.batch_first:
|
||
outputs = outputs.permute(1, 0, 2)
|
||
|
||
return outputs, hidden_states
|
||
|
||
def init_hidden(
|
||
self, batch_size: int, **kwargs
|
||
):
|
||
return [cell.init_hidden(batch_size, **kwargs) for cell in self.cells] |