Skip to content

Instantly share code, notes, and snippets.

@negedng
Created October 18, 2020 21:56
Show Gist options
  • Save negedng/5459d605d6e7908743ed1808471c8423 to your computer and use it in GitHub Desktop.
Save negedng/5459d605d6e7908743ed1808471c8423 to your computer and use it in GitHub Desktop.
# Encoding the data to integer token ids
def encode(examples):
tokens = [enc.ids for enc in tokenizer.encode_batch(examples['text'])]
return {'tokens': tokens}
ds2_train = ds2_train.map(encode, batched=True)
ds2_train = ds2_train.map(lambda examples: {'labels': examples['label']}, batched=True)
ds2_val = ds2_val.map(encode, batched=True)
ds2_val = ds2_val.map(lambda examples: {'labels': examples['label']}, batched=True)
# Format to TensorFlow Dataset
ds2_train.set_format(type='tensorflow', columns=['tokens', 'labels'])
ds2_val.set_format(type='tensorflow', columns=['tokens', 'labels'])
train_features = {x: ds2_train[x].to_tensor(default_value=0, shape=[None, maxlen]) for x in ['tokens']}
val_features = {x: ds2_val[x].to_tensor(default_value=0, shape=[None, maxlen]) for x in ['tokens']}
ds2_train = tf.data.Dataset.from_tensor_slices((train_features, ds2_train["labels"]))
ds2_val = tf.data.Dataset.from_tensor_slices((val_features, ds2_val["labels"]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment