Skip to content

Instantly share code, notes, and snippets.

@Ryu1845
Created February 26, 2025 11:43
Show Gist options
  • Save Ryu1845/b1d0f3aef2673842b9c988f323113413 to your computer and use it in GitHub Desktop.
Save Ryu1845/b1d0f3aef2673842b9c988f323113413 to your computer and use it in GitHub Desktop.
class Attention(nn.Module):
def __init__(self, dim):
self.pre_norm = nn.LayerNorm(dim)
self.to_qkv = nn.Linear(dim, 3*dim)
self.to_out = nn.Linear(dim, dim)
def forward(self, x):
x = self.pre_norm(x)
qkv = self.to_qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
attn_out = sdpa(q, k, v)
return x + self.to_out(attn_out)
class MLP(nn.Module):
def __init__(self, dim):
self.pre_norm = nn.LayerNorm(dim)
self.proj_in = nn.Linear(dim, 4*dim)
self.proj_out = nn.Linear(4*dim, dim)
def forward(self, x):
x = self.pre_norm(x)
skip = x
x = self.proj_in(x)
x = F.gelu(x)
x = self.proj_out(x)
return skip + x
class Block(nn.Module):
def __init__(self, dim):
self.attn = Attention(dim)
self.mlp = MLP(dim)
def forward(self, x):
x = self.attn(x)
x = self.mlp(x)
return x
class JackAttention(nn.Module):
def __init__(self, dim):
n_blocks = math.log(dim, 8)
self.blocks = nn.ModuleList()
inner_dim = dim
for _ in range(n_blocks):
self.blocks.append(Block(inner_dim))
inner_dim /= 8
def forward(self, x):
batch, time, dim = x.shape
for block in self.blocks:
x = block(x)
x = rearrange("b t (eight d)->b (t eight) d", eight=8)
x = x.reshape(batch, time, dim)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment