Created
August 17, 2017 10:13
-
-
Save saliksyed/bb914f19f3e26f6713d5fc7c966add13 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python2 | |
# -*- coding: utf-8 -*- | |
""" | |
Created on Thu Aug 17 10:02:54 2017 | |
@author: saliksyed | |
""" | |
from sklearn import datasets | |
from sklearn import preprocessing | |
from sklearn.neighbors import KDTree | |
import numpy as np | |
import random | |
iris = datasets.load_iris() | |
features = iris.data[:, :] | |
labels = iris.target | |
#random.seed(1) | |
def dist(pt1, pt2): | |
return np.linalg.norm(pt1 - pt2) | |
def classify_point(query_point, tree, training_labels, k=3): | |
# compute the distance between query and example | |
dist, ind = tree.query(query_point, k=k) | |
count = {} | |
for neighbor_idx in ind[0]: | |
label = training_labels[neighbor_idx] | |
if not label in count: | |
count[label] = 0 | |
count[label] += 1 | |
# pick the label that has the highest count! | |
# key = the label | |
# value = the count | |
sorted_counts = sorted(zip(count.keys(), count.values()), key=lambda x : x[1]) | |
return sorted_counts[0][0] | |
mixed_data = zip(features, labels) | |
random.shuffle(mixed_data) | |
train_percentage = 0.7 | |
train_count = int(round(train_percentage * len(mixed_data))) | |
training_data = mixed_data[:train_count] # ==> [([x1, x2, x3, x4], label)...] | |
features = [x[0] for x in training_data] | |
testing_data = mixed_data[train_count:] # ==> [([x1, x2, x3, x4], label)...] | |
# Compute the percentage correct :-) in the testing set | |
tree = KDTree(features, leaf_size=2) | |
training_labels = [x[1] for x in training_data] | |
correct = 0 | |
for pt in testing_data: | |
result = classify_point(pt[0], tree, training_labels) | |
if result == pt[1]: | |
correct += 1 | |
print float(correct) / float(len(testing_data)) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment