Last active
January 24, 2025 04:29
-
-
Save YodaEmbedding/8803d95de072f12b4ff14ffd2b5bd7e5 to your computer and use it in GitHub Desktop.
Pregenerated Vimeo90K NumPy memmap dataset
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Download and extract the Vimeo90k dataset first: | |
mkdir -p vimeo90k | |
cd vimeo90k | |
wget http://data.csail.mit.edu/tofu/dataset/vimeo_triplet.zip | |
unzip vimeo_triplet.zip | |
wget http://data.csail.mit.edu/tofu/dataset/vimeo_septuplet.zip | |
unzip vimeo_septuplet.zip | |
cd .. | |
Then, run one of the following: | |
python generate_vimeo90k_npy_dataset.py --tuplet=3 --mode=image --indir="vimeo90k/vimeo_triplet" --outdir="vimeo90k/vimeo_triplet_npy" | |
python generate_vimeo90k_npy_dataset.py --tuplet=3 --mode=video --indir="vimeo90k/vimeo_triplet" --outdir="vimeo90k/vimeo_triplet_npy_video" | |
python generate_vimeo90k_npy_dataset.py --tuplet=7 --mode=image --indir="vimeo90k/vimeo_septuplet" --outdir="vimeo90k/vimeo_septuplet_npy" | |
python generate_vimeo90k_npy_dataset.py --tuplet=7 --mode=video --indir="vimeo90k/vimeo_septuplet" --outdir="vimeo90k/vimeo_septuplet_npy_video" | |
If the mode is "image", each frame is treated separately, and may | |
undergo different transformations. | |
If the mode is "video", all frames undergo the same transformation. | |
""" | |
import argparse | |
from pathlib import Path | |
import numpy as np | |
import torch | |
from compressai.datasets import Vimeo90kDataset | |
from torch.utils.data import DataLoader | |
from torchvision import transforms | |
PATCH_LENGTH = 256 | |
PATCH_SIZE = (PATCH_LENGTH, PATCH_LENGTH) | |
FILENAMES = { | |
"train": "training", | |
"valid": "validation", | |
} | |
def get_dataset(dataset_path, split, tuplet, mode): | |
crop = ( | |
transforms.RandomCrop(PATCH_SIZE) | |
if split == "train" | |
else transforms.CenterCrop(PATCH_SIZE) | |
) | |
chw_to_hwc = ( | |
lambda x: x.permute(1, 2, 0) | |
if mode == "image" | |
else x.permute(0, -2, -1, -3) | |
if mode == "video" | |
else None | |
) | |
transform = transforms.Compose( | |
[ | |
crop, | |
# lambda img: torch.from_numpy(np.array(img)), | |
chw_to_hwc, | |
# transforms.ToTensor(), # NOTE: Converts HWC -> CHW. | |
] | |
) | |
dataset = Vimeo90kDataset( | |
root=dataset_path, | |
transform=transform, | |
split=split, | |
tuplet=tuplet, | |
# The following parameters are experimental. | |
# Old versions of CompressAI do not have these, | |
# and behave as if mode="image". | |
mode=mode, | |
transform_frame=transforms.ToTensor(), # NOTE: Converts HWC -> CHW. | |
) | |
loader = DataLoader(dataset, batch_size=16, shuffle=False, num_workers=8) | |
return dataset, loader | |
def generate_npy_dataset(indir, outdir, split, tuplet, mode, epochs): | |
dataset, loader = get_dataset(indir, split, tuplet, mode) | |
out_filepath = Path(f"{outdir}/{FILENAMES[split]}.npy") | |
out_filepath.parent.mkdir(exist_ok=True) | |
print(f"Writing to {out_filepath}...") | |
if mode == "image": | |
shape = (epochs * len(dataset), *PATCH_SIZE, 3) | |
elif mode == "video": | |
shape = (epochs * len(dataset), tuplet, *PATCH_SIZE, 3) | |
x_out = np.memmap(out_filepath, dtype="uint8", mode="w+", shape=shape) | |
offset = 0 | |
for epoch in range(epochs): | |
for i, x in enumerate(loader): | |
x = (x * 255).to(torch.uint8) | |
print( | |
f"{split} | " | |
f"{epoch} / {epochs} epochs | " | |
f"{offset:6d} / {len(dataset)} items | " | |
f"{i:6d} / {len(loader)} batches | " | |
# For ensuring that random output is stable: | |
f"checksum: {x.min():3.0f} {x.max():3.0f} {x.to(float).mean():3.0f}" | |
) | |
x_out[offset : offset + len(x)] = x.numpy() | |
offset += len(x) | |
x_out.flush() | |
del x_out | |
def parse_args(): | |
parser = argparse.ArgumentParser(description="Generate Vimeo90k dataset") | |
parser.add_argument("--indir", default="vimeo90k/vimeo_triplet") | |
parser.add_argument("--outdir", default="vimeo90k/vimeo_triplet_npy") | |
parser.add_argument("--tuplet", type=int, default=3) | |
parser.add_argument("--mode", default="image", choices=["image", "video"]) | |
parser.add_argument("--seed", type=int, default=1234) | |
parser.add_argument("--epochs", type=int, default=1) | |
return parser.parse_args() | |
def main(): | |
print(__doc__) | |
args = parse_args() | |
torch.manual_seed(args.seed) | |
for split in ["train", "valid"]: | |
generate_npy_dataset( | |
indir=args.indir, | |
outdir=args.outdir, | |
split=split, | |
tuplet=args.tuplet, | |
mode=args.mode, | |
epochs=args.epochs, | |
) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment