Last active
May 8, 2021 18:03
-
-
Save amohant4/f10e4f4f8a3f37f58e79be09a9ef8f87 to your computer and use it in GitHub Desktop.
Example usage of gluon in MXNet with test case of LeNet for MNIST
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import mxnet as mx | |
from mxnet.gluon import nn | |
from mxnet.gluon.data.vision import datasets, transforms | |
from mxnet import init, gluon | |
import time | |
def create_lenet_using_sequential(): | |
""" | |
Method to return a lenet using nn.Sequential from | |
MXNet. nn.Sequential is a subclass of nn.Block. 1 | |
Arguments: None | |
Returns: An instance of type Sequential. | |
""" | |
net = nn.Sequential() | |
net.add( | |
nn.Conv2D(channels=6,kernel_size=5,activation='relu'), | |
nn.MaxPool2D(pool_size=2,strides=2), | |
nn.Conv2D(channels=16,kernel_size=3,activation='relu'), | |
nn.MaxPool2D(pool_size=2,strides=2), | |
nn.Dense(120, activation='relu'), | |
nn.Dense(84, activation='relu'), | |
nn.Dense(10)) | |
return net | |
class myLeNet(nn.Block): | |
""" | |
Custom class implementing LeNet. | |
This implementation is very flexible. | |
Things to remember: | |
- this class is a subclass of nn.Block | |
- You need to call __init__ of nn.Block in the constructor | |
- __init__ defines all the nodes in the graph | |
__ forward defines the forward fuction of the network | |
""" | |
def __init__(self, **kwargs): | |
super(myLeNet, self).__init__(**kwargs) | |
self.conv1 = nn.Conv2D(channels=6,kernel_size=5,activation='relu') | |
self.pool1 = nn.MaxPool2D(pool_size=2,strides=2) | |
self.conv2 = nn.Conv2D(channels=16,kernel_size=3,activation='relu') | |
self.pool1 = nn.MaxPool2D(pool_size=2,strides=2) | |
self.fc1 = nn.Dense(120, activation='relu') | |
self.fc2 = nn.Dense(84, activation='relu') | |
self.fc3 = nn.Dense(10) | |
def forward(self,x): | |
return self.fc3(self.fc2(self.fc1(self.pool1(self.conv2(self.pool1(self.conv1(x))))))) | |
# Training Dataset creation ~~~ | |
mnist_train = datasets.FashionMNIST(train=True) | |
# Transform inputs, augmentation, normalization etc is done here | |
transformer = transforms.Compose([transforms.ToTensor(), transforms.Normalize(0.13,0.31)]) | |
# apply the transformation to each image in the dataset | |
mnist_train = mnist_train.transform_first(transformer) | |
# Data loader to facilitate loading of data during training | |
# Mark num_workers is 4, more worker threads are needed for complicated transforms and bigger batch size | |
batch_size = 256 | |
train_data = gluon.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=4) | |
# Validation Dataset creation ~~~ | |
mnist_valid = gluon.data.vision.FashionMNIST(train=False) | |
valid_data = gluon.data.DataLoader(mnist_valid.transform_first(transformer), batch_size=batch_size, num_workers=4) | |
# Create instance of the network and other necesities for training ~~~ | |
net = myLeNet() | |
net.initialize(init=init.Xavier()) # Initialize the parameters using Xavier initialization | |
softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss() # Define the loss | |
# create a trainer with SGD training and learning rate of 0.1 | |
trainer = gluon.Trainer(net.collect_params(),'sgd',{'learning_rate':0.1}) | |
# Metric function ~~~ | |
def acc(output,label): | |
""" | |
Utility function to return the accuracy of the network for given outputs and labels. | |
Arguments: | |
output: output from the network (in this case the last fc layer) | |
label: golden output as per the dataset | |
returns: | |
accuracy (scalar): average number of times the prediction is correct | |
""" | |
return (output.argmax(axis=1) == label.astype('float32')).mean().asscalar() | |
# Training Loop ~~~ | |
for epoch in range(10): | |
train_loss, train_acc, valid_acc = 0.,0.,0. | |
tic = time.time() | |
for data, label in train_data: # Iterate through the training dataset | |
with mx.autograd.record(): # Record gradient of error | |
output = net(data) # Forward pass | |
loss = softmax_cross_entropy(output, label) # get Loss | |
loss.backward() # Back propagate | |
trainer.step(batch_size) | |
train_loss += loss.mean().asscalar() | |
train_acc += acc(output, label) | |
for data, label in valid_data: # Validation Loop ~~ | |
valid_acc += acc(net(data), label) | |
# Log metrics to std output ~~ | |
print("Epoch %d: loss %.3f, train acc %.3f, test acc %.3f, in %.1f sec" % ( | |
epoch, train_loss/len(train_data), train_acc/len(train_data), | |
valid_acc/len(valid_data), time.time()-tic)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment