Skip to content

Instantly share code, notes, and snippets.

@OhadRubin
Created April 10, 2023 09:00
Show Gist options
  • Save OhadRubin/8c6242ae9c9d157c642bcf58fb043a35 to your computer and use it in GitHub Desktop.
Save OhadRubin/8c6242ae9c9d157c642bcf58fb043a35 to your computer and use it in GitHub Desktop.
Uses a queue and a multiprocessing pool to tokenize a tf dataset
import multiprocessing
import tensorflow as tf
import seqio
import numpy as np
import tqdm
from transformers import AutoTokenizer
import jax
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
tokenizer.pad_token = tokenizer.eos_token
def tree_unstack(tree):
leaves, treedef = jax.tree_util.tree_flatten(tree)
n_trees = leaves[0].shape[0]
new_leaves = [[] for _ in range(n_trees)]
for leaf in leaves:
for i in range(n_trees):
new_leaves[i].append(leaf[i])
new_trees = [treedef.unflatten(l) for l in new_leaves]
return new_trees
datasource = seqio.TfdsDataSource(
tfds_name="pg19:0.1.1",
)
dataset = datasource.get_dataset(split="train")
def apply_iter(dataset_iter, output_queue):
for batch in dataset_iter.prefetch(1000).batch(100).as_numpy_iterator():
byte_array = np.array(batch["book_text"], dtype=bytes)
decode_utf8 = np.vectorize(lambda x: x.decode('utf-8'))
result = decode_utf8(byte_array)
input_ids = tokenizer.batch_encode_plus(tuple(result))['input_ids']
batch['input_ids'] = np.array(input_ids,dtype=object)
output_queue.put(batch)
output_queue.put(None) # Add a sentinel value to signal the end of the worker's output
def _tokenize_dataset(dataset: tf.data.Dataset):
processes = 10
sharded_ds = [dataset.shard(num_shards=processes, index=i) for i in range(processes)]
output_queue = multiprocessing.Queue()
workers = [
multiprocessing.Process(target=apply_iter, args=(sharded_ds[i], output_queue))
for i in range(processes)
]
for worker in workers:
worker.start()
completed_workers = 0
try:
while completed_workers < processes:
tokenized_batch = output_queue.get()
if tokenized_batch is None: # Check for sentinel value
completed_workers += 1
else:
for el in tree_unstack(tokenized_batch):
yield el
except KeyboardInterrupt:
for worker in workers:
worker.terminate()
worker.join()
raise
for worker in workers:
worker.join()
def tokenize_dataset(dataset: tf.data.Dataset):
element_spec = dict(**dataset.element_spec)
element_spec['input_ids'] = tf.TensorSpec(shape=[None], dtype=tf.int64)
return tf.data.Dataset.from_generator(output_signature=element_spec,
generator=lambda: _tokenize_dataset(dataset))
tokenized_dataset = tokenize_dataset(dataset)
for x in tqdm.tqdm(tokenized_dataset.as_numpy_iterator()):
pass
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment