ShopTRAINING/server/models/cnn_bilstm_attention.py
2025-07-22 15:41:05 +08:00

103 lines
3.4 KiB
Python

# -*- coding: utf-8 -*-
"""
CNN-BiLSTM-Attention 模型定义,适配药店销售预测系统。
原始代码来源: python机器学习回归全家桶
"""
import torch
import torch.nn as nn
# 注意:由于原始代码使用了 TensorFlow/Keras 的层,我们将在这里创建一个 PyTorch 的等效实现。
# 这是一个更健壮、更符合现有系统架构的做法。
class Attention(nn.Module):
"""
PyTorch 实现的注意力机制。
"""
def __init__(self, feature_dim, step_dim, bias=True, **kwargs):
super(Attention, self).__init__(**kwargs)
self.supports_masking = True
self.bias = bias
self.feature_dim = feature_dim
self.step_dim = step_dim
self.features_dim = 0
weight = torch.zeros(feature_dim, 1)
nn.init.xavier_uniform_(weight)
self.weight = nn.Parameter(weight)
if bias:
self.b = nn.Parameter(torch.zeros(step_dim))
def forward(self, x, mask=None):
feature_dim = self.feature_dim
step_dim = self.step_dim
eij = torch.mm(
x.contiguous().view(-1, feature_dim),
self.weight
).view(-1, step_dim)
if self.bias:
eij = eij + self.b
eij = torch.tanh(eij)
a = torch.exp(eij)
if mask is not None:
a = a * mask
a = a / (torch.sum(a, 1, keepdim=True) + 1e-10)
weighted_input = x * torch.unsqueeze(a, -1)
return torch.sum(weighted_input, 1)
class CnnBiLstmAttention(nn.Module):
"""
CNN-BiLSTM-Attention 模型的 PyTorch 实现。
"""
def __init__(self, input_dim, output_dim, sequence_length, cnn_filters=64, cnn_kernel_size=1, lstm_units=128):
super(CnnBiLstmAttention, self).__init__()
self.sequence_length = sequence_length
self.cnn_filters = cnn_filters
self.lstm_units = lstm_units
# CNN 层
self.conv1d = nn.Conv1d(in_channels=input_dim, out_channels=cnn_filters, kernel_size=cnn_kernel_size)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool1d(kernel_size=1)
# BiLSTM 层
self.bilstm = nn.LSTM(input_size=cnn_filters, hidden_size=lstm_units, num_layers=1, batch_first=True, bidirectional=True)
# Attention 层
self.attention = Attention(feature_dim=lstm_units * 2, step_dim=sequence_length)
# 全连接输出层
self.dense = nn.Linear(lstm_units * 2, output_dim)
def forward(self, x):
# 输入 x 的形状: (batch_size, sequence_length, input_dim)
# CNN 处理
x = x.permute(0, 2, 1) # 转换为 (batch_size, input_dim, sequence_length) 以适应 Conv1d
x = self.conv1d(x)
x = self.relu(x)
x = x.permute(0, 2, 1) # 转换回 (batch_size, sequence_length, cnn_filters)
# BiLSTM 处理
lstm_out, _ = self.bilstm(x) # lstm_out 形状: (batch_size, sequence_length, lstm_units * 2)
# Attention 处理
# 注意:这里的 Attention 实现可能需要根据具体任务微调
# 一个简化的方法是直接使用 LSTM 的最终隐藏状态或输出
# 这里我们先用一个简化的逻辑:直接展平 LSTM 输出
attention_out = self.attention(lstm_out)
# 全连接层输出
output = self.dense(attention_out)
return output