"""Example ImageNet-style resnet training scenario with synthetic data.

Author: Mike Dusenberry
"""
import argparse

import numpy as np
import tensorflow as tf

# args
parser = argparse.ArgumentParser(add_help=False)  # to allow for `-h` as a flag for height
parser.add_argument("-n", type=int, default=175, help="num examples to generate")
parser.add_argument("-h", type=int, default=224, help="example height")
parser.add_argument("-w", type=int, default=224, help="example width")
parser.add_argument("-k", type=int, default=1000, help="num classes")
parser.add_argument("--batch_size", type=int, default=32, help="batch size")
parser.add_argument("--lr", type=float, default=0.001, help="learning rate")
parser.add_argument("--steps", type=int, default=1000, help="training steps")
parser.add_argument("--log_interval", type=int, default=100, help="how often to print the loss")
parser.add_argument("--buffer", type=int, default=100, help="size of prefetch buffer in batches")
parser.add_argument("--help", action='help', help="show this help message and exit")
FLAGS = parser.parse_args()

# synthetic data
x = np.random.randn(FLAGS.n, FLAGS.h, FLAGS.w, 3).astype(np.float32)
y = np.eye(FLAGS.k)[np.random.randint(FLAGS.k, size=FLAGS.n)].astype(np.float32)

# tf data
dataset = tf.data.Dataset.from_tensor_slices((x, y))
dataset = dataset.shuffle(100)
dataset = dataset.batch(FLAGS.batch_size)
dataset = dataset.repeat(-1)
dataset = dataset.prefetch(FLAGS.buffer)
iterator = dataset.make_one_shot_iterator()
x_batch, y_batch = iterator.get_next()

# tf model
resnet = tf.keras.applications.ResNet50(
    include_top=False, input_tensor=x_batch, input_shape=(FLAGS.h, FLAGS.w, 3))
out = tf.layers.flatten(resnet.output)
logits = tf.layers.dense(out, FLAGS.k)

# tf loss
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_batch, logits=logits))

# tf optimizer
opt = tf.train.AdamOptimizer(FLAGS.lr)
train_op = opt.minimize(loss)

# saver
saver = tf.train.Saver()

# init
sess = tf.Session()
sess.run(tf.global_variables_initializer())

# train loop
for i in range(FLAGS.steps):
  feed_dict = {tf.keras.backend.learning_phase(): True}
  if i % FLAGS.log_interval == 0:
    _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)
    print("loss: {}".format(loss_value))
  else:
    sess.run(train_op, feed_dict=feed_dict)

#saver.save(sess, "model.ckpt")