Last active
April 28, 2021 08:31
-
-
Save fzliu/64821d31816bce595a4bbd98588b37f5 to your computer and use it in GitHub Desktop.
Create image data from raw CIFAR-10 files.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
make_cifar10.py: Create training data from raw CIFAR-10 batches. | |
""" | |
import pickle | |
import glob | |
import os | |
import numpy as np | |
from skimage.io import imsave | |
PIXELS_DIR = "pixel_data" | |
LABEL_FILE = "labels.txt" | |
def unpack_file(fname): | |
""" | |
Unpacks a CIFAR-10 file. | |
""" | |
with open(fname, "rb") as f: | |
result = pickle.load(f, encoding="bytes") | |
return result | |
def save_as_image(img_flat, fname): | |
""" | |
Saves a data blob as an image file. | |
""" | |
# consecutive 1024 entries store color channels of 32x32 image | |
img_R = img_flat[0:1024].reshape((32, 32)) | |
img_G = img_flat[1024:2048].reshape((32, 32)) | |
img_B = img_flat[2048:3072].reshape((32, 32)) | |
img = np.dstack((img_R, img_G, img_B)) | |
imsave(os.path.join(PIXELS_DIR, fname), img) | |
def main(): | |
""" | |
Entry point. | |
""" | |
labels = {} | |
# use "data_batch_*" for just the training set | |
for fname in glob.glob("*_batch*"): | |
data = unpack_file(fname) | |
for i in range(10000): | |
img_flat = data[b"data"][i] | |
fname = data[b"filenames"][i].decode() | |
label = data[b"labels"][i] | |
# save the image and store the label | |
save_as_image(img_flat, fname) | |
labels[fname] = label | |
# write out labels file | |
with open(LABEL_FILE, "w") as f: | |
for (fname, label) in labels.items(): | |
f.write("{0} {1}\n".format(fname, label)) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment