Created
May 13, 2018 08:06
-
-
Save wkcn/b17dd1bc01c363fe2f244eaa29ceb94a to your computer and use it in GitHub Desktop.
mnist_count_time
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import mxnet as mx | |
import count_time | |
mnist = mx.test_utils.get_mnist() | |
# Fix the seed | |
mx.random.seed(42) | |
# Set the compute context, GPU is available otherwise CPU | |
ctx = mx.gpu() if mx.test_utils.list_gpus() else mx.cpu() | |
batch_size = 100 | |
train_iter = mx.io.NDArrayIter(mnist['train_data'], mnist['train_label'], batch_size, shuffle=True) | |
val_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size) | |
data = mx.sym.var('data') | |
data, t = mx.sym.Custom(op_type = 'CountTimeOP', data = data, first = True, cname = 't0') | |
# Flatten the data from 4-D shape into 2-D (batch_size, num_channel*width*height) | |
data = mx.sym.flatten(data=data) | |
# The first fully-connected layer and the corresponding activation function | |
fc1 = mx.sym.FullyConnected(data=data, num_hidden=128) | |
act1 = mx.sym.Activation(data=fc1, act_type="relu") | |
act1, t = mx.sym.Custom(op_type = 'CountTimeOP', data = act1, t = t, first = False, cname = 't1') | |
# The second fully-connected layer and the corresponding activation function | |
fc2 = mx.sym.FullyConnected(data=act1, num_hidden = 64) | |
act2 = mx.sym.Activation(data=fc2, act_type="relu") | |
act2, t = mx.sym.Custom(op_type = 'CountTimeOP', data = act2, t = t, first = False, cname = 't2') | |
# MNIST has 10 classes | |
fc3 = mx.sym.FullyConnected(data=act2, num_hidden=10) | |
# Softmax with cross entropy loss | |
mlp = mx.sym.SoftmaxOutput(data=fc3, name='softmax') | |
import logging | |
logging.getLogger().setLevel(logging.DEBUG) # logging to stdout | |
# create a trainable module on compute context | |
mlp_model = mx.mod.Module(symbol=mlp, context=ctx) | |
mlp_model.fit(train_iter, # train data | |
eval_data=val_iter, # validation data | |
optimizer='sgd', # use SGD to train | |
optimizer_params={'learning_rate':0.1}, # use fixed learning rate | |
eval_metric='acc', # report accuracy during training | |
batch_end_callback = mx.callback.Speedometer(batch_size, 100), # output progress for each 100 data batches | |
num_epoch=10) # train for at most 10 dataset passes |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment