ShopTRAINING/server/models/mlstm_model.py

240 lines
8.4 KiB
Python
Raw Normal View History

import torch
import torch.nn as nn
from .transformer_model import TransformerEncoder, TransformerDecoder
# 定义mLSTM单元
class mLSTMCell(nn.Module):
def __init__(self, input_size: int, hidden_size: int, bias: bool = True) -> None:
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.bias = bias
# 初始化权重和偏置
self.W_i = nn.Parameter(
nn.init.xavier_uniform_(torch.zeros(input_size, hidden_size)),
requires_grad=True,
)
self.W_f = nn.Parameter(
nn.init.xavier_uniform_(torch.zeros(input_size, hidden_size)),
requires_grad=True,
)
self.W_o = nn.Parameter(
nn.init.xavier_uniform_(torch.zeros(input_size, hidden_size)),
requires_grad=True,
)
self.W_q = nn.Parameter(
nn.init.xavier_uniform_(torch.zeros(input_size, hidden_size)),
requires_grad=True,
)
self.W_k = nn.Parameter(
nn.init.xavier_uniform_(torch.zeros(input_size, hidden_size)),
requires_grad=True,
)
self.W_v = nn.Parameter(
nn.init.xavier_uniform_(torch.zeros(input_size, hidden_size)),
requires_grad=True,
)
if self.bias:
self.B_i = nn.Parameter(torch.zeros(hidden_size), requires_grad=True)
self.B_f = nn.Parameter(torch.zeros(hidden_size), requires_grad=True)
self.B_o = nn.Parameter(torch.zeros(hidden_size), requires_grad=True)
self.B_q = nn.Parameter(torch.zeros(hidden_size), requires_grad=True)
self.B_k = nn.Parameter(torch.zeros(hidden_size), requires_grad=True)
self.B_v = nn.Parameter(torch.zeros(hidden_size), requires_grad=True)
def forward(
self,
x: torch.Tensor,
internal_state: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
# 获取内部状态
C, n, m = internal_state
# 计算输入、遗忘、输出、查询、键和值门
i_tilda = (
torch.matmul(x, self.W_i) + self.B_i
if self.bias
else torch.matmul(x, self.W_i)
)
f_tilda = (
torch.matmul(x, self.W_f) + self.B_f
if self.bias
else torch.matmul(x, self.W_f)
)
o_tilda = (
torch.matmul(x, self.W_o) + self.B_o
if self.bias
else torch.matmul(x, self.W_o)
)
q_t = (
torch.matmul(x, self.W_q) + self.B_q
if self.bias
else torch.matmul(x, self.W_q)
)
k_t = (
torch.matmul(x, self.W_k) / torch.sqrt(torch.tensor(self.hidden_size))
+ self.B_k
if self.bias
else torch.matmul(x, self.W_k) / torch.sqrt(torch.tensor(self.hidden_size))
)
v_t = (
torch.matmul(x, self.W_v) + self.B_v
if self.bias
else torch.matmul(x, self.W_v)
)
# 输入门的指数激活
i_t = torch.exp(i_tilda)
f_t = torch.sigmoid(f_tilda)
o_t = torch.sigmoid(o_tilda)
# 稳定化状态
m_t = torch.max(torch.log(f_t) + m, torch.log(i_t))
i_prime = torch.exp(i_tilda - m_t)
# 协方差矩阵和归一化状态
C_t = f_t.unsqueeze(-1) * C + i_prime.unsqueeze(-1) * torch.einsum(
"bi, bk -> bik", v_t, k_t
)
n_t = f_t * n + i_prime * k_t
normalize_inner = torch.diagonal(torch.matmul(n_t, q_t.T))
divisor = torch.max(
torch.abs(normalize_inner), torch.ones_like(normalize_inner)
)
h_tilda = torch.einsum("bkj,bj -> bk", C_t, q_t) / divisor.view(-1, 1)
h_t = o_t * h_tilda
return h_t, (C_t, n_t, m_t)
def init_hidden(
self, batch_size: int, **kwargs
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return (
torch.zeros(batch_size, self.hidden_size, self.hidden_size, **kwargs),
torch.zeros(batch_size, self.hidden_size, **kwargs),
torch.zeros(batch_size, self.hidden_size, **kwargs),
)
# 定义mLSTM层
class mLSTM(nn.Module):
def __init__(
self,
input_size: int,
hidden_size: int,
num_layers: int,
bias: bool = True,
batch_first: bool = False,
) -> None:
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bias = bias
self.batch_first = batch_first
self.cells = nn.ModuleList(
[
mLSTMCell(input_size if layer == 0 else hidden_size, hidden_size, bias)
for layer in range(num_layers)
]
)
def forward(
self,
x: torch.Tensor,
hidden_states=None,
):
# 如果batch_first为True置换输入张量
if self.batch_first:
x = x.permute(1, 0, 2)
if hidden_states is None:
hidden_states = self.init_hidden(x.size(1), device=x.device, dtype=x.dtype)
else:
# 检查隐藏状态是否正确长度
if len(hidden_states) != self.num_layers:
raise ValueError(
f"Expected hidden states of length {self.num_layers}, but got {len(hidden_states)}"
)
if any(state[0].size(0) != x.size(1) for state in hidden_states):
raise ValueError(
f"Expected hidden states of batch size {x.size(1)}, but got {hidden_states[0][0].size(0)}"
)
outputs = []
for t in range(x.size(0)):
x_t = x[t]
for layer in range(self.num_layers):
if layer == 0:
h_t, state = self.cells[layer](x_t, hidden_states[layer])
else:
h_t, state = self.cells[layer](h_t, hidden_states[layer])
hidden_states[layer] = state
outputs.append(h_t)
# 将输出堆叠成序列
outputs = torch.stack(outputs)
# 如果batch_first为True置换输出张量
if self.batch_first:
outputs = outputs.permute(1, 0, 2)
return outputs, hidden_states
def init_hidden(
self, batch_size: int, **kwargs
):
return [cell.init_hidden(batch_size, **kwargs) for cell in self.cells]
# 结合mLSTM和Transformer的模型
class MLSTMTransformer(nn.Module):
def __init__(self, num_features, hidden_size=128, mlstm_layers=1,
embed_dim=32, dense_dim=32, num_heads=4,
dropout_rate=0.1, num_blocks=3, output_sequence_length=1):
super().__init__()
# mLSTM部分
self.mlstm = mLSTM(
input_size=num_features,
hidden_size=hidden_size,
num_layers=mlstm_layers,
batch_first=True
)
# Transformer部分
self.input_embedding = nn.Linear(hidden_size, embed_dim)
self.positional_encoding = nn.Parameter(torch.randn(1, 1000, embed_dim) * 0.1)
self.encoders = nn.ModuleList([TransformerEncoder(embed_dim, dense_dim, num_heads, dropout_rate) for _ in range(num_blocks)])
self.decoders = nn.ModuleList([TransformerDecoder(embed_dim, dense_dim, num_heads, dropout_rate) for _ in range(num_blocks)])
self.output_layer = nn.Linear(embed_dim, 1)
self.dropout = nn.Dropout(dropout_rate)
self.output_sequence_length = output_sequence_length
def forward(self, inputs):
# mLSTM处理
mlstm_output, _ = self.mlstm(inputs)
# 将mLSTM输出嵌入到Transformer空间
x = self.input_embedding(mlstm_output)
x = x + self.positional_encoding[:, :x.size(1), :]
x = self.dropout(x)
# 编码器部分
encoder_outputs = x
for encoder in self.encoders:
encoder_outputs = encoder(encoder_outputs)
# 解码器部分
decoder_inputs = encoder_outputs[:, -1:, :].expand(-1, self.output_sequence_length, -1)
decoder_outputs = decoder_inputs
for decoder in self.decoders:
decoder_outputs = decoder(decoder_outputs, encoder_outputs)
return self.output_layer(decoder_outputs)