Skip to content

Instantly share code, notes, and snippets.

@betatim
Forked from fcharras/knn.py
Created July 20, 2023 11:20
Show Gist options
  • Save betatim/68219c95f539df51afad96cd9cd14a1c to your computer and use it in GitHub Desktop.
Save betatim/68219c95f539df51afad96cd9cd14a1c to your computer and use it in GitHub Desktop.
import time
import math
import numpy as np
import torch
from cuml.neighbors import NearestNeighbors
def cuml_kneighbors(query, data, n_neighbors):
knn = NearestNeighbors(
n_neighbors=n_neighbors,
algorithm="brute",
)
knn.fit(data)
distances, indices = knn.kneighbors(query)
return distances, indices
def kneighbors(
# NB: best performance might depend on the layout for `X` and `centroids`
# TODO: benchmark and warns or error out if the layout is not adapted
query, # (n_queries, n_features)
data, # (n_samples, n_features)
n_neighbors, # int
metric="euclidean", # str
max_compute_buffer_bytes=1073741824, # int (default 1 GiB)
):
n_queries, n_features = query.shape
n_samples = data.shape[0]
compute_dtype = query[0, 0].cpu().numpy().dtype.type
compute_dtype_itemsize = np.dtype(compute_dtype).itemsize
# The computation will be batched and the size of each batch is set so that the
# size of the buffer of pairwise distances computed for this batch do not exceed
# `maximum_comnpute_buffer_size`
(batch_size, n_batches, n_full_batches, last_batch_size) = _get_batch_properties(
expected_bytes_per_sample=n_samples * compute_dtype_itemsize,
max_compute_buffer_bytes=max_compute_buffer_bytes,
dataset_n_samples=n_samples,
)
if batch_size < 1:
raise RuntimeError("Buffer size is too small")
result = torch.empty(n_queries, n_neighbors, dtype=query.dtype, device=query.device)
idx = torch.empty(n_queries, n_neighbors, dtype=torch.int64, device=query.device)
batch_start_idx = batch_end_idx = 0
# TODO: investigate if it's possible to fuse pairwise distance computation and topk
# search. (seems there's no profitable way to do it on gpu)
for batch_idx in range(n_batches):
if batch_idx == n_full_batches:
batch_end_idx += last_batch_size
else:
batch_end_idx += batch_size
batch_slice = slice(batch_start_idx, batch_end_idx)
pairwise_distance = torch.cdist(query[batch_slice], data)
# ???: should we pass `sorted=False` ?
torch.topk(
pairwise_distance,
n_neighbors,
largest=False,
sorted=True,
out=(result[batch_slice], idx[batch_slice]),
)
del pairwise_distance
batch_start_idx += batch_size
# HACK: force synchronization to avoid memory overflow similar to
# torch.cuda.synchronize(X.device) but with device interoperability for a
# negligible cost.
result[-1, -1].item()
return result, idx
def _get_batch_properties(
expected_bytes_per_sample, max_compute_buffer_bytes, dataset_n_samples
):
batch_size = max_compute_buffer_bytes / expected_bytes_per_sample
if batch_size < 1:
raise RuntimeError("Buffer size is too small")
batch_size = min(math.floor(batch_size), dataset_n_samples)
n_batches = math.ceil(dataset_n_samples / batch_size)
n_full_batches = n_batches - 1
last_batch_size = ((dataset_n_samples - 1) % batch_size) + 1
return batch_size, n_batches, n_full_batches, last_batch_size
if __name__ == "__main__":
n_samples = 5_000_000 # common sizes: 10000, 100000, 1000000
n_features = 100
n_queries = 1000
n_neighbors = 100
device = "cuda"
dtype = torch.float32
seed = 123
rng = torch.Generator(device=device).manual_seed(543212345)
data = torch.rand(n_samples, n_features, generator=rng, dtype=dtype, device=device)
query = torch.rand(n_queries, n_features, generator=rng, dtype=dtype, device=device)
tic = time.time()
kneighbors(
query,
data,
n_neighbors,
metric="euclidean",
max_compute_buffer_bytes=1073741824,
)
toc = time.time()
print(f"torch {toc-tic=}")
tic = time.time()
cuml_kneighbors(
query,
data,
n_neighbors,
)
toc = time.time()
print(f"cuml {toc-tic=}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment