Created
February 19, 2023 12:15
-
-
Save gsakkis/cc814007dbfb555cbfe237efc570c6e6 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import multiprocessing | |
import os | |
from timeit import default_timer | |
from typing import Optional, Sequence, Type, Union | |
import numpy as np | |
import tiledb | |
from tiledb.ml.readers.pytorch import PyTorchTileDBDataLoader | |
from tiledb.ml.readers.tensorflow import TensorflowTileDBDataset | |
def rand_mask(shape: Sequence[int], sparsity: float) -> np.ndarray: | |
rnd = np.random.default_rng() | |
mask = np.zeros(shape, dtype=np.bool_) | |
non_zero_size = int(mask.size * (1 - sparsity)) | |
non_zero_idxs = rnd.choice(mask.size, size=non_zero_size) | |
mask.flat[non_zero_idxs] = True | |
return mask | |
def seq_array(shape: Sequence[int], dtype: np.dtype) -> np.ndarray: | |
return ( | |
np.repeat(np.arange(shape[0]), np.prod(shape[1:])).reshape(shape).astype(dtype) | |
) | |
def create_array( | |
folder: str, | |
shape: Sequence[int], | |
tiles: Sequence[int], | |
attr_dtypes: Sequence[np.dtype], | |
sparsity: float, | |
) -> None: | |
assert len(shape) == len(tiles) | |
assert all(shape[i] >= tiles[i] for i in range(len(shape))) | |
sparse = sparsity > 0 | |
filename = f"sparse-{sparsity}" if sparse else "dense" | |
filename += f"-shape_{'_'.join(map(str, shape))}" | |
filename += f"-tiles_{'_'.join(map(str, tiles))}" | |
filename += f"-attrs_{'_'.join(map(str, attr_dtypes))}" | |
uri = os.path.join(folder, filename) | |
print(uri) | |
if os.path.exists(uri): | |
return | |
schema = tiledb.ArraySchema( | |
sparse=sparse, | |
domain=tiledb.Domain( | |
*[ | |
tiledb.Dim( | |
name=f"d{i}", | |
domain=(0, shape[i] - 1), | |
tile=tile, | |
dtype=np.int32, | |
) | |
for i, tile in enumerate(tiles) | |
] | |
), | |
attrs=[ | |
tiledb.Attr(name=f"a{i}", dtype=dtype) | |
for i, dtype in enumerate(attr_dtypes) | |
], | |
) | |
tiledb.Array.create(uri, schema) | |
with tiledb.open(uri, "w") as tiledb_array: | |
attr_data = [seq_array(shape, dtype) for dtype in attr_dtypes] | |
idx = np.nonzero(rand_mask(shape, sparsity)) if sparse else slice(None) | |
tiledb_array[idx] = {f"a{i}": data[idx] for i, data in enumerate(attr_data)} | |
def read_dataset( | |
cls: Type[Union[TensorflowTileDBDataset, PyTorchTileDBDataLoader]], | |
x_uri: str, | |
y_uri: str, | |
batch_size: int, | |
buffer_bytes: int, | |
shuffle_buffer_size: int, | |
prefetch: Optional[int] = None, | |
num_workers: Optional[int] = None, | |
sparse_layout: Optional[str] = None, | |
x_key_dim: Optional[str] = None, | |
y_key_dim: Optional[str] = None, | |
config: Optional[tiledb.Config] = None, | |
) -> None: | |
kwargs = dict( | |
buffer_bytes=buffer_bytes, | |
batch_size=batch_size, | |
shuffle_buffer_size=shuffle_buffer_size, | |
x_key_dim=x_key_dim, | |
y_key_dim=y_key_dim, | |
) | |
if prefetch is not None: | |
kwargs["prefetch"] = prefetch | |
if num_workers is not None: | |
kwargs["num_workers"] = num_workers | |
if sparse_layout is not None: | |
kwargs["csr"] = sparse_layout == "csr" | |
x_array = tiledb.open(x_uri, config=config) | |
y_array = tiledb.open(y_uri, config=config) | |
with x_array, y_array: | |
# print(x_array.schema, x_array.nonempty_domain()) | |
# print(y_array.schema, y_array.nonempty_domain()) | |
loader = cls(x_array, y_array, **kwargs) | |
for _ in range(1): | |
time = default_timer() | |
print(sum(1 for _ in loader)) | |
print(f"Elapsed time={default_timer() - time:.2f}s") | |
def main() -> None: | |
parser = argparse.ArgumentParser( | |
description="Benchmark reading row slices from a TileDB array", | |
) | |
subparsers = parser.add_subparsers(help="command", dest="cmd") | |
parser_create = subparsers.add_parser( | |
"create", | |
help="create array", | |
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |
) | |
parser_create.add_argument("dir", help="parent directory to store the array") | |
parser_create.add_argument( | |
"--sparsity", | |
type=float, | |
default=0.0, | |
help="Array sparsity in the [0, 1) range. If non zero, the array will " | |
"be sparse with the given sparsity ratio.", | |
) | |
parser_create.add_argument( | |
"--shape", | |
type=lambda s: tuple(map(int, s.split(","))), | |
default=(100_000, 100, 10), | |
help="comma separated ints of the array dimension sizes", | |
) | |
parser_create.add_argument( | |
"--tiles", | |
type=lambda s: tuple(map(int, s.split(","))), | |
default=(2**10, 2**5, 2**2), | |
help="comma separated ints of the array tile extents", | |
) | |
parser_create.add_argument( | |
"--dtypes", | |
type=lambda s: tuple(map(np.dtype, s.split(","))), | |
default=(np.dtype("uint8"), np.dtype("float32")), | |
help="comma separated strings of the array attribute dtypes", | |
) | |
parser_read = subparsers.add_parser("read", help="read array") | |
parser_read.add_argument("x_uri", help="Training data TileDB URI") | |
parser_read.add_argument("y_uri", help="Labels TileDB URI") | |
loader_type = parser_read.add_mutually_exclusive_group(required=True) | |
loader_type.add_argument( | |
"--tensorflow", | |
dest="type", | |
action="store_const", | |
const=TensorflowTileDBDataset, | |
help="Use Tensorflow loader", | |
) | |
loader_type.add_argument( | |
"--pytorch", | |
dest="type", | |
action="store_const", | |
const=PyTorchTileDBDataLoader, | |
help="Use PyTorch loader", | |
) | |
parser_read.add_argument( | |
"-b", | |
"--batch_size", | |
type=int, | |
default=32, | |
help="Size of each batch", | |
) | |
parser_read.add_argument( | |
"-B", | |
"--buffer_bytes", | |
type=int, | |
help="Maximum size (in bytes) of memory to allocate for reading from each array", | |
) | |
parser_read.add_argument( | |
"-s", | |
"--shuffle_buffer_size", | |
type=int, | |
default=0, | |
help="Shuffling buffer size (or 0 for no shuffling)", | |
) | |
parser_read.add_argument( | |
"-w", | |
"--num_workers", | |
type=int, | |
help="Number of workers to use for data loading (PyTorch only)", | |
) | |
parser_read.add_argument( | |
"-p", | |
"--prefetch", | |
type=int, | |
help="Number of batches to prefetch", | |
) | |
parser_read.add_argument( | |
"--sparse_layout", | |
choices=("csr", "coo"), | |
help="Sparse layout for for 2d sparse arrays (Pytorch only)", | |
) | |
parser_read.add_argument( | |
"--x_key_dim", | |
help="X key dimension", | |
) | |
parser_read.add_argument( | |
"--y_key_dim", | |
help="Y key dimension", | |
) | |
parser_read.add_argument( | |
"--memory_budget", type=int, help="sm.memory_budget config value" | |
) | |
parser_read.add_argument( | |
"--max_incomplete_retries", | |
type=int, | |
help="py.max_incomplete_retries config value", | |
) | |
parser_read.add_argument( | |
"--init_buffer_bytes", | |
type=int, | |
help="py.init_buffer_bytes config value", | |
) | |
parser_read.add_argument("--stats", action="store_true", help="dump TileDB stats") | |
opts = parser.parse_args() | |
if opts.cmd == "create": | |
create_array( | |
folder=opts.dir, | |
sparsity=opts.sparsity, | |
shape=opts.shape, | |
tiles=opts.tiles, | |
attr_dtypes=opts.dtypes, | |
) | |
elif opts.cmd == "read": | |
config = tiledb.Config() | |
# config["sm.compute_concurrency_level"] = 8 | |
if opts.memory_budget is not None: | |
config["sm.memory_budget"] = opts.memory_budget | |
if opts.max_incomplete_retries is not None: | |
config["py.max_incomplete_retries"] = opts.max_incomplete_retries | |
if opts.init_buffer_bytes is not None: | |
config["py.init_buffer_bytes"] = opts.init_buffer_bytes | |
if opts.num_workers: | |
multiprocessing.set_start_method("forkserver") | |
read_dataset( | |
opts.type, | |
opts.x_uri, | |
opts.y_uri, | |
opts.batch_size, | |
opts.buffer_bytes, | |
opts.shuffle_buffer_size, | |
opts.prefetch, | |
opts.num_workers, | |
opts.sparse_layout, | |
opts.x_key_dim, | |
opts.y_key_dim, | |
config, | |
) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment