Last active
September 27, 2016 10:17
-
-
Save gurgeh/9369573cae9c346134ed5d121be51789 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import csv | |
import sys | |
from scipy import spatial | |
def compare_year(d, y1, cluster, y2): | |
dist = [] | |
for vec in d[y2]: | |
dist.append(spatial.distance.cosine(d[y1][cluster], vec)) | |
best = [(score, i) for (i, score) in enumerate(dist)] | |
return min(best) | |
def read_csv(fname): | |
with open(fname) as inf: | |
c = csv.reader(inf) | |
c.next() # skip header | |
d = {} | |
for row in c: | |
year = int(row[0]) | |
if year not in d: | |
d[year] = [] # create a new entry if we have not seen the year before | |
d[year].append([float(x) for x in row[3:]]) # add the cluster vector, converted to float | |
return d | |
def compare_years(fname): | |
d = read_csv(fname) | |
for year in d.keys(): | |
if year + 1 not in d: | |
continue | |
for cluster in range(len(d.values()[0])): | |
print year, year + 1, cluster | |
print compare_year(d, year, cluster, year + 1) | |
if __name__ == '__main__': | |
compare_years(sys.argv[1]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment