-
-
Save betatim/68219c95f539df51afad96cd9cd14a1c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import math | |
import numpy as np | |
import torch | |
from cuml.neighbors import NearestNeighbors | |
def cuml_kneighbors(query, data, n_neighbors): | |
knn = NearestNeighbors( | |
n_neighbors=n_neighbors, | |
algorithm="brute", | |
) | |
knn.fit(data) | |
distances, indices = knn.kneighbors(query) | |
return distances, indices | |
def kneighbors( | |
# NB: best performance might depend on the layout for `X` and `centroids` | |
# TODO: benchmark and warns or error out if the layout is not adapted | |
query, # (n_queries, n_features) | |
data, # (n_samples, n_features) | |
n_neighbors, # int | |
metric="euclidean", # str | |
max_compute_buffer_bytes=1073741824, # int (default 1 GiB) | |
): | |
n_queries, n_features = query.shape | |
n_samples = data.shape[0] | |
compute_dtype = query[0, 0].cpu().numpy().dtype.type | |
compute_dtype_itemsize = np.dtype(compute_dtype).itemsize | |
# The computation will be batched and the size of each batch is set so that the | |
# size of the buffer of pairwise distances computed for this batch do not exceed | |
# `maximum_comnpute_buffer_size` | |
(batch_size, n_batches, n_full_batches, last_batch_size) = _get_batch_properties( | |
expected_bytes_per_sample=n_samples * compute_dtype_itemsize, | |
max_compute_buffer_bytes=max_compute_buffer_bytes, | |
dataset_n_samples=n_samples, | |
) | |
if batch_size < 1: | |
raise RuntimeError("Buffer size is too small") | |
result = torch.empty(n_queries, n_neighbors, dtype=query.dtype, device=query.device) | |
idx = torch.empty(n_queries, n_neighbors, dtype=torch.int64, device=query.device) | |
batch_start_idx = batch_end_idx = 0 | |
# TODO: investigate if it's possible to fuse pairwise distance computation and topk | |
# search. (seems there's no profitable way to do it on gpu) | |
for batch_idx in range(n_batches): | |
if batch_idx == n_full_batches: | |
batch_end_idx += last_batch_size | |
else: | |
batch_end_idx += batch_size | |
batch_slice = slice(batch_start_idx, batch_end_idx) | |
pairwise_distance = torch.cdist(query[batch_slice], data) | |
# ???: should we pass `sorted=False` ? | |
torch.topk( | |
pairwise_distance, | |
n_neighbors, | |
largest=False, | |
sorted=True, | |
out=(result[batch_slice], idx[batch_slice]), | |
) | |
del pairwise_distance | |
batch_start_idx += batch_size | |
# HACK: force synchronization to avoid memory overflow similar to | |
# torch.cuda.synchronize(X.device) but with device interoperability for a | |
# negligible cost. | |
result[-1, -1].item() | |
return result, idx | |
def _get_batch_properties( | |
expected_bytes_per_sample, max_compute_buffer_bytes, dataset_n_samples | |
): | |
batch_size = max_compute_buffer_bytes / expected_bytes_per_sample | |
if batch_size < 1: | |
raise RuntimeError("Buffer size is too small") | |
batch_size = min(math.floor(batch_size), dataset_n_samples) | |
n_batches = math.ceil(dataset_n_samples / batch_size) | |
n_full_batches = n_batches - 1 | |
last_batch_size = ((dataset_n_samples - 1) % batch_size) + 1 | |
return batch_size, n_batches, n_full_batches, last_batch_size | |
if __name__ == "__main__": | |
n_samples = 5_000_000 # common sizes: 10000, 100000, 1000000 | |
n_features = 100 | |
n_queries = 1000 | |
n_neighbors = 100 | |
device = "cuda" | |
dtype = torch.float32 | |
seed = 123 | |
rng = torch.Generator(device=device).manual_seed(543212345) | |
data = torch.rand(n_samples, n_features, generator=rng, dtype=dtype, device=device) | |
query = torch.rand(n_queries, n_features, generator=rng, dtype=dtype, device=device) | |
tic = time.time() | |
kneighbors( | |
query, | |
data, | |
n_neighbors, | |
metric="euclidean", | |
max_compute_buffer_bytes=1073741824, | |
) | |
toc = time.time() | |
print(f"torch {toc-tic=}") | |
tic = time.time() | |
cuml_kneighbors( | |
query, | |
data, | |
n_neighbors, | |
) | |
toc = time.time() | |
print(f"cuml {toc-tic=}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment