Created
April 10, 2023 09:00
-
-
Save OhadRubin/8c6242ae9c9d157c642bcf58fb043a35 to your computer and use it in GitHub Desktop.
Uses a queue and a multiprocessing pool to tokenize a tf dataset
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import multiprocessing | |
import tensorflow as tf | |
import seqio | |
import numpy as np | |
import tqdm | |
from transformers import AutoTokenizer | |
import jax | |
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") | |
tokenizer.pad_token = tokenizer.eos_token | |
def tree_unstack(tree): | |
leaves, treedef = jax.tree_util.tree_flatten(tree) | |
n_trees = leaves[0].shape[0] | |
new_leaves = [[] for _ in range(n_trees)] | |
for leaf in leaves: | |
for i in range(n_trees): | |
new_leaves[i].append(leaf[i]) | |
new_trees = [treedef.unflatten(l) for l in new_leaves] | |
return new_trees | |
datasource = seqio.TfdsDataSource( | |
tfds_name="pg19:0.1.1", | |
) | |
dataset = datasource.get_dataset(split="train") | |
def apply_iter(dataset_iter, output_queue): | |
for batch in dataset_iter.prefetch(1000).batch(100).as_numpy_iterator(): | |
byte_array = np.array(batch["book_text"], dtype=bytes) | |
decode_utf8 = np.vectorize(lambda x: x.decode('utf-8')) | |
result = decode_utf8(byte_array) | |
input_ids = tokenizer.batch_encode_plus(tuple(result))['input_ids'] | |
batch['input_ids'] = np.array(input_ids,dtype=object) | |
output_queue.put(batch) | |
output_queue.put(None) # Add a sentinel value to signal the end of the worker's output | |
def _tokenize_dataset(dataset: tf.data.Dataset): | |
processes = 10 | |
sharded_ds = [dataset.shard(num_shards=processes, index=i) for i in range(processes)] | |
output_queue = multiprocessing.Queue() | |
workers = [ | |
multiprocessing.Process(target=apply_iter, args=(sharded_ds[i], output_queue)) | |
for i in range(processes) | |
] | |
for worker in workers: | |
worker.start() | |
completed_workers = 0 | |
try: | |
while completed_workers < processes: | |
tokenized_batch = output_queue.get() | |
if tokenized_batch is None: # Check for sentinel value | |
completed_workers += 1 | |
else: | |
for el in tree_unstack(tokenized_batch): | |
yield el | |
except KeyboardInterrupt: | |
for worker in workers: | |
worker.terminate() | |
worker.join() | |
raise | |
for worker in workers: | |
worker.join() | |
def tokenize_dataset(dataset: tf.data.Dataset): | |
element_spec = dict(**dataset.element_spec) | |
element_spec['input_ids'] = tf.TensorSpec(shape=[None], dtype=tf.int64) | |
return tf.data.Dataset.from_generator(output_signature=element_spec, | |
generator=lambda: _tokenize_dataset(dataset)) | |
tokenized_dataset = tokenize_dataset(dataset) | |
for x in tqdm.tqdm(tokenized_dataset.as_numpy_iterator()): | |
pass |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment