Created
November 6, 2022 13:53
-
-
Save rdisipio/07be6cc024af2eabf10196f4901cd8b3 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from tensorflow.keras import layers | |
from tensorflow.keras import initializers | |
class QuantizedFeaturesEmbedding(layers.Layer): | |
def __init__(self, | |
n_features, | |
n_bins, | |
embed_dim, | |
**kwargs): | |
super(QuantizedFeaturesEmbedding, self).__init__(**kwargs) | |
self.n_features = n_features | |
self.n_bins = n_bins | |
self.embed_dim = embed_dim | |
assert self.n_features > 0 | |
assert self.n_bins > 0 | |
assert self.embed_dim > 0 | |
def build(self, input_shape=None): | |
self.embeddings = self.add_weight( | |
shape=(self.n_features, self.n_bins, self.embed_dim), | |
initializer=tf.keras.initializers.GlorotUniform(), | |
name='quantized_features_embeddings') | |
self.built = True | |
def _enumerate(self, ids): | |
''' | |
input: [[1,0,2,1], [0,2,1,1]] | |
output: [[(0,1), (1,0), (2,2), (3,1)], | |
[(0,0), (1,2), (2,1), (3,1)]] | |
so that, for the first input: | |
A[(0,1)] = feature 0, embedding 1 | |
A[(1,0)] = feature 1, embedding 0 | |
A[(2,2)] = feature 2, embedding 2 | |
A[(3,1)] = feature 3, embedding 1 | |
''' | |
batch_size = ids.shape[0] | |
pos = tf.expand_dims(tf.range(0, self.n_features), axis=0) # [[0,1,2,..,N-1]] | |
pos = tf.tile(pos, tf.constant([batch_size, 1])) # repeat for this batch | |
tf.assert_equal(pos.shape, ids.shape) | |
res = tf.stack([pos, ids], axis=-1) | |
return res | |
def call(self, ids): | |
''' | |
Input shape: | |
2D tensor with shape (batch_size, n_features) | |
Output shape: | |
3D tensor with shape (batch_size, n_features, embed_dim) | |
''' | |
assert ids.shape[-1] == self.n_features | |
ids = tf.convert_to_tensor(ids, dtype=tf.int32) | |
idx = self._enumerate(ids) | |
return tf.gather_nd(self.embeddings, idx) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment