Skip to content

Instantly share code, notes, and snippets.

@khanhnamle1994
Created April 23, 2020 12:43
Show Gist options
  • Select an option

  • Save khanhnamle1994/67568fd36dbbc683d5219bdead47dd11 to your computer and use it in GitHub Desktop.

Select an option

Save khanhnamle1994/67568fd36dbbc683d5219bdead47dd11 to your computer and use it in GitHub Desktop.
Deep Factorization Machine model
import torch
from layer import FactorizationMachine, FeaturesEmbedding, FeaturesLinear, MultiLayerPerceptron
class DeepFactorizationMachineModel(torch.nn.Module):
"""
A Pytorch implementation of DeepFM.
Reference:
H Guo, et al. DeepFM: A Factorization-Machine based Neural Network for CTR Prediction, 2017.
"""
def __init__(self, field_dims, embed_dim, mlp_dims, dropout):
super().__init__()
self.linear = FeaturesLinear(field_dims)
self.fm = FactorizationMachine(reduce_sum=True)
self.embedding = FeaturesEmbedding(field_dims, embed_dim)
self.embed_output_dim = len(field_dims) * embed_dim
self.mlp = MultiLayerPerceptron(self.embed_output_dim, mlp_dims, dropout)
def forward(self, x):
"""
:param x: Long tensor of size ``(batch_size, num_fields)``
"""
embed_x = self.embedding(x)
x = self.linear(x) + self.fm(embed_x) + self.mlp(embed_x.view(-1, self.embed_output_dim))
return torch.sigmoid(x.squeeze(1))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment