Skip to content

Instantly share code, notes, and snippets.

@khanhnamle1994
Created February 22, 2020 22:15
Show Gist options
  • Select an option

  • Save khanhnamle1994/3843cff57653f457fa9a6124fade554c to your computer and use it in GitHub Desktop.

Select an option

Save khanhnamle1994/3843cff57653f457fa9a6124fade554c to your computer and use it in GitHub Desktop.
Matrix Factorization with Biases class
import torch
from torch import nn
import torch.nn.functional as F
class MF(nn.Module):
def __call__(self, train_x):
# These are the user indices, and correspond to "u" variable
user_id = train_x[:, 0]
# These are the item indices, correspond to the "i" variable
item_id = train_x[:, 1]
# Initialize a vector user = p_u using the user indices
vector_user = self.user(user_id)
# Initialize a vector item = q_i using the item indices
vector_item = self.item(item_id)
# The user-item interaction: p_u * q_i is a dot product between the 2 vectors above
ui_interaction = torch.sum(vector_user * vector_item, dim=1)
# Pull out biases
bias_user = self.bias_user(user_id).squeeze()
bias_item = self.bias_item(item_id).squeeze()
biases = (self.bias + bias_user + bias_item)
# Add the bias to the user-item interaction to obtain the final prediction
prediction = ui_interaction + biases
return prediction
def loss(self, prediction, target):
# Calculate the Mean Squared Error between target and prediction
loss_mse = F.mse_loss(prediction, target.squeeze())
# Compute L2 regularization over the biases for user and the biases for item matrices
prior_bias_user = l2_regularize(self.bias_user.weight) * self.c_bias
prior_bias_item = l2_regularize(self.bias_item.weight) * self.c_bias
# Compute L2 regularization over user (P) and item (Q) matrices
prior_user = l2_regularize(self.user.weight) * self.c_vector
prior_item = l2_regularize(self.item.weight) * self.c_vector
# Add up the MSE loss + user & item regularization + user & item biases regularization
total = loss_mse + prior_user + prior_item + prior_bias_user + prior_bias_item
return total
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment