Last active
June 3, 2020 08:39
-
-
Save oscar-defelice/0ba2c81f1f8c55f9535f7af636885379 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_triplets_hard(batch_size, X_usr, X_item, df, return_cache = False): | |
""" | |
Returns the list of three arrays to feed the model. | |
Parameters | |
---------- | |
batch_size : int | |
size of the batch. | |
X_usr : numpy array of shape (n_users, n_user_features) | |
array of user metadata. | |
X_item : numpy array of shape (n_items, n_item_features) | |
array of item metadata. | |
df : Pandas DataFrame | |
dataframe containing user-item ratings. | |
return_cache : bool | |
parameter to triggere whether we want the list of ids corresponding to | |
triplets. | |
default: False | |
Returns | |
------- | |
triplets : list of numpy arrays | |
list containing 3 tensors A,P,N corresponding to: | |
- Anchor A : (batch_size, n_user_features) | |
- Positive P : (batch_size, n_item_features) | |
- Negative N : (batch_size, n_item_features) | |
""" | |
# constant values | |
n_user_features = X_usr.shape[1] | |
n_item_features = X_item.shape[1] | |
# define user_list | |
user_list = list(df.index.values) | |
# initialise result | |
triplets = [np.zeros((batch_size, n_user_features)), # anchor | |
np.zeros((batch_size, n_item_features)), # pos | |
np.zeros((batch_size, n_item_features)) # neg | |
] | |
user_ids = [] | |
p_ids = [] | |
n_ids = [] | |
for i in range(batch_size): | |
# pick one random user for anchor | |
anchor_id = random.choice(user_list) | |
user_ids.append(anchor_id) | |
# all possible positive/negative samples for selected anchor | |
p_item_ids = get_pos(df, anchor_id) | |
n_item_ids = get_neg(df, anchor_id) | |
# pick one of the positve ids | |
try: | |
positive_id = random.choice(p_item_ids) | |
except IndexError: | |
positive_id = 0 | |
p_ids.append(positive_id) | |
# pick the most similar negative id | |
try: | |
n_min = np.argmin([(cosine_dist(X_item[positive_id-1], X_item[k-1])) for k in n_item_ids]) | |
negative_id = n_item_ids[n_min] | |
except: | |
try: | |
negative_id = random.choice(n_item_ids) | |
except IndexError: | |
negative_id = 0 | |
n_ids.append(negative_id) | |
# define triplet | |
triplets[0][i,:] = X_usr[anchor_id-1][:] | |
if positive_id == 0: | |
triplets[1][i,:] = np.zeros((n_item_features,)) | |
else: | |
triplets[1][i,:] = X_item[positive_id-1][:] | |
if negative_id == 0: | |
triplets[2][i,:] = np.zeros((n_item_features,)) | |
else: | |
triplets[2][i,:] = X_item[negative_id-1][:] | |
if return_cache: | |
cache = {'users': user_ids, 'positive': p_ids, 'negative': n_ids} | |
return triplets, cache | |
return triplets |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment