56 lines
2.0 KiB
Python
56 lines
2.0 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from pytorch_tcn import TCN
|
|
|
|
|
|
class TCNForecaster(nn.Module):
|
|
"""
|
|
A TCN-based forecasting model.
|
|
"""
|
|
def __init__(self, num_features, output_sequence_length, num_channels=[64, 128, 256], kernel_size=3, dropout=0.1):
|
|
"""
|
|
Initializes the TCNForecaster.
|
|
|
|
Args:
|
|
num_features (int): The number of input features.
|
|
output_sequence_length (int): The length of the forecast sequence.
|
|
num_channels (list of int): The number of channels in each TCN layer.
|
|
kernel_size (int): The size of the convolutional kernel.
|
|
dropout (float): The dropout rate.
|
|
"""
|
|
super(TCNForecaster, self).__init__()
|
|
self.tcn = TCN(
|
|
num_inputs=num_features,
|
|
num_channels=num_channels,
|
|
kernel_size=kernel_size,
|
|
dropout=dropout
|
|
)
|
|
self.linear = nn.Linear(num_channels[-1], output_sequence_length)
|
|
self.output_sequence_length = output_sequence_length
|
|
|
|
def forward(self, x):
|
|
"""
|
|
Forward pass.
|
|
|
|
Args:
|
|
x (torch.Tensor): The input tensor of shape (batch_size, sequence_length, num_features).
|
|
|
|
Returns:
|
|
torch.Tensor: The output tensor of shape (batch_size, output_sequence_length, 1).
|
|
"""
|
|
# TCN expects input of shape (batch_size, num_features, sequence_length)
|
|
# Our input is (batch_size, sequence_length, num_features), so we need to permute it.
|
|
x = x.permute(0, 2, 1)
|
|
|
|
tcn_out = self.tcn(x)
|
|
|
|
# We only need the output of the last timestep
|
|
last_timestep_out = tcn_out[:, :, -1]
|
|
|
|
# Pass it through the linear layer to get the forecast
|
|
output = self.linear(last_timestep_out)
|
|
|
|
# Reshape to match the expected target shape (batch_size, output_sequence_length, 1)
|
|
output = output.view(-1, self.output_sequence_length, 1)
|
|
|
|
return output |