Skip to content

Instantly share code, notes, and snippets.

@a-agmon
Created September 21, 2020 06:09
Show Gist options
  • Save a-agmon/9f2029a83e7666085e01a89809faccbf to your computer and use it in GitHub Desktop.
Save a-agmon/9f2029a83e7666085e01a89809faccbf to your computer and use it in GitHub Desktop.
def user_embedding_model(embedding_size = 50):
#Embed items and users in vec space
# Both inputs are 1-dimensional
user = Input(name = 'user', shape = [1])
item = Input(name = 'item', shape = [1])
# (None, 1, 50))
user_embedding = Embedding(name = 'user_embedding',
input_dim = len(user_index),
output_dim = embedding_size)(user)
# (None, 1, 50))
item_embedding = Embedding(name = 'item_embedding',
input_dim = len(item_index),
output_dim = embedding_size)(item)
# Merge the layers with a dot product along the second axis (shape will be (None, 1, 1))
merged = Dot(name = 'dot_product', normalize = True, axes = 2)([user_embedding, item_embedding])
# Reshape to be a single number (shape will be (None, 1))
merged = Reshape(target_shape = [1])(merged)
merged = Dense(1, activation = 'sigmoid')(merged)
model = Model(inputs = [user, item], outputs = merged)
model.compile(optimizer = 'Adam', loss = 'binary_crossentropy', metrics = ['accuracy'])
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment