Created
September 1, 2022 02:55
-
-
Save chenyaofo/adca43b9163f7f76c9ce5d2b6cc21c5e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
This script aims to create webdataset tar shards with multi-processing. | |
''' | |
import os | |
import random | |
import datetime | |
from multiprocessing import Process | |
from torchvision.datasets.folder import ImageFolder | |
from webdataset import TarWriter | |
def make_wds_shards(pattern, num_shards, num_workers, samples, map_func, **kwargs): | |
random.shuffle(samples) | |
samples_per_shards = [samples[i::num_shards] for i in range(num_shards)] | |
shard_ids = list(range(num_shards)) | |
processes = [ | |
Process( | |
target=write_partial_samples, | |
args=( | |
pattern, | |
shard_ids[i::num_workers], | |
samples_per_shards[i::num_workers], | |
map_func, | |
kwargs | |
) | |
) | |
for i in range(num_workers)] | |
for p in processes: | |
p.start() | |
for p in processes: | |
p.join() | |
def write_partial_samples(pattern, shard_ids, samples, map_func, kwargs): | |
for shard_id, samples in zip(shard_ids, samples): | |
write_samples_into_single_shard(pattern, shard_id, samples, map_func, kwargs) | |
def write_samples_into_single_shard(pattern, shard_id, samples, map_func, kwargs): | |
fname = pattern % shard_id | |
print(f"[{datetime.datetime.now()}] start to write samples to shard {fname}") | |
stream = TarWriter(fname, **kwargs) | |
size = 0 | |
for i, item in enumerate(samples): | |
size += stream.write(map_func(item)) | |
if i % 1000 == 0: | |
print(f"[{datetime.datetime.now()}] complete to write {i:06d} samples to shard {fname}") | |
stream.close() | |
print(f"[{datetime.datetime.now()}] complete to write samples to shard {fname}!!!") | |
return size | |
def main(source, dest, num_shards, num_workers): | |
root = source | |
items = [] | |
dataset = ImageFolder(root=root, loader=lambda x: x) | |
for i in range(len(dataset)): | |
items.append(dataset[i]) | |
def map_func(item): | |
name, class_idx = item | |
with open(os.path.join(name), "rb") as stream: | |
image = stream.read() | |
sample = { | |
"__key__": os.path.splitext(os.path.basename(name))[0], | |
"jpg": image, | |
"cls": str(class_idx).encode("ascii") | |
} | |
return sample | |
make_wds_shards( | |
pattern=dest, | |
num_shards=num_shards, | |
num_workers=num_workers, | |
samples=items, | |
map_func=map_func, | |
) | |
if __name__ == "__main__": | |
source = [ | |
"/home/chenyf/datasets/ImageNet-C/gaussian_noise", | |
"/home/chenyf/datasets/ImageNet-C/impulse_noise", | |
"/home/chenyf/datasets/ImageNet-C/shot_noise", | |
# --- | |
"/home/chenyf/datasets/ImageNet-C/defocus_blur", | |
"/home/chenyf/datasets/ImageNet-C/glass_blur", | |
"/home/chenyf/datasets/ImageNet-C/motion_blur", | |
"/home/chenyf/datasets/ImageNet-C/zoom_blur", | |
# --- | |
"/home/chenyf/datasets/ImageNet-C/snow", | |
"/home/chenyf/datasets/ImageNet-C/frost", | |
"/home/chenyf/datasets/ImageNet-C/fog", | |
"/home/chenyf/datasets/ImageNet-C/brightness", | |
# --- | |
"/home/chenyf/datasets/ImageNet-C/contrast", | |
"/home/chenyf/datasets/ImageNet-C/elastic_transform", | |
"/home/chenyf/datasets/ImageNet-C/pixelate", | |
"/home/chenyf/datasets/ImageNet-C/jpeg_compression", | |
] | |
dest_prefix = "/home/chenyf/datasets/ImageNet-C-wds/" | |
for s in source: | |
for i in [1,2,3,4,5]: | |
print(f"[{datetime.datetime.now()}] start transfer {s}/{i}") | |
os.makedirs(os.path.join(dest_prefix, os.path.basename(s), f"severity={i}"), exist_ok=False) | |
main( | |
source=os.path.join(s, str(i)), | |
dest=os.path.join(dest_prefix, os.path.basename(s), f"severity={i}","%06d.tar"), | |
num_shards=32, | |
num_workers=4 | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment