Created
December 2, 2020 16:28
-
-
Save jayendra13/224bacf22b411224fb444fcbbd25da21 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import t5 | |
import tensorflow.compat.v1 as tf | |
tpu = tf.distribute.cluster_resolver.TPUClusterResolver("node-1") | |
tpu_address = tpu.get_master() | |
tf.disable_v2_behavior() | |
def data_gen(split, shuffle_files=False): | |
ds = tf.data.Dataset.from_tensor_slices(["foobar"]).repeat(100) | |
return ds | |
def preprocessor_fn(ds): | |
def to_inputs_and_targets(ex): | |
return { | |
"inputs" : tf.strings.join(["pqrs:" , ex]), | |
"targets" : ex | |
} | |
return ds.map(to_inputs_and_targets, num_parallel_calls=tf.data.experimental.AUTOTUNE) | |
t5.data.TaskRegistry.remove("foo") | |
t5.data.TaskRegistry.add( | |
"foo", | |
dataset_fn=data_gen, | |
splits=["train"], | |
metric_fns=[t5.evaluation.metrics.accuracy], | |
text_preprocessor=[preprocessor_fn], | |
num_input_examples={"train": 100} | |
) | |
t5.data.MixtureRegistry.remove("foo_mixture") | |
t5.data.MixtureRegistry.add( | |
"foo_mixture", | |
["foo"], | |
default_rate=1.0 | |
) | |
model_parallelism, train_batch_size, keep_checkpoint_max = 1, 256, 16 | |
model_dir = "<BUCKET_LOCATION>" | |
model = t5.models.MtfModel( | |
model_dir=model_dir, | |
tpu=tpu_address, | |
tpu_topology="v3-8", | |
model_parallelism=model_parallelism, | |
batch_size=train_batch_size, | |
sequence_length={"inputs": 128, "targets": 32}, | |
learning_rate_schedule=0.003, | |
save_checkpoints_steps=5000, | |
keep_checkpoint_max=keep_checkpoint_max, | |
iterations_per_loop=100, | |
) | |
PRETRAINED_DIR = "gs://t5-data/pretrained_models/small" | |
FINETUNE_STEPS = 2500 | |
model.finetune( | |
mixture_or_task_name="foo_mixture", | |
pretrained_model_dir=PRETRAINED_DIR, | |
finetune_steps=FINETUNE_STEPS | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment