Skip to content

Instantly share code, notes, and snippets.

View lucidrains's full-sized avatar

Phil Wang lucidrains

View GitHub Profile
@lucidrains
lucidrains / gte.nim
Last active January 13, 2026 15:50
gte-pure-c.nim
## GTE-Small Embedding Library - Nim Port
## A single-file, self-contained text embedding solution.
##
## Original C implementation by Antirez (Salvatore Sanfilippo)
## Nim port maintains the same algorithm and produces identical results.
##
## MIT License - Copyright (c) 2026 Salvatore Sanfilippo
## See LICENSE file for full terms.
##
## USAGE: Just compile and run - model downloads automatically on first use!
from typing import Tuple
import gc
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
import triton.testing
@lucidrains
lucidrains / sugar_bsilu.py
Last active May 31, 2025 16:02
proposed SUGAR with BSiLU
# https://arxiv.org/abs/2505.22074
import torch
from torch.nn import Module
class SugarBSiLU(Module):
# proposed SUGAR with B-SiLU section 3.1
# it was their best performing
def __init__(
@lucidrains
lucidrains / rnnify.py
Last active December 7, 2024 15:45
rnnify.py
from __future__ import annotations
from functools import partial
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import Module, ModuleList
# helper functions
@lucidrains
lucidrains / liere.py
Last active December 29, 2024 17:57
liere
from __future__ import annotations
import torch
from torch import nn
from torch.nn import Module
from einops import einsum, rearrange, reduce
def apply_liere_pos_emb(rotations, t):
return einsum(t, rotations, 'b h n d, n d e -> b h n e')
@lucidrains
lucidrains / tree_attn_decode.py
Created August 12, 2024 17:48
Tree Attention Decoding
import torch
from torch import einsum
import torch.distributed as dist
def tree_attn_decode(q, k, v):
"""
Algorithm 3 proposed in Tree Attention
https://arxiv.org/abs/2408.04093
"""
@lucidrains
lucidrains / vit_with_mask.py
Created December 8, 2022 20:14
ViT, but you can pass in images with patches masked out
import torch
import torch.nn.functional as F
from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
# helpers
def pair(t):
@lucidrains
lucidrains / uniprot_mapping.py
Created January 7, 2022 05:28
uniprot mapping for python3
import urllib
import urllib.parse
from urllib.request import urlopen
def uniprot_mapping(fromtype, totype, identifier):
base = 'http://www.uniprot.org'
tool = 'mapping'
params = {
'from': fromtype,
# Schedules with t from 0-1, eg use as lr_sch(t/steps)
def lr_sch(t):
left_br = 20 * t - 5
right_br = - (1.45 * t + 2.08)
def denom(sign):
return (1 + jnp.exp(- sign * (19 * (t - 0.015))))
return 10 ** ((left_br / denom(-1)) + (right_br / denom(+1)))
def wd_sch(t):
return 10 ** (-np.log(np.exp( 10.7 * t - 2.7) + 1) - 2 )
@lucidrains
lucidrains / faster_rng.py
Created June 2, 2021 17:08
faster rng for jax
def hardware_uniform(rng_key: PRNGKey,
shape: Shape,
dtype: Dtype = np.float32,
minval: Array = np.float32(0),
maxval: Array = np.float32(1)) -> Array:
del rng_key # non-deterministic prng.
minval = lax.convert_element_type(minval, dtype)
maxval = lax.convert_element_type(maxval, dtype)
return lax.rng_uniform(minval, maxval, shape)