Skip to content

Instantly share code, notes, and snippets.

@rahulvigneswaran
Created October 1, 2019 06:17
Show Gist options
  • Save rahulvigneswaran/b33d45ae73903483ca400fdcca5dca77 to your computer and use it in GitHub Desktop.
Save rahulvigneswaran/b33d45ae73903483ca400fdcca5dca77 to your computer and use it in GitHub Desktop.
The following function displays the count of zero and non-zero weights in a Pytorch model.
def print_nonzeros(model):
nonzero = total = 0
for name, p in model.named_parameters():
tensor = p.data.cpu().numpy()
nz_count = np.count_nonzero(tensor)
total_params = np.prod(tensor.shape)
nonzero += nz_count
total += total_params
print(f'{name:20} | nonzeros = {nz_count:7} / {total_params:7} ({100 * nz_count / total_params:6.2f}%) | total_pruned = {total_params - nz_count :7} | shape = {tensor.shape}')
print(f'alive: {nonzero}, pruned : {total - nonzero}, total: {total}, Compression rate : {total/nonzero:10.2f}x ({100 * (total-nonzero) / total:6.2f}% pruned)')
return (round((nonzero/total)*100,1))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment