103 lines
3.4 KiB
Python
103 lines
3.4 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
CNN-BiLSTM-Attention 模型定义,适配药店销售预测系统。
|
|
原始代码来源: python机器学习回归全家桶
|
|
"""
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
# 注意:由于原始代码使用了 TensorFlow/Keras 的层,我们将在这里创建一个 PyTorch 的等效实现。
|
|
# 这是一个更健壮、更符合现有系统架构的做法。
|
|
|
|
class Attention(nn.Module):
|
|
"""
|
|
PyTorch 实现的注意力机制。
|
|
"""
|
|
def __init__(self, feature_dim, step_dim, bias=True, **kwargs):
|
|
super(Attention, self).__init__(**kwargs)
|
|
|
|
self.supports_masking = True
|
|
self.bias = bias
|
|
self.feature_dim = feature_dim
|
|
self.step_dim = step_dim
|
|
self.features_dim = 0
|
|
|
|
weight = torch.zeros(feature_dim, 1)
|
|
nn.init.xavier_uniform_(weight)
|
|
self.weight = nn.Parameter(weight)
|
|
|
|
if bias:
|
|
self.b = nn.Parameter(torch.zeros(step_dim))
|
|
|
|
def forward(self, x, mask=None):
|
|
feature_dim = self.feature_dim
|
|
step_dim = self.step_dim
|
|
|
|
eij = torch.mm(
|
|
x.contiguous().view(-1, feature_dim),
|
|
self.weight
|
|
).view(-1, step_dim)
|
|
|
|
if self.bias:
|
|
eij = eij + self.b
|
|
|
|
eij = torch.tanh(eij)
|
|
a = torch.exp(eij)
|
|
|
|
if mask is not None:
|
|
a = a * mask
|
|
|
|
a = a / (torch.sum(a, 1, keepdim=True) + 1e-10)
|
|
|
|
weighted_input = x * torch.unsqueeze(a, -1)
|
|
return torch.sum(weighted_input, 1)
|
|
|
|
|
|
class CnnBiLstmAttention(nn.Module):
|
|
"""
|
|
CNN-BiLSTM-Attention 模型的 PyTorch 实现。
|
|
"""
|
|
def __init__(self, input_dim, output_dim, sequence_length, cnn_filters=64, cnn_kernel_size=1, lstm_units=128):
|
|
super(CnnBiLstmAttention, self).__init__()
|
|
self.sequence_length = sequence_length
|
|
self.cnn_filters = cnn_filters
|
|
self.lstm_units = lstm_units
|
|
|
|
# CNN 层
|
|
self.conv1d = nn.Conv1d(in_channels=input_dim, out_channels=cnn_filters, kernel_size=cnn_kernel_size)
|
|
self.relu = nn.ReLU()
|
|
self.maxpool = nn.MaxPool1d(kernel_size=1)
|
|
|
|
# BiLSTM 层
|
|
self.bilstm = nn.LSTM(input_size=cnn_filters, hidden_size=lstm_units, num_layers=1, batch_first=True, bidirectional=True)
|
|
|
|
# Attention 层
|
|
self.attention = Attention(feature_dim=lstm_units * 2, step_dim=sequence_length)
|
|
|
|
# 全连接输出层
|
|
self.dense = nn.Linear(lstm_units * 2, output_dim)
|
|
|
|
def forward(self, x):
|
|
# 输入 x 的形状: (batch_size, sequence_length, input_dim)
|
|
|
|
# CNN 处理
|
|
x = x.permute(0, 2, 1) # 转换为 (batch_size, input_dim, sequence_length) 以适应 Conv1d
|
|
x = self.conv1d(x)
|
|
x = self.relu(x)
|
|
x = x.permute(0, 2, 1) # 转换回 (batch_size, sequence_length, cnn_filters)
|
|
|
|
# BiLSTM 处理
|
|
lstm_out, _ = self.bilstm(x) # lstm_out 形状: (batch_size, sequence_length, lstm_units * 2)
|
|
|
|
# Attention 处理
|
|
# 注意:这里的 Attention 实现可能需要根据具体任务微调
|
|
# 一个简化的方法是直接使用 LSTM 的最终隐藏状态或输出
|
|
# 这里我们先用一个简化的逻辑:直接展平 LSTM 输出
|
|
attention_out = self.attention(lstm_out)
|
|
|
|
# 全连接层输出
|
|
output = self.dense(attention_out)
|
|
|
|
return output
|