Skip to content

Instantly share code, notes, and snippets.

@juliensimon
Last active April 24, 2017 00:07
Show Gist options
  • Save juliensimon/ed9dd3b63cb180818368e937ce0f3e44 to your computer and use it in GitHub Desktop.
Save juliensimon/ed9dd3b63cb180818368e937ce0f3e44 to your computer and use it in GitHub Desktop.
Load CIFAR-10 in NDArrays
import mxnet as mx
def buildTrainingSet(path):
training_data = []
training_label = []
for f in ("data_batch_1", "data_batch_2", "data_batch_3", "data_batch_4", "data_batch_5"):
imgarray, lblarray = extractImagesAndLabels(path, f)
if not training_data:
training_data = imgarray
training_label = lblarray
else:
training_data = mx.nd.concatenate([training_data, imgarray])
training_label = mx.nd.concatenate([training_label, lblarray])
return training_data, training_label
path="cifar-10-batches-py/"
batch=128
training_data, training_label = buildTrainingSet(path)
train_iter = mx.io.NDArrayIter(
data=training_data, label=training_label, batch_size=batch, shuffle=True)
valid_data, valid_label = extractImagesAndLabels(path, "test_batch")
valid_iter = mx.io.NDArrayIter(
data=valid_data, label=valid_label, batch_size=batch, shuffle=True)
print training_data.shape
print training_label.shape
print valid_data.shape
print valid_label.shape
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment