Created
March 23, 2023 20:11
-
-
Save harpone/e018bf25059687a4355e7c76d6807de8 to your computer and use it in GitHub Desktop.
LRU with self-attention
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def forward_sequential(h, xs, U, W, nu, theta): | |
"""Forward pass through the network sequentially over input `xs` of any length. | |
NOTE: has no batch dimension. To be batched with `vmap`. | |
Args: | |
h (torch.tensor): shape [D_h, ]; previous state | |
xs (torch.tensor): shape [T, D_x]; input sequence | |
U (torch.tensor): Parameter matrix of shape [D_h, D_x] | |
W (torch.tensor): Parameter matrix of shape [D_h, D_x] | |
xi (torch.tensor): Parameter vector of shape [D_h, ] | |
eta (torch.tensor): Parameter vector of shape [D_h, ] | |
Returns: | |
hs (torch.tensor): shape [T, D_h]; output sequence | |
""" | |
T = xs.shape[0] | |
D_h = h.shape[0] | |
hs = torch.zeros(T, D_h, device=xs.device) | |
for t in range(T): | |
h = torch.exp(U @ xs[t] - nu - theta * 1j) * h + W @ xs[t] | |
hs[t] = h | |
return hs.real | |
def forward_parallel(h, xs, U, W, nu, theta): | |
"""Forward pass through the network in parallel over input `xs` of any length by using | |
the exact solution of the recurrence relation. | |
NOTE: has no batch dimension. To be batched with `vmap`. | |
Args: | |
h (torch.tensor): shape [D_h, ]; previous state | |
xs (torch.tensor): shape [T, D_x]; input sequence | |
U (torch.tensor): Parameter matrix of shape [D_h, D_x] | |
W (torch.tensor): Parameter matrix of shape [D_h, D_x] | |
xi (torch.tensor): Parameter vector of shape [D_h, ] | |
eta (torch.tensor): Parameter vector of shape [D_h, ] | |
Returns: | |
hs (torch.tensor): shape [T, D_h]; output sequence | |
""" | |
gammas = torch.cumsum(torch.matmul(xs, U.T) - nu - theta * 1j, dim=0) # [T, D_h] | |
betas = torch.matmul(xs, W.T) # [T, D_h] | |
source = torch.cumsum(torch.exp(-gammas) * betas, dim=0) # [T, D_h] | |
hs = torch.exp(gammas) * (h[None] + source) | |
return hs.real | |
#### Benchmark code: | |
device = torch.device('cuda') | |
D_h = 256 | |
D_x = 64 | |
U = torch.randn(D_h, D_x, device=device) | |
W = torch.randn(D_h, D_x, device=device) | |
xi = torch.linspace(0.001, 0.5, D_h, device=device) | |
eta = torch.linspace(0, 2 * math.pi * (D_h - 1) / D_h, D_h, device=device) | |
T = 1024 | |
xs = torch.randn(T, D_x, device=device) | |
h = torch.randn(D_h, device=device) | |
def sequential_timer(): | |
hs_seq = forward_sequential(h, xs, U, W, xi, eta) | |
torch.cuda.synchronize() | |
def parallel_timer(): | |
hs_par = forward_parallel(h, xs, U, W, xi, eta) | |
torch.cuda.synchronize() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
substack post with e.g. benchmark results here: https://open.substack.com/pub/heikkiarponen/p/linear-recurrent-units-and-attention?r=8u5xp&utm_campaign=post&utm_medium=web