Skip to content

Instantly share code, notes, and snippets.

@BarclayII
Created June 21, 2018 08:44
Show Gist options
  • Save BarclayII/533393733ab987dcaebec27ca694581f to your computer and use it in GitHub Desktop.
Save BarclayII/533393733ab987dcaebec27ca694581f to your computer and use it in GitHub Desktop.
Maximum Inner Product Search with Asymmetric LSH + Random Projection
# References:
# https://arxiv.org/pdf/1405.5869.pdf
# https://arxiv.org/abs/1507.05910
import numpy as np
from scipy.spatial.distance import cosine
# Rescaling and appending new components to "normalize" data vectors
X = np.random.randn(10000, 100)
Xn = np.sqrt((X ** 2).sum(1))
maxXn = Xn.max()
X_s = X / maxXn * 0.83
Xn = Xn / maxXn * 0.83
p = np.arange(1, 7)
X_p = np.concatenate([X_s, 0.5 - Xn[:, None] ** (2 ** p[None, :])], 1)
# Compute hash by random projection and collect the bins
# Each vector is now represented as an integer
bases = np.random.randn(10, 106)
h = np.sign(X_p @ bases.T).clip(0, 1).astype('int')
h_s = np.array([int(''.join(str(v) for v in _h), 2) for _h in h])
bins = np.array(list(set(h_s)))
# Rescale and compute the hash for query vector
Y = np.random.randn(1, 100)
Y_q = np.concatenate([Y, np.zeros((1, 6))], 1)
h_q = (Y_q @ bases.T).clip(0, 1).astype('int')
h_q_s = np.array([int(''.join(str(v) for v in _h), 2) for _h in h_q])
# Compare number of different bits between data and query hashes.
# The vector with maximum dot product should be more likely to
# appear in the bins with less different bits in their hashes.
diff = np.array([bin(i).count('1') for i in bins ^ h_q_s[0]])
diff_values = np.unique(np.sort(diff))
closest = []
for v in diff_values:
new_bins = (diff == v).nonzero()[0]
if len(closest) > 0 and len(closest) + len(new_bins) > 20:
break
closest.extend(new_bins)
max_sim = -np.inf
max_dot = -np.inf
max_item = None
# Check each candidate within all the selected bins.
candids = np.isin(h_s, bins[closest]).nonzero()[0]
for item in candids:
cos_sim = 1 - cosine(X_p[item], Y_q[0])
dot = X[item] @ Y[0]
assert (max_dot - dot) * (max_sim - cos_sim) >= 0
if max_sim < cos_sim:
max_sim = cos_sim
max_dot = dot
max_item = item
# Compare against the real result
real_dots = (X @ Y.T).flatten()
rank = np.searchsorted(np.sort(real_dots), max_dot)
real_max_item = np.argmax(real_dots)
real_max_dot = np.max(real_dots)
real_diff = bin(h_s[real_max_item] ^ h_q_s[0]).count('1')
print(rank, len(closest), len(candids))
# Ran 100 simulations with the parameters above and got the following
# rank distribution:
# Min = 8518
# 25% = 9923
# 50% = 9978
# 75% = 9992
# Max = 9999
# Also the number of candidates to be checked:
# Min = 2
# 25% = 29
# 50% = 71
# 75% = 327
# Max = 3454
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment