ShopTRAINING/server/models/cnn_bilstm_attention.py

100 lines
3.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
CNN-BiLSTM-Attention 模型定义,适配药店销售预测系统。
原始代码来源: python机器学习回归全家桶
"""
import torch
import torch.nn as nn
# 注意:由于原始代码使用了 TensorFlow/Keras 的层,我们将在这里创建一个 PyTorch 的等效实现。
# 这是一个更健壮、更符合现有系统架构的做法。
class Attention(nn.Module):
"""
PyTorch 实现的注意力机制。
"""
def __init__(self, feature_dim, step_dim, bias=True, **kwargs):
super(Attention, self).__init__(**kwargs)
self.supports_masking = True
self.bias = bias
self.feature_dim = feature_dim
self.step_dim = step_dim
weight = torch.zeros(feature_dim, 1)
nn.init.xavier_uniform_(weight)
self.weight = nn.Parameter(weight)
if bias:
self.b = nn.Parameter(torch.zeros(step_dim))
def forward(self, x, mask=None):
feature_dim = self.feature_dim
step_dim = self.step_dim
eij = torch.mm(
x.contiguous().view(-1, feature_dim),
self.weight
).view(-1, step_dim)
if self.bias:
eij = eij + self.b
eij = torch.tanh(eij)
# 使用数值稳定的 softmax
if mask is not None:
# 如果有掩码将掩码位置的eij设置为一个很大的负数这样softmax后会接近0
eij = eij.masked_fill(mask == 0, -1e9)
a = torch.nn.functional.softmax(eij, dim=1)
weighted_input = x * torch.unsqueeze(a, -1)
return torch.sum(weighted_input, 1)
class CnnBiLstmAttention(nn.Module):
"""
CNN-BiLSTM-Attention 模型的 PyTorch 实现。
"""
def __init__(self, input_dim, output_dim, sequence_length, cnn_filters=64, cnn_kernel_size=1, lstm_units=128):
super(CnnBiLstmAttention, self).__init__()
self.sequence_length = sequence_length
self.cnn_filters = cnn_filters
self.lstm_units = lstm_units
# CNN 层
self.conv1d = nn.Conv1d(in_channels=input_dim, out_channels=cnn_filters, kernel_size=cnn_kernel_size)
self.relu = nn.ReLU()
# BiLSTM 层
self.bilstm = nn.LSTM(input_size=cnn_filters, hidden_size=lstm_units, num_layers=1, batch_first=True, bidirectional=True)
# Attention 层
self.attention = Attention(feature_dim=lstm_units * 2, step_dim=sequence_length)
# 全连接输出层
self.dense = nn.Linear(lstm_units * 2, output_dim)
def forward(self, x):
# 输入 x 的形状: (batch_size, sequence_length, input_dim)
# CNN 处理
x = x.permute(0, 2, 1) # 转换为 (batch_size, input_dim, sequence_length) 以适应 Conv1d
x = self.conv1d(x)
x = self.relu(x)
x = x.permute(0, 2, 1) # 转换回 (batch_size, sequence_length, cnn_filters)
# BiLSTM 处<><E5A484>
lstm_out, _ = self.bilstm(x) # lstm_out 形状: (batch_size, sequence_length, lstm_units * 2)
# Attention 处理
# 注意:这里的 Attention 实现可能需要根据具体任务微调
# 一个简化的方法是直接使用 LSTM 的最终隐藏状态或输出
# 这里我们先用一个简化的逻辑:直接展平 LSTM 输出
attention_out = self.attention(lstm_out)
# 全连接层输出
output = self.dense(attention_out)
return output