Skip to content

Instantly share code, notes, and snippets.

@khanhnamle1994
Created February 22, 2020 22:12
Show Gist options
  • Select an option

  • Save khanhnamle1994/d98a33c1fef277b91790db28544889ee to your computer and use it in GitHub Desktop.

Select an option

Save khanhnamle1994/d98a33c1fef277b91790db28544889ee to your computer and use it in GitHub Desktop.
Matrix Factorization with Side Features class
import torch
from torch import nn
import torch.nn.functional as F
class MF(nn.Module):
def __call__(self, train_x):
# These are the user indices, and correspond to "u" variable
user_id = train_x[:, 0]
# These are the item indices, correspond to the "i" variable
item_id = train_x[:, 1]
# Initialize a vector user = p_u using the user indices
vector_user = self.user(user_id)
# Initialize a vector item = q_i using the item indices
vector_item = self.item(item_id)
# The user-item interaction: p_u * q_i is a dot product between the user vector and the item vector
ui_interaction = torch.sum(vector_user * vector_item, dim=1)
# Pull out biases
bias_user = self.bias_user(user_id).squeeze()
bias_item = self.bias_item(item_id).squeeze()
biases = (self.bias + bias_user + bias_item)
# These are the occupation indices, and correspond to "o" variable
occu_id = train_x[:, 3]
# Initialize a vector occupation = r_o using the occupation indices
vector_occu = self.occu(occu_id)
# The user-occupation interaction: p_u * r_o is a dot product between the user vector and the occupation vector
uo_interaction = torch.sum(vector_user * vector_occu, dim=1)
# Add the bias, the user-item interaction, and the user-occupation interaction to obtain the final prediction
prediction = ui_interaction + uo_interaction + biases
return prediction
def loss(self, prediction, target):
# Calculate the Mean Squared Error between target and prediction
loss_mse = F.mse_loss(prediction.squeeze(), target.squeeze())
# Compute L2 regularization over the biases for user and the biases for item matrices
prior_bias_user = l2_regularize(self.bias_user.weight) * self.c_bias
prior_bias_item = l2_regularize(self.bias_item.weight) * self.c_bias
# Compute L2 regularization over user (P) and item (Q) matrices
prior_user = l2_regularize(self.user.weight) * self.c_vector
prior_item = l2_regularize(self.item.weight) * self.c_vector
# Compute L2 regularization over occupation (R) matrices
prior_occu = l2_regularize(self.occu.weight) * self.c_vector
# Add up the MSE loss + user & item regularization + user & item biases regularization + occupation regularization
total = loss_mse + prior_user + prior_item + prior_bias_item + prior_bias_user + prior_occu
return total
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment