Skip to content

Instantly share code, notes, and snippets.

@neerajvashistha
Created October 8, 2018 10:55
Show Gist options
  • Select an option

  • Save neerajvashistha/66ef66a57208cca1d0f14b4a6c1d6f8b to your computer and use it in GitHub Desktop.

Select an option

Save neerajvashistha/66ef66a57208cca1d0f14b4a6c1d6f8b to your computer and use it in GitHub Desktop.
import numpy as np
def distance(p1,p2):
""" return distance between pony p1 and p2 """
return np.sqrt(np.sum(np.power(p2-p1,2)))
def majority_vote(votes):
"""
return winner for a list of votes
or just return scipy.stats.mode(votes)
"""
from collections import Counter
import random
vote_count = Counter(votes)
winner = []
max_count = max(vote_count.values())
for vote,count in vote_count.items():
if count == max_count:
winner.append(vote)
return random.choice(winner)
def find_nearest_neighbors(p,points,k=5):
""" find the k nearest neigh of point p and return indices """
distances = np.zeros(points.shape[0])
for i in range(len(distances)):
distances[i] = distance(p,points[i])
ind = np.argsort(distances)
return ind[:k]
def knn_predict(p,points,outcomes,k=5):
"""
p : point which you classify
points: numpy array consisting of x,y cordinates of n points
outcomes: the class of the above n points
k: k nearest neigh
"""
ind = find_nearest_neighbors(p,points,k)
return majority_vote(outcomes[ind])
if __name__ == '__main__':
p1= np.array([1,1])
p2= np.array([4,4])
distance(p1,p2)
print(majority_vote(np.array([0,1,1,1,3,4,1,3,1,4,1])))
points = np.array([[1,1],[1,2],[2,2],[2,1],[3,1],[3,3],[1,4],[4,4],[2,5]])
p=np.array([2,4.3])
import matplotlib.pyplot as plt
plt.plot(points[:,0],points[:,1],"ro")
plt.plot(p[0],p[1],"bo")
plt.show()
points[find_nearest_neighbors(p,points,2)]
outcomes = np.array([0,0,0,0,0,1,1,1,1])
print(knn_predict(p,points,outcomes,k=2))
out = knn_predict(p,points,outcomes,k=2)
plt.plot(p[0],p[1],"bo")
point_0 = points[np.where(outcomes==0)]
plt.plot(point_0[:,0],point_0[:,1],"go")
point_1 = points[np.where(outcomes==1)]
plt.plot(point_1[:,0],point_1[:,1],"ro")
plt.show()
if out == 0:
plt.plot(p[0],p[1],"go")
else:
plt.plot(p[0],p[1],"ro")
plt.plot(point_0[:,0],point_0[:,1],"go")
plt.plot(point_1[:,0],point_1[:,1],"ro")
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment