Last active
December 4, 2020 16:24
-
-
Save ZviBaratz/c0084fa88747c6b2f7f15f540ef25d93 to your computer and use it in GitHub Desktop.
Naive implementation of a k-NN estimator. This code is written entirely for educational purposes and should not be relied upon in practical applications.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class KNearestNeighbors: | |
""" | |
Simple implementation of a k-NN estimator. | |
""" | |
def __init__(self, n_neighbors: int = 1) -> None: | |
self.k = n_neighbors | |
self.X_train = None | |
self.y_train = None | |
def fit(self, X_train: np.ndarray, y_train: np.ndarray) -> None: | |
""" | |
Set the train dataset attributes to be used for prediction. | |
""" | |
self.X_train = X_train | |
self.y_train = y_train | |
def get_neighbor_classes(self, observation: np.ndarray) -> np.ndarray: | |
""" | |
Returns an array of the classes of the *k* nearest neighbors. | |
""" | |
distances = np.sqrt(np.sum((self.X_train - observation)**2, axis=1)) | |
# Create an array of training set indices ordered by their | |
# distance from the current observation | |
indices = np.argsort(distances, axis=0) | |
selected_indices = indices[:self.k] | |
return self.y_train[selected_indices] | |
def estimate_class(self, observation: np.ndarray) -> int: | |
""" | |
Estimates to which class a given row (*observation*) belongs. | |
""" | |
neighbor_classes = self.get_neighbor_classes(observation) | |
classes, counts = np.unique(neighbor_classes, return_counts=True) | |
return classes[np.argmax(counts)] | |
def predict(self, X: np.ndarray): | |
""" | |
Apply k-NN estimation for each row in a given dataset. | |
""" | |
return np.apply_along_axis(self.estimate_class, 1, X) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment