Last active
July 17, 2019 18:31
-
-
Save amohant4/b47a6563637c6a8bbc133b5f5bbce4c9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import mxnet as mx | |
from mxnet import init | |
# Create a mxnet symbol for the graph ~~~ | |
def create_net_moduleAPI(): | |
""" | |
Method to create a symbol for LeNet in MXNet. | |
Arguments: None | |
Returns: mx.sym for LeNet | |
""" | |
net = mx.sym.Variable('data') | |
net = mx.sym.Convolution(net, name='conv1', num_filter=6, kernel=(5,5)) | |
net = mx.sym.Activation(net, name='conv1_relu', act_type="relu") | |
net = mx.sym.Pooling(net, name='maxpool1', pool_type='max', kernel=(2,2), stride=(2,2)) | |
net = mx.sym.Convolution(net, name='conv2', num_filter=16, kernel=(5,5)) | |
net = mx.sym.Activation(net, name='conv2_relu', act_type="relu") | |
net = mx.sym.Pooling(net, name='maxpool2', pool_type='max', kernel=(2,2), stride=(2,2)) | |
net = mx.sym.FullyConnected(net, name='fc1', num_hidden=120) | |
net = mx.sym.Activation(net, name='fc1relu', act_type="relu") | |
net = mx.sym.FullyConnected(net, name='fc2', num_hidden=84) | |
net = mx.sym.Activation(net, name='fc2relu', act_type="relu") | |
net = mx.sym.FullyConnected(net, name='fc3', num_hidden=10) | |
net = mx.sym.SoftmaxOutput(net, name='softmax') | |
return net | |
# Use the symbol to get create a Module object. ~~~ | |
mod = mx.mod.Module(symbol=create_net_moduleAPI(), # Symbol of the graph | |
context=mx.cpu(), # mx.gpu() if you got GPUs | |
data_names=['data'], # name of the symbol which has | |
label_names=['softmax_label']) # final output label. '_label' is appended by mxnet | |
# Get the dataset ~~~~ | |
# Using MXNet's predefined utilities to make life easier | |
mnist = mx.test_utils.get_mnist() | |
batch_size = 256 | |
train_iter = mx.io.NDArrayIter(mnist['train_data'], mnist['train_label'], batch_size, shuffle=True) | |
val_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size) | |
# Bind the module to data (infers shape of nodes) | |
mod.bind(data_shapes=[('data',(256,1,28,28))], label_shapes=[('softmax_label',(256,))]) | |
#mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label) | |
mod.init_params(initializer=init.Xavier()) # Initialize parameters | |
mod.init_optimizer(optimizer='sgd', optimizer_params=(('learning_rate', 0.1), )) # initialize Optimizer | |
metric = mx.metric.create('acc') # We are using accuracy as an metric here to see how good we are doing. | |
# Training loop ~~~~ | |
for epoch in range(5): | |
train_iter.reset() # Reset training data iter to start fresh for this epoch | |
metric.reset() # Reset metrics so as to accumulate for this epoch | |
for batch in train_iter: | |
mod.forward(batch, is_train=True) # Foward pass on the batch | |
mod.update_metric(metric, batch.label) # Update accuracy on the batch | |
mod.backward() # Trickle the error back and get gradient on error | |
mod.update() # Update the parameters | |
print('Epoch %d, Training %s' % (epoch, metric.get())) # Print accuracy |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment