Skip to content

Instantly share code, notes, and snippets.

@wassname
Last active July 4, 2025 13:35
Show Gist options
  • Select an option

  • Save wassname/7eb4095a4f3d3b5eea8adaaf4419c822 to your computer and use it in GitHub Desktop.

Select an option

Save wassname/7eb4095a4f3d3b5eea8adaaf4419c822 to your computer and use it in GitHub Desktop.
pytorch Causal Conv2d
from torch.nn.modules.utils import _pair
class CausalConv2d(nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1, bias=True):
kernel_size = _pair(kernel_size)
stride = _pair(stride)
dilation = _pair(dilation)
if padding is None:
padding = [int((kernel_size[i] -1) * dilation[i]) for i in range(len(kernel_size))]
else:
padding = padding * 2
self.left_padding = _pair(padding)
super().__init__(in_channels, out_channels, kernel_size,
stride=stride, padding=0, dilation=dilation,
groups=groups, bias=bias)
def forward(self, inputs):
inputs = F.pad(inputs, (self.left_padding[1], 0, self.left_padding[0], 0))
output = super().forward(inputs)
return output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment