Created
August 6, 2018 19:17
-
-
Save wwwaldo/8422a3eb499cc2ed4e021fc22bd173e8 to your computer and use it in GitHub Desktop.
KNN example from Grokking Algorithms.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import math | |
import unittest | |
import sys | |
import matplotlib.pyplot as plt | |
import heapq | |
# k neighbours for oranges and grapefruits. tested function calls before running. | |
def knn(k, datastore, fruit): | |
neighbours = [] | |
for datum in datastore.data: | |
numeric_val, fruit_type = datum | |
dist = - np.linalg.norm(numeric_val - fruit) # negate distance because queue.PQ uses min as pop priorities | |
if len(neighbours) < k : | |
heapq.heappush(neighbours, (dist, fruit_type)) | |
else: | |
heapq.heappushpop(neighbours, (dist, fruit_type)) # push new val, but also pop worst val | |
orange_count, grapefruit_count = 0, 0 | |
for n in neighbours: | |
_ , fruit_type = n # "pattern matching" | |
if fruit_type == 'Orange': | |
orange_count += 1 | |
if fruit_type == 'Grapefruit': | |
grapefruit_count += 1 | |
return 'Orange' if orange_count > grapefruit_count else 'Grapefruit' # grapefruit-biased tiebreaker | |
class Datastore: | |
def __init__(self, data): | |
self.data = data | |
@classmethod | |
def fromData(cls, oranges, grapefruits): | |
orange_data = [ (datum, 'Orange') for datum in oranges ] | |
grapefruit_data = [ (datum, 'Grapefruit') for datum in grapefruits ] | |
all_data = orange_data + grapefruit_data | |
return cls(all_data) | |
# make 100 fruit. | |
def make_fruit(diameter, colour): | |
mean = (diameter, colour) | |
return np.random.normal(loc=mean, size=(100, 2)) | |
# Grapefruits are 3 cherries in diameter, and have maximal redness. | |
def make_grapefruit(): | |
return make_fruit(3, 1.) | |
# Oranges are 2 cherries in diameter, and have 0.5 redness. | |
def make_orange(): | |
return make_fruit(2, 0.5) | |
if __name__ == "__main__": | |
print("Grapefruits are neat") | |
grapefruits = make_grapefruit() | |
oranges = make_orange() | |
# 'hold on' is on by default | |
plt.plot(grapefruits[:,0], grapefruits[:, 1], 'ro') | |
plt.plot(oranges[:,0], oranges[:, 1], 'bo') | |
# make the data store | |
data = Datastore.fromData(oranges, grapefruits) | |
myfruit = np.array([[3, 1]]) # probably an orange | |
decision = knn(10, data, myfruit) | |
plt.plot(myfruit[:,0], myfruit[:,1], 'co') | |
print(f"The jury has decided: {decision}") | |
plt.show() | |
exit(0) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment