Skip to content

Instantly share code, notes, and snippets.

@lucidrains
Last active December 29, 2024 17:57
Show Gist options
  • Save lucidrains/20601799202431c6fc9a1c2a2e6f57d2 to your computer and use it in GitHub Desktop.
Save lucidrains/20601799202431c6fc9a1c2a2e6f57d2 to your computer and use it in GitHub Desktop.
liere
from __future__ import annotations
import torch
from torch import nn
from torch.nn import Module
from einops import einsum, rearrange, reduce
def apply_liere_pos_emb(rotations, t):
return einsum(t, rotations, 'b h n d, n d e -> b h n e')
class Liere(Module):
def __init__(
self,
dim,
num_dim = 2, # 3 for video, 4 etc
):
super().__init__()
self.num_dim = num_dim
self.generator_params = nn.Parameter(torch.rand(
num_dim,
dim,
dim
) * 2 * torch.pi)
def forward(
self,
dimensions: tuple[int, ...]
):
device = self.generator_params.device
assert len(dimensions) == self.num_dim
upper_tri = self.generator_params.triu(1)
skew = upper_tri - rearrange(upper_tri, '... i j -> ... j i')
dim_aranges = [torch.arange(d, device = device) for d in dimensions]
positions = torch.stack(torch.meshgrid(dim_aranges, indexing = 'ij'), dim = -1)
positions = rearrange(positions, '... p -> (...) p 1 1')
matrices = reduce(skew * positions, 'n p d e -> n d e', 'sum')
rotations = torch.matrix_exp(matrices.float().contiguous())
return rotations
if __name__ == '__main__':
q = torch.randn(1, 8, 64 * 32, 64)
k = torch.randn(1, 8, 64 * 32, 64)
liere = Liere(64)
rotations = liere((64, 32))
q = apply_liere_pos_emb(rotations, q)
k = apply_liere_pos_emb(rotations, k)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment