Last active
May 22, 2025 15:59
-
-
Save epwalsh/8d1f11b895bf05bdb34a40be7ded8ae4 to your computer and use it in GitHub Desktop.
Random batch generator
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Generator | |
import numpy as np | |
def generate_batch( | |
*, | |
vocab_size: int, | |
sequence_length: int, | |
num_instances: int, | |
rng: np.random.Generator, | |
dtype=np.uint32, | |
) -> np.ndarray: | |
batch = ( | |
np.arange(0, sequence_length, dtype=dtype) | |
.reshape(1, -1) | |
.repeat(num_instances, 0) | |
) | |
start_offsets = rng.integers( | |
0, vocab_size - sequence_length, size=(num_instances, 1), dtype=dtype | |
) | |
return batch + start_offsets | |
def generate_batches( | |
*, | |
global_seed: int, | |
local_data_parallel_rank: int, | |
vocab_size: int, | |
sequence_length: int, | |
num_local_instances: int, | |
total_batches: int, | |
dtype=np.uint32, | |
) -> Generator[np.ndarray, None, None]: | |
rng = np.random.default_rng(global_seed + local_data_parallel_rank) | |
for _ in range(total_batches): | |
yield ( | |
generate_batch( | |
vocab_size=vocab_size, | |
sequence_length=sequence_length, | |
num_instances=num_local_instances, | |
rng=rng, | |
dtype=dtype, | |
) | |
) | |
if __name__ == "__main__": | |
for i, batch in enumerate( | |
generate_batches( | |
global_seed=2132, | |
local_data_parallel_rank=0, | |
vocab_size=32, | |
sequence_length=4, | |
num_local_instances=2, | |
total_batches=4, | |
) | |
): | |
print(f"batch {i}:") | |
print(batch) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment