Created
April 30, 2021 06:48
-
-
Save ypeleg/63610eb7df97804699d5f7f70e945063 to your computer and use it in GitHub Desktop.
Minimal keras implementation: "Perceiver: General Perception with Iterative Attention. Jaegle et al"
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Cleaned and minimal perceiver transformer, originally from code https://github.com/Rishit-dagli/Perceiver | |
# Original paper: Perceiver: General Perception with Iterative Attention. Jaegle et al. https://arxiv.org/pdf/2103.03206.pdf. | |
import math | |
import tensorflow as tf | |
from typing import Callable | |
from einops import rearrange, repeat | |
def fourier_encode(x, max_freq, num_bands = 4, base = 2): | |
x = tf.expand_dims(x, -1) | |
x = tf.cast(x, dtype = tf.float32) | |
orig_x = x | |
scales = tf.experimental.numpy.logspace(1.0, math.log(max_freq / 2) / math.log(base), num = num_bands, base = base, dtype = tf.float32, ) | |
scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)] | |
x = x * scales * math.pi | |
x = tf.concat([tf.math.sin(x), tf.math.cos(x)], axis = -1) | |
x = tf.concat((x, orig_x), axis = -1) | |
return x | |
class PreNorm(tf.keras.layers.Layer): | |
def __init__(self, dim, fn, context_dim = None): | |
super(PreNorm, self).__init__() | |
self.fn = fn | |
self.norm = tf.keras.layers.LayerNormalization(axis = -1) | |
if context_dim is None: self.norm_context = None | |
else: self.norm_context = tf.keras.layers.LayerNormalization(axis = -1) | |
def call(self, x, **kwargs): | |
x = self.norm(x) | |
return self.fn(x) | |
class Perceiver(tf.keras.Model): | |
def __init__(self, num_freq_bands, depth, max_freq, freq_base = 2, input_channels = 3, input_axis = 2, num_latents = 512, latent_dim = 512, cross_heads = 1, latent_heads = 8, cross_dim_head = 64, latent_dim_head = 64, num_classes = 1000, attn_dropout = 0.0, ff_dropout = 0.0, ): | |
super(Perceiver, self).__init__() | |
self.input_axis = input_axis | |
self.max_freq = max_freq | |
self.num_freq_bands = num_freq_bands | |
self.freq_base = freq_base | |
input_dim = input_axis * ((num_freq_bands * 2) + 1) + input_channels | |
self.latents = tf.Variable(tf.random.normal([num_latents, latent_dim])) | |
get_cross_attn: Callable[[], PreNorm] = lambda: PreNorm(latent_dim, Attention(latent_dim, input_dim, heads = cross_heads, dim_head = cross_dim_head, dropout = attn_dropout, ), context_dim = input_dim, ) | |
get_cross_ff: Callable[[], PreNorm] = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout = ff_dropout)) | |
get_latent_attn: Callable[[], PreNorm] = lambda: PreNorm(latent_dim, Attention(latent_dim, heads = latent_heads, dim_head = latent_dim_head, dropout = attn_dropout, ), ) | |
get_latent_ff: Callable[[], PreNorm] = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout = ff_dropout)) | |
self.existing_layers = list() | |
for i in range(depth): | |
self.existing_layers.append(get_cross_attn()) | |
self.existing_layers.append(get_cross_ff()) | |
self.existing_layers.append(get_latent_attn()) | |
self.existing_layers.append(get_latent_ff()) | |
self.existing_layers = tf.keras.Sequential(self.existing_layers) | |
self.to_logits = tf.keras.Sequential([tf.keras.layers.LayerNormalization(axis = -1), tf.keras.layers.Dense(num_classes, input_dim = latent_dim), ]) | |
def call(self, data, mask = None): | |
b, *axis, _ = data.shape | |
axis_pos = list(map(lambda size: tf.linspace(-1.0, 1.0, num = size), axis)) | |
pos = tf.stack(tf.meshgrid(*axis_pos, indexing = "ij"), axis = -1) | |
enc_pos = fourier_encode(pos, self.max_freq, self.num_freq_bands, base = self.freq_base) | |
enc_pos = rearrange(enc_pos, "... n d -> ... (n d)") | |
enc_pos = repeat(enc_pos, "... -> b ...", b = b) | |
data = tf.concat((data, enc_pos), axis = -1) | |
data = rearrange(data, "b ... d -> b (...) d") | |
x = repeat(self.latents, "n d -> b n d", b = b) | |
x = self.existing_layers(x) | |
x = tf.math.reduce_mean(x, axis = -2) | |
return self.to_logits(x) | |
class GEGLU(tf.keras.layers.Layer): | |
def call(self, x): | |
x, gates = tf.split(x, 2, axis = -1) | |
return x * tf.nn.gelu(gates) | |
class Attention(tf.keras.layers.Layer): | |
def __init__(self, query_dim, context_dim = None, heads = 8, dim_head = 64, dropout = 0.0): | |
super(Attention, self).__init__() | |
inner_dim = dim_head * heads | |
if context_dim is None: context_dim = query_dim | |
self.scale = dim_head ** -0.5 | |
self.heads = heads | |
self.to_queries = tf.keras.layers.Dense(inner_dim, input_dim = query_dim, use_bias = False) | |
self.to_keys_values = tf.keras.layers.Dense(inner_dim * 2, input_dim = query_dim, use_bias = False) | |
self.to_out = tf.keras.Sequential([tf.keras.layers.Dense(inner_dim, input_dim = query_dim), tf.keras.layers.Dropout(dropout), ]) | |
def call(self, x, context = None, mask = None): | |
h = self.heads | |
queries = self.to_queries(x) | |
if context is None: context = x | |
kv = self.to_keys_values(context) | |
keys, values = tf.split(kv, num_or_size_splits = 2, axis = -1) | |
queries, keys, values = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h = h), (queries, keys, values), ) | |
sim = tf.einsum("b i d, b j d -> b i j", queries, keys) * self.scale | |
if mask is not None: | |
mask = rearrange(mask, "b ... -> b (...)") | |
max_neg_value = -tf.experimental.numpy.finfo(sim.dtype).max | |
mask = repeat(mask, "b j -> (b h) () j", h = h) | |
sim = tf.where(tf.bitwise.invert(mask), max_neg_value, sim) | |
attn = tf.nn.softmax(sim, axis = -1) | |
out = tf.einsum("b i j, b j d -> b i d", attn, values) | |
out = rearrange(out, "(b h) n d -> b n (h d)", h = h) | |
out = self.to_out(out) | |
return out | |
class FeedForward(tf.keras.layers.Layer): | |
def __init__(self, dim, mult = 4, dropout = 0.0): | |
super(FeedForward, self).__init__() | |
self.net = tf.keras.Sequential([tf.keras.layers.Dense(dim * mult * 2, input_dim = dim), GEGLU(), tf.keras.layers.Dropout(dropout), tf.keras.layers.Dense(dim, input_dim = dim * mult), ]) | |
def call(self, inputs): return self.net(inputs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment