Created
May 30, 2024 20:39
-
-
Save cloneofsimo/cefe61fb4b437ce366e4f25f37fbb7b8 to your computer and use it in GitHub Desktop.
lance dataset concurrent writes?
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import lance | |
| import pyarrow as pa | |
| import numpy as np | |
| import time | |
| import os | |
| import multiprocessing as mp | |
| def producer(N =1 ): | |
| yield pa.RecordBatch.from_arrays([ | |
| pa.array([np.random.rand(32*32*4).astype(np.float32) for _ in range(N)]), | |
| pa.array(["A cat sitting on a bed" for _ in range(N)]), | |
| pa.array(["cat_on_bed" for _ in range(N)]), | |
| pa.array([np.random.randint(0, 256, 2048 * 256).astype(np.uint8) for _ in range(N)]) | |
| ], ["vae_256x256_latents", "caption", "uid", "t5_xl_embeddings"]) | |
| def initialize_lance(out_root): | |
| schema = pa.schema([pa.field("vae_256x256_latents", pa.list_(pa.float32())), | |
| pa.field("caption", pa.string()), | |
| pa.field("uid", pa.string()), | |
| pa.field("t5_xl_embeddings", pa.list_(pa.uint8()))]) | |
| dataset = lance.write_dataset(producer(), out_root, schema=schema, mode='overwrite') | |
| print(f"Initialized Lance at '{out_root}'") | |
| print(dataset.to_table().to_pandas()) | |
| dataset = lance.write_dataset(producer(), out_root, schema=schema, mode='append') | |
| # print("---Test 1: deleted Lance") | |
| # dataset.delete("uid = 'cat_on_bed'") | |
| # print(dataset.to_table()) | |
| print("---Test 2. Bulk write via single process, checking write speed.") | |
| N = 1000 | |
| stuff_to_write = pa.RecordBatch.from_arrays([ | |
| pa.array([np.random.rand(32*32*4).astype(np.float32) for _ in range(N)]), | |
| pa.array(["A cat sitting on a bed" for _ in range(N)]), | |
| pa.array(["cat_on_bed" for _ in range(N)]), | |
| pa.array([np.random.randint(0, 256, 2048 * 256).astype(np.uint8) for _ in range(N)]) | |
| ], ["vae_256x256_latents", "caption", "uid", "t5_xl_embeddings"]) | |
| t0 = time.time() | |
| dataset = lance.write_dataset(stuff_to_write, out_root, schema=schema, mode='append') | |
| print(f"Time taken: {time.time() - t0:.2f} seconds") | |
| print(dataset.to_table().to_pandas()) | |
| for numproc in [2, 4, 8]: | |
| print(f"---Test 3: Num Procs {numproc}. Bulk write via multiple processes, checking write speed.") | |
| t0 = time.time() | |
| with mp.Pool(numproc) as pool: | |
| pool.map(write_to_dataset, [(stuff_to_write, out_root, schema)] * numproc) | |
| print(f"Time taken: {time.time() - t0:.2f} seconds") | |
| print(dataset.to_table()) | |
| def write_to_dataset(args): | |
| stuff_to_write, out_root, schema = args | |
| dataset = lance.dataset(out_root) | |
| dataset.merge_insert("uid").when_not_matched_insert_all().execute(stuff_to_write) | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser(description="Initialize Lance database.") | |
| parser.add_argument( | |
| "--out_root", | |
| type=str, | |
| required=True, | |
| help="Output root directory for Lance." | |
| ) | |
| args = parser.parse_args() | |
| initialize_lance(args.out_root) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment