240 lines
8.4 KiB
Python
240 lines
8.4 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
from .transformer_model import TransformerEncoder, TransformerDecoder
|
||
|
||
# 定义mLSTM单元
|
||
class mLSTMCell(nn.Module):
|
||
def __init__(self, input_size: int, hidden_size: int, bias: bool = True) -> None:
|
||
super().__init__()
|
||
|
||
self.input_size = input_size
|
||
self.hidden_size = hidden_size
|
||
self.bias = bias
|
||
|
||
# 初始化权重和偏置
|
||
self.W_i = nn.Parameter(
|
||
nn.init.xavier_uniform_(torch.zeros(input_size, hidden_size)),
|
||
requires_grad=True,
|
||
)
|
||
self.W_f = nn.Parameter(
|
||
nn.init.xavier_uniform_(torch.zeros(input_size, hidden_size)),
|
||
requires_grad=True,
|
||
)
|
||
self.W_o = nn.Parameter(
|
||
nn.init.xavier_uniform_(torch.zeros(input_size, hidden_size)),
|
||
requires_grad=True,
|
||
)
|
||
self.W_q = nn.Parameter(
|
||
nn.init.xavier_uniform_(torch.zeros(input_size, hidden_size)),
|
||
requires_grad=True,
|
||
)
|
||
self.W_k = nn.Parameter(
|
||
nn.init.xavier_uniform_(torch.zeros(input_size, hidden_size)),
|
||
requires_grad=True,
|
||
)
|
||
self.W_v = nn.Parameter(
|
||
nn.init.xavier_uniform_(torch.zeros(input_size, hidden_size)),
|
||
requires_grad=True,
|
||
)
|
||
|
||
if self.bias:
|
||
self.B_i = nn.Parameter(torch.zeros(hidden_size), requires_grad=True)
|
||
self.B_f = nn.Parameter(torch.zeros(hidden_size), requires_grad=True)
|
||
self.B_o = nn.Parameter(torch.zeros(hidden_size), requires_grad=True)
|
||
self.B_q = nn.Parameter(torch.zeros(hidden_size), requires_grad=True)
|
||
self.B_k = nn.Parameter(torch.zeros(hidden_size), requires_grad=True)
|
||
self.B_v = nn.Parameter(torch.zeros(hidden_size), requires_grad=True)
|
||
|
||
def forward(
|
||
self,
|
||
x: torch.Tensor,
|
||
internal_state: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
||
# 获取内部状态
|
||
C, n, m = internal_state
|
||
|
||
# 计算输入、遗忘、输出、查询、键和值门
|
||
i_tilda = (
|
||
torch.matmul(x, self.W_i) + self.B_i
|
||
if self.bias
|
||
else torch.matmul(x, self.W_i)
|
||
)
|
||
f_tilda = (
|
||
torch.matmul(x, self.W_f) + self.B_f
|
||
if self.bias
|
||
else torch.matmul(x, self.W_f)
|
||
)
|
||
o_tilda = (
|
||
torch.matmul(x, self.W_o) + self.B_o
|
||
if self.bias
|
||
else torch.matmul(x, self.W_o)
|
||
)
|
||
q_t = (
|
||
torch.matmul(x, self.W_q) + self.B_q
|
||
if self.bias
|
||
else torch.matmul(x, self.W_q)
|
||
)
|
||
k_t = (
|
||
torch.matmul(x, self.W_k) / torch.sqrt(torch.tensor(self.hidden_size))
|
||
+ self.B_k
|
||
if self.bias
|
||
else torch.matmul(x, self.W_k) / torch.sqrt(torch.tensor(self.hidden_size))
|
||
)
|
||
v_t = (
|
||
torch.matmul(x, self.W_v) + self.B_v
|
||
if self.bias
|
||
else torch.matmul(x, self.W_v)
|
||
)
|
||
|
||
# 输入门的指数激活
|
||
i_t = torch.exp(i_tilda)
|
||
f_t = torch.sigmoid(f_tilda)
|
||
o_t = torch.sigmoid(o_tilda)
|
||
|
||
# 稳定化状态
|
||
m_t = torch.max(torch.log(f_t) + m, torch.log(i_t))
|
||
i_prime = torch.exp(i_tilda - m_t)
|
||
|
||
# 协方差矩阵和归一化状态
|
||
C_t = f_t.unsqueeze(-1) * C + i_prime.unsqueeze(-1) * torch.einsum(
|
||
"bi, bk -> bik", v_t, k_t
|
||
)
|
||
n_t = f_t * n + i_prime * k_t
|
||
|
||
normalize_inner = torch.diagonal(torch.matmul(n_t, q_t.T))
|
||
divisor = torch.max(
|
||
torch.abs(normalize_inner), torch.ones_like(normalize_inner)
|
||
)
|
||
h_tilda = torch.einsum("bkj,bj -> bk", C_t, q_t) / divisor.view(-1, 1)
|
||
h_t = o_t * h_tilda
|
||
|
||
return h_t, (C_t, n_t, m_t)
|
||
|
||
def init_hidden(
|
||
self, batch_size: int, **kwargs
|
||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||
return (
|
||
torch.zeros(batch_size, self.hidden_size, self.hidden_size, **kwargs),
|
||
torch.zeros(batch_size, self.hidden_size, **kwargs),
|
||
torch.zeros(batch_size, self.hidden_size, **kwargs),
|
||
)
|
||
|
||
# 定义mLSTM层
|
||
class mLSTM(nn.Module):
|
||
def __init__(
|
||
self,
|
||
input_size: int,
|
||
hidden_size: int,
|
||
num_layers: int,
|
||
bias: bool = True,
|
||
batch_first: bool = False,
|
||
) -> None:
|
||
super().__init__()
|
||
self.input_size = input_size
|
||
self.hidden_size = hidden_size
|
||
self.num_layers = num_layers
|
||
self.bias = bias
|
||
self.batch_first = batch_first
|
||
|
||
self.cells = nn.ModuleList(
|
||
[
|
||
mLSTMCell(input_size if layer == 0 else hidden_size, hidden_size, bias)
|
||
for layer in range(num_layers)
|
||
]
|
||
)
|
||
|
||
def forward(
|
||
self,
|
||
x: torch.Tensor,
|
||
hidden_states=None,
|
||
):
|
||
# 如果batch_first为True,置换输入张量
|
||
if self.batch_first:
|
||
x = x.permute(1, 0, 2)
|
||
|
||
if hidden_states is None:
|
||
hidden_states = self.init_hidden(x.size(1), device=x.device, dtype=x.dtype)
|
||
else:
|
||
# 检查隐藏状态是否正确长度
|
||
if len(hidden_states) != self.num_layers:
|
||
raise ValueError(
|
||
f"Expected hidden states of length {self.num_layers}, but got {len(hidden_states)}"
|
||
)
|
||
if any(state[0].size(0) != x.size(1) for state in hidden_states):
|
||
raise ValueError(
|
||
f"Expected hidden states of batch size {x.size(1)}, but got {hidden_states[0][0].size(0)}"
|
||
)
|
||
|
||
outputs = []
|
||
for t in range(x.size(0)):
|
||
x_t = x[t]
|
||
for layer in range(self.num_layers):
|
||
if layer == 0:
|
||
h_t, state = self.cells[layer](x_t, hidden_states[layer])
|
||
else:
|
||
h_t, state = self.cells[layer](h_t, hidden_states[layer])
|
||
hidden_states[layer] = state
|
||
outputs.append(h_t)
|
||
|
||
# 将输出堆叠成序列
|
||
outputs = torch.stack(outputs)
|
||
|
||
# 如果batch_first为True,置换输出张量
|
||
if self.batch_first:
|
||
outputs = outputs.permute(1, 0, 2)
|
||
|
||
return outputs, hidden_states
|
||
|
||
def init_hidden(
|
||
self, batch_size: int, **kwargs
|
||
):
|
||
return [cell.init_hidden(batch_size, **kwargs) for cell in self.cells]
|
||
|
||
# 结合mLSTM和Transformer的模型
|
||
class MLSTMTransformer(nn.Module):
|
||
def __init__(self, num_features, hidden_size=128, mlstm_layers=1,
|
||
embed_dim=32, dense_dim=32, num_heads=4,
|
||
dropout_rate=0.1, num_blocks=3, output_sequence_length=1):
|
||
super().__init__()
|
||
|
||
# mLSTM部分
|
||
self.mlstm = mLSTM(
|
||
input_size=num_features,
|
||
hidden_size=hidden_size,
|
||
num_layers=mlstm_layers,
|
||
batch_first=True
|
||
)
|
||
|
||
# Transformer部分
|
||
self.input_embedding = nn.Linear(hidden_size, embed_dim)
|
||
self.positional_encoding = nn.Parameter(torch.randn(1, 1000, embed_dim) * 0.1)
|
||
|
||
self.encoders = nn.ModuleList([TransformerEncoder(embed_dim, dense_dim, num_heads, dropout_rate) for _ in range(num_blocks)])
|
||
self.decoders = nn.ModuleList([TransformerDecoder(embed_dim, dense_dim, num_heads, dropout_rate) for _ in range(num_blocks)])
|
||
|
||
self.output_layer = nn.Linear(embed_dim, 1)
|
||
self.dropout = nn.Dropout(dropout_rate)
|
||
self.output_sequence_length = output_sequence_length
|
||
|
||
def forward(self, inputs):
|
||
# mLSTM处理
|
||
mlstm_output, _ = self.mlstm(inputs)
|
||
|
||
# 将mLSTM输出嵌入到Transformer空间
|
||
x = self.input_embedding(mlstm_output)
|
||
x = x + self.positional_encoding[:, :x.size(1), :]
|
||
x = self.dropout(x)
|
||
|
||
# 编码器部分
|
||
encoder_outputs = x
|
||
for encoder in self.encoders:
|
||
encoder_outputs = encoder(encoder_outputs)
|
||
|
||
# 解码器部分
|
||
decoder_inputs = encoder_outputs[:, -1:, :].expand(-1, self.output_sequence_length, -1)
|
||
decoder_outputs = decoder_inputs
|
||
|
||
for decoder in self.decoders:
|
||
decoder_outputs = decoder(decoder_outputs, encoder_outputs)
|
||
|
||
return self.output_layer(decoder_outputs) |