Skip to content

Instantly share code, notes, and snippets.

@franckjay
Created March 25, 2024 02:59
Show Gist options
  • Save franckjay/7477648035fc8dda7317c2862d9448d3 to your computer and use it in GitHub Desktop.
Save franckjay/7477648035fc8dda7317c2862d9448d3 to your computer and use it in GitHub Desktop.
Forward call for multiple user and item embeddings
def forward(self, x, u_cats, i_cats):
"""
Forward pass
:param x: Float Tensor
:param u_cats: User index tensor
:param i_cats: Item index tensor
:return: Predictions for this batch
"""
curr_batch_size = len(u_cats)
# Take User and Item embeddings for each value
u_embs = self.user_embedding(u_cats.long())
i_embs = self.item_embedding(i_cats.long())
# Orient along the correct axis
u_embs = u_embs.view(curr_batch_size, -1)
i_embs = i_embs.view(curr_batch_size, -1)
# Concat float values and embeddings together
x = torch.cat([u_embs, i_embs, x], 1)
return self.layers(x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment