Last active
October 23, 2023 23:52
-
-
Save KeAWang/5093bffdf6fac21ab1adaefd5b7ad9a0 to your computer and use it in GitHub Desktop.
Temporal Convolutional Network in PyTorch (https://arxiv.org/abs/1803.01271)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from typing import List | |
import torch.nn.functional as F | |
def receptive_field(kernel_size: int, dilation: int): | |
return 1 + (kernel_size - 1) * dilation | |
class Seq2SeqConv1d(torch.nn.Module): | |
""" Pads input so that conv output has the same length as the input | |
i.e. N x Cin x T -> N x Cout x T | |
""" | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: int, | |
dilation: int = 1, | |
groups: int = 1, | |
causal: bool = False, | |
): | |
super().__init__() | |
self.receptive_field = receptive_field(kernel_size, dilation) | |
padding = self.receptive_field // 2 # Each side | |
self.conv = torch.nn.Conv1d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
stride=1, | |
padding=padding, | |
dilation=dilation, | |
groups=groups, | |
) | |
if causal: | |
self.num_chomp = padding | |
elif self.receptive_field % 2 == 0: | |
# With this padding, output size will be | |
# input_size - (ceil(receptive_field / 2) - floor(receptive_field / 2)) + 1 | |
# If receptive_field is even, output_size = input_size + 1, so we truncate | |
self.num_chomp = 1 | |
else: | |
self.num_chomp = 0 | |
def forward(self, x): | |
out = self.conv(x) | |
if self.num_chomp > 0: | |
out = out[..., : -self.num_chomp] | |
return out | |
class Seq2SeqConv1dBlock(torch.nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: int, | |
dilation: int, | |
causal: bool = False, | |
): | |
super().__init__() | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.kernel_size = kernel_size | |
self.dilation = dilation | |
self.activation = torch.nn.GELU() | |
self.conv1 = Seq2SeqConv1d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
dilation=dilation, | |
causal=causal, | |
) | |
self.conv2 = Seq2SeqConv1d( | |
in_channels=out_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
dilation=dilation, | |
causal=causal, | |
) | |
self.projector = torch.nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1) | |
# TODO: should we initialize weights to N(0, 0.01) like in TCN? | |
# TODO: do we need layer norm? | |
# TODO: do we need dropout? | |
def forward(self, x): | |
res = self.projector(x) | |
# preactivation residual blocks | |
out = self.activation(x) | |
out = self.conv1(out) | |
out = self.activation(out) | |
out = self.conv2(out) | |
out = out + res | |
return out | |
class TemporalConvNet(torch.nn.Module): | |
""" | |
Expects N x num_timesteps x input_size. | |
Outputs N x hidden_sizes[-1]. | |
""" | |
def __init__( | |
self, | |
num_timesteps: int, | |
input_size: int, | |
hidden_sizes: List[int], | |
kernel_size: int, | |
dilation_factor: int = 1, | |
causal: bool = True, | |
): | |
super().__init__() | |
blocks = [ | |
Seq2SeqConv1dBlock( | |
in_channels=hidden_sizes[i - 1] if i > 0 else input_size, | |
out_channels=hidden_sizes[i], | |
kernel_size=kernel_size, | |
dilation=int(dilation_factor ** i), | |
causal=causal, | |
) | |
for i in range(len(hidden_sizes)) | |
] | |
self.blocks = torch.nn.ModuleList(blocks) | |
self.linear_comb = Seq2SeqConv1d(in_channels=num_timesteps, out_channels=1, kernel_size=1, causal=False) | |
proj_size = hidden_sizes[-1] | |
self.projection = torch.nn.Sequential(torch.nn.Linear(proj_size, proj_size), | |
torch.nn.GELU(), | |
torch.nn.Linear(proj_size, proj_size) | |
) | |
self.num_timesteps = num_timesteps | |
self.input_size = input_size | |
self.hidden_sizes = hidden_sizes | |
def forward(self, x): | |
assert x.shape[1:] == (self.num_timesteps, self.input_size) | |
x = x.transpose(1, 2) # N x T x D -> N x D x T | |
for block in self.blocks: | |
x = block(x) | |
x = x.transpose(1, 2) # N x D x T -> N x T x D | |
x = self.linear_comb(x) # N x T x D -> N x 1 x D | |
x = x.squeeze(1) # N x 1 x D -> N x D | |
assert x.shape[1:] == (self.hidden_sizes[-1],) | |
return x |
Note to self: bug in causal
from wrong padding
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Some notes on generalizing causal convs: pytorch/pytorch#1333