Skip to content

Instantly share code, notes, and snippets.

@epwalsh
Last active May 22, 2025 15:59
Show Gist options
  • Save epwalsh/8d1f11b895bf05bdb34a40be7ded8ae4 to your computer and use it in GitHub Desktop.
Save epwalsh/8d1f11b895bf05bdb34a40be7ded8ae4 to your computer and use it in GitHub Desktop.
Random batch generator
from typing import Generator
import numpy as np
def generate_batch(
*,
vocab_size: int,
sequence_length: int,
num_instances: int,
rng: np.random.Generator,
dtype=np.uint32,
) -> np.ndarray:
batch = (
np.arange(0, sequence_length, dtype=dtype)
.reshape(1, -1)
.repeat(num_instances, 0)
)
start_offsets = rng.integers(
0, vocab_size - sequence_length, size=(num_instances, 1), dtype=dtype
)
return batch + start_offsets
def generate_batches(
*,
global_seed: int,
local_data_parallel_rank: int,
vocab_size: int,
sequence_length: int,
num_local_instances: int,
total_batches: int,
dtype=np.uint32,
) -> Generator[np.ndarray, None, None]:
rng = np.random.default_rng(global_seed + local_data_parallel_rank)
for _ in range(total_batches):
yield (
generate_batch(
vocab_size=vocab_size,
sequence_length=sequence_length,
num_instances=num_local_instances,
rng=rng,
dtype=dtype,
)
)
if __name__ == "__main__":
for i, batch in enumerate(
generate_batches(
global_seed=2132,
local_data_parallel_rank=0,
vocab_size=32,
sequence_length=4,
num_local_instances=2,
total_batches=4,
)
):
print(f"batch {i}:")
print(batch)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment