Skip to content

Instantly share code, notes, and snippets.

@luxedo
Created April 16, 2019 17:16
Show Gist options
  • Save luxedo/73de2cc4dd69d6ab88aefd0aefabe9a8 to your computer and use it in GitHub Desktop.
Save luxedo/73de2cc4dd69d6ab88aefd0aefabe9a8 to your computer and use it in GitHub Desktop.
Script to truncate the output layer of .npz models for tensorpack
#!/usr/bin/env python3
import sys
import numpy as np
def main(input, output, categories):
with np.load(input) as npz:
data = dict(npz)
print("Chopping last layer")
data['maskrcnn/conv/W:0'] = data['maskrcnn/conv/W:0'][:,:,:,:categories]
data['maskrcnn/conv/b:0'] = data['maskrcnn/conv/b:0'][:categories]
np.savez(output, data)
print("Saved {}".format(output))
if __name__ == '__main__':
try:
input = sys.argv[1]
output = sys.argv[2]
categories = int(sys.argv[3])
main(input, output, categories)
except (IndexError, ValueError):
print("Usage: ptyhon edit-last-layer.py INPUT_FILE OUTPUT_FILE NUMBER_OF_CATEGORIES:int")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment