Last active
November 26, 2017 12:36
-
-
Save joao-timescale/419f08b47a6a33c21a9412083615cd58 to your computer and use it in GitHub Desktop.
k-nearest neighbors algorithm
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from scipy.stats import mode | |
from scipy.spatial import KDTree | |
class KNearestNeighbors(object): | |
def __init__(self, k=1): | |
assert k >= 1 | |
self.__k = int(k) | |
def fit(self, X, y): | |
# Build a k-d tree from the training data | |
self.__X = KDTree(X) | |
self.__y = y | |
def predict(self, X): | |
# Query the k-d tree for the k-nearest neighbors, | |
# and obtain the indices of the neighboring points. | |
# These indices can be mapped to the labels of these | |
# points. | |
_, neighbor_ind = self.__X.query(X, k=self.__k) | |
# Map the indices of neighboring points to obtain the | |
# corresponding classes. | |
classes = self.__y[neighbor_ind] | |
# Obtain the statistical mode of the classes for classification | |
classes_mode, _ = mode(classes, axis=1) | |
# Obtain contiguous array from classes_mode array and return classfication | |
Z = np.asarray(classes_mode.ravel(), dtype=np.intp) | |
return Z |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment