Created
April 7, 2017 19:03
-
-
Save surya501/d0e5bef7733a16c5118916dc8a8fb369 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Pytorch implementation of LSHHash from https://github.com/kayzhu/LSHash | |
# Homework: Try to implement LSHHash in pytorch to speed up meanshift. | |
# Motivation: i.e. why calculate all distances when you need only a few. | |
import numpy as np | |
import importlib | |
import torch_utils | |
importlib.reload(torch_utils) | |
from torch_utils import * | |
class PyTorchLSHash(object): | |
def __init__(self, hash_size, input_dim, num_hashtables=1): | |
self.uniform_planes = [np.random.randn(hash_size, input_dim) | |
for _ in range(num_hashtables)] | |
self.hash_tables = [dict() for i in range(num_hashtables)] | |
def _hash(self, planes, input_point): | |
input_point = np.array(input_point) # for faster dot product | |
projections = np.dot(planes, input_point) | |
return "".join(['1' if i > 0 else '0' for i in projections]) | |
def index(self, input_point): | |
value = tuple(input_point) | |
for i, table in enumerate(self.hash_tables): | |
table.setdefault(self._hash(self.uniform_planes[i], input_point), []).append(value) | |
def getCandidates(self, query_point): | |
candidates = set() | |
for i, table in enumerate(self.hash_tables): | |
binary_hash = self._hash(self.uniform_planes[i], query_point) | |
candidates.update(table.get(binary_hash, [])) | |
return candidates | |
def getResults(self, candidates, result, num_results): | |
candidates = list(zip(candidates, result)) | |
candidates.sort(key=lambda x: x[1]) | |
return candidates[:num_results] if num_results else candidates | |
# this is the original implementation, condensed by me for reimplementation | |
def query(self, query_point, num_results=None): | |
candidates = self.getCandidates(query_point) | |
d_func = PyTorchLSHash.euclidean_dist_square | |
candidates = [(ix, d_func(query_point, np.asarray(ix))) | |
for ix in candidates] | |
candidates.sort(key=lambda x: x[1]) | |
return candidates[:num_results] if num_results else candidates | |
# numpy version of the same. | |
def queryNp(self, query_point, num_results=None): | |
candidates = self.getCandidates(query_point) | |
diff = np.array(list(candidates)) - query_point | |
result = (diff * diff).sum(-1) # just a dot product. | |
return self.getResults(candidates, result, num_results) | |
# torch version, but without batching. | |
# I think this should work as the BLAS library should | |
# take care of batching. | |
def fastQueryNoBatching(self, query_point, num_results=None): | |
candidates = self.getCandidates(query_point) | |
candidates_t = torch.FloatTensor(list(candidates)).cuda() | |
query_point_t = torch.FloatTensor(query_point).cuda().unsqueeze_(0) | |
diff = sub(candidates_t, query_point_t) | |
dp = (diff * diff).sum((1)) # dot product | |
result = dp.cpu().numpy().flatten().tolist() # convert to tensor to python | |
return self.getResults(candidates, result, num_results) | |
# torch version, but without batching. | |
# I think this should work as the BLAS library should | |
# take care of batching. | |
# results are sorted in the GPU itself. | |
def fastQueryNoBatchingAllGPU(self, query_point, num_results=None): | |
candidates = self.getCandidates(query_point) | |
candidates_t = torch.FloatTensor(list(candidates)).cuda() | |
query_point_t = torch.FloatTensor(query_point).cuda().unsqueeze_(0) | |
diff = sub(candidates_t, query_point_t) | |
dp = (diff * diff).sum((1)) # dot product | |
# result = dp.cpu().numpy().flatten().tolist() # convert to tensor to python | |
# return self.getResults(candidates, result, num_results) | |
n = len(list(candidates)) | |
if num_results and n < num_results: | |
num_results = n | |
_discard, indices = torch.topk(dp,num_results,dim= 0, largest=False) | |
indices = torch.squeeze(indices) | |
dp = torch.unsqueeze(dp,0) | |
distance = torch.index_select(dp, 1, indices).cpu().numpy().flatten().tolist() | |
select_candidates = [tuple((candidates_t[x].cpu())) for x in indices] | |
return select_candidates | |
# torch version implemented with batching | |
def fastQuery(self, query_point, num_results=None, bs=500): | |
candidates = self.getCandidates(query_point) | |
n = len(list(candidates)) | |
candidates_t = torch.FloatTensor(list(candidates)).cuda() | |
dp = torch.FloatTensor(n).cuda() # tensor to store distance results. | |
query_point_t = torch.FloatTensor(query_point).cuda().unsqueeze_(0) | |
for i in range(0, n, bs): | |
s = slice(i, min(i + bs, n)) | |
diff = sub(candidates_t[s], query_point_t) | |
dp[s] = (diff * diff).sum((1)) | |
result = dp.cpu().numpy().flatten().tolist() # convert to tensor to python | |
return self.getResults(candidates, result, num_results) | |
# torch version, but top results are now selected in the gpu. | |
def fastQueryAllGPU(self, query_point, num_results=None, bs=500): | |
candidates = self.getCandidates(query_point) | |
n = len(list(candidates)) | |
candidates_t = torch.FloatTensor(list(candidates)).cuda() | |
dp = torch.FloatTensor(n).cuda() # tensor to store distance results. | |
query_point_t = torch.FloatTensor(query_point).cuda().unsqueeze_(0) | |
for i in range(0, n, bs): | |
s = slice(i, min(i + bs, n)) | |
diff = sub(candidates_t[s], query_point_t) | |
dp[s] = (diff * diff).sum((1)) | |
# Instead of doing this in the CPU space, we should be able to use torch.topk function. | |
if num_results and n < num_results: | |
num_results = n | |
_discard, indices = torch.topk(dp,num_results,dim= 0, largest=False) | |
indices = torch.squeeze(indices) | |
dp = torch.unsqueeze(dp,0) | |
distance = torch.index_select(dp, 1, indices).cpu().numpy().flatten().tolist() | |
select_candidates = [tuple((candidates_t[x].cpu())) for x in indices] | |
return select_candidates | |
# currently, we only implement this euclidean_dist_square function from LSHHash library. | |
# others should be easy to implement though. | |
@staticmethod | |
def euclidean_dist_square(x, y): | |
""" This is a hot function, hence some optimizations are made. """ | |
diff = np.array(x) - y | |
return np.dot(diff, diff) | |
# run the above code like below. | |
# hash_size = 16 | |
# # hash_size = 2 | |
# input_vector_size = 8 | |
# num_samples = 10000000 # number of samples to insert in the array | |
# # num_samples = 500 | |
# num_results = 2 | |
# # lets generate random values between 1,30 for each vector. | |
# a = np.random.randint(1,30,(num_samples,input_vector_size)).tolist() | |
# b = a[-1:] + np.ones((1,input_vector_size), dtype=int) # slightly perturb the last element for search search | |
# # print(b, a[-1:]) | |
# query_item = b[0].tolist() # this query should result in a distance of input_vector_size | |
# # Add items into the index. Slow right now as there each item is added one at a time. | |
# # it can be optimized to use bulk load, but not sure if this would be a hotspot. | |
# %%time | |
# pylsh = PyTorchLSHash(hash_size, input_vector_size) | |
# for x in a: | |
# pylsh.index(x) | |
# %time query_result = pylsh.fastQueryAllGPU(query_item, num_results, bs=500) | |
# print(query_result, query_item) | |
# %time query_result = pylsh.fastQuery(query_item, num_results, bs=500) | |
# print(query_result, query_item) | |
# %time query_result = pylsh.fastQueryNoBatchingAllGPU(query_item, num_results) | |
# print(query_result, query_item) | |
# %time query_result = pylsh.fastQueryNoBatching(query_item, num_results) | |
# print(query_result, query_item) | |
# %time query_result = pylsh.queryNp(query_item, num_results) | |
# print(query_result, query_item) | |
# %time query_result = pylsh.query(query_item, num_results) | |
# print(query_result, query_item) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment