Skip to content

Instantly share code, notes, and snippets.

@lucidrains
Last active January 7, 2021 16:41
Show Gist options
  • Save lucidrains/0d57316e8729ec794f3fe100454d636a to your computer and use it in GitHub Desktop.
Save lucidrains/0d57316e8729ec794f3fe100454d636a to your computer and use it in GitHub Desktop.
# link to package https://github.com/lucidrains/slot-attention
import torch
from torch import nn
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
return x + self.fn(x)
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
def forward(self, x):
x = self.norm(x)
return self.fn(x)
class SlotAttention(nn.Module):
def __init__(self, num_slots, dim, iters = 3, eps = 1e-8, mlp_hidden_size = 128):
super().__init__()
self.num_slots = num_slots
self.iters = iters
self.eps = eps
self.scale = dim ** -0.5
self.slots_mu = nn.Parameter(torch.randn(1, 1, dim))
self.slots_sigma = nn.Parameter(torch.randn(1, 1, dim))
self.to_q = nn.Linear(dim, dim, bias = False)
self.to_k = nn.Linear(dim, dim, bias = False)
self.to_v = nn.Linear(dim, dim, bias = False)
self.gru = nn.GRU(dim, dim)
self.mlp = Residual(PreNorm(dim, nn.Sequential(
nn.Linear(dim, mlp_hidden_size),
nn.ReLU(inplace = True),
nn.Linear(mlp_hidden_size, dim)
)))
self.norm_input = nn.LayerNorm(dim)
self.norm_slots = nn.LayerNorm(dim)
def forward(self, inputs):
b, n, d, n_s = *inputs.shape, self.num_slots
mu = self.slots_mu.expand(b, n_s, -1)
sigma = self.slots_sigma.expand(b, n_s, -1)
slots = torch.normal(mu, sigma)
slots_shape = slots.shape
inputs = self.norm_input(inputs)
k, v = self.to_k(inputs), self.to_v(inputs)
for _ in range(self.iters):
slots_prev = slots
slots = self.norm_slots(slots)
q = self.to_q(slots)
dots = torch.einsum('bid,bjd->bij', q, k) * self.scale
attn = dots.softmax(dim=1) + self.eps
attn = attn / attn.sum(dim=-1, keepdim=True)
updates = torch.einsum('bjd,bij->bid', v, attn)
slots, _ = self.gru(
updates.reshape(1, -1, d),
slots_prev.reshape(1, -1, d)
)
slots = slots.reshape(b, -1, d)
slots = self.mlp(slots)
return slots
slot_attn = SlotAttention(num_slots=5, dim=512)
inputs = torch.randn(1, 1024, 512)
slot_attn(inputs)
@tkipf
Copy link

tkipf commented Jun 29, 2020

Thanks!

@psteinb
Copy link

psteinb commented Jul 1, 2020

Maybe I haben overlooked it, but how is this wonderful work licensed?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment