Last active
March 25, 2021 03:26
-
-
Save lucidrains/b441ceb8387922e2912fbbd6fa39828e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn, einsum | |
from einops import rearrange, repeat | |
class FixedPositionalEmbedding(nn.Module): | |
def __init__(self, dim): | |
super().__init__() | |
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) | |
self.register_buffer('inv_freq', inv_freq) | |
def forward(self, x, seq_dim = 1, offset = 0): | |
t = torch.arange(x.shape[seq_dim], device = x.device).type_as(self.inv_freq) + offset | |
sinusoid_inp = einsum('i , j -> i j', t, self.inv_freq) | |
emb = torch.stack((sinusoid_inp.sin(), sinusoid_inp.cos()), dim = -1) | |
emb = rearrange(emb, 'n ... -> n (...)') | |
return emb | |
dim = 512 | |
dim_head = 64 | |
heads = 8 | |
inp = torch.randn(1, 1024, dim) | |
# usual sinusoidal pos emb | |
pos_emb = FixedPositionalEmbedding(dim_head) | |
sinu_pos = pos_emb(inp) | |
# parameters | |
to_q = nn.Linear(dim, dim_head * heads) | |
to_k = nn.Linear(dim, dim_head * heads) | |
# project to get queries and keys | |
q = to_q(inp) | |
k = to_k(inp) | |
# merge heads | |
q = rearrange(q, 'b n (h d) -> (b h) n d', h = heads) | |
k = rearrange(k, 'b n (h d) -> (b h) n d', h = heads) | |
# main novelty of roformer | |
def rotate_every_two(x): | |
x = rearrange(x, '... (d j) -> ... d j', j = 2) | |
x1, x2 = x.unbind(dim = -1) | |
x = torch.stack((-x2, x1), dim = -1) | |
return rearrange(x, '... d j -> ... (d j)') | |
def apply_rotory_pos_emb(q, k, sinu_pos): | |
sinu_pos = rearrange(sinu_pos, 'n (d j) -> n d j', j = 2) | |
sin, cos = sinu_pos.unbind(dim = -1) | |
sin, cos = map(lambda t: repeat(t, 'b n -> b (n j)', j = 2), (sin, cos)) | |
q_rotated = rotate_every_two(q) | |
k_rotated = rotate_every_two(k) | |
q = (q * cos) + (q_rotated * sin) | |
k = (k * cos) + (k_rotated * sin) | |
return q, k | |
q, k = apply_rotory_pos_emb(q, k, sinu_pos) | |
# do rest of attention, linear or otherwise |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment