Created
December 19, 2017 18:23
-
-
Save zmjjmz/7637e2713a458ac1f69655af20e38717 to your computer and use it in GitHub Desktop.
shitty lookup layer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class TokenizeLookupLayer(keras.layers.Layer): | |
""" | |
Layer that encapsulates the following: | |
- Tokenizing sentences by space (or given delimiter) | |
- Looking up the words with a given vocabulary list / table | |
- Resetting the shape of the above to be batch_size x pad_len (using dark magic) | |
# Input Shape | |
2D string tensor with shape `(batch_size, 1)` | |
# Output Shape | |
2D int32 tensor with shape `(batch_size, pad_len)` | |
""" | |
def __init__(self, word_ind_map, pad_len, pad_value=0, oov_value=1, **kwargs): | |
super(TokenizeLookupLayer, self).__init__(**kwargs) | |
self.input_spec = keras.engine.InputSpec( | |
ndim=2, dtype='string') | |
self.pad_len = pad_len | |
self.pad_value = pad_value | |
self.oov_value = oov_value | |
self.word_ind_map = word_ind_map | |
def get_config(self): | |
config = { | |
'word_ind_map': self.word_ind_map, | |
'pad_len': self.pad_len, | |
'pad_value': self.pad_value, | |
'oov_value': self.oov_value, | |
} | |
base_config = super(TokenizeLookupLayer, self).get_config() | |
config.update(base_config) | |
return config | |
def build(self, input_shape): | |
self.lookup_tab = tensorflow.contrib.lookup.HashTable( | |
tensorflow.contrib.lookup.KeyValueTensorInitializer( | |
*zip(*self.word_ind_map.iteritems())), | |
default_value=self.oov_value) | |
try: | |
tensorflow.tables_initializer().run(session=keras.backend.get_session()) | |
except tensorflow.errors.FailedPreconditionError: | |
#TODO(ZJ) this is probably wrong?: DS-209 | |
pass | |
super(TokenizeLookupLayer, self).build(input_shape) | |
def call(self, str_inp): | |
# no name supported for this op?! | |
tokenized_inp = tensorflow.string_split( | |
tensorflow.squeeze(str_inp, axis=1)) | |
sparse_inp_lookedup = self.lookup_tab.lookup( | |
tokenized_inp, | |
name='lookup' | |
) | |
# this could be batch_size x max_seq_len_in_batch | |
# and max_seq_len_in_batch bears no relation to pad_len, but we need to | |
# get it out in pad_len | |
dense_inp = tensorflow.sparse_tensor_to_dense( | |
sparse_inp_lookedup, | |
default_value=self.pad_value, | |
name='dense' | |
) | |
# So essentially: add 0s to the end up to pad_len | |
# pad | |
pad_full = tensorflow.pad( | |
dense_inp, | |
paddings=tensorflow.constant([[0, 0], [0, self.pad_len]]), | |
#paddings=tensorflow.constant([[0, self.pad_len]]), | |
mode='CONSTANT', | |
constant_values=self.pad_value, | |
name='pad' | |
) | |
# Then limit the second dimension to pad_len | |
# slice | |
sliced = pad_full[:, :self.pad_len] | |
return sliced | |
def compute_output_shape(self, input_shape): | |
# return (input_shape[0], self.pad_len) | |
return (input_shape[0], self.pad_len,) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Howdy! I forked this and added a regexp replace before the split, allowing for regexp-based tokenization instead of just delim-based tokenization, in case its useful to you:
https://gist.github.com/soaxelbrooke/246959a7290313fb22be021d9c82a394