Created
February 22, 2020 22:03
-
-
Save khanhnamle1994/f2c29799f32e469f8b6c8e96ebf6a129 to your computer and use it in GitHub Desktop.
Matrix Factorization with Mixture of Tastes class
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| class MF(nn.Module): | |
| def __call__(self, train_x): | |
| # These are the user and item indices | |
| user_id = train_x[:, 0] | |
| item_id = train_x[:, 1] | |
| # Initialize a vector item using the item indices | |
| vector_item = self.item(item_id) | |
| # Pull out biases | |
| bias_user = self.bias_user(user_id).squeeze() | |
| bias_item = self.bias_item(item_id).squeeze() | |
| biases = (self.bias + bias_user + bias_item) | |
| # **NEW: Initialize the user taste & attention matrices using the user IDs | |
| user_taste = self.user_taste[user_id] | |
| user_attention = self.user_attention[user_id] | |
| vector_itemx = vector_item.unsqueeze(2).expand_as(user_attention) | |
| attention = F.softmax(user_attention * vector_itemx, dim=1) | |
| attentionx = attention.sum(2).unsqueeze(2).expand_as(user_attention) | |
| # Calculate the weighted preference to be the dot product of the user taste and attention | |
| weighted_preference = (user_taste * attentionx).sum(2) | |
| # This is a dot product of the weighted preference and vector item | |
| dot = (weighted_preference * vector_item).sum(1) | |
| # Final prediction is the sum of the biases and the dot product above | |
| prediction = dot + biases | |
| return prediction | |
| def loss(self, prediction, target): | |
| # Calculate the Mean Squared Error between target and prediction | |
| loss_mse = F.mse_loss(prediction.squeeze(), target.squeeze()) | |
| # Compute L2 regularization over the biases for user and the biases for item matrices | |
| prior_bias_user = l2_regularize(self.bias_user.weight) * self.c_bias | |
| prior_bias_item = l2_regularize(self.bias_item.weight) * self.c_bias | |
| # Compute L2 regularization over the user tastes and user attentions matrix | |
| prior_taste = l2_regularize(self.user_taste) * self.c_vector | |
| prior_attention = l2_regularize(self.user_attention) * self.c_vector | |
| # Compute L2 regularization over item matrix | |
| prior_item = l2_regularize(self.item.weight) * self.c_vector | |
| # Add up the MSE loss + user & item biases regularization + item regularization + user taste & attention regularization | |
| total = (loss_mse + prior_bias_item + prior_bias_user + prior_taste + prior_attention + prior_item) | |
| return total |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment