Skip to content

Instantly share code, notes, and snippets.

@rsokl
Created January 24, 2025 16:39
Show Gist options
  • Save rsokl/707ec9a25407fce2ed54bda7ec98691c to your computer and use it in GitHub Desktop.
Save rsokl/707ec9a25407fce2ed54bda7ec98691c to your computer and use it in GitHub Desktop.
mnist.py
# pip install datasets
from typing import Optional
from datasets import load_dataset
def zero_or_one_dataset(
split: str = "train",
batch_size: int = 4,
subsample_size: Optional[int] = None,
):
assert split in ("test","train")
if subsample_size is None:
subsample_size = float("inf")
else:
dataset_size_by_split = {"train": 12665, "test": 2115}
assert subsample_size <= dataset_size_by_split[split]
img_batch, label_batch = [], []
while True:
ctr = 0
for example in load_dataset("mnist")[split]:
img, label = example["image"], example["label"]
if label not in {0, 1}:
continue
ctr += 1
img = np.array(img).reshape(-1)
img_batch.append(img)
label_batch.append(label)
if len(img_batch) == batch_size:
img_batch = (np.stack(img_batch) - 31.0) / 77.0
yield np.stack(img_batch), np.array(label_batch)
img_batch, label_batch = [], []
if ctr >= subsample_size:
break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment