Skip to content

Instantly share code, notes, and snippets.

@Manojbhat09
Created July 16, 2021 03:40
Show Gist options
  • Save Manojbhat09/92c4d4eea3f25a23958220d6152063d1 to your computer and use it in GitHub Desktop.
Save Manojbhat09/92c4d4eea3f25a23958220d6152063d1 to your computer and use it in GitHub Desktop.
Pure python DBSCAN algorithm. For simple interview practice.
# machine learning 101: dbscan
# wonderful unsupervised clustering, fast and works efficiently
# checkout HDBSCAN for an even more beautiful algorithm
'''
How to write dbscan:
1. look into the probem of clustering
2. start by first sample
3. compute the distance, get the neighbors
4. if the neighs are more than min samples, store the nighbors and expand recursively, save to same cluster and return, keep visited to avoid cycles,
5. after getting all the samples store clusters
6. label the clusters, 0, 1,2, last remaining 3
hyperparameters:
min_samples, max_radius
(-10,-10) (6,6)
(5,5)
(1,1)
(0,0)
'''
points = [(-10,-10), (0,0), (1,1), (5,5), (6,6)]
def dbscan(points, min_samples=2, max_radius=10):
visited = [0]*len(points)
neighs, clusters = {}, []
distance = lambda x, y: math.sqrt((x[0]-y[0])**2 + (x[1]-y[1])**2)
def compute_dist(idx):
dists = []
for point_idx, point in enumerate(points):
dists.append(distance(point, points[idx]))
return dists
def get_neighs(idx):
nonlocal neighs
if idx in neighs.keys():
return neighs[idx]
dists = compute_dist(idx)
neighs_list = []
for d_id, dist in enumerate(dists):
if dist < max_radius and dist!=0.0:
neighs_list.append(d_id)
if len(neighs_list) < min_samples:
neighs[idx] = []
else:
neighs[idx] = neighs_list
return neighs[idx]
def expand_cluster(core_idx, neighbors):
nonlocal neighs
# recursively check if there are more neighbors of neighbors
cluster = [core_idx]
visited[core_idx] = 1
for neigh_idx in neighbors:
if visited[neigh_idx]: continue
new_neighs = get_neighs(neigh_idx)
if new_neighs:
cluster += expand_cluster(neigh_idx, new_neighs)
return cluster
def label_clusters(clusters):
labels = [len(clusters)]*(len(points))
for idx, cluster in enumerate(clusters):
for pt_idx in cluster:
labels[pt_idx] = idx
return labels
for sample_id in range(len(points)):
if visited[sample_id]: continue
nei_list = get_neighs(sample_id)
visited[sample_id] = 1
if nei_list:
# this is a core point
cluster = expand_cluster(sample_id, nei_list)
clusters.append(cluster)
else:
clusters.append([sample_id])
labels = label_clusters(clusters)
return labels
labels = dbscan(points)
print(labels) # [0,1,1,1,1]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment