Skip to content

Instantly share code, notes, and snippets.

Created April 15, 2019 07:08
Show Gist options
  • Save nielsuit227/a38e5ce43ed969dda36f43bb9ca2b617 to your computer and use it in GitHub Desktop.
Save nielsuit227/a38e5ce43ed969dda36f43bb9ca2b617 to your computer and use it in GitHub Desktop.
Class function
import numpy as np
import matplotlib.pyplot as plt
class DBSCAN(object):
def __init__(self, radius=1.0, minpoints=50, tree=None):
# Pass variables
if tree is not None:
self._tree = tree
self._tree = []
self._eps = radius
self._minpoints = minpoints
self._next_cluster_id = 0
# Create empties
self._n = []
self._m = []
self._data = []
self._clusterid = []
self._checked = []
self._outliers = []
def fit(self, data):
self._data = data
self._n, self._m = np.shape(self._data)
self._clusterid = -1 * np.ones(self._n)
self._checked = False * np.ones(self._n)
for i in range(self._n):
if self._clusterid[i] == -1:
print('[%.2f %%] Last Cluster Size: %.0f'
% (100*i/self._n, np.sum(self._clusterid == np.max(self._clusterid))))
return self._clusterid
# _clusterid is initialized as -1 (outliers) and keeps track of assigned clusters.
# _checked is initalized as False and every point is marked True when it is checked against other points
def expandcluster(self, point):
seeds = self.neighbor(point)
if np.sum(seeds) < self._minpoints:
return self._clusterid
self._clusterid[seeds] = self.nextid(seeds)
seeds[np.where(seeds)[0][0]] = False
while np.sum(seeds) != 0:
npoint = np.where(seeds)[0][0]
nseeds = self.neighbor(npoint)
if np.sum(nseeds) >= self._minpoints:
self._clusterid[nseeds] = self._clusterid[npoint]
seeds = np.maximum(nseeds, seeds)
seeds = np.logical_and(seeds is True, self._checked is False)
seeds[npoint] = False
# seeds is a vector which marks unchecked data samples in the neighborhood of the cluster. Points that are radius
# of extended clusters are added. It's a to do list for checking distances pretty much.
def neighbor(self, point):
seeds = np.zeros(self._n, dtype='bool')
if np.size(point) == 1:
self._checked[point] = True
if not self._tree:
seedsindex = self.bruteforce(self._data[point, :])
seedsindex = self._tree.exact_r_nn(self._eps, self._data[point, :], maxeval=10)
if not self._tree:
seedsindex = self.bruteforce(point)
seedsindex = self._tree.exact_r_nn(self._eps, point, maxeval=10)
seeds[seedsindex] = True
return seeds
def nextid(self, seeds):
self._next_cluster_id += 1
return np.ones(int(np.sum(seeds))) * self._next_cluster_id
def bruteforce(self, point):
distance = np.sqrt(np.sum((self._data - point) ** 2, 1))
return np.where(distance < self._eps)
def plot(self):
self._outliers = self._clusterid == -1
plt.scatter(self._data[:, 0], self._data[:, 1], c='g', s=1)
plt.scatter(self._data[self._outliers, 0], self._data[self._outliers, 1], c='r', s=1)
plt.suptitle('Estimated clusters')
def predict_outlier(self, point):
seeds = self.neighbor(point)
if np.sum(seeds) > self._minpoints:
return False
return True
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment