Last active
August 2, 2022 17:21
-
-
Save recamshak/f5761d2d7bcd4cb0fe109aba01d8331c to your computer and use it in GitHub Desktop.
sklearn DBSCAN with O(n) memory budget
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sklearn.datasets import make_blobs | |
from sklearn.cluster import dbscan | |
from sklearn.cluster._dbscan_inner import dbscan_inner | |
from sklearn.metrics import pairwise_distances_chunked | |
from scipy.sparse import csr_matrix | |
import numpy as np | |
# dataset | |
n = 50000 | |
ds, _ = make_blobs(n, 100, 50) | |
# dbscan parameters | |
eps = 20 | |
min_samples = 5 | |
# Build a sparse adjacency matrix. Two samples are adjacent if their euclidiean distance is smaller than `eps`. | |
# The memory usage can be tuned by adjusting `working_memory` in `pairwise_distances_chunked`. | |
# See https://scikit-learn.org/stable/modules/generated/sklearn.metrics.pairwise_distances_chunked.html | |
# | |
# Remark: `indptr` is O(n) so it doesn't need to be stored in file and memmapped but it's convenient to have it | |
# along with `indices`. | |
with open("indices.dat", "wb") as f_indices, open("indptr.dat", "wb") as f_indptr: | |
nnz = 0 | |
f_indptr.write(np.intp(0).tobytes()) | |
for block in pairwise_distances_chunked(ds): | |
neighbors_indices = block < eps | |
csr = csr_matrix(neighbors_indices) | |
f_indices.write(csr.indices.astype(np.intp).tobytes()) | |
f_indptr.write((csr.indptr[1:] + nnz).astype(np.intp).tobytes()) | |
nnz += csr.nnz | |
indices = np.memmap("indices.dat", np.intp, mode="r") | |
indptr = np.memmap("indptr.dat", np.intp, mode="r") | |
# the following is an adaptation of the original dbscan code from sklearn | |
n_neighbors = np.ediff1d(indptr) | |
neighborhoods = np.empty(n, dtype=object) | |
neighborhoods[:] = np.split(indices, indptr[1:-1]) | |
labels = np.full(n, -1, dtype=np.intp) | |
core_samples = np.asarray(n_neighbors >= min_samples, dtype=np.uint8) | |
dbscan_inner(core_samples, neighborhoods, labels) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment