Skip to content

Instantly share code, notes, and snippets.

@pythonlessons
Created August 24, 2023 10:12
Show Gist options
  • Save pythonlessons/0f7a08ac00c311b55bb88bcf794c573b to your computer and use it in GitHub Desktop.
Save pythonlessons/0f7a08ac00c311b55bb88bcf794c573b to your computer and use it in GitHub Desktop.
transformers_nlp_data
from mltu.tensorflow.dataProvider import DataProvider
import numpy as np
def preprocess_inputs(data_batch, label_batch):
encoder_input = np.zeros((len(data_batch), tokenizer.max_length)).astype(np.int64)
decoder_input = np.zeros((len(label_batch), detokenizer.max_length)).astype(np.int64)
decoder_output = np.zeros((len(label_batch), detokenizer.max_length)).astype(np.int64)
data_batch_tokens = tokenizer.texts_to_sequences(data_batch)
label_batch_tokens = detokenizer.texts_to_sequences(label_batch)
for index, (data, label) in enumerate(zip(data_batch_tokens, label_batch_tokens)):
encoder_input[index][:len(data)] = data
decoder_input[index][:len(label)-1] = label[:-1] # Drop the [END] tokens
decoder_output[index][:len(label)-1] = label[1:] # Drop the [START] tokens
return (encoder_input, decoder_input), decoder_output
train_dataProvider = DataProvider(
train_dataset,
batch_size=4,
batch_postprocessors=[preprocess_inputs],
use_cache=True
)
val_dataProvider = DataProvider(
val_dataset,
batch_size=4,
batch_postprocessors=[preprocess_inputs],
use_cache=True
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment