Skip to content

Instantly share code, notes, and snippets.

@tmbdev
Created July 8, 2021 07:29
Show Gist options
  • Save tmbdev/2a27a0d2328cabae6d8e6eb06020ba9b to your computer and use it in GitHub Desktop.
Save tmbdev/2a27a0d2328cabae6d8e6eb06020ba9b to your computer and use it in GitHub Desktop.
# -*- Python -*-
# Perform imagenet-style augmentation and normalization on the shards
# on stdin, returning a new dataset on stdout.
import sys
from torchvision import transforms
import webdataset as wds
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
augment = transforms.Compose(
[
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]
)
dataset = wds.WebDataset("-").decode("pil")
sink = wds.TarWriter("-")
for sample in dataset:
print(sample.get("__key__"), file=sys.stderr)
sample["npy"] = augment(sample["jpg"]).numpy().astype("float16")
sink.write(sample)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment