Created
March 19, 2026 22:40
-
-
Save ebetica/c58f83ddca366d3ed48c99621a7d21a8 to your computer and use it in GitHub Desktop.
lance_mapper: parallel map over Lance datasets on SLURM (usage example)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| lance_mapper: parallel map over Lance datasets on SLURM | |
| ======================================================== | |
| This example shows how to use LanceMapper to run an embarrassingly parallel | |
| computation over a Lance dataset using SLURM job arrays. | |
| The pattern: | |
| 1. Subclass LanceMapper | |
| 2. Set key_column (unique ID column) and rows_per_shard | |
| 3. Override init() to load your model / expensive resources once per worker | |
| 4. Override map_shard() to process a shard's PyArrow table, yielding dicts | |
| The framework handles: | |
| - SLURM array job submission + progress monitoring | |
| - Preemption-safe checkpointing (partial shards resume automatically) | |
| - Work stealing (fast workers pick up unfinished shards) | |
| - Atomic writes (crash-safe parquet intermediates) | |
| - Final merge into a single Lance v2 dataset with zstd compression | |
| Install (inside the evolutionaryscale monorepo): | |
| pixi shell | |
| # lance_mapper is part of the evolutionaryscale package | |
| Example dataset (create a toy one): | |
| >>> import lance, pyarrow as pa | |
| >>> tab = pa.table({"id": range(10000), "text": [f"hello {i}" for i in range(10000)]}) | |
| >>> lance.write_dataset(tab, "/tmp/example.lance") | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| from typing import Any, Iterator | |
| import pyarrow as pa | |
| from evolutionaryscale.utils.lance_mapper import ( | |
| LanceMapper, | |
| LocalBackend, | |
| SLURMBackend, | |
| TaskGranularity, | |
| ) | |
| class UpperCaseMapper(LanceMapper): | |
| """Toy example: uppercase every text field and add its length.""" | |
| key_column = "id" # must exist in input AND every yielded batch | |
| rows_per_shard = 500 # rows per logical shard (tune for your workload) | |
| def __init__(self, prefix: str = "", **kwargs: Any): | |
| super().__init__(**kwargs) | |
| self.prefix = prefix | |
| def init(self) -> None: | |
| """Called once per worker. Load models / resources here. | |
| Return value is passed as `ctx` to map_shard(). | |
| For this toy example we don't need anything. | |
| """ | |
| return None | |
| def map_shard(self, table: pa.Table, ctx: Any) -> Iterator[dict]: | |
| """Process one shard. Yield dict[str, list] batches. | |
| - Must include key_column in every yielded dict. | |
| - May skip rows (they just won't appear in the output). | |
| - Can yield one row at a time or accumulate mini-batches. | |
| """ | |
| ids = table["id"].to_pylist() | |
| texts = table["text"].to_pylist() | |
| # Process in mini-batches of 100 for efficiency | |
| batch_size = 100 | |
| for start in range(0, len(ids), batch_size): | |
| end = start + batch_size | |
| batch_ids = ids[start:end] | |
| batch_texts = texts[start:end] | |
| yield { | |
| "id": batch_ids, | |
| "text_upper": [self.prefix + t.upper() for t in batch_texts], | |
| "text_len": [len(t) for t in batch_texts], | |
| } | |
| # --- CLI wiring (optional, enables `python example.py run ...`) --- | |
| @classmethod | |
| def _add_cli_args(cls, parser: argparse.ArgumentParser) -> None: | |
| """Add subclass-specific CLI args.""" | |
| parser.add_argument("--prefix", default="", help="Prefix to add") | |
| parser.add_argument("--partition", default="lowpri", help="SLURM partition") | |
| parser.add_argument("--qos", default="low", help="SLURM QOS") | |
| parser.add_argument("--time-limit", default="4:00:00", help="SLURM time limit") | |
| @classmethod | |
| def _from_cli_args( | |
| cls, args: argparse.Namespace, shard_ids: list[int] | None = None | |
| ) -> "UpperCaseMapper": | |
| """Construct mapper from parsed CLI args.""" | |
| return cls( | |
| prefix=getattr(args, "prefix", ""), | |
| input_dataset=args.input_dataset, | |
| output=args.output, | |
| backend=SLURMBackend( | |
| task_granularity=TaskGranularity.CPU, # no GPU needed | |
| partition=getattr(args, "partition", "lowpri"), | |
| qos=getattr(args, "qos", "low"), | |
| time_limit=getattr(args, "time_limit", "4:00:00"), | |
| ), | |
| shard_ids=shard_ids, | |
| ) | |
| def _extra_run_single_args(self) -> list[str]: | |
| """Extra args forwarded to each SLURM worker's run_single command.""" | |
| return ["--prefix", self.prefix] if self.prefix else [] | |
| # --------------------------------------------------------------------------- | |
| # Usage examples | |
| # --------------------------------------------------------------------------- | |
| # | |
| # 1. FULL SLURM PIPELINE (submit → wait → merge): | |
| # | |
| # python lance_mapper_example.py run \ | |
| # --input-dataset /path/to/input.lance \ | |
| # --output /path/to/output.lance \ | |
| # --num-workers 50 | |
| # | |
| # This submits a SLURM array job with 50 workers, shows a progress bar, | |
| # and auto-merges into a final Lance dataset when all shards complete. | |
| # | |
| # | |
| # 2. SINGLE SHARD SMOKE TEST (run on a single node, no SLURM array): | |
| # | |
| # srun -c 4 --mem=8G -p lowpri --qos=dev -t 0:30:00 \ | |
| # pixi run python lance_mapper_example.py run_single \ | |
| # --input-dataset /path/to/input.lance \ | |
| # --output /path/to/test-output.lance \ | |
| # --shard-ids 0 | |
| # | |
| # | |
| # 3. CHECK PROGRESS: | |
| # | |
| # python lance_mapper_example.py status \ | |
| # --output /path/to/output.lance | |
| # | |
| # | |
| # 4. MANUAL MERGE (if run was interrupted): | |
| # | |
| # python lance_mapper_example.py merge \ | |
| # --output /path/to/output.lance | |
| # | |
| # # Or merge even if some shards are missing: | |
| # python lance_mapper_example.py merge \ | |
| # --output /path/to/output.lance --allow-incomplete | |
| # | |
| # | |
| # 5. PROGRAMMATIC USE (no CLI): | |
| # | |
| # import lance | |
| # ds = lance.dataset("input.lance") | |
| # | |
| # mapper = UpperCaseMapper( | |
| # input_dataset="input.lance", | |
| # output="output.lance", | |
| # backend=LocalBackend(num_workers=1), # local, no SLURM | |
| # ) | |
| # mapper.run_single() # process all shards locally | |
| # result_path = mapper.merge() # merge into output.lance | |
| # result = lance.dataset(str(result_path)) | |
| # print(result.to_table().to_pandas()) | |
| # | |
| # | |
| # 6. GPU WORKLOAD (e.g., model inference): | |
| # | |
| # class FoldMapper(LanceMapper): | |
| # key_column = "header" | |
| # rows_per_shard = 1000 | |
| # | |
| # def init(self): | |
| # return load_model("checkpoint.pt") | |
| # | |
| # def map_shard(self, table, model): | |
| # for row in table.to_pylist(): | |
| # result = model.predict(row["sequence"]) | |
| # yield { | |
| # "header": [row["header"]], | |
| # "prediction": [result.to_bytes()], | |
| # "score": [float(result.score)], | |
| # } | |
| # | |
| # # Submit with GPU resources: | |
| # mapper = FoldMapper( | |
| # input_dataset="sequences.lance", | |
| # output="predictions.lance", | |
| # backend=SLURMBackend( | |
| # task_granularity=TaskGranularity.GPU, # 1 GPU, 12 CPUs, 128G RAM | |
| # partition="h100-reserved", | |
| # qos="low", | |
| # time_limit="48:00:00", | |
| # ), | |
| # ) | |
| # mapper.run(num_workers=500) | |
| # --------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| UpperCaseMapper.cli() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment