Created
June 19, 2023 20:29
-
-
Save Ryu1845/d510c47575d94b54bff0397bd9138d0a to your computer and use it in GitHub Desktop.
JAX implementation of Block-State Transfomer (copied from https://arxiv.org/abs/2306.09539)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Block-State Transformer Layer.""" | |
# Block Transformers are non-recurrent and parallelizable. | |
block_transformer = jax.vmap(BRecT.nonrecurrent_cell) | |
def BST(u): | |
"""Block-State Transformer Layer.""" | |
global MF # True if Multi-Filter, False otherwise (SH/MH) | |
# split inputs into windows (l/w, w, d) | |
u = jnp.split(u, seq_length // win_length, axis=0) | |
# collect context states from SSM outputs | |
context_states = [SH/MH/MF]_context_states(u) | |
# pass the contexts in place of recurrent states | |
y = block_transformer( | |
token_embeddings=u, | |
recurrent_state=context_states, | |
use_cross_attn_causal_mask=not MF, | |
use_cross_positional_emb=MF, # context IDs | |
) | |
return rearrange(y, "lw w d -> (lw w) d") # (l, d) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Context state collection for BST-MF variant.""" | |
# (MF): Multi-Filter | |
def MF_context_states(u): | |
"""Multi-Filter Context Collection.""" | |
h, b = get_filters_[unstruct/s4](channels=num_states) | |
y_s = multichannel_convolution(u, h, b) | |
# y_s: (l, d, s) | |
context_states = jnp.split( | |
y_s, seq_length // win_length, axis=0) | |
# context_states: (l/w, w, d, s) | |
# collect the last context states | |
context_states = context_states[:, -1, ...] # (l/w, d, s) | |
context_states = rearrange( | |
context_states, "lw d s -> lw s d") | |
# shift context states corresponding to windows | |
context_states = jnp.roll(context_states, 1, axis=1) | |
# replace the initial window with trainable weights | |
init_context = get_init_context(num_states) # (d, s) | |
context_states[0] = init_context | |
# lift to multiple heads | |
context_states = dense(context_states) | |
return context_states # (l/w, s, d, h) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Context state collection for BST-MH variant.""" | |
# (MH): Multi-Head | |
def MH_context_states(u): | |
"""Multi-Head Context Collection.""" | |
h, b = get_filters_[unstruct/s4](channels=num_heads) | |
y_h = multichannel_convolution(u, h, b) | |
# y_h: (l, d, h) | |
context_states = jnp.split( | |
y_h, seq_length // win_length, axis=0) | |
return context_states # (l/w, w, d, h) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Context state collection for BST-SH variant.""" | |
num_heads = 8 # (h) | |
num_states = 32 # (s) | |
# (SH): Single-Head | |
def SH_context_states(u): | |
"""Single-Head Context Collection.""" | |
h, b = get_filters_[unstruct/s4](channels=1) | |
y_1 = multichannel_convolution(u, h, b) | |
# y_1: (l, d, 1) | |
# lift to multiple heads | |
y_h = dense(y_1) | |
# y_h: (l, d, h) | |
context_states = jnp.split( | |
y_h, seq_length // win_length, axis=0) | |
return context_states # (l/w, w, d, h) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Unstructured filters and convolutions.""" | |
import jax | |
from jax import numpy as jnp | |
from einops import rearrange | |
win_length = 512 # (w) | |
seq_length = 4096 # (l) | |
def get_filters_unstruct(channels): | |
"""Returns trainable filters and biases. | |
Args: | |
channels: number of filters. | |
Returns: | |
h: filter of shape (seq_length, channels, dim) | |
b: bias of shape (channels, dim) | |
""" | |
t = jnp.linspace(0.0, 1.0, seq_length) | |
h = jnp.exp(- alpha * t) * dense(positional_emb(t)) | |
b = get_bias() | |
return h, b | |
def multichannel_convolution(u, h, b): | |
"""Multichannel convolution function. | |
Args: | |
u: input of shape (seq_length, dim) | |
h: filters of shape (seq_length, channels, dim) | |
b: bias of shape (channels, dim) | |
""" | |
h = rearrange(h, "l c d -> c d l") | |
fft_size = seq_length * 2 | |
u_f = jnp.fft.rfft(x, n=fft_size) | |
h_f = jnp.fft.rfft(h, n=fft_size) | |
y = jnp.fft.irfft(h_f * x_f, n=fft_size, norm="forward")[ | |
..., :seq_length] # (c, d, l) | |
y = y + x * b[..., None] # (c, d, l) | |
y = rearrange(y, "c d l -> l d c") | |
return y |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment