Skip to content

Instantly share code, notes, and snippets.

@lakshmanok
Created August 28, 2018 16:11
Show Gist options
  • Save lakshmanok/4858f69a1a456ef169135a8b0e018952 to your computer and use it in GitHub Desktop.
Save lakshmanok/4858f69a1a456ef169135a8b0e018952 to your computer and use it in GitHub Desktop.
def train_and_evaluate(output_dir, hparams):
STEPS_PER_EVAL = 1000
max_steps = hparams['train_steps']
eval_batch_size = min(1024, hparams['num_eval_images'])
eval_batch_size = eval_batch_size - eval_batch_size % 8 # divisible by num_cores
tf.logging.info('train_batch_size=%d eval_batch_size=%d max_steps=%d',
hparams['train_batch_size'],
eval_batch_size,
max_steps)
# TPU change 3
if hparams['use_tpu']:
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
hparams['tpu'],
zone=hparams['tpu_zone'],
project=hparams['project'])
config = tf.contrib.tpu.RunConfig(
cluster=tpu_cluster_resolver,
model_dir=output_dir,
save_checkpoints_steps=STEPS_PER_EVAL,
tpu_config=tf.contrib.tpu.TPUConfig(
iterations_per_loop=STEPS_PER_EVAL,
per_host_input_for_training=True))
else:
config = tf.contrib.tpu.RunConfig()
estimator = tf.contrib.tpu.TPUEstimator( # TPU change 4
model_fn=image_classifier,
config=config,
params=hparams,
model_dir=output_dir,
train_batch_size=hparams['train_batch_size'],
eval_batch_size=eval_batch_size,
use_tpu=hparams['use_tpu']
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment