Skip to content

Instantly share code, notes, and snippets.

@korkridake
Last active December 4, 2018 13:46
Show Gist options
  • Save korkridake/71b39e8ee7ae635cb4b1b66dbb1272c7 to your computer and use it in GitHub Desktop.
Save korkridake/71b39e8ee7ae635cb4b1b66dbb1272c7 to your computer and use it in GitHub Desktop.
# -------------------------------------------------------------------
# Handwritten Digits Classification Competition
# Source: https://mxnet.incubator.apache.org/tutorials/r/mnistCompetition.html
# MNIST is a handwritten digits image data set created by Yann LeCun. Every digit is represented by a 28 x 28 pixel image. It’s become a standard data set for testing classifiers on simple image input. A neural network is a strong model for image classification tasks. There’s a long-term hosted competition on Kaggle using this data set. This tutorial shows how to use MXNet to compete in this challenge.
# -------------------------------------------------------------------
# -------------------------------------------------------------------
# Load the dependency packages
# -------------------------------------------------------------------
library(mxnet)
# -------------------------------------------------------------------
# Load the data
# -------------------------------------------------------------------
train <- read.csv('../data/handwriting_recognition/train.csv', header=TRUE)
test <- read.csv('..//data/handwriting_recognition/test.csv', header=TRUE)
train <- data.matrix(train)
test <- data.matrix(test)
train.x <- train[,-1]
train.y <- train[,1]
# Every image is represented as a single row in train/test.
# The greyscale of each image falls in the range [0, 255]. Linearly transform it into [0,1]
# Transpose the input matrix to npixel x nexamples, which is the major format
# for columns accepted by MXNet (and the convention of R).
train.x <- t(train.x/255)
test <- t(test/255)
table(train.y)
## train.y
## 0 1 2 3 4 5 6 7 8 9
## 4132 4684 4177 4351 4072 3795 4137 4401 4063 4188
# input
data <- mx.symbol.Variable('data')
# first conv
conv1 <- mx.symbol.Convolution(data=data, kernel=c(5,5), num_filter=20)
tanh1 <- mx.symbol.Activation(data=conv1, act_type="tanh")
pool1 <- mx.symbol.Pooling(data=tanh1, pool_type="max",
kernel=c(2,2), stride=c(2,2))
# second conv
conv2 <- mx.symbol.Convolution(data=pool1, kernel=c(5,5), num_filter=50)
tanh2 <- mx.symbol.Activation(data=conv2, act_type="tanh")
pool2 <- mx.symbol.Pooling(data=tanh2, pool_type="max",
kernel=c(2,2), stride=c(2,2))
# first fullc
flatten <- mx.symbol.Flatten(data=pool2)
fc1 <- mx.symbol.FullyConnected(data=flatten, num_hidden=500)
tanh3 <- mx.symbol.Activation(data=fc1, act_type="tanh")
# second fullc
fc2 <- mx.symbol.FullyConnected(data=tanh3, num_hidden=10)
# loss
lenet <- mx.symbol.SoftmaxOutput(data=fc2)
train.array <- train.x
dim(train.array) <- c(28, 28, 1, ncol(train.x))
test.array <- test
dim(test.array) <- c(28, 28, 1, ncol(test))
devices <- mx.cpu()
mx.set.seed(0)
tic <- proc.time()
model <- mx.model.FeedForward.create(lenet,
X=train.array,
y=train.y,
ctx=devices,
num.round=10,
array.batch.size=100,
learning.rate=0.05,
momentum=0.9,
wd=0.00001,
eval.metric=mx.metric.accuracy,
epoch.end.callback=mx.callback.log.train.metric(100))
# Start training with 1 devices
# [1] Train-accuracy=0.535833334009207
# [2] Train-accuracy=0.97138096108323
# [3] Train-accuracy=0.983595249482564
# [4] Train-accuracy=0.989023818430446
# [5] Train-accuracy=0.992619053806577
# [6] Train-accuracy=0.995166671276092
# [7] Train-accuracy=0.996928574357714
# [8] Train-accuracy=0.99816666841507
# [9] Train-accuracy=0.998738096441541
# [10] Train-accuracy=0.999071429456983
# How long does it take to train the model?
print(proc.time() - tic)
# user system elapsed
# 1366.15 362.99 1004.2
# -----------------------------------------------------------------
# Making Predictions
# -----------------------------------------------------------------
preds <- predict(model, test.array)
pred.label <- max.col(t(preds)) - 1
submission <- data.frame(ImageId=1:ncol(test), Label=pred.label)
write.csv(submission,
file='submission_LeNet.csv',
row.names=FALSE,
quote=FALSE)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment