This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Simplified Implementation of the Linear Recurrent Unit | |
------------------------------------------------------ | |
We present here a simplified JAX implementation of the Linear Recurrent Unit (LRU). | |
The state of the LRU is driven by the input $(u_k)_{k=1}^L$ of sequence length $L$ | |
according to the following formula (and efficiently parallelized using an associative scan): | |
$x_{k} = \Lambda x_{k-1} +\exp(\gamma^{\log})\odot (B u_{k})$, | |
and the output is computed at each timestamp $k$ as follows: $y_k = C x_k + D u_k$. | |
In our code, $B,C$ follow Glorot initialization, with $B$ scaled additionally by a factor 2 | |
to account for halving the state variance by taking the real part of the output projection. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import io | |
import urllib | |
import numpy as np | |
import PIL | |
import matplotlib.pyplot as plt | |
meme_url = 'https://i.kym-cdn.com/photos/images/original/001/100/432/0f5.jpg' | |
img = PIL.Image.open( | |
io.BytesIO( |