Skip to content

Instantly share code, notes, and snippets.

@TerenceLiu98
Created October 8, 2021 03:51
Show Gist options
  • Save TerenceLiu98/609736a939bbecd3de57159a59dd61ce to your computer and use it in GitHub Desktop.
Save TerenceLiu98/609736a939bbecd3de57159a59dd61ce to your computer and use it in GitHub Desktop.
from torch import nn
def CausalConv1d(in_channels, out_channels, kernel_size, dilation=1, **kwargs):
pad = (kernel_size - 1) * dilation + 1
return nn.Conv1d(in_channels, out_channels, kernel_size, padding = pad, dilation = dilation, **kwargs)
def CasalConv2d(in_channels, out_channels, kernel_size, dilation=1, **kwargs):
pad = (kernel_size - 1) * dilation + 1
return nn.Conv2d(in_channels, out_channels, kernel_size, padding = pad, dilation = dilation, **kwargs)
def CasalConv3d(in_channels, out_channels, kernel_size, dilation=1, **kwargs):
pad = (kernel_size - 1) * dilation + 1
return nn.Conv3d(in_channels, out_channels, kernel_size, padding = pad, dilation = dilation, **kwargs)
## thus, the dilation convolution / causal convolution (1d) is shown
m = CausalConv1d(1, 1, kernel_size=3, dilation=2, bias=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment