Created
October 5, 2022 15:30
-
-
Save conradry/7376e7504456e9c194638c78d009523e to your computer and use it in GitHub Desktop.
Run length encode a zarr labelmap
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import zarr | |
import argparse | |
import numpy as np | |
from empanada.inference.tracker import InstanceTracker | |
from empanada.inference.rle import pan_seg_to_rle_seg, rle_seg_to_pan_seg | |
from tqdm import tqdm | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('zarr_store', type=str, metavar='zarr_store', help='Path to zarr directory') | |
parser.add_argument('zarr_key', type=str, metavar='zarr_key', | |
help='Name of dataset in the zarr store (e.g. panoptic_xy)') | |
parser.add_argument('label_divisor', type=int, metavar='label_divisor', | |
help='Label divisor to separate classes (max number objects 3D in napari)') | |
args = parser.parse_args() | |
# load the zarr and segmentation volume | |
data = zarr.open(args.zarr_store, mode='r') | |
seg = data[args.zarr_key] | |
# depending on how the zarr array was chunked | |
# it can be very slow to encode images slice by slice | |
# if possible, uncomment the next line to load the full seg | |
# into memory | |
#seg = np.array(seg) | |
# assuming mitonet | |
labels = [1] | |
thing_list = [1] | |
# create an instance tracker | |
tracker = InstanceTracker( | |
class_id=labels[0], label_divisor=args.label_divisor, | |
shape3d=seg.shape, axis='xy' | |
) | |
# run length encode segmentation slice by slice | |
for index, mask2d in tqdm(enumerate(np.split(seg[...], len(seg), axis=0)), total=len(seg)): | |
mask2d = np.squeeze(mask2d) | |
rle_seg = pan_seg_to_rle_seg( | |
mask2d, labels, args.label_divisor, thing_list, force_connected=False | |
) | |
assert np.allclose(mask2d, rle_seg_to_pan_seg(rle_seg, mask2d.shape)), \ | |
"RLE segmentation is wrong; are you sure label_divisor is correct?" | |
tracker.update(rle_seg[labels[0]], index) | |
# end tracking | |
tracker.finish() | |
# save the run length encoded segmentation to json | |
tracker.write_to_json(os.path.join(args.zarr_store, f'{args.zarr_key}.json')) | |
print('Finished.') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment