Created
October 18, 2020 20:09
-
-
Save negedng/a785b973cbe17e82ba9607111aaf4ef3 to your computer and use it in GitHub Desktop.
with modifications from https://www.tensorflow.org/tutorials/load_data/text
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import Counter | |
# Building vocabulary set for tokenizer | |
tokenizer = tfds.features.text.Tokenizer() | |
vocabulary_counter = Counter() | |
for text_tensor, _ in ds1_train: | |
some_tokens = tokenizer.tokenize(text_tensor.numpy()) | |
vocabulary_counter.update(some_tokens) | |
vocabulary = vocabulary_counter.most_common(max_features-2) | |
vocabulary_set = set([a for a,_ in vocabulary]) | |
encoder = tfds.features.text.TokenTextEncoder(vocabulary_set) | |
# Encoding functions | |
def encode(text_tensor, label): | |
encoded_text = encoder.encode(text_tensor.numpy()) | |
return encoded_text, label | |
def encode_map_fn(text, label): | |
# py_func doesn't set the shape of the returned tensors. | |
encoded_text, label = tf.py_function(encode, | |
inp=[text, label], | |
Tout=(tf.int64, tf.int64)) | |
# `tf.data.Datasets` work best if all components have a shape set | |
# so set the shapes manually: | |
encoded_text.set_shape([None]) | |
label.set_shape([]) | |
return encoded_text, label | |
def truncate(tokens, label): | |
return tokens[:maxlen], label |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment