from absl import app, flags
import glob
import os
import tensorflow as tf
from model import mesh_model, compile_classifier
from dataio import get_batched_dataset

flags.DEFINE_boolean('sparse',
                     default=False,
                     help='use proposed sparse implementation')
flags.DEFINE_string('tb_dir',
                    default='/tmp/graphics/',
                    help='root directory to store tensorboard data')
flags.DEFINE_integer('epochs', default=10, help='number of epochs to train')

NUM_CLASSES = 16


def get_datasets():
    url = ('https://storage.googleapis.com/tensorflow-graphics/notebooks/'
           'mesh_segmentation/{}')
    path_to_data_zip = tf.keras.utils.get_file('data.zip',
                                               origin=url.format('data.zip'),
                                               extract=True)

    test_data_files = [
        os.path.join(os.path.dirname(path_to_data_zip),
                     'data/Dancer_test_sequence.tfrecords')
    ]
    test_dataset = get_batched_dataset(test_data_files)

    path_to_train_data_zip = tf.keras.utils.get_file(
        'train_data.zip', origin=url.format('train_data.zip'), extract=True)

    train_data_files = glob.glob(
        os.path.join(os.path.dirname(path_to_train_data_zip),
                     '*train*.tfrecords'))
    train_dataset = get_batched_dataset(train_data_files)
    return train_dataset, test_dataset


def main(_):
    FLAGS = flags.FLAGS
    sparse = FLAGS.sparse
    sparse_impl = 'sparse_matmul' if sparse else 'gather_sum'
    train_dataset, test_dataset = get_datasets()
    initial_vertex_feature_dim = train_dataset.element_spec[0][0].shape[-1]

    model = mesh_model(num_classes=NUM_CLASSES,
                       initial_vertex_feature_dim=initial_vertex_feature_dim,
                       sparse_impl=sparse_impl)
    compile_classifier(model)
    tb_dir = os.path.join(FLAGS.tb_dir, sparse_impl)
    model.fit(train_dataset,
              validation_data=test_dataset,
              epochs=FLAGS.epochs,
              callbacks=[tf.keras.callbacks.TensorBoard(tb_dir)])


if __name__ == '__main__':
    app.run(main)