Last active
August 29, 2015 14:20
-
-
Save atqamar/4b469c302f6324c1ab9a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import numpy.linalg as la | |
from annoy import AnnoyIndex | |
from spotify.rec_sys_batch.item_maps import ItemMaps | |
from spotify.util.uri import GidURI, decode_uuid4 | |
BUILD = '2015-W17' | |
ROOT = '/var/rec-sys/%s' % BUILD | |
tree = AnnoyIndex(40) | |
tree.load('%s/word2vec2-track-40.tree' % ROOT) | |
gid = lambda u: decode_uuid4(u.split(':')[2]).replace('-', '') | |
uri = lambda g: GidURI('track', g).to_uri() | |
maps = ItemMaps(build=BUILD, item_type_a='track', item_type_b='track') | |
maps.open() | |
def vec_sim(vec_a, vec_b): | |
vec_a /= la.norm(vec_a) | |
vec_b /= la.norm(vec_b) | |
return np.dot(vec_a, vec_b) | |
def is_zero(a): | |
vec_a = np.array(tree.get_item_vector(a)) | |
if np.all(vec_a == 0): | |
return True | |
else: | |
return False | |
def item_vec(a): | |
vec_a = np.array(tree.get_item_vector(a)) | |
if np.all(vec_a == 0): | |
print 'Got 0 vector' | |
return | |
return list(vec_a) | |
def item_nn(a, n): | |
vec_a = item_vec(a) | |
if vec_a is None: | |
return | |
return tree.get_nns_by_vector(vec_a, 3000)[:n] | |
def item_sim(a, b): | |
vec_a = np.array(tree.get_item_vector(a)) | |
vec_b = np.array(tree.get_item_vector(b)) | |
if np.all(vec_a == 0) or np.all(vec_b == 0): | |
print 'Got 0 vector' | |
return | |
return vec_sim(vec_a, vec_b) | |
def draw_path(start_uri, end_uri, k): | |
start_id = int(maps.a_gid_to_a_id(gid(start_uri))) | |
end_id = int(maps.a_gid_to_a_id(gid(end_uri))) | |
if is_zero(start_id) or is_zero(end_id): | |
print 'either start or end is a zero vector' | |
return None | |
path = [(start_id, item_sim(start_id, end_id))] | |
stop = start_id | |
n_stops = 0 | |
while True: | |
candidates = item_nn(stop, k) | |
if end_id in candidates: | |
path.append((end_id, 1.0)) | |
print 'Done!' | |
return path | |
sorted_cands = sorted(map( | |
lambda c: (item_sim(c, end_id), c), | |
candidates), reverse=True) | |
pp = [i for i, _ in path] # past points | |
best_stop = [(i, s) for s, i in sorted_cands if i not in pp][0] | |
path.append(best_stop) | |
stop = best_stop[0] | |
print n_stops, 'similarity to end', best_stop[1] | |
n_stops += 1 | |
def radio(start_uri, end_uri, k): | |
''' Makes radio from start uri to end uri. | |
Issue is that distance to end is | |
not constantly decreasing. | |
''' | |
for t_id, t_sim in draw_path(start_uri, end_uri, k): | |
t_gid = maps.a_id_to_all_a_gids(str(t_id))[0] | |
t_uri = uri(t_gid) | |
print t_uri | |
def monotone_radio(start_uri, end_uri, k): | |
''' Makes radio from start uri to end uri. | |
Ensures that distance to end is | |
monotincally decreasing. | |
''' | |
curr_sim = -999.0 | |
for t_id, t_sim in draw_path(start_uri, end_uri, k): | |
t_gid = maps.a_id_to_all_a_gids(str(t_id))[0] | |
t_uri = uri(t_gid) | |
if t_sim <= curr_sim: | |
continue | |
curr_sim = t_sim | |
print t_uri | |
START = 'spotify:track:2Foc5Q5nqNiosCNqttzHof' # get lucky - dp | |
END = 'spotify:track:5t0E9V1RiHBflzs71pfGGG' # piano sonata no. 14 - beethoven | |
K = 20 | |
radio(START, END, K) | |
monotone_radio(START, END, K) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment