Last active
April 19, 2017 14:46
-
-
Save facundoq/93a9d90c52c94aa9b329c47a4150d288 to your computer and use it in GitHub Desktop.
Training a lenet on mnist using MXNet.jl, without using the `fit` function
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using MXNet | |
function accuracy(predicted_probability::mx.NDArray,label::mx.NDArray) | |
predicted_label=copy(mx.argmax(predicted_probability,axis=1)) | |
julia_label=copy(label) | |
result=sum(predicted_label.==julia_label) | |
result | |
end | |
function initialize_model(ex,model) | |
#todo remove the model as parameter and get the arg_names from ex | |
initializer=mx.XavierInitializer(distribution = mx.xv_uniform, regularization = mx.xv_avg, magnitude = 3) | |
arg_names=mx.list_arguments(model) | |
args=Dict(zip(arg_names,ex.arg_arrays)) | |
grads=Dict(zip(arg_names,ex.grad_arrays)) | |
for (k,v) in args | |
if grads[k]!=nothing && !endswith(string(k),"label") | |
mx.init(initializer,k,v) | |
end | |
end | |
end | |
function update_weights(updater::Function,model,ex,iteration::Int) | |
arg_names=keys(ex.arg_dict) | |
arg_names=mx.list_arguments(model) | |
args = Dict(zip(arg_names,ex.arg_arrays )) | |
grads= Dict(zip(arg_names,ex.grad_arrays)) | |
for (k,v) in args | |
if grads[k]!=nothing && !endswith(string(k),"label") | |
weight_index=findfirst(arg_names,k) | |
updater(weight_index,grads[k],args[k]) | |
end | |
end | |
end | |
#-------------------------------------------------------------------------------- | |
# define lenet | |
# input | |
data = mx.Variable(:data) | |
# first conv | |
conv1 = @mx.chain mx.Convolution(data, kernel=(5,5), num_filter=20) => | |
mx.Activation(act_type=:tanh) => | |
mx.Pooling(pool_type=:max, kernel=(2,2), stride=(2,2)) | |
# second conv | |
conv2 = @mx.chain mx.Convolution(conv1, kernel=(5,5), num_filter=50) => | |
mx.Activation(act_type=:tanh) => | |
mx.Pooling(pool_type=:max, kernel=(2,2), stride=(2,2)) | |
# first fully-connected | |
fc1 = @mx.chain mx.Flatten(conv2) => | |
mx.FullyConnected(num_hidden=500) => | |
mx.Activation(act_type=:tanh) | |
# second fully-connected | |
fc2 = mx.FullyConnected(fc1, num_hidden=10) | |
# softmax loss | |
lenet = mx.SoftmaxOutput(fc2, name=:softmax) | |
#-------------------------------------------------------------------------------- | |
# load data | |
batch_size = 512 | |
include("mnist-data.jl") | |
train_provider, eval_provider = get_mnist_providers(batch_size; flat=false) | |
#-------------------------------------------------------------------------------- | |
data_shape=(28,28,1,batch_size) | |
ex=mx.simple_bind(lenet,mx.cpu(),data=data_shape) | |
initialize_model(ex,lenet) | |
optimizer = mx.SGD(lr=0.05, momentum=0.9, weight_decay=0.00001) | |
op_state = mx.OptimizationState(batch_size) | |
optimizer.state = op_state | |
updater = mx.get_updater(optimizer) | |
n_epoch=10 | |
for epoch in range(1,n_epoch) | |
train_acc=0 | |
val_acc=0 | |
nbatch=0 | |
op_state.curr_epoch = epoch | |
op_state.curr_batch = 0 | |
println("Epoch $epoch") | |
for batch in train_provider | |
data=mx.get(train_provider,batch,:data) | |
ex.arg_dict[:data][:]=data | |
label=mx.get(train_provider,batch,:softmax_label) | |
ex.arg_dict[:softmax_label][:]=label | |
mx.forward(ex,is_train=true) | |
predicted_probability=ex.outputs[1] | |
mx.backward(ex) | |
update_weights(updater,lenet,ex,nbatch) | |
train_acc+=accuracy(predicted_probability,label) | |
nbatch+=1 | |
op_state.curr_batch += 1 | |
end | |
println("Train accuracy: $train_acc/$(nbatch*batch_size)") | |
train_acc/=(nbatch*batch_size) | |
println("Train accuracy: $train_acc") | |
nbatch=0 | |
for batch in eval_provider | |
ex.arg_dict[:data][:]=mx.get(eval_provider,batch,:data) | |
label=mx.get(eval_provider,batch,:softmax_label) | |
ex.arg_dict[:softmax_label][:]=label | |
mx.forward(ex,is_train=false) | |
predicted_probability=ex.outputs[1] | |
val_acc+=accuracy(predicted_probability,label) | |
nbatch+=1 | |
end | |
val_acc/=(nbatch*batch_size) | |
println("Val accuracy: $val_acc") | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment