Created
February 22, 2020 22:15
-
-
Save khanhnamle1994/3843cff57653f457fa9a6124fade554c to your computer and use it in GitHub Desktop.
Matrix Factorization with Biases class
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| class MF(nn.Module): | |
| def __call__(self, train_x): | |
| # These are the user indices, and correspond to "u" variable | |
| user_id = train_x[:, 0] | |
| # These are the item indices, correspond to the "i" variable | |
| item_id = train_x[:, 1] | |
| # Initialize a vector user = p_u using the user indices | |
| vector_user = self.user(user_id) | |
| # Initialize a vector item = q_i using the item indices | |
| vector_item = self.item(item_id) | |
| # The user-item interaction: p_u * q_i is a dot product between the 2 vectors above | |
| ui_interaction = torch.sum(vector_user * vector_item, dim=1) | |
| # Pull out biases | |
| bias_user = self.bias_user(user_id).squeeze() | |
| bias_item = self.bias_item(item_id).squeeze() | |
| biases = (self.bias + bias_user + bias_item) | |
| # Add the bias to the user-item interaction to obtain the final prediction | |
| prediction = ui_interaction + biases | |
| return prediction | |
| def loss(self, prediction, target): | |
| # Calculate the Mean Squared Error between target and prediction | |
| loss_mse = F.mse_loss(prediction, target.squeeze()) | |
| # Compute L2 regularization over the biases for user and the biases for item matrices | |
| prior_bias_user = l2_regularize(self.bias_user.weight) * self.c_bias | |
| prior_bias_item = l2_regularize(self.bias_item.weight) * self.c_bias | |
| # Compute L2 regularization over user (P) and item (Q) matrices | |
| prior_user = l2_regularize(self.user.weight) * self.c_vector | |
| prior_item = l2_regularize(self.item.weight) * self.c_vector | |
| # Add up the MSE loss + user & item regularization + user & item biases regularization | |
| total = loss_mse + prior_user + prior_item + prior_bias_user + prior_bias_item | |
| return total |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment