This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn.functional as F | |
from torch.optim import Adam | |
from einops import rearrange, repeat | |
import sidechainnet as scn | |
from se3_transformer_pytorch.se3_transformer_pytorch import SE3Transformer | |
torch.set_default_dtype(torch.float64) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn, einsum | |
from einops import rearrange, repeat | |
class FixedPositionalEmbedding(nn.Module): | |
def __init__(self, dim): | |
super().__init__() | |
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) | |
self.register_buffer('inv_freq', inv_freq) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from torch.optim import Adam | |
from einops import rearrange, repeat | |
import sidechainnet as scn | |
from en_transformer.en_transformer import EnTransformer | |
torch.set_default_dtype(torch.float64) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn.functional as F | |
from torch.optim import Adam | |
from einops import rearrange, repeat | |
import sidechainnet as scn | |
from se3_transformer_pytorch.se3_transformer_pytorch import SE3Transformer | |
torch.set_default_dtype(torch.float64) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn, einsum | |
from einops import rearrange, repeat | |
class EGNN(nn.Module): | |
def __init__( | |
self, | |
dim, | |
edge_dim, | |
m_dim = 16 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def unitwise_norm(x): | |
if len(x.squeeze().shape) <= 1: | |
dim = None | |
keepdim = False | |
elif len(x.shape) in (2, 3): | |
dim = 1 | |
keepdim = True | |
elif len(x.shape) == 4: | |
dim = (1, 2, 3) # pytorch convolution kernel is OIHW | |
keepdim = True |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class ViT(nn.Module): | |
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dropout = 0., emb_dropout = 0.): | |
super().__init__() | |
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size' | |
num_patches = (image_size // patch_size) ** 3 | |
patch_dim = channels * patch_size ** 3 | |
self.patch_size = patch_size | |
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# link to package https://github.com/lucidrains/slot-attention | |
import torch | |
from torch import nn | |
class Residual(nn.Module): | |
def __init__(self, fn): | |
super().__init__() | |
self.fn = fn | |
def forward(self, x): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
def expand_dim(t, dim, k): | |
t = t.unsqueeze(dim) | |
expand_shape = [-1] * len(t.shape) | |
expand_shape[dim] = k | |
return t.expand(*expand_shape) | |
class PKM(nn.Module): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
def psi(x): | |
return F.elu(x) + 1 | |
class LinearAttention(nn.Module): | |
def __init__(self, dim, heads): | |
super().__init__() |