Skip to content

Instantly share code, notes, and snippets.

@wenfahu
Last active September 4, 2017 02:09
Show Gist options
  • Save wenfahu/3496d29b85b8a02ee5df9fdc64678ad7 to your computer and use it in GitHub Desktop.
Save wenfahu/3496d29b85b8a02ee5df9fdc64678ad7 to your computer and use it in GitHub Desktop.

Overall process

This toolset provides channel level pruning of inception-renet v2 model( the details of inception resnet v2 model, please refer to .. _Inception-ResnetV2: https://arxiv.org/abs/1602.07261

  1. python inf.py [meta] [ckpt] [output_mask] --threshold [threshold] : get the indices (mask) for the convolutional channel weights under the threshold.
  2. python freeze_graph.py [model_dir] [output_file]: freeze the model weights
  3. simplify the tensorflow graph using .. GTT: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/graph_transforms/README.md#optimizing-for-deployment
  4. python conv_travser.py [graph_path] [mask_path] [output_graph] [output_mask]: prune the model based on the dependency of inception resnet v2, the output_graph is the pruned model and the output_mask is used for further training.
  5. python zero_ckpt.py [meta] [ckpt] [zidx] [output]: set the pruned model weights using generated mask above.

(Re)Training

train_pruned_classifer.py [-h] [--logs_base_dir LOGS_BASE_DIR]
[--models_base_dir MODELS_BASE_DIR] [--gpu_memory_fraction GPU_MEMORY_FRACTION] [--pretrained_model PRETRAINED_MODEL] [--data_dir DATA_DIR] [--model_def MODEL_DEF] [--max_nrof_epochs MAX_NROF_EPOCHS] [--batch_size BATCH_SIZE] [--image_size IMAGE_SIZE] [--epoch_size EPOCH_SIZE] [--embedding_size EMBEDDING_SIZE] [--random_crop] [--random_flip] [--random_rotate] [--keep_probability KEEP_PROBABILITY] [--weight_decay WEIGHT_DECAY] [--decov_loss_factor DECOV_LOSS_FACTOR] [--center_loss_factor CENTER_LOSS_FACTOR] [--center_loss_alfa CENTER_LOSS_ALFA] [--optimizer {ADAGRAD,ADADELTA,ADAM,RMSPROP,MOM}] [--learning_rate LEARNING_RATE] [--learning_rate_decay_epochs LEARNING_RATE_DECAY_EPOCHS] [--learning_rate_decay_factor LEARNING_RATE_DECAY_FACTOR] [--moving_average_decay MOVING_AVERAGE_DECAY] [--seed SEED] [--nrof_preprocess_threads NROF_PREPROCESS_THREADS] [--log_histograms] [--learning_rate_schedule_file LEARNING_RATE_SCHEDULE_FILE] [--filter_filename FILTER_FILENAME] [--filter_percentile FILTER_PERCENTILE] [--filter_min_nrof_images_per_class FILTER_MIN_NROF_IMAGES_PER_CLASS] [--no_store_revision_info] [--lfw_pairs LFW_PAIRS] [--lfw_file_ext {jpg,png}] [--lfw_dir LFW_DIR] [--lfw_batch_size LFW_BATCH_SIZE] [--lfw_nrof_folds LFW_NROF_FOLDS] [--finetune] [--pruning_mask PRUNING_MASK] [--meta_graph META_GRAPH] [--group_lasso_factor GROUP_LASSO_FACTOR]
optional arguments:
-h, --help show this help message and exit
--logs_base_dir LOGS_BASE_DIR
 Directory where to write event logs.
--models_base_dir MODELS_BASE_DIR
 Directory where to write trained models and checkpoints.
--gpu_memory_fraction GPU_MEMORY_FRACTION
 Upper bound on the amount of GPU memory that will be used by the process.
--pretrained_model PRETRAINED_MODEL
 Load a pretrained model before training starts.
--data_dir DATA_DIR
 Path to the data directory containing aligned face patches. Multiple directories are separated with colon.
--model_def MODEL_DEF
 Model definition. Points to a module containing the definition of the inference graph.
--max_nrof_epochs MAX_NROF_EPOCHS
 Number of epochs to run.
--batch_size BATCH_SIZE
 Number of images to process in a batch.
--image_size IMAGE_SIZE
 Image size (height, width) in pixels.
--epoch_size EPOCH_SIZE
 Number of batches per epoch.
--embedding_size EMBEDDING_SIZE
 Dimensionality of the embedding.
--random_crop Performs random cropping of training images. If false, the center image_size pixels from the training images are used. If the size of the images in the data directory is equal to image_size no cropping is performed
--random_flip Performs random horizontal flipping of training images.
--random_rotate
 Performs random rotations of training images.
--keep_probability KEEP_PROBABILITY
 Keep probability of dropout for the fully connected layer(s).
--weight_decay WEIGHT_DECAY
 L2 weight regularization.
--decov_loss_factor DECOV_LOSS_FACTOR
 DeCov loss factor.
--center_loss_factor CENTER_LOSS_FACTOR
 Center loss factor.
--center_loss_alfa CENTER_LOSS_ALFA
 Center update rate for center loss.
--optimizer {ADAGRAD,ADADELTA,ADAM,RMSPROP,MOM}
The optimization algorithm to use
--learning_rate LEARNING_RATE
 Initial learning rate. If set to a negative value a learning rate schedule can be specified in the file "learning_rate_schedule.txt"
--learning_rate_decay_epochs LEARNING_RATE_DECAY_EPOCHS
 Number of epochs between learning rate decay.
--learning_rate_decay_factor LEARNING_RATE_DECAY_FACTOR
 Learning rate decay factor.
--moving_average_decay MOVING_AVERAGE_DECAY
 Exponential decay for tracking of training parameters.
--seed SEED Random seed.
--nrof_preprocess_threads NROF_PREPROCESS_THREADS
 Number of preprocessing (data loading and augumentation) threads.
--log_histograms
 Enables logging of weight/bias histograms in tensorboard.
--learning_rate_schedule_file LEARNING_RATE_SCHEDULE_FILE
 File containing the learning rate schedule that is used when learning_rate is set to to -1.
--filter_filename FILTER_FILENAME
 File containing image data used for dataset filtering
--filter_percentile FILTER_PERCENTILE
 Keep only the percentile images closed to its class center
--filter_min_nrof_images_per_class FILTER_MIN_NROF_IMAGES_PER_CLASS
 Keep only the classes with this number of examples or more
--no_store_revision_info
 Disables storing of git revision info in revision_info.txt.
--lfw_pairs LFW_PAIRS
 The file containing the pairs to use for validation.
--lfw_file_ext {jpg,png}
The file extension for the LFW dataset.
--lfw_dir LFW_DIR
 Path to the data directory containing aligned face patches.
--lfw_batch_size LFW_BATCH_SIZE
 Number of images to process in a batch in the LFW test set.
--lfw_nrof_folds LFW_NROF_FOLDS
 Number of folds to use for cross validation. Mainly used for testing.
--finetune fine tune the model
--pruning_mask PRUNING_MASK
 pruning mask for back prop
--meta_graph META_GRAPH
 Load a pretrained metagraph before training starts.
--group_lasso_factor GROUP_LASSO_FACTOR
 scale for group lasso regularization

the group_lasso_factor the generally 8e-5 for inception resnet v2 model and if the pruning_mask is provided, the gradients of the in the mask is zeroed out, thus preventing the pruned parameters from updating.

This toolset is heavily dependent on the David Sandberg's .. _facenet: https://github.com/davidsandberg/facenet. The model is trained on the MS Celeb 1M .. dataset: https://www.microsoft.com/en-us/research/project/ms-celeb-1m-challenge-recognizing-one-million-celebrities-real-world/. And the lfw face verification .. _protocol: http://vis-www.cs.umass.edu/lfw/

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment