Created
December 16, 2019 00:56
-
-
Save lucidrains/188f7fe11f21367e19185677d5f7a332 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
class Recommend(nn.Module): | |
def __init__(self, num_items, num_users, dims): | |
super().__init__() | |
self.user_embed = nn.Embedding(num_users, dims) | |
self.item_embed = nn.Embedding(num_items, dims) | |
self.net = nn.Sequential( | |
nn.Linear(2 * dims, 4 * dims), | |
nn.LeakyReLU(), | |
nn.Linear(4 * dims, 1), | |
nn.Sigmoid() | |
) | |
def forward(self, x): | |
f = torch.cat((self.user_embed(x[:, 0]), self.item_embed(x[:, 1])), dim=1) | |
return self.net(f) | |
r = Recommend(12, 1000, 512) | |
x = torch.tensor([[0, 0], [0, 1], [1, 0]]) | |
y = torch.tensor([1, 0, 1]).float() | |
# user 0 likes item 0 | |
# user 0 dislikes item 1 | |
# user 1 likes item 0 | |
output = r(dataset) | |
loss = F.binary_cross_entropy(output, y) | |
loss.backward() | |
# run above in training loop | |
r(torch.tensor[[1, 1]]) | |
# predict if user1 likes item 1 | |
cluster(r.user_embed.weight) | |
cluster(r.item_embed.weight) | |
# cluster user and item embeddings to find similar users and items |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment