from absl import app, flags import glob import os import tensorflow as tf from model import mesh_model, compile_classifier from dataio import get_batched_dataset flags.DEFINE_boolean('sparse', default=False, help='use proposed sparse implementation') flags.DEFINE_string('tb_dir', default='/tmp/graphics/', help='root directory to store tensorboard data') flags.DEFINE_integer('epochs', default=10, help='number of epochs to train') NUM_CLASSES = 16 def get_datasets(): url = ('https://storage.googleapis.com/tensorflow-graphics/notebooks/' 'mesh_segmentation/{}') path_to_data_zip = tf.keras.utils.get_file('data.zip', origin=url.format('data.zip'), extract=True) test_data_files = [ os.path.join(os.path.dirname(path_to_data_zip), 'data/Dancer_test_sequence.tfrecords') ] test_dataset = get_batched_dataset(test_data_files) path_to_train_data_zip = tf.keras.utils.get_file( 'train_data.zip', origin=url.format('train_data.zip'), extract=True) train_data_files = glob.glob( os.path.join(os.path.dirname(path_to_train_data_zip), '*train*.tfrecords')) train_dataset = get_batched_dataset(train_data_files) return train_dataset, test_dataset def main(_): FLAGS = flags.FLAGS sparse = FLAGS.sparse sparse_impl = 'sparse_matmul' if sparse else 'gather_sum' train_dataset, test_dataset = get_datasets() initial_vertex_feature_dim = train_dataset.element_spec[0][0].shape[-1] model = mesh_model(num_classes=NUM_CLASSES, initial_vertex_feature_dim=initial_vertex_feature_dim, sparse_impl=sparse_impl) compile_classifier(model) tb_dir = os.path.join(FLAGS.tb_dir, sparse_impl) model.fit(train_dataset, validation_data=test_dataset, epochs=FLAGS.epochs, callbacks=[tf.keras.callbacks.TensorBoard(tb_dir)]) if __name__ == '__main__': app.run(main)