Skip to content

Instantly share code, notes, and snippets.

@jayendra13
Created December 2, 2020 16:28
Show Gist options
  • Save jayendra13/224bacf22b411224fb444fcbbd25da21 to your computer and use it in GitHub Desktop.
Save jayendra13/224bacf22b411224fb444fcbbd25da21 to your computer and use it in GitHub Desktop.
import t5
import tensorflow.compat.v1 as tf
tpu = tf.distribute.cluster_resolver.TPUClusterResolver("node-1")
tpu_address = tpu.get_master()
tf.disable_v2_behavior()
def data_gen(split, shuffle_files=False):
ds = tf.data.Dataset.from_tensor_slices(["foobar"]).repeat(100)
return ds
def preprocessor_fn(ds):
def to_inputs_and_targets(ex):
return {
"inputs" : tf.strings.join(["pqrs:" , ex]),
"targets" : ex
}
return ds.map(to_inputs_and_targets, num_parallel_calls=tf.data.experimental.AUTOTUNE)
t5.data.TaskRegistry.remove("foo")
t5.data.TaskRegistry.add(
"foo",
dataset_fn=data_gen,
splits=["train"],
metric_fns=[t5.evaluation.metrics.accuracy],
text_preprocessor=[preprocessor_fn],
num_input_examples={"train": 100}
)
t5.data.MixtureRegistry.remove("foo_mixture")
t5.data.MixtureRegistry.add(
"foo_mixture",
["foo"],
default_rate=1.0
)
model_parallelism, train_batch_size, keep_checkpoint_max = 1, 256, 16
model_dir = "<BUCKET_LOCATION>"
model = t5.models.MtfModel(
model_dir=model_dir,
tpu=tpu_address,
tpu_topology="v3-8",
model_parallelism=model_parallelism,
batch_size=train_batch_size,
sequence_length={"inputs": 128, "targets": 32},
learning_rate_schedule=0.003,
save_checkpoints_steps=5000,
keep_checkpoint_max=keep_checkpoint_max,
iterations_per_loop=100,
)
PRETRAINED_DIR = "gs://t5-data/pretrained_models/small"
FINETUNE_STEPS = 2500
model.finetune(
mixture_or_task_name="foo_mixture",
pretrained_model_dir=PRETRAINED_DIR,
finetune_steps=FINETUNE_STEPS
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment