Last active
December 29, 2024 17:57
-
-
Save lucidrains/20601799202431c6fc9a1c2a2e6f57d2 to your computer and use it in GitHub Desktop.
liere
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
import torch | |
from torch import nn | |
from torch.nn import Module | |
from einops import einsum, rearrange, reduce | |
def apply_liere_pos_emb(rotations, t): | |
return einsum(t, rotations, 'b h n d, n d e -> b h n e') | |
class Liere(Module): | |
def __init__( | |
self, | |
dim, | |
num_dim = 2, # 3 for video, 4 etc | |
): | |
super().__init__() | |
self.num_dim = num_dim | |
self.generator_params = nn.Parameter(torch.rand( | |
num_dim, | |
dim, | |
dim | |
) * 2 * torch.pi) | |
def forward( | |
self, | |
dimensions: tuple[int, ...] | |
): | |
device = self.generator_params.device | |
assert len(dimensions) == self.num_dim | |
upper_tri = self.generator_params.triu(1) | |
skew = upper_tri - rearrange(upper_tri, '... i j -> ... j i') | |
dim_aranges = [torch.arange(d, device = device) for d in dimensions] | |
positions = torch.stack(torch.meshgrid(dim_aranges, indexing = 'ij'), dim = -1) | |
positions = rearrange(positions, '... p -> (...) p 1 1') | |
matrices = reduce(skew * positions, 'n p d e -> n d e', 'sum') | |
rotations = torch.matrix_exp(matrices.float().contiguous()) | |
return rotations | |
if __name__ == '__main__': | |
q = torch.randn(1, 8, 64 * 32, 64) | |
k = torch.randn(1, 8, 64 * 32, 64) | |
liere = Liere(64) | |
rotations = liere((64, 32)) | |
q = apply_liere_pos_emb(rotations, q) | |
k = apply_liere_pos_emb(rotations, k) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment