Skip to content

Instantly share code, notes, and snippets.

@cloneofsimo
Created May 30, 2024 20:39
Show Gist options
  • Save cloneofsimo/cefe61fb4b437ce366e4f25f37fbb7b8 to your computer and use it in GitHub Desktop.
Save cloneofsimo/cefe61fb4b437ce366e4f25f37fbb7b8 to your computer and use it in GitHub Desktop.
lance dataset concurrent writes?
import lance
import pyarrow as pa
import numpy as np
import time
import os
import multiprocessing as mp
def producer(N =1 ):
yield pa.RecordBatch.from_arrays([
pa.array([np.random.rand(32*32*4).astype(np.float32) for _ in range(N)]),
pa.array(["A cat sitting on a bed" for _ in range(N)]),
pa.array(["cat_on_bed" for _ in range(N)]),
pa.array([np.random.randint(0, 256, 2048 * 256).astype(np.uint8) for _ in range(N)])
], ["vae_256x256_latents", "caption", "uid", "t5_xl_embeddings"])
def initialize_lance(out_root):
schema = pa.schema([pa.field("vae_256x256_latents", pa.list_(pa.float32())),
pa.field("caption", pa.string()),
pa.field("uid", pa.string()),
pa.field("t5_xl_embeddings", pa.list_(pa.uint8()))])
dataset = lance.write_dataset(producer(), out_root, schema=schema, mode='overwrite')
print(f"Initialized Lance at '{out_root}'")
print(dataset.to_table().to_pandas())
dataset = lance.write_dataset(producer(), out_root, schema=schema, mode='append')
# print("---Test 1: deleted Lance")
# dataset.delete("uid = 'cat_on_bed'")
# print(dataset.to_table())
print("---Test 2. Bulk write via single process, checking write speed.")
N = 1000
stuff_to_write = pa.RecordBatch.from_arrays([
pa.array([np.random.rand(32*32*4).astype(np.float32) for _ in range(N)]),
pa.array(["A cat sitting on a bed" for _ in range(N)]),
pa.array(["cat_on_bed" for _ in range(N)]),
pa.array([np.random.randint(0, 256, 2048 * 256).astype(np.uint8) for _ in range(N)])
], ["vae_256x256_latents", "caption", "uid", "t5_xl_embeddings"])
t0 = time.time()
dataset = lance.write_dataset(stuff_to_write, out_root, schema=schema, mode='append')
print(f"Time taken: {time.time() - t0:.2f} seconds")
print(dataset.to_table().to_pandas())
for numproc in [2, 4, 8]:
print(f"---Test 3: Num Procs {numproc}. Bulk write via multiple processes, checking write speed.")
t0 = time.time()
with mp.Pool(numproc) as pool:
pool.map(write_to_dataset, [(stuff_to_write, out_root, schema)] * numproc)
print(f"Time taken: {time.time() - t0:.2f} seconds")
print(dataset.to_table())
def write_to_dataset(args):
stuff_to_write, out_root, schema = args
dataset = lance.dataset(out_root)
dataset.merge_insert("uid").when_not_matched_insert_all().execute(stuff_to_write)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Initialize Lance database.")
parser.add_argument(
"--out_root",
type=str,
required=True,
help="Output root directory for Lance."
)
args = parser.parse_args()
initialize_lance(args.out_root)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment