Created
July 16, 2021 03:40
-
-
Save Manojbhat09/92c4d4eea3f25a23958220d6152063d1 to your computer and use it in GitHub Desktop.
Pure python DBSCAN algorithm. For simple interview practice.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# machine learning 101: dbscan | |
# wonderful unsupervised clustering, fast and works efficiently | |
# checkout HDBSCAN for an even more beautiful algorithm | |
''' | |
How to write dbscan: | |
1. look into the probem of clustering | |
2. start by first sample | |
3. compute the distance, get the neighbors | |
4. if the neighs are more than min samples, store the nighbors and expand recursively, save to same cluster and return, keep visited to avoid cycles, | |
5. after getting all the samples store clusters | |
6. label the clusters, 0, 1,2, last remaining 3 | |
hyperparameters: | |
min_samples, max_radius | |
(-10,-10) (6,6) | |
(5,5) | |
(1,1) | |
(0,0) | |
''' | |
points = [(-10,-10), (0,0), (1,1), (5,5), (6,6)] | |
def dbscan(points, min_samples=2, max_radius=10): | |
visited = [0]*len(points) | |
neighs, clusters = {}, [] | |
distance = lambda x, y: math.sqrt((x[0]-y[0])**2 + (x[1]-y[1])**2) | |
def compute_dist(idx): | |
dists = [] | |
for point_idx, point in enumerate(points): | |
dists.append(distance(point, points[idx])) | |
return dists | |
def get_neighs(idx): | |
nonlocal neighs | |
if idx in neighs.keys(): | |
return neighs[idx] | |
dists = compute_dist(idx) | |
neighs_list = [] | |
for d_id, dist in enumerate(dists): | |
if dist < max_radius and dist!=0.0: | |
neighs_list.append(d_id) | |
if len(neighs_list) < min_samples: | |
neighs[idx] = [] | |
else: | |
neighs[idx] = neighs_list | |
return neighs[idx] | |
def expand_cluster(core_idx, neighbors): | |
nonlocal neighs | |
# recursively check if there are more neighbors of neighbors | |
cluster = [core_idx] | |
visited[core_idx] = 1 | |
for neigh_idx in neighbors: | |
if visited[neigh_idx]: continue | |
new_neighs = get_neighs(neigh_idx) | |
if new_neighs: | |
cluster += expand_cluster(neigh_idx, new_neighs) | |
return cluster | |
def label_clusters(clusters): | |
labels = [len(clusters)]*(len(points)) | |
for idx, cluster in enumerate(clusters): | |
for pt_idx in cluster: | |
labels[pt_idx] = idx | |
return labels | |
for sample_id in range(len(points)): | |
if visited[sample_id]: continue | |
nei_list = get_neighs(sample_id) | |
visited[sample_id] = 1 | |
if nei_list: | |
# this is a core point | |
cluster = expand_cluster(sample_id, nei_list) | |
clusters.append(cluster) | |
else: | |
clusters.append([sample_id]) | |
labels = label_clusters(clusters) | |
return labels | |
labels = dbscan(points) | |
print(labels) # [0,1,1,1,1] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment