Skip to content

Instantly share code, notes, and snippets.

@khuangaf
Created September 3, 2018 12:11
Show Gist options
  • Save khuangaf/41e3de743e4dcd9f8748cde47a6cd434 to your computer and use it in GitHub Desktop.
Save khuangaf/41e3de743e4dcd9f8748cde47a6cd434 to your computer and use it in GitHub Desktop.
def forward(self, user_indices, item_indices, titles):
user_embedding_mlp = self.embedding_user_mlp(user_indices)
item_embedding_mlp = self.embedding_item_mlp(item_indices)
user_embedding_mf = self.embedding_user_mf(user_indices)
item_embedding_mf = self.embedding_item_mf(item_indices)
#### mf part
mf_vector =torch.mul(user_embedding_mf, item_embedding_mf)
mf_vector = torch.nn.Dropout(self.config.dropout_rate_mf)(mf_vector)
#### mlp part
mlp_vector = torch.cat([user_embedding_mlp, item_embedding_mlp], dim=-1) # the concat latent vector
for idx, _ in enumerate(range(len(self.fc_layers))):
mlp_vector = self.fc_layers[idx](mlp_vector)
mlp_vector = torch.nn.ReLU()(mlp_vector)
mlp_vector = torch.nn.Dropout(self.config.dropout_rate_mlp)(mlp_vector)
vector = torch.cat([mlp_vector, mf_vector], dim=-1)
logits = self.logits(vector)
output = self.sigmoid(logits)
return output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment