Skip to content

Instantly share code, notes, and snippets.

View radi-cho's full-sized avatar

Radi Cho radi-cho

View GitHub Profile
self.spatial_projection = layers.Dense(units=num_patches, bias_initializer="Ones")
v_channels = tf.linalg.matrix_transpose(v)
v_projected = self.spatial_projection(v_channels)
v_projected = tf.linalg.matrix_transpose(v_projected)
def call(self, inputs):
x = self.normalize1(inputs)
x_projected = self.channel_projection1(x)
# x_spatial shape: [batch_size, num_patches, embedding_dim].
x_spatial = self.spatial_gating_unit(x_projected)
# x_projected shape: [batch_size, num_patches, embedding_dim].
x_projected = self.channel_projection2(x_spatial)
class gMLPLayer(layers.Layer):
def __init__(self, num_patches, embedding_dim, dropout_rate, *args, **kwargs):
super(gMLPLayer, self).__init__(*args, **kwargs)
self.channel_projection1 = keras.Sequential(
[
layers.Dense(units=embedding_dim * 2),
layers.ReLU(),
layers.Dropout(rate=dropout_rate),
]
gmlp_blocks = keras.Sequential(
[gMLPLayer(num_patches, embedding_dim, dropout_rate) for _ in range(num_blocks)]
)
import tensorflow as tf
from datasets import load_dataset
from transformers import TFMT5ForConditionalGeneration, MT5Tokenizer, DataCollatorForSeq2Seq
from tensorflow.keras.optimizers import Adam
tokenizer = MT5Tokenizer.from_pretrained("google/mt5-small")
model = TFMT5ForConditionalGeneration.from_pretrained("google/mt5-small")
dataset = load_dataset("csv", data_files="train.csv")
dataset = dataset["train"].shuffle(seed=42)
def preprocess_function(examples):
padding = "max_length"
max_length = 200
inputs = [ex for ex in examples["Text"]]
targets = [ex for ex in examples["Expected"]]
model_inputs = tokenizer(inputs, max_length=max_length, padding=padding, truncation=True)
train_dataset = dataset.map(preprocess_function, batched=True, desc="Running tokenizer")
data_collator = DataCollatorForSeq2Seq(
tokenizer,
model=model,
label_pad_token_id=tokenizer.pad_token_id,
pad_to_multiple_of=64,
return_tensors="np")
tf_train_dataset = model.prepare_tf_dataset(
model.compile(optimizer=Adam(3e-5))
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)
model.fit(tf_train_dataset, epochs=10, callbacks=[early_stopping])