Skip to content

Instantly share code, notes, and snippets.

@surya501
Created April 7, 2017 19:03
Show Gist options
  • Save surya501/d0e5bef7733a16c5118916dc8a8fb369 to your computer and use it in GitHub Desktop.
Save surya501/d0e5bef7733a16c5118916dc8a8fb369 to your computer and use it in GitHub Desktop.
# Pytorch implementation of LSHHash from https://github.com/kayzhu/LSHash
# Homework: Try to implement LSHHash in pytorch to speed up meanshift.
# Motivation: i.e. why calculate all distances when you need only a few.
import numpy as np
import importlib
import torch_utils
importlib.reload(torch_utils)
from torch_utils import *
class PyTorchLSHash(object):
def __init__(self, hash_size, input_dim, num_hashtables=1):
self.uniform_planes = [np.random.randn(hash_size, input_dim)
for _ in range(num_hashtables)]
self.hash_tables = [dict() for i in range(num_hashtables)]
def _hash(self, planes, input_point):
input_point = np.array(input_point) # for faster dot product
projections = np.dot(planes, input_point)
return "".join(['1' if i > 0 else '0' for i in projections])
def index(self, input_point):
value = tuple(input_point)
for i, table in enumerate(self.hash_tables):
table.setdefault(self._hash(self.uniform_planes[i], input_point), []).append(value)
def getCandidates(self, query_point):
candidates = set()
for i, table in enumerate(self.hash_tables):
binary_hash = self._hash(self.uniform_planes[i], query_point)
candidates.update(table.get(binary_hash, []))
return candidates
def getResults(self, candidates, result, num_results):
candidates = list(zip(candidates, result))
candidates.sort(key=lambda x: x[1])
return candidates[:num_results] if num_results else candidates
# this is the original implementation, condensed by me for reimplementation
def query(self, query_point, num_results=None):
candidates = self.getCandidates(query_point)
d_func = PyTorchLSHash.euclidean_dist_square
candidates = [(ix, d_func(query_point, np.asarray(ix)))
for ix in candidates]
candidates.sort(key=lambda x: x[1])
return candidates[:num_results] if num_results else candidates
# numpy version of the same.
def queryNp(self, query_point, num_results=None):
candidates = self.getCandidates(query_point)
diff = np.array(list(candidates)) - query_point
result = (diff * diff).sum(-1) # just a dot product.
return self.getResults(candidates, result, num_results)
# torch version, but without batching.
# I think this should work as the BLAS library should
# take care of batching.
def fastQueryNoBatching(self, query_point, num_results=None):
candidates = self.getCandidates(query_point)
candidates_t = torch.FloatTensor(list(candidates)).cuda()
query_point_t = torch.FloatTensor(query_point).cuda().unsqueeze_(0)
diff = sub(candidates_t, query_point_t)
dp = (diff * diff).sum((1)) # dot product
result = dp.cpu().numpy().flatten().tolist() # convert to tensor to python
return self.getResults(candidates, result, num_results)
# torch version, but without batching.
# I think this should work as the BLAS library should
# take care of batching.
# results are sorted in the GPU itself.
def fastQueryNoBatchingAllGPU(self, query_point, num_results=None):
candidates = self.getCandidates(query_point)
candidates_t = torch.FloatTensor(list(candidates)).cuda()
query_point_t = torch.FloatTensor(query_point).cuda().unsqueeze_(0)
diff = sub(candidates_t, query_point_t)
dp = (diff * diff).sum((1)) # dot product
# result = dp.cpu().numpy().flatten().tolist() # convert to tensor to python
# return self.getResults(candidates, result, num_results)
n = len(list(candidates))
if num_results and n < num_results:
num_results = n
_discard, indices = torch.topk(dp,num_results,dim= 0, largest=False)
indices = torch.squeeze(indices)
dp = torch.unsqueeze(dp,0)
distance = torch.index_select(dp, 1, indices).cpu().numpy().flatten().tolist()
select_candidates = [tuple((candidates_t[x].cpu())) for x in indices]
return select_candidates
# torch version implemented with batching
def fastQuery(self, query_point, num_results=None, bs=500):
candidates = self.getCandidates(query_point)
n = len(list(candidates))
candidates_t = torch.FloatTensor(list(candidates)).cuda()
dp = torch.FloatTensor(n).cuda() # tensor to store distance results.
query_point_t = torch.FloatTensor(query_point).cuda().unsqueeze_(0)
for i in range(0, n, bs):
s = slice(i, min(i + bs, n))
diff = sub(candidates_t[s], query_point_t)
dp[s] = (diff * diff).sum((1))
result = dp.cpu().numpy().flatten().tolist() # convert to tensor to python
return self.getResults(candidates, result, num_results)
# torch version, but top results are now selected in the gpu.
def fastQueryAllGPU(self, query_point, num_results=None, bs=500):
candidates = self.getCandidates(query_point)
n = len(list(candidates))
candidates_t = torch.FloatTensor(list(candidates)).cuda()
dp = torch.FloatTensor(n).cuda() # tensor to store distance results.
query_point_t = torch.FloatTensor(query_point).cuda().unsqueeze_(0)
for i in range(0, n, bs):
s = slice(i, min(i + bs, n))
diff = sub(candidates_t[s], query_point_t)
dp[s] = (diff * diff).sum((1))
# Instead of doing this in the CPU space, we should be able to use torch.topk function.
if num_results and n < num_results:
num_results = n
_discard, indices = torch.topk(dp,num_results,dim= 0, largest=False)
indices = torch.squeeze(indices)
dp = torch.unsqueeze(dp,0)
distance = torch.index_select(dp, 1, indices).cpu().numpy().flatten().tolist()
select_candidates = [tuple((candidates_t[x].cpu())) for x in indices]
return select_candidates
# currently, we only implement this euclidean_dist_square function from LSHHash library.
# others should be easy to implement though.
@staticmethod
def euclidean_dist_square(x, y):
""" This is a hot function, hence some optimizations are made. """
diff = np.array(x) - y
return np.dot(diff, diff)
# run the above code like below.
# hash_size = 16
# # hash_size = 2
# input_vector_size = 8
# num_samples = 10000000 # number of samples to insert in the array
# # num_samples = 500
# num_results = 2
# # lets generate random values between 1,30 for each vector.
# a = np.random.randint(1,30,(num_samples,input_vector_size)).tolist()
# b = a[-1:] + np.ones((1,input_vector_size), dtype=int) # slightly perturb the last element for search search
# # print(b, a[-1:])
# query_item = b[0].tolist() # this query should result in a distance of input_vector_size
# # Add items into the index. Slow right now as there each item is added one at a time.
# # it can be optimized to use bulk load, but not sure if this would be a hotspot.
# %%time
# pylsh = PyTorchLSHash(hash_size, input_vector_size)
# for x in a:
# pylsh.index(x)
# %time query_result = pylsh.fastQueryAllGPU(query_item, num_results, bs=500)
# print(query_result, query_item)
# %time query_result = pylsh.fastQuery(query_item, num_results, bs=500)
# print(query_result, query_item)
# %time query_result = pylsh.fastQueryNoBatchingAllGPU(query_item, num_results)
# print(query_result, query_item)
# %time query_result = pylsh.fastQueryNoBatching(query_item, num_results)
# print(query_result, query_item)
# %time query_result = pylsh.queryNp(query_item, num_results)
# print(query_result, query_item)
# %time query_result = pylsh.query(query_item, num_results)
# print(query_result, query_item)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment