Created
February 22, 2020 22:12
-
-
Save khanhnamle1994/d98a33c1fef277b91790db28544889ee to your computer and use it in GitHub Desktop.
Matrix Factorization with Side Features class
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| class MF(nn.Module): | |
| def __call__(self, train_x): | |
| # These are the user indices, and correspond to "u" variable | |
| user_id = train_x[:, 0] | |
| # These are the item indices, correspond to the "i" variable | |
| item_id = train_x[:, 1] | |
| # Initialize a vector user = p_u using the user indices | |
| vector_user = self.user(user_id) | |
| # Initialize a vector item = q_i using the item indices | |
| vector_item = self.item(item_id) | |
| # The user-item interaction: p_u * q_i is a dot product between the user vector and the item vector | |
| ui_interaction = torch.sum(vector_user * vector_item, dim=1) | |
| # Pull out biases | |
| bias_user = self.bias_user(user_id).squeeze() | |
| bias_item = self.bias_item(item_id).squeeze() | |
| biases = (self.bias + bias_user + bias_item) | |
| # These are the occupation indices, and correspond to "o" variable | |
| occu_id = train_x[:, 3] | |
| # Initialize a vector occupation = r_o using the occupation indices | |
| vector_occu = self.occu(occu_id) | |
| # The user-occupation interaction: p_u * r_o is a dot product between the user vector and the occupation vector | |
| uo_interaction = torch.sum(vector_user * vector_occu, dim=1) | |
| # Add the bias, the user-item interaction, and the user-occupation interaction to obtain the final prediction | |
| prediction = ui_interaction + uo_interaction + biases | |
| return prediction | |
| def loss(self, prediction, target): | |
| # Calculate the Mean Squared Error between target and prediction | |
| loss_mse = F.mse_loss(prediction.squeeze(), target.squeeze()) | |
| # Compute L2 regularization over the biases for user and the biases for item matrices | |
| prior_bias_user = l2_regularize(self.bias_user.weight) * self.c_bias | |
| prior_bias_item = l2_regularize(self.bias_item.weight) * self.c_bias | |
| # Compute L2 regularization over user (P) and item (Q) matrices | |
| prior_user = l2_regularize(self.user.weight) * self.c_vector | |
| prior_item = l2_regularize(self.item.weight) * self.c_vector | |
| # Compute L2 regularization over occupation (R) matrices | |
| prior_occu = l2_regularize(self.occu.weight) * self.c_vector | |
| # Add up the MSE loss + user & item regularization + user & item biases regularization + occupation regularization | |
| total = loss_mse + prior_user + prior_item + prior_bias_item + prior_bias_user + prior_occu | |
| return total |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment