This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def call(self, inputs): | |
x = self.normalize1(inputs) | |
x_projected = self.channel_projection1(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
v_channels = tf.linalg.matrix_transpose(v) | |
v_projected = self.spatial_projection(v_channels) | |
v_projected = tf.linalg.matrix_transpose(v_projected) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
self.spatial_projection = layers.Dense(units=num_patches, bias_initializer="Ones") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def spatial_gating_unit(self, x): | |
# u and v shape: [batch_size, num_patchs, embedding_dim] | |
u, v = tf.split(x, num_or_size_splits=2, axis=2) | |
v = self.normalize2(v) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from tensorflow import keras | |
from tensorflow.keras import layers | |
class gMLPLayer(layers.Layer): | |
def __init__(self, num_patches, embedding_dim, dropout_rate, *args, **kwargs): | |
super(gMLPLayer, self).__init__(*args, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
self.normalize1 = layers.LayerNormalization(epsilon=1e-6) | |
self.normalize2 = layers.LayerNormalization(epsilon=1e-6) | |
self.channel_projection1 = keras.Sequential( | |
[ | |
layers.Dense(units=embedding_dim * 2), | |
layers.ReLU(), | |
layers.Dropout(rate=dropout_rate), | |
] | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
tf.one_hot(tf.argmax(p), depth = len(p)) | |
<tf.Tensor: shape=(3,), dtype=float32, numpy=array([0., 1., 0.], dtype=float32)> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
temperature = 0.01 | |
dist = tfp.distributions.RelaxedOneHotCategorical(temperature, probs=p) | |
dist.sample() | |
<tf.Tensor: shape=(3,), dtype=float32, numpy=array([0., 1., 0.], dtype=float32)> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
temperature = 10 | |
dist = tfp.distributions.RelaxedOneHotCategorical(temperature, probs=p) | |
dist.sample() | |
<tf.Tensor: shape=(3,), dtype=float32, numpy=array([0.31916314, 0.34642866, 0.33440822], dtype=float32)> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
JSON.safeStringify = (obj, indent = 2) => { | |
let cache = []; | |
const retVal = JSON.stringify( | |
obj, | |
(key, value) => | |
typeof value === "object" && value !== null | |
? cache.includes(value) | |
? undefined // Duplicate reference found, discard key | |
: cache.push(value) && value // Store value in our collection | |
: value, |