Created
February 22, 2020 22:16
-
-
Save khanhnamle1994/99d9e60c3119e65a45295aa463dfe7a6 to your computer and use it in GitHub Desktop.
Matrix Factorization class
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| class MF(nn.Module): | |
| def __call__(self, train_x): | |
| # These are the user indices, and correspond to "u" variable | |
| user_id = train_x[:, 0] | |
| # These are the item indices, correspond to the "i" variable | |
| item_id = train_x[:, 1] | |
| # Initialize a vector user = p_u using the user indices | |
| vector_user = self.user(user_id) | |
| # Initialize a vector item = q_i using the item indices | |
| vector_item = self.item(item_id) | |
| # The user-item interaction: p_u * q_i is a dot product between the 2 vectors above | |
| ui_interaction = torch.sum(vector_user * vector_item, dim=1) | |
| return ui_interaction | |
| def loss(self, prediction, target): | |
| # Calculate the Mean Squared Error between target = R_ui and prediction = p_u * q_i | |
| loss_mse = F.mse_loss(prediction, target.squeeze()) | |
| # Compute L2 regularization over user (P) and item (Q) matrices | |
| prior_user = l2_regularize(self.user.weight) * self.c_vector | |
| prior_item = l2_regularize(self.item.weight) * self.c_vector | |
| # Add up the MSE loss + user & item regularization | |
| total = loss_mse + prior_user + prior_item | |
| return total |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment