Created
June 7, 2020 06:31
-
-
Save rijulg/38cdeec892a8cfbc3a841da4fbe0d517 to your computer and use it in GitHub Desktop.
Pruning and converting sparse network to dense in pytorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python | |
import copy | |
import torch | |
import numpy as np | |
torch.manual_seed(0) | |
class Original(torch.nn.Module): | |
def __init__(self): | |
super(Original, self).__init__() | |
self.l1 = torch.nn.Linear(3, 3, bias=True) | |
def forward(self, x): | |
return self.l1(x) | |
class SparseLinear(torch.nn.Module): | |
def __init__(self, original): | |
super(SparseLinear, self).__init__() | |
nonzero_weight = (original.weight != 0) | |
needs_bias = original.bias is not None | |
self.linears = torch.nn.ModuleList() | |
for i, weight in enumerate(nonzero_weight): | |
capture_indices = weight.nonzero().squeeze() | |
l = torch.nn.Linear(weight.sum(), 1, bias=needs_bias) | |
l.weight.data = original.weight[i, capture_indices].view(-1) | |
if needs_bias: | |
l.bias.data = original.bias[i] | |
l.register_buffer('weight_mask', weight) | |
l.register_buffer('capture_indices', capture_indices) | |
self.linears.append(l) | |
def forward(self, x): | |
y = [] | |
for linear in self.linears: | |
capture_indices = linear._buffers['capture_indices'] | |
_x = x[capture_indices].view(-1) | |
_y = linear(_x) | |
y += [_y] | |
return torch.stack(y) | |
def prune(model): | |
k = 20 | |
all_weights = [] | |
for p in model.parameters(): | |
if len(p.data.size()) != 1: | |
all_weights += list(p.cpu().data.abs().numpy().flatten()) | |
threshold = np.percentile(np.array(all_weights), k) | |
for p in model.parameters(): | |
if len(p.data.size()) != 1: | |
mask = p.data.abs() > threshold | |
mask = torch.autograd.Variable(mask, requires_grad=False, volatile=False) | |
p.data = p.data * mask.data.float() | |
def prune2dense(model): | |
for name, module in model.named_modules(): | |
if hasattr(module, 'weight'): | |
setattr(model, name, SparseLinear(module)) | |
def countparams(model): | |
total = sum(p.numel() for p in model.parameters()) | |
withgrad = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
print(f"Parameters:: total: {total}, withgrad: {withgrad}") | |
def print_state_dict(model): | |
dic = model.state_dict() | |
print("\tstate_dict:: ", dic) | |
print() | |
original = Original() | |
original.eval() | |
x = torch.FloatTensor([1, 2, 3]) | |
# print_state_dict(original) | |
countparams(original) | |
print("y::", original(x)) | |
prune(original) | |
# print_state_dict(original) | |
countparams(original) | |
print("y::", original(x)) | |
condensed = copy.deepcopy(original) | |
prune2dense(condensed) | |
condensed.eval() | |
# print_state_dict(condensed) | |
countparams(condensed) | |
print("y::", condensed(x)) | |
##### Results | |
# Parameters:: total: 9, withgrad: 9 | |
# y:: tensor([-0.8104, -0.4052, 0.7504], grad_fn=<SqueezeBackward3>) | |
# Parameters:: total: 9, withgrad: 9 | |
# y:: tensor([-0.8061, -0.4052, 0.7618], grad_fn=<SqueezeBackward3>) | |
# Parameters:: total: 7, withgrad: 7 | |
# y:: tensor([-0.8061, -0.4052, 0.7618], grad_fn=<StackBackward>) | |
##### Results | |
# Parameters:: total: 9, withgrad: 9 | |
# y:: tensor([-0.8104, -0.4052, 0.7504], grad_fn=<SqueezeBackward3>) | |
# Parameters:: total: 9, withgrad: 9 | |
# y:: tensor([-0.8061, -0.4052, 0.7618], grad_fn=<SqueezeBackward3>) | |
# Parameters:: total: 7, withgrad: 7 | |
# y:: tensor([-0.8061, -0.4052, 0.7618], grad_fn=<StackBackward>) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment