Skip to content

Instantly share code, notes, and snippets.

View lucidrains's full-sized avatar

Phil Wang lucidrains

View GitHub Profile
from torch import nn
import torch.nn.functional as F
def cumavg(t, dim):
r = torch.arange(1, t.shape[dim] + 1, device=t.device, dtype=t.dtype)
expand_slice = [None] * len(t.shape)
expand_slice[dim] = slice(None, None)
return t.cumsum(dim=dim) / r[tuple(expand_slice)]
def group_fc(conv1d, t):
class SelfAttention(nn.Module):
def __init__(self, dim, heads, dim_heads = None):
super().__init__()
self.dim_heads = (dim // heads) if dim_heads is None else dim_heads
dim_hidden = self.dim_heads * heads
self.heads = heads
self.to_qkv = nn.Linear(dim, 3 * dim_hidden, bias = False)
self.to_out = nn.Linear(dim_hidden, dim, bias = False)
import torch
from torch import nn
import torch.nn.functional as F
def psi(x):
return F.elu(x) + 1
class LinearAttention(nn.Module):
def __init__(self, dim, heads):
super().__init__()
import torch
from torch import nn
def expand_dim(t, dim, k):
t = t.unsqueeze(dim)
expand_shape = [-1] * len(t.shape)
expand_shape[dim] = k
return t.expand(*expand_shape)
class PKM(nn.Module):
# link to package https://github.com/lucidrains/slot-attention
import torch
from torch import nn
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dropout = 0., emb_dropout = 0.):
super().__init__()
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
num_patches = (image_size // patch_size) ** 3
patch_dim = channels * patch_size ** 3
self.patch_size = patch_size
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
def unitwise_norm(x):
if len(x.squeeze().shape) <= 1:
dim = None
keepdim = False
elif len(x.shape) in (2, 3):
dim = 1
keepdim = True
elif len(x.shape) == 4:
dim = (1, 2, 3) # pytorch convolution kernel is OIHW
keepdim = True
@lucidrains
lucidrains / egnn.py
Last active February 26, 2021 19:08
import torch
from torch import nn, einsum
from einops import rearrange, repeat
class EGNN(nn.Module):
def __init__(
self,
dim,
edge_dim,
m_dim = 16
import torch
import torch.nn.functional as F
from torch.optim import Adam
from einops import rearrange, repeat
import sidechainnet as scn
from se3_transformer_pytorch.se3_transformer_pytorch import SE3Transformer
torch.set_default_dtype(torch.float64)
import torch
import torch.nn.functional as F
from torch import nn
from torch.optim import Adam
from einops import rearrange, repeat
import sidechainnet as scn
from en_transformer.en_transformer import EnTransformer
torch.set_default_dtype(torch.float64)