56 lines
2.0 KiB
Python

import torch
import torch.nn as nn
from pytorch_tcn import TCN
class TCNForecaster(nn.Module):
"""
A TCN-based forecasting model.
"""
def __init__(self, num_features, output_sequence_length, num_channels=[64, 128, 256], kernel_size=3, dropout=0.1):
"""
Initializes the TCNForecaster.
Args:
num_features (int): The number of input features.
output_sequence_length (int): The length of the forecast sequence.
num_channels (list of int): The number of channels in each TCN layer.
kernel_size (int): The size of the convolutional kernel.
dropout (float): The dropout rate.
"""
super(TCNForecaster, self).__init__()
self.tcn = TCN(
num_inputs=num_features,
num_channels=num_channels,
kernel_size=kernel_size,
dropout=dropout
)
self.linear = nn.Linear(num_channels[-1], output_sequence_length)
self.output_sequence_length = output_sequence_length
def forward(self, x):
"""
Forward pass.
Args:
x (torch.Tensor): The input tensor of shape (batch_size, sequence_length, num_features).
Returns:
torch.Tensor: The output tensor of shape (batch_size, output_sequence_length, 1).
"""
# TCN expects input of shape (batch_size, num_features, sequence_length)
# Our input is (batch_size, sequence_length, num_features), so we need to permute it.
x = x.permute(0, 2, 1)
tcn_out = self.tcn(x)
# We only need the output of the last timestep
last_timestep_out = tcn_out[:, :, -1]
# Pass it through the linear layer to get the forecast
output = self.linear(last_timestep_out)
# Reshape to match the expected target shape (batch_size, output_sequence_length, 1)
output = output.view(-1, self.output_sequence_length, 1)
return output