"""Example ImageNet-style resnet training scenario with synthetic data. Author: Mike Dusenberry """ import argparse import numpy as np import tensorflow as tf # args parser = argparse.ArgumentParser(add_help=False) # to allow for `-h` as a flag for height parser.add_argument("-n", type=int, default=175, help="num examples to generate") parser.add_argument("-h", type=int, default=224, help="example height") parser.add_argument("-w", type=int, default=224, help="example width") parser.add_argument("-k", type=int, default=1000, help="num classes") parser.add_argument("--batch_size", type=int, default=32, help="batch size") parser.add_argument("--lr", type=float, default=0.001, help="learning rate") parser.add_argument("--steps", type=int, default=1000, help="training steps") parser.add_argument("--log_interval", type=int, default=100, help="how often to print the loss") parser.add_argument("--buffer", type=int, default=100, help="size of prefetch buffer in batches") parser.add_argument("--help", action='help', help="show this help message and exit") FLAGS = parser.parse_args() # synthetic data x = np.random.randn(FLAGS.n, FLAGS.h, FLAGS.w, 3).astype(np.float32) y = np.eye(FLAGS.k)[np.random.randint(FLAGS.k, size=FLAGS.n)].astype(np.float32) # tf data dataset = tf.data.Dataset.from_tensor_slices((x, y)) dataset = dataset.shuffle(100) dataset = dataset.batch(FLAGS.batch_size) dataset = dataset.repeat(-1) dataset = dataset.prefetch(FLAGS.buffer) iterator = dataset.make_one_shot_iterator() x_batch, y_batch = iterator.get_next() # tf model resnet = tf.keras.applications.ResNet50( include_top=False, input_tensor=x_batch, input_shape=(FLAGS.h, FLAGS.w, 3)) out = tf.layers.flatten(resnet.output) logits = tf.layers.dense(out, FLAGS.k) # tf loss loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_batch, logits=logits)) # tf optimizer opt = tf.train.AdamOptimizer(FLAGS.lr) train_op = opt.minimize(loss) # saver saver = tf.train.Saver() # init sess = tf.Session() sess.run(tf.global_variables_initializer()) # train loop for i in range(FLAGS.steps): feed_dict = {tf.keras.backend.learning_phase(): True} if i % FLAGS.log_interval == 0: _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict) print("loss: {}".format(loss_value)) else: sess.run(train_op, feed_dict=feed_dict) #saver.save(sess, "model.ckpt")