Skip to content

Instantly share code, notes, and snippets.

@Quentin-Anthony
Created July 16, 2025 17:45
Show Gist options
  • Save Quentin-Anthony/46cccbe1de58df93d54e56e047dcb5f9 to your computer and use it in GitHub Desktop.
Save Quentin-Anthony/46cccbe1de58df93d54e56e047dcb5f9 to your computer and use it in GitHub Desktop.

How Megatron Builds and Pulls from Datasets

I describe a bit below on how megatron statically builds datasets, and then how models can pull from those datasets at training time. In order:

  1. How GPT datasets are produced inside Megatron‑Core;
  2. Exactly what a training step receives (__getitem__ --> DataLoader --> model);
  3. How to host the finished .bin / .idx pair in an S3‑compatible bucket and stream it lazily during training. I think this is the desired end-state for templar's training needs.

1. Dataset-Builder Overview

Megatron builds datasets statically at the start of training. It uses user arguments (sequence length, batch size, dataset seed, number of iterations) to build that dataset within a few minutes on the root rank. During training, all torch dataloader workers pull these complete samples (pre-tokenized, chopped to seqlen, contiguous, etc) so that our preprocessing time is zero. We avoid being bottlenecked by the load time from NFS, to CPU RAM, to GPU VRAM via parallel workers under the torch dataloader and a prefetch factor (i.e. if there are 8 workers with a prefetch factor of 2, there will be 8 CPU threads lined up to feed each GPU with 2 samples each).

First off, the primary classes involved are:

  • IndexedDatasetBuilder
    • Flattens already‑tokenised sequences into a binary blob (.bin).
    • Writes a companion index (.idx) with:
      • byte pointer per sequence
      • sequence length
      • document boundaries (and optional “mode” byte).
  • GPTDataset
    • Wraps one IndexedDataset, restricts it to a (train/valid/test) slice, and converts variable‑length documents into a stream of fixed‑length training samples with causal masks, labels, etc.
  • BlendedDataset
    • Stitches several GPTDatasets together according to user‑specified weights/ratios so that a single __getitem__ yields a virtual mixture of sources.
  • BlendedMegatronDatasetBuilder
    • Decides, per split, whether to build a raw GPTDataset or a BlendedDataset; builds everything once on rank‑0, caches the indices, then re‑opens them mmap‑style on all ranks.

helpers.cpp::build_sample_idx walks the flattened token stream once and writes a prefix array of length num_samples + 1.

and here's some simplified pseudocode on the C++ dataset builder impl if you want more detail:

doc_i = 0; offset = 0;
sample_idx[0] = (doc_i, offset);      // start of sample 0
for s in [1 … num_samples] {
  remaining = seq_length + extra;
  while (remaining > 0) {
      doc_len = sizes[document_index[doc_i]] - offset;
      remaining -= doc_len;
      if (remaining <= 0) { offset += doc_len + remaining - extra; }
      else                  { ++doc_i; offset = 0; }
  }
  sample_idx[s] = (doc_i, offset);  // start of next sample
}

2. How GPTDataset Builds Samples

2.1 Index Arrays

Megatron uses a few arrays to manage indices at the document-level and sample-level:

  • document_index (Shape: [num_epochs × num_docs])
    • List of doc‑IDs, repeated across epochs, globally shuffled (optionally leaves the final partial epoch un‑shuffled for size control).
  • sample_index (Shape: [num_samples + 1, 2])
    • For every sample boundary stores (doc_idx, offset) into the flattened token stream.
    • Built by C++ helper helpers.cpp::build_sample_idx.
  • shuffle_index (Shape: [num_samples])
    • A permutation of [0 … num_samples‑1]; makes every __getitem__ random without reshuffling tensors.

We first build the dataset on the root rank, and:

  • Generates the above index arrays, and write them under path_to_cache/<hash>-GPTDataset-{train|valid|test}-{document|sample|shuffle}_index.npy.
  • Subsequent runs and other ranks just memory‑map these files (mmap_mode="r"), so work is not duplicated and RAM doesn't explode

3. What __getitem__ Returns

Assume sequence_length = 2048, add_extra_token_to_sequence = 1 (default).

  • tokens (dtype: torch.int64, Shape: [2048])
    • Input IDs (last token removed when +1 extra token is used).
  • labels (dtype: torch.int64, Shape: [2048])
    • Next‑token targets (first token removed, padding --> pad_id).
  • attention_mask* (dtype: torch.bool, Shape: [1, 2048, 2048])
    • True = masked. Upper‑triangular causal mask, optionally trimmed at EOD boundaries.
  • loss_mask (dtype: torch.float32, Shape: [2048])
    • 1.0 except for:
      • padding
      • eod_token (if eod_mask_loss=True).
  • position_ids (dtype: torch.int64, Shape: [2048])
    • Typically [0…2047]; re‑zeroed after every EOD if reset_position_ids.

* attention_mask is emitted only when create_attention_mask=True (default).

torch.utils.data.default_collate simply stacks these into batch tensors:

loader = DataLoader(train_ds, batch_size=8, pin_memory=True, shuffle=False)
batch  = next(iter(loader))

tokens       = batch["tokens"].to(device)       # [B, 2048]
labels       = batch["labels"].to(device)       # [B, 2048]
loss_mask    = batch["loss_mask"].to(device)    # [B, 2048]
position_ids = batch["position_ids"].to(device) # [B, 2048]
attn_mask    = batch.get("attention_mask", None) # [B, 1, 2048, 2048]

4. Loading Samples During Training

First off, multiple epochs are handled by resetting the iterator to the start. Indices are not rebuilt, meaning that we have no control over replay.

Each torch dataloader worker fetches the following:

global_idx               # from DataLoader loop
shuffled = shuffle_index[global_idx]
(doc_i, offset) = sample_index[shuffled]
tokens = dataset.get(doc_i, offset, L+1)   # numpy view or S3 read

dataset.get chooses the right BinReader implementation:

  • _MMapBinReader: Used for local file & mmap=True (default). Memory-maps .bin once, returns views.
  • _FileBinReader: Used for local file & mmap=False. Opens + seeks + reads each call (slower, but less VM usage).
  • _S3BinReader: Used when path starts with s3://. Keeps an in-mem chunk cache. On miss: issues one GetObject Range covering (offset // chunk_size) × chunk_size … offset+len.
  • _MultiStorageClientBinReader: Used when path starts with msc://profile/.... Same idea, implemented via NVIDIA MSC client.

Index (.idx) always resides locally — object_storage_cache_path/<bucket>/<key>.idx.

5. Hosting on Cloud Bucket

With all this in mind, here's how I think we can use the existing megatron pipeline to live-stream samples from a cloud bucket like S3:

5.1 Offline Dataset Build

Before sharing the dataset with miners, we first statically build the dataset (sample indices, etc). We upload this prebuilt dataset into a bucket so that miners don't have to preprocess themselves after pulling from it.

# Produces .bin and .idx
python build_indexed_dataset.py \
  --in  my_corpus.txt \
  --out /tmp/my_gpt_data

# Upload to S3-compatible store
aws --endpoint-url https://<accountid>.r2.cloudflarestorage.com \
    s3 cp /tmp/my_gpt_data.bin s3://mybucket/my_gpt_data.bin
aws --endpoint-url https://<accountid>.r2.cloudflarestorage.com \
    s3 cp /tmp/my_gpt_data.idx  s3://mybucket/my_gpt_data.idx

5.2 Miners pulling from bucket during training

We would share the bucket with miners along with any specific settings (e.g. tokenizer we used, seqlen, seed, etc):

from megatron.core.datasets import GPTDataset, GPTDatasetConfig
from megatron.core.datasets.blended_megatron_dataset_builder import (
    BlendedMegatronDatasetBuilder,
)

cfg = GPTDatasetConfig(
    random_seed = 42,
    sequence_length = 2048,
    blend   = (["s3://mybucket/my_gpt_data"], None),
    split   = "98,1,1",
    object_storage_cache_path = "/local/cache/idx",
    tokenizer = my_tokenizer,
)

sizes = [None, 5000, 5000]   # train full epoch, small valid/test
builder = BlendedMegatronDatasetBuilder(
    cls = GPTDataset,
    sizes = sizes,
    is_built_on_rank = lambda: True,   # build everywhere
    config = cfg,
)

train_ds, val_ds, test_ds = builder.build()

We may have to share bucket details, which they set as env vars. For example:

export AWS_ACCESS_KEY_ID="<r2-access-key>"
export AWS_SECRET_ACCESS_KEY="<r2-secret>"
export AWS_DEFAULT_REGION="auto"
export AWS_S3_ENDPOINT_URL="https://<accountid>.r2.cloudflarestorage.com"

Appendix A – class to file mapping

Below is a lookup showing every class / helper / C++ symbol referenced in the narrative and the exact file that defines it in the directory tree you provided.

Concept mentioned above Python / C++ Symbol  Defined in …
dataset bin IndexedDatasetBuilder, IndexedDataset indexed_dataset.py
  • Index writer/reader internals _IndexWriter, _IndexReader, DType indexed_dataset.py
  • Bin readers (local & remote) _MMapBinReader, _FileBinReader, _S3BinReader, _MultiStorageClientBinReader indexed_dataset.py
C++ helpers build_sample_idx_int32, build_sample_idx_int64, build_blending_indices, build_exhaustive_blending_indices, build_mapping, build_blocks_mapping helpers.cpp (Python bindings re‑exported via helpers.py)
GPT dataset GPTDataset, GPTDatasetConfig gpt_dataset.py
  • Mock variant MockGPTDataset, MockGPTLowLevelDataset gpt_dataset.py
Mask‑LM base class MaskedWordPieceDataset, MaskedWordPieceDatasetConfig masked_dataset.py
dataset blend & split orchestration BlendedDataset blended_dataset.py
BlendedMegatronDatasetBuilder blended_megatron_dataset_builder.py
BlendedMegatronDatasetConfig blended_megatron_dataset_config.py
Shared utils Split enum, normalize, get_blend_from_list, compile_helpers utils.py
Tokenizer abstraction MegatronTokenizer megatron_tokenizer.py
Object‑storage setup ObjectStorageConfig, helpers such as is_object_storage_path, _S3BinReader (Python side) object_storage_utils.py
(legacy alias utils_object_storage.py)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment