ShopTRAINING/server/models/mlstm_model.py

240 lines
8.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
from .transformer_model import TransformerEncoder, TransformerDecoder
# 定义mLSTM单元
class mLSTMCell(nn.Module):
def __init__(self, input_size: int, hidden_size: int, bias: bool = True) -> None:
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.bias = bias
# 初始化权重和偏置
self.W_i = nn.Parameter(
nn.init.xavier_uniform_(torch.zeros(input_size, hidden_size)),
requires_grad=True,
)
self.W_f = nn.Parameter(
nn.init.xavier_uniform_(torch.zeros(input_size, hidden_size)),
requires_grad=True,
)
self.W_o = nn.Parameter(
nn.init.xavier_uniform_(torch.zeros(input_size, hidden_size)),
requires_grad=True,
)
self.W_q = nn.Parameter(
nn.init.xavier_uniform_(torch.zeros(input_size, hidden_size)),
requires_grad=True,
)
self.W_k = nn.Parameter(
nn.init.xavier_uniform_(torch.zeros(input_size, hidden_size)),
requires_grad=True,
)
self.W_v = nn.Parameter(
nn.init.xavier_uniform_(torch.zeros(input_size, hidden_size)),
requires_grad=True,
)
if self.bias:
self.B_i = nn.Parameter(torch.zeros(hidden_size), requires_grad=True)
self.B_f = nn.Parameter(torch.zeros(hidden_size), requires_grad=True)
self.B_o = nn.Parameter(torch.zeros(hidden_size), requires_grad=True)
self.B_q = nn.Parameter(torch.zeros(hidden_size), requires_grad=True)
self.B_k = nn.Parameter(torch.zeros(hidden_size), requires_grad=True)
self.B_v = nn.Parameter(torch.zeros(hidden_size), requires_grad=True)
def forward(
self,
x: torch.Tensor,
internal_state: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
# 获取内部状态
C, n, m = internal_state
# 计算输入、遗忘、输出、查询、键和值门
i_tilda = (
torch.matmul(x, self.W_i) + self.B_i
if self.bias
else torch.matmul(x, self.W_i)
)
f_tilda = (
torch.matmul(x, self.W_f) + self.B_f
if self.bias
else torch.matmul(x, self.W_f)
)
o_tilda = (
torch.matmul(x, self.W_o) + self.B_o
if self.bias
else torch.matmul(x, self.W_o)
)
q_t = (
torch.matmul(x, self.W_q) + self.B_q
if self.bias
else torch.matmul(x, self.W_q)
)
k_t = (
torch.matmul(x, self.W_k) / torch.sqrt(torch.tensor(self.hidden_size))
+ self.B_k
if self.bias
else torch.matmul(x, self.W_k) / torch.sqrt(torch.tensor(self.hidden_size))
)
v_t = (
torch.matmul(x, self.W_v) + self.B_v
if self.bias
else torch.matmul(x, self.W_v)
)
# 输入门的指数激活
i_t = torch.exp(i_tilda)
f_t = torch.sigmoid(f_tilda)
o_t = torch.sigmoid(o_tilda)
# 稳定化状态
m_t = torch.max(torch.log(f_t) + m, torch.log(i_t))
i_prime = torch.exp(i_tilda - m_t)
# 协方差矩阵和归一化状态
C_t = f_t.unsqueeze(-1) * C + i_prime.unsqueeze(-1) * torch.einsum(
"bi, bk -> bik", v_t, k_t
)
n_t = f_t * n + i_prime * k_t
normalize_inner = torch.diagonal(torch.matmul(n_t, q_t.T))
divisor = torch.max(
torch.abs(normalize_inner), torch.ones_like(normalize_inner)
)
h_tilda = torch.einsum("bkj,bj -> bk", C_t, q_t) / divisor.view(-1, 1)
h_t = o_t * h_tilda
return h_t, (C_t, n_t, m_t)
def init_hidden(
self, batch_size: int, **kwargs
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return (
torch.zeros(batch_size, self.hidden_size, self.hidden_size, **kwargs),
torch.zeros(batch_size, self.hidden_size, **kwargs),
torch.zeros(batch_size, self.hidden_size, **kwargs),
)
# 定义mLSTM层
class mLSTM(nn.Module):
def __init__(
self,
input_size: int,
hidden_size: int,
num_layers: int,
bias: bool = True,
batch_first: bool = False,
) -> None:
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bias = bias
self.batch_first = batch_first
self.cells = nn.ModuleList(
[
mLSTMCell(input_size if layer == 0 else hidden_size, hidden_size, bias)
for layer in range(num_layers)
]
)
def forward(
self,
x: torch.Tensor,
hidden_states=None,
):
# 如果batch_first为True置换输入张量
if self.batch_first:
x = x.permute(1, 0, 2)
if hidden_states is None:
hidden_states = self.init_hidden(x.size(1), device=x.device, dtype=x.dtype)
else:
# 检查隐藏状态是否正确长度
if len(hidden_states) != self.num_layers:
raise ValueError(
f"Expected hidden states of length {self.num_layers}, but got {len(hidden_states)}"
)
if any(state[0].size(0) != x.size(1) for state in hidden_states):
raise ValueError(
f"Expected hidden states of batch size {x.size(1)}, but got {hidden_states[0][0].size(0)}"
)
outputs = []
for t in range(x.size(0)):
x_t = x[t]
for layer in range(self.num_layers):
if layer == 0:
h_t, state = self.cells[layer](x_t, hidden_states[layer])
else:
h_t, state = self.cells[layer](h_t, hidden_states[layer])
hidden_states[layer] = state
outputs.append(h_t)
# 将输出堆叠成序列
outputs = torch.stack(outputs)
# 如果batch_first为True置换输出张量
if self.batch_first:
outputs = outputs.permute(1, 0, 2)
return outputs, hidden_states
def init_hidden(
self, batch_size: int, **kwargs
):
return [cell.init_hidden(batch_size, **kwargs) for cell in self.cells]
# 结合mLSTM和Transformer的模型
class MLSTMTransformer(nn.Module):
def __init__(self, num_features, hidden_size=128, mlstm_layers=1,
embed_dim=32, dense_dim=32, num_heads=4,
dropout_rate=0.1, num_blocks=3, output_sequence_length=1):
super().__init__()
# mLSTM部分
self.mlstm = mLSTM(
input_size=num_features,
hidden_size=hidden_size,
num_layers=mlstm_layers,
batch_first=True
)
# Transformer部分
self.input_embedding = nn.Linear(hidden_size, embed_dim)
self.positional_encoding = nn.Parameter(torch.randn(1, 1000, embed_dim) * 0.1)
self.encoders = nn.ModuleList([TransformerEncoder(embed_dim, dense_dim, num_heads, dropout_rate) for _ in range(num_blocks)])
self.decoders = nn.ModuleList([TransformerDecoder(embed_dim, dense_dim, num_heads, dropout_rate) for _ in range(num_blocks)])
self.output_layer = nn.Linear(embed_dim, 1)
self.dropout = nn.Dropout(dropout_rate)
self.output_sequence_length = output_sequence_length
def forward(self, inputs):
# mLSTM处理
mlstm_output, _ = self.mlstm(inputs)
# 将mLSTM输出嵌入到Transformer空间
x = self.input_embedding(mlstm_output)
x = x + self.positional_encoding[:, :x.size(1), :]
x = self.dropout(x)
# 编码器部分
encoder_outputs = x
for encoder in self.encoders:
encoder_outputs = encoder(encoder_outputs)
# 解码器部分
decoder_inputs = encoder_outputs[:, -1:, :].expand(-1, self.output_sequence_length, -1)
decoder_outputs = decoder_inputs
for decoder in self.decoders:
decoder_outputs = decoder(decoder_outputs, encoder_outputs)
return self.output_layer(decoder_outputs)