I describe a bit below on how megatron statically builds datasets, and then how models can pull from those datasets at training time. In order:
- How GPT datasets are produced inside Megatron‑Core;
- Exactly what a training step receives (
__getitem__
--> DataLoader --> model); - How to host the finished
.bin
/.idx
pair in an S3‑compatible bucket and stream it lazily during training. I think this is the desired end-state for templar's training needs.
Megatron builds datasets statically at the start of training. It uses user arguments (sequence length, batch size, dataset seed, number of iterations) to build that dataset within a few minutes on the root rank. During training, all torch dataloader workers pull these complete samples (pre-tokenized, chopped to seqlen, contiguous, etc) so that our preprocessing time is zero. We avoid being bottlenecked by the load time from NFS, to CPU RAM, to GPU VRAM via parallel workers under the torch dataloader and a prefetch factor (i.e. if there are 8 workers with a prefetch factor of 2, there will be 8 CPU threads lined up to feed each GPU with 2 samples each).
First off, the primary classes involved are:
IndexedDatasetBuilder
- Flattens already‑tokenised sequences into a binary blob (
.bin
). - Writes a companion index (
.idx
) with:- byte pointer per sequence
- sequence length
- document boundaries (and optional “mode” byte).
- Flattens already‑tokenised sequences into a binary blob (
GPTDataset
- Wraps one
IndexedDataset
, restricts it to a (train/valid/test) slice, and converts variable‑length documents into a stream of fixed‑length training samples with causal masks, labels, etc.
- Wraps one
BlendedDataset
- Stitches several
GPTDataset
s together according to user‑specified weights/ratios so that a single__getitem__
yields a virtual mixture of sources.
- Stitches several
BlendedMegatronDatasetBuilder
- Decides, per split, whether to build a raw
GPTDataset
or aBlendedDataset
; builds everything once on rank‑0, caches the indices, then re‑opens them mmap‑style on all ranks.
- Decides, per split, whether to build a raw
helpers.cpp::build_sample_idx
walks the flattened token stream once and writes a prefix array of length num_samples + 1
.
and here's some simplified pseudocode on the C++ dataset builder impl if you want more detail:
doc_i = 0; offset = 0;
sample_idx[0] = (doc_i, offset); // start of sample 0
for s in [1 … num_samples] {
remaining = seq_length + extra;
while (remaining > 0) {
doc_len = sizes[document_index[doc_i]] - offset;
remaining -= doc_len;
if (remaining <= 0) { offset += doc_len + remaining - extra; }
else { ++doc_i; offset = 0; }
}
sample_idx[s] = (doc_i, offset); // start of next sample
}
Megatron uses a few arrays to manage indices at the document-level and sample-level:
document_index
(Shape:[num_epochs × num_docs]
)- List of doc‑IDs, repeated across epochs, globally shuffled (optionally leaves the final partial epoch un‑shuffled for size control).
sample_index
(Shape:[num_samples + 1, 2]
)- For every sample boundary stores
(doc_idx, offset)
into the flattened token stream. - Built by C++ helper
helpers.cpp::build_sample_idx
.
- For every sample boundary stores
shuffle_index
(Shape:[num_samples]
)- A permutation of
[0 … num_samples‑1]
; makes every__getitem__
random without reshuffling tensors.
- A permutation of
We first build the dataset on the root rank, and:
- Generates the above index arrays, and write them under
path_to_cache/<hash>-GPTDataset-{train|valid|test}-{document|sample|shuffle}_index.npy
. - Subsequent runs and other ranks just memory‑map these files (
mmap_mode="r"
), so work is not duplicated and RAM doesn't explode
Assume sequence_length = 2048
, add_extra_token_to_sequence = 1
(default).
tokens
(dtype
:torch.int64
, Shape:[2048]
)- Input IDs (last token removed when
+1
extra token is used).
- Input IDs (last token removed when
labels
(dtype
:torch.int64
, Shape:[2048]
)- Next‑token targets (first token removed, padding -->
pad_id
).
- Next‑token targets (first token removed, padding -->
attention_mask
* (dtype
:torch.bool
, Shape:[1, 2048, 2048]
)- True = masked. Upper‑triangular causal mask, optionally trimmed at EOD boundaries.
loss_mask
(dtype
:torch.float32
, Shape:[2048]
)- 1.0 except for:
- padding
eod_token
(ifeod_mask_loss=True
).
- 1.0 except for:
position_ids
(dtype
:torch.int64
, Shape:[2048]
)- Typically
[0…2047]
; re‑zeroed after every EOD ifreset_position_ids
.
- Typically
* attention_mask
is emitted only when create_attention_mask=True
(default).
torch.utils.data.default_collate
simply stacks these into batch tensors:
loader = DataLoader(train_ds, batch_size=8, pin_memory=True, shuffle=False)
batch = next(iter(loader))
tokens = batch["tokens"].to(device) # [B, 2048]
labels = batch["labels"].to(device) # [B, 2048]
loss_mask = batch["loss_mask"].to(device) # [B, 2048]
position_ids = batch["position_ids"].to(device) # [B, 2048]
attn_mask = batch.get("attention_mask", None) # [B, 1, 2048, 2048]
First off, multiple epochs are handled by resetting the iterator to the start. Indices are not rebuilt, meaning that we have no control over replay.
Each torch dataloader worker fetches the following:
global_idx # from DataLoader loop
shuffled = shuffle_index[global_idx]
(doc_i, offset) = sample_index[shuffled]
tokens = dataset.get(doc_i, offset, L+1) # numpy view or S3 read
dataset.get
chooses the right BinReader
implementation:
_MMapBinReader
: Used for local file &mmap=True
(default). Memory-maps.bin
once, returns views._FileBinReader
: Used for local file &mmap=False
. Opens + seeks + reads each call (slower, but less VM usage)._S3BinReader
: Used when path starts withs3://
. Keeps an in-mem chunk cache. On miss: issues oneGetObject Range
covering(offset // chunk_size) × chunk_size … offset+len
._MultiStorageClientBinReader
: Used when path starts withmsc://profile/...
. Same idea, implemented via NVIDIA MSC client.
Index (.idx
) always resides locally — object_storage_cache_path/<bucket>/<key>.idx
.
With all this in mind, here's how I think we can use the existing megatron pipeline to live-stream samples from a cloud bucket like S3:
Before sharing the dataset with miners, we first statically build the dataset (sample indices, etc). We upload this prebuilt dataset into a bucket so that miners don't have to preprocess themselves after pulling from it.
# Produces .bin and .idx
python build_indexed_dataset.py \
--in my_corpus.txt \
--out /tmp/my_gpt_data
# Upload to S3-compatible store
aws --endpoint-url https://<accountid>.r2.cloudflarestorage.com \
s3 cp /tmp/my_gpt_data.bin s3://mybucket/my_gpt_data.bin
aws --endpoint-url https://<accountid>.r2.cloudflarestorage.com \
s3 cp /tmp/my_gpt_data.idx s3://mybucket/my_gpt_data.idx
We would share the bucket with miners along with any specific settings (e.g. tokenizer we used, seqlen, seed, etc):
from megatron.core.datasets import GPTDataset, GPTDatasetConfig
from megatron.core.datasets.blended_megatron_dataset_builder import (
BlendedMegatronDatasetBuilder,
)
cfg = GPTDatasetConfig(
random_seed = 42,
sequence_length = 2048,
blend = (["s3://mybucket/my_gpt_data"], None),
split = "98,1,1",
object_storage_cache_path = "/local/cache/idx",
tokenizer = my_tokenizer,
)
sizes = [None, 5000, 5000] # train full epoch, small valid/test
builder = BlendedMegatronDatasetBuilder(
cls = GPTDataset,
sizes = sizes,
is_built_on_rank = lambda: True, # build everywhere
config = cfg,
)
train_ds, val_ds, test_ds = builder.build()
We may have to share bucket details, which they set as env vars. For example:
export AWS_ACCESS_KEY_ID="<r2-access-key>"
export AWS_SECRET_ACCESS_KEY="<r2-secret>"
export AWS_DEFAULT_REGION="auto"
export AWS_S3_ENDPOINT_URL="https://<accountid>.r2.cloudflarestorage.com"
Below is a lookup showing every class / helper / C++ symbol referenced in the narrative and the exact file that defines it in the directory tree you provided.
Concept mentioned above | Python / C++ Symbol | Defined in … |
---|---|---|
dataset bin | IndexedDatasetBuilder , IndexedDataset |
indexed_dataset.py |
• Index writer/reader internals | _IndexWriter , _IndexReader , DType |
indexed_dataset.py |
• Bin readers (local & remote) | _MMapBinReader , _FileBinReader , _S3BinReader , _MultiStorageClientBinReader |
indexed_dataset.py |
C++ helpers | build_sample_idx_int32 , build_sample_idx_int64 , build_blending_indices , build_exhaustive_blending_indices , build_mapping , build_blocks_mapping |
helpers.cpp (Python bindings re‑exported via helpers.py ) |
GPT dataset | GPTDataset , GPTDatasetConfig |
gpt_dataset.py |
• Mock variant | MockGPTDataset , MockGPTLowLevelDataset |
gpt_dataset.py |
Mask‑LM base class | MaskedWordPieceDataset , MaskedWordPieceDatasetConfig |
masked_dataset.py |
dataset blend & split orchestration | BlendedDataset |
blended_dataset.py |
BlendedMegatronDatasetBuilder |
blended_megatron_dataset_builder.py |
|
BlendedMegatronDatasetConfig |
blended_megatron_dataset_config.py |
|
Shared utils | Split enum, normalize , get_blend_from_list , compile_helpers |
utils.py |
Tokenizer abstraction | MegatronTokenizer |
megatron_tokenizer.py |
Object‑storage setup | ObjectStorageConfig , helpers such as is_object_storage_path , _S3BinReader (Python side) |
object_storage_utils.py (legacy alias utils_object_storage.py ) |