Skip to content

Instantly share code, notes, and snippets.

@rdisipio
Created November 6, 2022 13:53
Show Gist options
  • Save rdisipio/07be6cc024af2eabf10196f4901cd8b3 to your computer and use it in GitHub Desktop.
Save rdisipio/07be6cc024af2eabf10196f4901cd8b3 to your computer and use it in GitHub Desktop.
from tensorflow.keras import layers
from tensorflow.keras import initializers
class QuantizedFeaturesEmbedding(layers.Layer):
def __init__(self,
n_features,
n_bins,
embed_dim,
**kwargs):
super(QuantizedFeaturesEmbedding, self).__init__(**kwargs)
self.n_features = n_features
self.n_bins = n_bins
self.embed_dim = embed_dim
assert self.n_features > 0
assert self.n_bins > 0
assert self.embed_dim > 0
def build(self, input_shape=None):
self.embeddings = self.add_weight(
shape=(self.n_features, self.n_bins, self.embed_dim),
initializer=tf.keras.initializers.GlorotUniform(),
name='quantized_features_embeddings')
self.built = True
def _enumerate(self, ids):
'''
input: [[1,0,2,1], [0,2,1,1]]
output: [[(0,1), (1,0), (2,2), (3,1)],
[(0,0), (1,2), (2,1), (3,1)]]
so that, for the first input:
A[(0,1)] = feature 0, embedding 1
A[(1,0)] = feature 1, embedding 0
A[(2,2)] = feature 2, embedding 2
A[(3,1)] = feature 3, embedding 1
'''
batch_size = ids.shape[0]
pos = tf.expand_dims(tf.range(0, self.n_features), axis=0) # [[0,1,2,..,N-1]]
pos = tf.tile(pos, tf.constant([batch_size, 1])) # repeat for this batch
tf.assert_equal(pos.shape, ids.shape)
res = tf.stack([pos, ids], axis=-1)
return res
def call(self, ids):
'''
Input shape:
2D tensor with shape (batch_size, n_features)
Output shape:
3D tensor with shape (batch_size, n_features, embed_dim)
'''
assert ids.shape[-1] == self.n_features
ids = tf.convert_to_tensor(ids, dtype=tf.int32)
idx = self._enumerate(ids)
return tf.gather_nd(self.embeddings, idx)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment