import torch import torch.nn as nn from pytorch_tcn import TCN class TCNForecaster(nn.Module): """ A TCN-based forecasting model. """ def __init__(self, num_features, output_sequence_length, num_channels=[64, 128, 256], kernel_size=3, dropout=0.1): """ Initializes the TCNForecaster. Args: num_features (int): The number of input features. output_sequence_length (int): The length of the forecast sequence. num_channels (list of int): The number of channels in each TCN layer. kernel_size (int): The size of the convolutional kernel. dropout (float): The dropout rate. """ super(TCNForecaster, self).__init__() self.tcn = TCN( num_inputs=num_features, num_channels=num_channels, kernel_size=kernel_size, dropout=dropout ) self.linear = nn.Linear(num_channels[-1], output_sequence_length) self.output_sequence_length = output_sequence_length def forward(self, x): """ Forward pass. Args: x (torch.Tensor): The input tensor of shape (batch_size, sequence_length, num_features). Returns: torch.Tensor: The output tensor of shape (batch_size, output_sequence_length, 1). """ # TCN expects input of shape (batch_size, num_features, sequence_length) # Our input is (batch_size, sequence_length, num_features), so we need to permute it. x = x.permute(0, 2, 1) tcn_out = self.tcn(x) # We only need the output of the last timestep last_timestep_out = tcn_out[:, :, -1] # Pass it through the linear layer to get the forecast output = self.linear(last_timestep_out) # Reshape to match the expected target shape (batch_size, output_sequence_length, 1) output = output.view(-1, self.output_sequence_length, 1) return output