Skip to content

Instantly share code, notes, and snippets.

@gaphex
Created May 9, 2019 16:41
Show Gist options
  • Select an option

  • Save gaphex/036ea80ee7f22ad5cb810e5afef1fc6c to your computer and use it in GitHub Desktop.

Select an option

Save gaphex/036ea80ee7f22ad5cb810e5afef1fc6c to your computer and use it in GitHub Desktop.
model_fn = model_fn_builder(
bert_config=bert_config,
init_checkpoint=INIT_CHECKPOINT,
learning_rate=LEARNING_RATE,
num_train_steps=TRAIN_STEPS,
num_warmup_steps=10,
use_tpu=USE_TPU,
use_one_hot_embeddings=True)
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(TPU_ADDRESS)
run_config = tf.contrib.tpu.RunConfig(
cluster=tpu_cluster_resolver,
model_dir=BERT_GCS_DIR,
save_checkpoints_steps=SAVE_CHECKPOINTS_STEPS,
tpu_config=tf.contrib.tpu.TPUConfig(
iterations_per_loop=SAVE_CHECKPOINTS_STEPS,
num_shards=NUM_TPU_CORES,
per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2))
estimator = tf.contrib.tpu.TPUEstimator(
use_tpu=USE_TPU,
model_fn=model_fn,
config=run_config,
train_batch_size=TRAIN_BATCH_SIZE,
eval_batch_size=EVAL_BATCH_SIZE)
train_input_fn = input_fn_builder(
input_files=input_files,
max_seq_length=MAX_SEQ_LENGTH,
max_predictions_per_seq=MAX_PREDICTIONS,
is_training=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment