Skip to content

Instantly share code, notes, and snippets.

@gngdb
Last active November 20, 2018 12:03
Show Gist options
  • Select an option

  • Save gngdb/8a204cb00bb6b0d7518f2e70c44aa7b0 to your computer and use it in GitHub Desktop.

Select an option

Save gngdb/8a204cb00bb6b0d7518f2e70c44aa7b0 to your computer and use it in GitHub Desktop.
Inefficient HashedNet implementation: https://arxiv.org/abs/1504.04788
# implementation of https://arxiv.org/abs/1504.04788
import torch
import torch.nn as nn
import torch.nn.functional as F
import xxhash
class HashFunction(object):
"""Hash function as described in the paper, maps a key (i,j) to a natural number
in {1,...,K_L}"""
def __init__(self, max_index, seed=0):
self.xx = xxhash.xxh32(seed=seed)
self.range_scale = float(max_index)/float(2**32)
def __call__(self, i, j):
i,j = i.to_bytes(32, 'big'), j.to_bytes(32, 'big')
self.xx.update(i)
self.xx.update(j)
k = self.xx.intdigest()
return int(float(k)*self.range_scale)
class HashedLinear(nn.Linear):
"""A Linear layer implemented with parameter sharing using the hashing trick."""
def __init__(self, in_features, out_features, budget, bias=True):
original_params = in_features*out_features
assert budget < original_params,\
f'Budget {budget} too large for {original_params} parameters'
super(HashedLinear, self).__init__(in_features, out_features, bias=bias)
# truncate weight matrix to budget
budgeted = self.weight.data.view(-1)[:budget]
del self.weight
# register new budgeted weights
self.register_parameter('weight', nn.Parameter(budgeted))
# precompute (inefficiently) the index matrix
# using seed from torch rng
seed = int(torch.randint(high=2**32, size=(1,)).numpy())
self.h = HashFunction(budget, seed=seed)
idxs = torch.zeros((out_features, in_features)).long()
for i in range(out_features):
for j in range(in_features):
idxs[i,j] = self.h(i,j)
# register these integers as a buffer
self.register_buffer('idxs', idxs)
def forward(self, x):
# first compute weight matrix using indexes
W = self.weight[self.idxs]
# then complete the forward pass as normal
return F.linear(x, W, bias=self.bias)
if __name__ == '__main__':
X = torch.randn(16,32)
l = HashedLinear(32, 10, 5, bias=False)
# check we really have only 5 weights
for p in l.parameters():
print(p.size())
# check deterministic
Y = l(X)
assert torch.abs(Y - l(X)).max() < 1e-3
# check we can save and load from state_dict
sd = l.state_dict()
l = HashedLinear(32, 10, 5, bias=False)
l.load_state_dict(sd)
assert torch.abs(Y - l(X)).max() < 1e-3
# check gpu
l = l.cuda()
X = X.cuda()
print(l(X).size())
# check gradient is calculated
out = l(X)
out.mean().backward()
print(l.weight.grad)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment