Skip to content

Instantly share code, notes, and snippets.

@CMCDragonkai
Last active June 16, 2021 03:49
Show Gist options
  • Save CMCDragonkai/832b37432c00f494f0cd43fbe393dd21 to your computer and use it in GitHub Desktop.
Save CMCDragonkai/832b37432c00f494f0cd43fbe393dd21 to your computer and use it in GitHub Desktop.
CIFAR10 Loading #python
import os
import pickle
import numpy as np
from pathlib import Path
from typing import List, Tuple
# expect CIFAR10 to point to a directory like
# .
# ├── batches.meta
# ├── data_batch_1
# ├── data_batch_2
# ├── data_batch_3
# ├── data_batch_4
# ├── data_batch_5
# ├── readme.html
# └── test_batch
def load_cifar10() -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
cifar_dir = os.environ.get("CIFAR10", None)
if cifar_dir is None:
raise KeyError("CIFAR10 environment variable must exist for tests to run!")
else:
cifar_dir_path = Path(cifar_dir)
class_ids: List[str] = []
images: List[np.ndarray] = []
for data_path in cifar_dir_path.glob("*_batch*"):
with open(data_path, mode="rb") as f:
batch = pickle.load(f, encoding="bytes")
class_ids += batch[b"labels"]
images.append(
np.transpose(batch[b"data"].reshape(-1, 3, 32, 32), (0, 2, 3, 1))
)
class_ids_ = np.asarray(class_ids)
images_ = np.concatenate(images)
with open(cifar_dir_path / "batches.meta", mode="rb") as f:
meta = pickle.load(f, encoding="bytes")
class_names = np.asarray([n.decode("utf-8") for n in meta[b"label_names"]])
return (images_, class_ids_, class_names)
(images, class_ids, class_names) = load_cifar10()
print(images.shape) # (60000, 32, 32, 3)
print(class_ids.shape) # (60000,)
print(class_names.shape) # (10,)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment