Created
November 14, 2023 18:42
-
-
Save KeAWang/db737b1fe43a864fb15fb4b4c9005ef5 to your computer and use it in GitHub Desktop.
PyTorch NaN embedder
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
class NanWrapper(torch.nn.Module): | |
"""Wrapper module around a torch Module that handles incoming nans""" | |
def __init__(self, module): | |
super().__init__() | |
self.module = module | |
def forward(self, x): | |
""" Masks the entire last dimension (usually the feature/channel dimension) if any element is NaN. """ | |
mask = ~torch.any(torch.isnan(x), dim=-1, keepdim=True) # 0 if nan, 1 otherwise | |
masked_x = torch.where(mask, x, torch.zeros_like(x)) | |
fx = self.module(masked_x) | |
masked_fx = fx * mask | |
return masked_fx | |
class ConvLinear(torch.nn.Module): | |
"""Linear projection implemented as a 1x1 convolution; useful for sequence data""" | |
def __init__(self, in_features, out_features, bias=True, channel_last=True): | |
super().__init__() | |
self.conv = torch.nn.Conv1d(in_features, out_features, kernel_size=1, bias=bias) | |
self.channel_last = channel_last | |
def forward(self, x): | |
assert x.ndim == 3, "Expected input to be (batch_size, seq_len, input_size) or (batch_size, input_size, seq_len)" | |
if self.channel_last: | |
x = x.transpose(1, 2) | |
x = self.conv(x) | |
x = x.transpose(1, 2) | |
else: | |
x = self.conv(x) | |
return x | |
# example usage | |
if __name__ == "__main__": | |
nan_embedder1 = ConvLinear(3, 4) | |
nan_embedder2 = torch.nn.Linear(3, 4) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment