Skip to content

Instantly share code, notes, and snippets.

@ebetica
Created March 19, 2026 22:40
Show Gist options
  • Select an option

  • Save ebetica/c58f83ddca366d3ed48c99621a7d21a8 to your computer and use it in GitHub Desktop.

Select an option

Save ebetica/c58f83ddca366d3ed48c99621a7d21a8 to your computer and use it in GitHub Desktop.
lance_mapper: parallel map over Lance datasets on SLURM (usage example)
"""
lance_mapper: parallel map over Lance datasets on SLURM
========================================================
This example shows how to use LanceMapper to run an embarrassingly parallel
computation over a Lance dataset using SLURM job arrays.
The pattern:
1. Subclass LanceMapper
2. Set key_column (unique ID column) and rows_per_shard
3. Override init() to load your model / expensive resources once per worker
4. Override map_shard() to process a shard's PyArrow table, yielding dicts
The framework handles:
- SLURM array job submission + progress monitoring
- Preemption-safe checkpointing (partial shards resume automatically)
- Work stealing (fast workers pick up unfinished shards)
- Atomic writes (crash-safe parquet intermediates)
- Final merge into a single Lance v2 dataset with zstd compression
Install (inside the evolutionaryscale monorepo):
pixi shell
# lance_mapper is part of the evolutionaryscale package
Example dataset (create a toy one):
>>> import lance, pyarrow as pa
>>> tab = pa.table({"id": range(10000), "text": [f"hello {i}" for i in range(10000)]})
>>> lance.write_dataset(tab, "/tmp/example.lance")
"""
from __future__ import annotations
import argparse
from typing import Any, Iterator
import pyarrow as pa
from evolutionaryscale.utils.lance_mapper import (
LanceMapper,
LocalBackend,
SLURMBackend,
TaskGranularity,
)
class UpperCaseMapper(LanceMapper):
"""Toy example: uppercase every text field and add its length."""
key_column = "id" # must exist in input AND every yielded batch
rows_per_shard = 500 # rows per logical shard (tune for your workload)
def __init__(self, prefix: str = "", **kwargs: Any):
super().__init__(**kwargs)
self.prefix = prefix
def init(self) -> None:
"""Called once per worker. Load models / resources here.
Return value is passed as `ctx` to map_shard().
For this toy example we don't need anything.
"""
return None
def map_shard(self, table: pa.Table, ctx: Any) -> Iterator[dict]:
"""Process one shard. Yield dict[str, list] batches.
- Must include key_column in every yielded dict.
- May skip rows (they just won't appear in the output).
- Can yield one row at a time or accumulate mini-batches.
"""
ids = table["id"].to_pylist()
texts = table["text"].to_pylist()
# Process in mini-batches of 100 for efficiency
batch_size = 100
for start in range(0, len(ids), batch_size):
end = start + batch_size
batch_ids = ids[start:end]
batch_texts = texts[start:end]
yield {
"id": batch_ids,
"text_upper": [self.prefix + t.upper() for t in batch_texts],
"text_len": [len(t) for t in batch_texts],
}
# --- CLI wiring (optional, enables `python example.py run ...`) ---
@classmethod
def _add_cli_args(cls, parser: argparse.ArgumentParser) -> None:
"""Add subclass-specific CLI args."""
parser.add_argument("--prefix", default="", help="Prefix to add")
parser.add_argument("--partition", default="lowpri", help="SLURM partition")
parser.add_argument("--qos", default="low", help="SLURM QOS")
parser.add_argument("--time-limit", default="4:00:00", help="SLURM time limit")
@classmethod
def _from_cli_args(
cls, args: argparse.Namespace, shard_ids: list[int] | None = None
) -> "UpperCaseMapper":
"""Construct mapper from parsed CLI args."""
return cls(
prefix=getattr(args, "prefix", ""),
input_dataset=args.input_dataset,
output=args.output,
backend=SLURMBackend(
task_granularity=TaskGranularity.CPU, # no GPU needed
partition=getattr(args, "partition", "lowpri"),
qos=getattr(args, "qos", "low"),
time_limit=getattr(args, "time_limit", "4:00:00"),
),
shard_ids=shard_ids,
)
def _extra_run_single_args(self) -> list[str]:
"""Extra args forwarded to each SLURM worker's run_single command."""
return ["--prefix", self.prefix] if self.prefix else []
# ---------------------------------------------------------------------------
# Usage examples
# ---------------------------------------------------------------------------
#
# 1. FULL SLURM PIPELINE (submit → wait → merge):
#
# python lance_mapper_example.py run \
# --input-dataset /path/to/input.lance \
# --output /path/to/output.lance \
# --num-workers 50
#
# This submits a SLURM array job with 50 workers, shows a progress bar,
# and auto-merges into a final Lance dataset when all shards complete.
#
#
# 2. SINGLE SHARD SMOKE TEST (run on a single node, no SLURM array):
#
# srun -c 4 --mem=8G -p lowpri --qos=dev -t 0:30:00 \
# pixi run python lance_mapper_example.py run_single \
# --input-dataset /path/to/input.lance \
# --output /path/to/test-output.lance \
# --shard-ids 0
#
#
# 3. CHECK PROGRESS:
#
# python lance_mapper_example.py status \
# --output /path/to/output.lance
#
#
# 4. MANUAL MERGE (if run was interrupted):
#
# python lance_mapper_example.py merge \
# --output /path/to/output.lance
#
# # Or merge even if some shards are missing:
# python lance_mapper_example.py merge \
# --output /path/to/output.lance --allow-incomplete
#
#
# 5. PROGRAMMATIC USE (no CLI):
#
# import lance
# ds = lance.dataset("input.lance")
#
# mapper = UpperCaseMapper(
# input_dataset="input.lance",
# output="output.lance",
# backend=LocalBackend(num_workers=1), # local, no SLURM
# )
# mapper.run_single() # process all shards locally
# result_path = mapper.merge() # merge into output.lance
# result = lance.dataset(str(result_path))
# print(result.to_table().to_pandas())
#
#
# 6. GPU WORKLOAD (e.g., model inference):
#
# class FoldMapper(LanceMapper):
# key_column = "header"
# rows_per_shard = 1000
#
# def init(self):
# return load_model("checkpoint.pt")
#
# def map_shard(self, table, model):
# for row in table.to_pylist():
# result = model.predict(row["sequence"])
# yield {
# "header": [row["header"]],
# "prediction": [result.to_bytes()],
# "score": [float(result.score)],
# }
#
# # Submit with GPU resources:
# mapper = FoldMapper(
# input_dataset="sequences.lance",
# output="predictions.lance",
# backend=SLURMBackend(
# task_granularity=TaskGranularity.GPU, # 1 GPU, 12 CPUs, 128G RAM
# partition="h100-reserved",
# qos="low",
# time_limit="48:00:00",
# ),
# )
# mapper.run(num_workers=500)
# ---------------------------------------------------------------------------
if __name__ == "__main__":
UpperCaseMapper.cli()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment