Skip to content

Instantly share code, notes, and snippets.

View lucidrains's full-sized avatar

Phil Wang lucidrains

View GitHub Profile
import torch
import torch.nn.functional as F
from torch.optim import Adam
from einops import rearrange, repeat
import sidechainnet as scn
from se3_transformer_pytorch.se3_transformer_pytorch import SE3Transformer
torch.set_default_dtype(torch.float64)
import torch
from torch import nn, einsum
from einops import rearrange, repeat
class FixedPositionalEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
import torch
import torch.nn.functional as F
from torch import nn
from torch.optim import Adam
from einops import rearrange, repeat
import sidechainnet as scn
from en_transformer.en_transformer import EnTransformer
torch.set_default_dtype(torch.float64)
import torch
import torch.nn.functional as F
from torch.optim import Adam
from einops import rearrange, repeat
import sidechainnet as scn
from se3_transformer_pytorch.se3_transformer_pytorch import SE3Transformer
torch.set_default_dtype(torch.float64)
@lucidrains
lucidrains / egnn.py
Last active February 26, 2021 19:08
import torch
from torch import nn, einsum
from einops import rearrange, repeat
class EGNN(nn.Module):
def __init__(
self,
dim,
edge_dim,
m_dim = 16
def unitwise_norm(x):
if len(x.squeeze().shape) <= 1:
dim = None
keepdim = False
elif len(x.shape) in (2, 3):
dim = 1
keepdim = True
elif len(x.shape) == 4:
dim = (1, 2, 3) # pytorch convolution kernel is OIHW
keepdim = True
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dropout = 0., emb_dropout = 0.):
super().__init__()
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
num_patches = (image_size // patch_size) ** 3
patch_dim = channels * patch_size ** 3
self.patch_size = patch_size
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
# link to package https://github.com/lucidrains/slot-attention
import torch
from torch import nn
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
import torch
from torch import nn
def expand_dim(t, dim, k):
t = t.unsqueeze(dim)
expand_shape = [-1] * len(t.shape)
expand_shape[dim] = k
return t.expand(*expand_shape)
class PKM(nn.Module):
import torch
from torch import nn
import torch.nn.functional as F
def psi(x):
return F.elu(x) + 1
class LinearAttention(nn.Module):
def __init__(self, dim, heads):
super().__init__()