ShopTRAINING/server/models/slstm_model.py

162 lines
5.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
# 定义sLSTM单元
class sLSTMCell(nn.Module):
def __init__(self, input_size: int, hidden_size: int, bias: bool = True) -> None:
super().__init__()
# 存储输入和隐藏层大小
self.input_size = input_size
self.hidden_size = hidden_size
self.bias = bias
# 组合权重和循环权重到单个矩阵
self.W = nn.Parameter(
nn.init.xavier_uniform_(
torch.randn(self.input_size + self.hidden_size, 4 * self.hidden_size)
),
requires_grad=True,
)
# 组合偏置到单个矩阵
if self.bias:
self.B = nn.Parameter(
(torch.zeros(4 * self.hidden_size)), requires_grad=True
)
def forward(
self,
x: torch.Tensor,
internal_state: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
) -> tuple[
torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
]:
# 解包内部状态
h, c, n, m = internal_state # (batch_size, hidden_size)
# 组合权重和输入
combined = torch.cat((x, h), dim=1) # (batch_size, input_size + hidden_size)
# 计算线性变换
gates = torch.matmul(combined, self.W) # (batch_size, 4 * hidden_size)
# 如果包括偏置,则添加
if self.bias:
gates += self.B
# 将门分为输入、遗忘、输出和稳定化门
z_tilda, i_tilda, f_tilda, o_tilda = torch.split(gates, self.hidden_size, dim=1)
# 计算状态的激活
z_t = torch.tanh(z_tilda) # (batch_size, hidden_size)
# 输入门的指数激活
i_t = torch.exp(i_tilda) # (batch_size, hidden_size)
# 遗忘门的sigmoid激活
f_t = torch.sigmoid(f_tilda) # (batch_size, hidden_size)
# 输出门的sigmoid激活
o_t = torch.sigmoid(o_tilda) # (batch_size, input_size)
# 计算稳定化状态
m_t = torch.max(torch.log(f_t) + m, torch.log(i_t)) # (batch_size, hidden_size)
# 计算输入稳定化状态
i_prime = torch.exp(i_tilda - m_t) # (batch_size, hidden_size)
# 计算新的内部状态
c_t = f_t * c + i_prime * z_t # (batch_size, hidden_size)
n_t = f_t * n + i_prime # (batch_size, hidden_size)
# 计算稳定化隐藏状态
h_tilda = c_t / n_t # (batch_size, hidden_size)
# 计算新的隐藏状态
h_t = o_t * h_tilda # (batch_size, hidden_size)
return h_t, (
h_t,
c_t,
n_t,
m_t,
)
def init_hidden(
self, batch_size: int, **kwargs
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
return (
torch.zeros(batch_size, self.hidden_size, **kwargs),
torch.zeros(batch_size, self.hidden_size, **kwargs),
torch.zeros(batch_size, self.hidden_size, **kwargs),
torch.zeros(batch_size, self.hidden_size, **kwargs),
)
# 定义sLSTM层
class sLSTM(nn.Module):
def __init__(
self,
input_size: int,
hidden_size: int,
num_layers: int,
bias: bool = True,
batch_first: bool = False,
) -> None:
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bias = bias
self.batch_first = batch_first
self.cells = nn.ModuleList(
[
sLSTMCell(input_size if layer == 0 else hidden_size, hidden_size, bias)
for layer in range(num_layers)
]
)
def forward(
self,
x: torch.Tensor,
hidden_states=None,
):
# 如果batch_first为True置换输入张量
if self.batch_first:
x = x.permute(1, 0, 2)
# 如果未提供隐藏状态,则初始化
if hidden_states is None:
hidden_states = self.init_hidden(x.size(1), device=x.device, dtype=x.dtype)
else:
# 检查隐藏状态是否正确长度
if len(hidden_states) != self.num_layers:
raise ValueError(
f"Expected hidden states of length {self.num_layers}, but got {len(hidden_states)}"
)
if any(state[0].size(0) != x.size(1) for state in hidden_states):
raise ValueError(
f"Expected hidden states of batch size {x.size(1)}, but got {hidden_states[0][0].size(0)}"
)
H, C, N, M = [], [], [], []
# 通过每一层处理序列
outputs = []
for t in range(x.size(0)):
x_t = x[t]
for layer in range(self.num_layers):
if layer == 0:
h_t, state = self.cells[layer](x_t, hidden_states[layer])
else:
h_t, state = self.cells[layer](h_t, hidden_states[layer])
hidden_states[layer] = state
outputs.append(h_t)
# 将输出堆叠成序列
outputs = torch.stack(outputs)
# 如果batch_first为True置换输出张量
if self.batch_first:
outputs = outputs.permute(1, 0, 2)
return outputs, hidden_states
def init_hidden(
self, batch_size: int, **kwargs
):
return [cell.init_hidden(batch_size, **kwargs) for cell in self.cells]