import torch
from torch import nn, einsum
from einops import rearrange, repeat
class FixedPositionalEmbedding(nn.Module):
def __init__(self, dim):
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
def forward(self, x, seq_dim = 1, offset = 0):
t = torch.arange(x.shape[seq_dim], device = x.device).type_as(self.inv_freq) + offset
sinusoid_inp = einsum('i , j -> i j', t, self.inv_freq)
emb = torch.stack((sinusoid_inp.sin(), sinusoid_inp.cos()), dim = -1)
emb = rearrange(emb, 'n ... -> n (...)')
return emb
dim = 512
dim_head = 64
heads = 8
inp = torch.randn(1, 1024, dim)
# usual sinusoidal pos emb
pos_emb = FixedPositionalEmbedding(dim_head)
sinu_pos = pos_emb(inp)
# parameters
to_q = nn.Linear(dim, dim_head * heads)
to_k = nn.Linear(dim, dim_head * heads)
# project to get queries and keys
q = to_q(inp)
k = to_k(inp)
# merge heads
q = rearrange(q, 'b n (h d) -> (b h) n d', h = heads)
k = rearrange(k, 'b n (h d) -> (b h) n d', h = heads)
# main novelty of roformer
def rotate_every_two(x):
x = rearrange(x, '... (d j) -> ... d j', j = 2)
x1, x2 = x.unbind(dim = -1)
x = torch.stack((-x2, x1), dim = -1)
return rearrange(x, '... d j -> ... (d j)')
def apply_rotory_pos_emb(q, k, sinu_pos):
sinu_pos = rearrange(sinu_pos, 'n (d j) -> n d j', j = 2)
sin, cos = sinu_pos.unbind(dim = -1)
sin, cos = map(lambda t: repeat(t, 'b n -> b (n j)', j = 2), (sin, cos))
q_rotated = rotate_every_two(q)
k_rotated = rotate_every_two(k)
q = (q * cos) + (q_rotated * sin)
k = (k * cos) + (k_rotated * sin)
return q, k
q, k = apply_rotory_pos_emb(q, k, sinu_pos)
# do rest of attention, linear or otherwise
