ShopTRAINING/server/models/slstm_model.py

162 lines
5.4 KiB
Python
Raw Normal View History

import torch
import torch.nn as nn
# 定义sLSTM单元
class sLSTMCell(nn.Module):
def __init__(self, input_size: int, hidden_size: int, bias: bool = True) -> None:
super().__init__()
# 存储输入和隐藏层大小
self.input_size = input_size
self.hidden_size = hidden_size
self.bias = bias
# 组合权重和循环权重到单个矩阵
self.W = nn.Parameter(
nn.init.xavier_uniform_(
torch.randn(self.input_size + self.hidden_size, 4 * self.hidden_size)
),
requires_grad=True,
)
# 组合偏置到单个矩阵
if self.bias:
self.B = nn.Parameter(
(torch.zeros(4 * self.hidden_size)), requires_grad=True
)
def forward(
self,
x: torch.Tensor,
internal_state: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
) -> tuple[
torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
]:
# 解包内部状态
h, c, n, m = internal_state # (batch_size, hidden_size)
# 组合权重和输入
combined = torch.cat((x, h), dim=1) # (batch_size, input_size + hidden_size)
# 计算线性变换
gates = torch.matmul(combined, self.W) # (batch_size, 4 * hidden_size)
# 如果包括偏置,则添加
if self.bias:
gates += self.B
# 将门分为输入、遗忘、输出和稳定化门
z_tilda, i_tilda, f_tilda, o_tilda = torch.split(gates, self.hidden_size, dim=1)
# 计算状态的激活
z_t = torch.tanh(z_tilda) # (batch_size, hidden_size)
# 输入门的指数激活
i_t = torch.exp(i_tilda) # (batch_size, hidden_size)
# 遗忘门的sigmoid激活
f_t = torch.sigmoid(f_tilda) # (batch_size, hidden_size)
# 输出门的sigmoid激活
o_t = torch.sigmoid(o_tilda) # (batch_size, input_size)
# 计算稳定化状态
m_t = torch.max(torch.log(f_t) + m, torch.log(i_t)) # (batch_size, hidden_size)
# 计算输入稳定化状态
i_prime = torch.exp(i_tilda - m_t) # (batch_size, hidden_size)
# 计算新的内部状态
c_t = f_t * c + i_prime * z_t # (batch_size, hidden_size)
n_t = f_t * n + i_prime # (batch_size, hidden_size)
# 计算稳定化隐藏状态
h_tilda = c_t / n_t # (batch_size, hidden_size)
# 计算新的隐藏状态
h_t = o_t * h_tilda # (batch_size, hidden_size)
return h_t, (
h_t,
c_t,
n_t,
m_t,
)
def init_hidden(
self, batch_size: int, **kwargs
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
return (
torch.zeros(batch_size, self.hidden_size, **kwargs),
torch.zeros(batch_size, self.hidden_size, **kwargs),
torch.zeros(batch_size, self.hidden_size, **kwargs),
torch.zeros(batch_size, self.hidden_size, **kwargs),
)
# 定义sLSTM层
class sLSTM(nn.Module):
def __init__(
self,
input_size: int,
hidden_size: int,
num_layers: int,
bias: bool = True,
batch_first: bool = False,
) -> None:
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bias = bias
self.batch_first = batch_first
self.cells = nn.ModuleList(
[
sLSTMCell(input_size if layer == 0 else hidden_size, hidden_size, bias)
for layer in range(num_layers)
]
)
def forward(
self,
x: torch.Tensor,
hidden_states=None,
):
# 如果batch_first为True置换输入张量
if self.batch_first:
x = x.permute(1, 0, 2)
# 如果未提供隐藏状态,则初始化
if hidden_states is None:
hidden_states = self.init_hidden(x.size(1), device=x.device, dtype=x.dtype)
else:
# 检查隐藏状态是否正确长度
if len(hidden_states) != self.num_layers:
raise ValueError(
f"Expected hidden states of length {self.num_layers}, but got {len(hidden_states)}"
)
if any(state[0].size(0) != x.size(1) for state in hidden_states):
raise ValueError(
f"Expected hidden states of batch size {x.size(1)}, but got {hidden_states[0][0].size(0)}"
)
H, C, N, M = [], [], [], []
# 通过每一层处理序列
outputs = []
for t in range(x.size(0)):
x_t = x[t]
for layer in range(self.num_layers):
if layer == 0:
h_t, state = self.cells[layer](x_t, hidden_states[layer])
else:
h_t, state = self.cells[layer](h_t, hidden_states[layer])
hidden_states[layer] = state
outputs.append(h_t)
# 将输出堆叠成序列
outputs = torch.stack(outputs)
# 如果batch_first为True置换输出张量
if self.batch_first:
outputs = outputs.permute(1, 0, 2)
return outputs, hidden_states
def init_hidden(
self, batch_size: int, **kwargs
):
return [cell.init_hidden(batch_size, **kwargs) for cell in self.cells]