Skip to content

Instantly share code, notes, and snippets.

@juliensimon
Last active April 28, 2017 12:36
Show Gist options
  • Save juliensimon/e99c7842bb61926282cd1bb03dc4d2e5 to your computer and use it in GitHub Desktop.
Save juliensimon/e99c7842bb61926282cd1bb03dc4d2e5 to your computer and use it in GitHub Desktop.
Build a new ResNext-101 on CIFAR-10 with AdaDelta
import mxnet as mx
import numpy as np
import cv2, cPickle, logging
from symbols import resnext
logging.basicConfig(level=logging.DEBUG)
path="cifar-10-batches-py/"
examples=50000
batch=128
epochs=300
train_iter = mx.io.ImageRecordIter(
path_imgrec="cifar10_train.rec", data_name="data", label_name="softmax_label", batch_size=batch,
data_shape=(3,32,32), shuffle=True)
valid_iter = mx.io.ImageRecordIter(
path_imgrec="cifar10_val.rec", data_name="data", label_name="softmax_label", batch_size=batch,
data_shape=(3,32,32))
# Use ResNext-110
sym = resnext.get_symbol(10, 110, "3,32,32")
mod = mx.mod.Module(symbol=sym, context=(mx.gpu(0), mx.gpu(1), mx.gpu(2), mx.gpu(3)))
mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label)
mod.init_params(initializer=mx.init.Xavier())
mod.fit(train_iter, eval_data=valid_iter, optimizer='adadelta', num_epoch=epochs)
mod.save_checkpoint("resnext-101-symbol", epochs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment