240 lines
8.4 KiB
Python
240 lines
8.4 KiB
Python
![]() |
import torch
|
|||
|
import torch.nn as nn
|
|||
|
from .transformer_model import TransformerEncoder, TransformerDecoder
|
|||
|
|
|||
|
# 定义mLSTM单元
|
|||
|
class mLSTMCell(nn.Module):
|
|||
|
def __init__(self, input_size: int, hidden_size: int, bias: bool = True) -> None:
|
|||
|
super().__init__()
|
|||
|
|
|||
|
self.input_size = input_size
|
|||
|
self.hidden_size = hidden_size
|
|||
|
self.bias = bias
|
|||
|
|
|||
|
# 初始化权重和偏置
|
|||
|
self.W_i = nn.Parameter(
|
|||
|
nn.init.xavier_uniform_(torch.zeros(input_size, hidden_size)),
|
|||
|
requires_grad=True,
|
|||
|
)
|
|||
|
self.W_f = nn.Parameter(
|
|||
|
nn.init.xavier_uniform_(torch.zeros(input_size, hidden_size)),
|
|||
|
requires_grad=True,
|
|||
|
)
|
|||
|
self.W_o = nn.Parameter(
|
|||
|
nn.init.xavier_uniform_(torch.zeros(input_size, hidden_size)),
|
|||
|
requires_grad=True,
|
|||
|
)
|
|||
|
self.W_q = nn.Parameter(
|
|||
|
nn.init.xavier_uniform_(torch.zeros(input_size, hidden_size)),
|
|||
|
requires_grad=True,
|
|||
|
)
|
|||
|
self.W_k = nn.Parameter(
|
|||
|
nn.init.xavier_uniform_(torch.zeros(input_size, hidden_size)),
|
|||
|
requires_grad=True,
|
|||
|
)
|
|||
|
self.W_v = nn.Parameter(
|
|||
|
nn.init.xavier_uniform_(torch.zeros(input_size, hidden_size)),
|
|||
|
requires_grad=True,
|
|||
|
)
|
|||
|
|
|||
|
if self.bias:
|
|||
|
self.B_i = nn.Parameter(torch.zeros(hidden_size), requires_grad=True)
|
|||
|
self.B_f = nn.Parameter(torch.zeros(hidden_size), requires_grad=True)
|
|||
|
self.B_o = nn.Parameter(torch.zeros(hidden_size), requires_grad=True)
|
|||
|
self.B_q = nn.Parameter(torch.zeros(hidden_size), requires_grad=True)
|
|||
|
self.B_k = nn.Parameter(torch.zeros(hidden_size), requires_grad=True)
|
|||
|
self.B_v = nn.Parameter(torch.zeros(hidden_size), requires_grad=True)
|
|||
|
|
|||
|
def forward(
|
|||
|
self,
|
|||
|
x: torch.Tensor,
|
|||
|
internal_state: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
|||
|
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|||
|
# 获取内部状态
|
|||
|
C, n, m = internal_state
|
|||
|
|
|||
|
# 计算输入、遗忘、输出、查询、键和值门
|
|||
|
i_tilda = (
|
|||
|
torch.matmul(x, self.W_i) + self.B_i
|
|||
|
if self.bias
|
|||
|
else torch.matmul(x, self.W_i)
|
|||
|
)
|
|||
|
f_tilda = (
|
|||
|
torch.matmul(x, self.W_f) + self.B_f
|
|||
|
if self.bias
|
|||
|
else torch.matmul(x, self.W_f)
|
|||
|
)
|
|||
|
o_tilda = (
|
|||
|
torch.matmul(x, self.W_o) + self.B_o
|
|||
|
if self.bias
|
|||
|
else torch.matmul(x, self.W_o)
|
|||
|
)
|
|||
|
q_t = (
|
|||
|
torch.matmul(x, self.W_q) + self.B_q
|
|||
|
if self.bias
|
|||
|
else torch.matmul(x, self.W_q)
|
|||
|
)
|
|||
|
k_t = (
|
|||
|
torch.matmul(x, self.W_k) / torch.sqrt(torch.tensor(self.hidden_size))
|
|||
|
+ self.B_k
|
|||
|
if self.bias
|
|||
|
else torch.matmul(x, self.W_k) / torch.sqrt(torch.tensor(self.hidden_size))
|
|||
|
)
|
|||
|
v_t = (
|
|||
|
torch.matmul(x, self.W_v) + self.B_v
|
|||
|
if self.bias
|
|||
|
else torch.matmul(x, self.W_v)
|
|||
|
)
|
|||
|
|
|||
|
# 输入门的指数激活
|
|||
|
i_t = torch.exp(i_tilda)
|
|||
|
f_t = torch.sigmoid(f_tilda)
|
|||
|
o_t = torch.sigmoid(o_tilda)
|
|||
|
|
|||
|
# 稳定化状态
|
|||
|
m_t = torch.max(torch.log(f_t) + m, torch.log(i_t))
|
|||
|
i_prime = torch.exp(i_tilda - m_t)
|
|||
|
|
|||
|
# 协方差矩阵和归一化状态
|
|||
|
C_t = f_t.unsqueeze(-1) * C + i_prime.unsqueeze(-1) * torch.einsum(
|
|||
|
"bi, bk -> bik", v_t, k_t
|
|||
|
)
|
|||
|
n_t = f_t * n + i_prime * k_t
|
|||
|
|
|||
|
normalize_inner = torch.diagonal(torch.matmul(n_t, q_t.T))
|
|||
|
divisor = torch.max(
|
|||
|
torch.abs(normalize_inner), torch.ones_like(normalize_inner)
|
|||
|
)
|
|||
|
h_tilda = torch.einsum("bkj,bj -> bk", C_t, q_t) / divisor.view(-1, 1)
|
|||
|
h_t = o_t * h_tilda
|
|||
|
|
|||
|
return h_t, (C_t, n_t, m_t)
|
|||
|
|
|||
|
def init_hidden(
|
|||
|
self, batch_size: int, **kwargs
|
|||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|||
|
return (
|
|||
|
torch.zeros(batch_size, self.hidden_size, self.hidden_size, **kwargs),
|
|||
|
torch.zeros(batch_size, self.hidden_size, **kwargs),
|
|||
|
torch.zeros(batch_size, self.hidden_size, **kwargs),
|
|||
|
)
|
|||
|
|
|||
|
# 定义mLSTM层
|
|||
|
class mLSTM(nn.Module):
|
|||
|
def __init__(
|
|||
|
self,
|
|||
|
input_size: int,
|
|||
|
hidden_size: int,
|
|||
|
num_layers: int,
|
|||
|
bias: bool = True,
|
|||
|
batch_first: bool = False,
|
|||
|
) -> None:
|
|||
|
super().__init__()
|
|||
|
self.input_size = input_size
|
|||
|
self.hidden_size = hidden_size
|
|||
|
self.num_layers = num_layers
|
|||
|
self.bias = bias
|
|||
|
self.batch_first = batch_first
|
|||
|
|
|||
|
self.cells = nn.ModuleList(
|
|||
|
[
|
|||
|
mLSTMCell(input_size if layer == 0 else hidden_size, hidden_size, bias)
|
|||
|
for layer in range(num_layers)
|
|||
|
]
|
|||
|
)
|
|||
|
|
|||
|
def forward(
|
|||
|
self,
|
|||
|
x: torch.Tensor,
|
|||
|
hidden_states=None,
|
|||
|
):
|
|||
|
# 如果batch_first为True,置换输入张量
|
|||
|
if self.batch_first:
|
|||
|
x = x.permute(1, 0, 2)
|
|||
|
|
|||
|
if hidden_states is None:
|
|||
|
hidden_states = self.init_hidden(x.size(1), device=x.device, dtype=x.dtype)
|
|||
|
else:
|
|||
|
# 检查隐藏状态是否正确长度
|
|||
|
if len(hidden_states) != self.num_layers:
|
|||
|
raise ValueError(
|
|||
|
f"Expected hidden states of length {self.num_layers}, but got {len(hidden_states)}"
|
|||
|
)
|
|||
|
if any(state[0].size(0) != x.size(1) for state in hidden_states):
|
|||
|
raise ValueError(
|
|||
|
f"Expected hidden states of batch size {x.size(1)}, but got {hidden_states[0][0].size(0)}"
|
|||
|
)
|
|||
|
|
|||
|
outputs = []
|
|||
|
for t in range(x.size(0)):
|
|||
|
x_t = x[t]
|
|||
|
for layer in range(self.num_layers):
|
|||
|
if layer == 0:
|
|||
|
h_t, state = self.cells[layer](x_t, hidden_states[layer])
|
|||
|
else:
|
|||
|
h_t, state = self.cells[layer](h_t, hidden_states[layer])
|
|||
|
hidden_states[layer] = state
|
|||
|
outputs.append(h_t)
|
|||
|
|
|||
|
# 将输出堆叠成序列
|
|||
|
outputs = torch.stack(outputs)
|
|||
|
|
|||
|
# 如果batch_first为True,置换输出张量
|
|||
|
if self.batch_first:
|
|||
|
outputs = outputs.permute(1, 0, 2)
|
|||
|
|
|||
|
return outputs, hidden_states
|
|||
|
|
|||
|
def init_hidden(
|
|||
|
self, batch_size: int, **kwargs
|
|||
|
):
|
|||
|
return [cell.init_hidden(batch_size, **kwargs) for cell in self.cells]
|
|||
|
|
|||
|
# 结合mLSTM和Transformer的模型
|
|||
|
class MLSTMTransformer(nn.Module):
|
|||
|
def __init__(self, num_features, hidden_size=128, mlstm_layers=1,
|
|||
|
embed_dim=32, dense_dim=32, num_heads=4,
|
|||
|
dropout_rate=0.1, num_blocks=3, output_sequence_length=1):
|
|||
|
super().__init__()
|
|||
|
|
|||
|
# mLSTM部分
|
|||
|
self.mlstm = mLSTM(
|
|||
|
input_size=num_features,
|
|||
|
hidden_size=hidden_size,
|
|||
|
num_layers=mlstm_layers,
|
|||
|
batch_first=True
|
|||
|
)
|
|||
|
|
|||
|
# Transformer部分
|
|||
|
self.input_embedding = nn.Linear(hidden_size, embed_dim)
|
|||
|
self.positional_encoding = nn.Parameter(torch.randn(1, 1000, embed_dim) * 0.1)
|
|||
|
|
|||
|
self.encoders = nn.ModuleList([TransformerEncoder(embed_dim, dense_dim, num_heads, dropout_rate) for _ in range(num_blocks)])
|
|||
|
self.decoders = nn.ModuleList([TransformerDecoder(embed_dim, dense_dim, num_heads, dropout_rate) for _ in range(num_blocks)])
|
|||
|
|
|||
|
self.output_layer = nn.Linear(embed_dim, 1)
|
|||
|
self.dropout = nn.Dropout(dropout_rate)
|
|||
|
self.output_sequence_length = output_sequence_length
|
|||
|
|
|||
|
def forward(self, inputs):
|
|||
|
# mLSTM处理
|
|||
|
mlstm_output, _ = self.mlstm(inputs)
|
|||
|
|
|||
|
# 将mLSTM输出嵌入到Transformer空间
|
|||
|
x = self.input_embedding(mlstm_output)
|
|||
|
x = x + self.positional_encoding[:, :x.size(1), :]
|
|||
|
x = self.dropout(x)
|
|||
|
|
|||
|
# 编码器部分
|
|||
|
encoder_outputs = x
|
|||
|
for encoder in self.encoders:
|
|||
|
encoder_outputs = encoder(encoder_outputs)
|
|||
|
|
|||
|
# 解码器部分
|
|||
|
decoder_inputs = encoder_outputs[:, -1:, :].expand(-1, self.output_sequence_length, -1)
|
|||
|
decoder_outputs = decoder_inputs
|
|||
|
|
|||
|
for decoder in self.decoders:
|
|||
|
decoder_outputs = decoder(decoder_outputs, encoder_outputs)
|
|||
|
|
|||
|
return self.output_layer(decoder_outputs)
|