Created
July 6, 2017 10:09
-
-
Save saliksyed/44e854de1592f292e694cc22c7eb5417 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python2 | |
# -*- coding: utf-8 -*- | |
""" | |
Created on Wed Jul 5 00:19:18 2017 | |
@author: saliksyed | |
""" | |
import json | |
from sklearn.cluster import KMeans | |
country_routes = json.loads(open("routes.json").read()) | |
countries_to_index = {} | |
idx = 0 | |
for country in country_routes: | |
for key in country_routes[country]["to"].keys(): | |
if not key in countries_to_index: | |
countries_to_index[key] = idx | |
idx += 1 | |
total_countries = idx + 1 | |
vectors = [] | |
countries = [] | |
for country in country_routes: | |
vec = [0] * total_countries | |
for key in country_routes[country]["to"].keys(): | |
vec[countries_to_index[key]] = 1 | |
vectors.append(vec) | |
countries.append(country) | |
from sklearn import svm | |
country_to_country_code = json.loads(open("country_name_to_country_code.json").read()) | |
populations = json.loads(open("population_by_country_code.json").read()) | |
gdps = json.loads(open("gdp_by_country_code.json").read()) | |
X = [] | |
y = [] | |
for i, country in enumerate(countries): | |
if country in country_to_country_code: | |
code = country_to_country_code[country] | |
if populations[code] != None: | |
X.append(vectors[i]) | |
if populations[code] > 1000000: | |
y.append(1) | |
else: | |
y.append(0) | |
# now we'll keep most of the data is training data, and keep the rest to "test" | |
# if our algorithm actually works on new examples it hasn't seen before | |
num_train = int(len(X)/2.0) | |
X_train = X[:num_train] | |
y_train = y[:num_train] | |
X_test = X[num_train:] | |
y_test = y[num_train:] | |
clf = svm.SVC() | |
clf.fit(X_train, y_train) | |
predictions = clf.predict(X_test) | |
print predictions | |
# Now let's check our predictions: | |
correct = 0 | |
for i, prediction in enumerate(predictions): | |
if prediction == y_test[i]: | |
correct += 1 | |
print "Percentage correct:" | |
print float(correct) / len(predictions) * 100 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment