Skip to content

Instantly share code, notes, and snippets.

@khanhnamle1994
Created February 22, 2020 22:03
Show Gist options
  • Select an option

  • Save khanhnamle1994/f2c29799f32e469f8b6c8e96ebf6a129 to your computer and use it in GitHub Desktop.

Select an option

Save khanhnamle1994/f2c29799f32e469f8b6c8e96ebf6a129 to your computer and use it in GitHub Desktop.
Matrix Factorization with Mixture of Tastes class
import torch
from torch import nn
import torch.nn.functional as F
class MF(nn.Module):
def __call__(self, train_x):
# These are the user and item indices
user_id = train_x[:, 0]
item_id = train_x[:, 1]
# Initialize a vector item using the item indices
vector_item = self.item(item_id)
# Pull out biases
bias_user = self.bias_user(user_id).squeeze()
bias_item = self.bias_item(item_id).squeeze()
biases = (self.bias + bias_user + bias_item)
# **NEW: Initialize the user taste & attention matrices using the user IDs
user_taste = self.user_taste[user_id]
user_attention = self.user_attention[user_id]
vector_itemx = vector_item.unsqueeze(2).expand_as(user_attention)
attention = F.softmax(user_attention * vector_itemx, dim=1)
attentionx = attention.sum(2).unsqueeze(2).expand_as(user_attention)
# Calculate the weighted preference to be the dot product of the user taste and attention
weighted_preference = (user_taste * attentionx).sum(2)
# This is a dot product of the weighted preference and vector item
dot = (weighted_preference * vector_item).sum(1)
# Final prediction is the sum of the biases and the dot product above
prediction = dot + biases
return prediction
def loss(self, prediction, target):
# Calculate the Mean Squared Error between target and prediction
loss_mse = F.mse_loss(prediction.squeeze(), target.squeeze())
# Compute L2 regularization over the biases for user and the biases for item matrices
prior_bias_user = l2_regularize(self.bias_user.weight) * self.c_bias
prior_bias_item = l2_regularize(self.bias_item.weight) * self.c_bias
# Compute L2 regularization over the user tastes and user attentions matrix
prior_taste = l2_regularize(self.user_taste) * self.c_vector
prior_attention = l2_regularize(self.user_attention) * self.c_vector
# Compute L2 regularization over item matrix
prior_item = l2_regularize(self.item.weight) * self.c_vector
# Add up the MSE loss + user & item biases regularization + item regularization + user taste & attention regularization
total = (loss_mse + prior_bias_item + prior_bias_user + prior_taste + prior_attention + prior_item)
return total
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment