Created
August 24, 2023 10:12
-
-
Save pythonlessons/0f7a08ac00c311b55bb88bcf794c573b to your computer and use it in GitHub Desktop.
transformers_nlp_data
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from mltu.tensorflow.dataProvider import DataProvider | |
import numpy as np | |
def preprocess_inputs(data_batch, label_batch): | |
encoder_input = np.zeros((len(data_batch), tokenizer.max_length)).astype(np.int64) | |
decoder_input = np.zeros((len(label_batch), detokenizer.max_length)).astype(np.int64) | |
decoder_output = np.zeros((len(label_batch), detokenizer.max_length)).astype(np.int64) | |
data_batch_tokens = tokenizer.texts_to_sequences(data_batch) | |
label_batch_tokens = detokenizer.texts_to_sequences(label_batch) | |
for index, (data, label) in enumerate(zip(data_batch_tokens, label_batch_tokens)): | |
encoder_input[index][:len(data)] = data | |
decoder_input[index][:len(label)-1] = label[:-1] # Drop the [END] tokens | |
decoder_output[index][:len(label)-1] = label[1:] # Drop the [START] tokens | |
return (encoder_input, decoder_input), decoder_output | |
train_dataProvider = DataProvider( | |
train_dataset, | |
batch_size=4, | |
batch_postprocessors=[preprocess_inputs], | |
use_cache=True | |
) | |
val_dataProvider = DataProvider( | |
val_dataset, | |
batch_size=4, | |
batch_postprocessors=[preprocess_inputs], | |
use_cache=True | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment