Skip to content

Instantly share code, notes, and snippets.

@sahandilshan
Created February 22, 2021 13:02
Show Gist options
  • Save sahandilshan/4dd99ddc66a1d317c95ba01d3fe72a65 to your computer and use it in GitHub Desktop.
Save sahandilshan/4dd99ddc66a1d317c95ba01d3fe72a65 to your computer and use it in GitHub Desktop.
A code snippet to train a model with MNIST dataset and compress it using pruning with PyTorch. Completed full code can be found on here https://github.com/sahandilshan/Simple-NN-Compression
from torch.nn.utils import prune
pruning_percenatage = 0.40
parameters_to_prune = (
(pruned_model.fc1, 'weight'),
(pruned_model.fc2, 'weight'),
(pruned_model.fc3, 'weight'),
(pruned_model.fc4, 'weight'),
)
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=pruning_percenatage, # Specifying the percentage
)
prune.remove(pruned_model.fc1, 'weight')
prune.remove(pruned_model.fc2, 'weight')
prune.remove(pruned_model.fc3, 'weight')
prune.remove(pruned_model.fc4, 'weight')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment