Skip to content

Instantly share code, notes, and snippets.

@Arkoniak
Created December 29, 2016 12:35
Show Gist options
  • Save Arkoniak/4ae1910dcb210e2b0db7fbd7027261ba to your computer and use it in GitHub Desktop.
Save Arkoniak/4ae1910dcb210e2b0db7fbd7027261ba to your computer and use it in GitHub Desktop.
Julia MXNet custom loss function, logistic regression
using MXNet
import MXNet.mx: _update_single_output, reset!, get
using Distributions
# This example based on
# http://deeplearning.net/software/theano/tutorial/examples.html#a-real-example-logistic-regression
#####################################
# Custom evaluation metric
# It just sum predictions, because in the case of custom
# loss layer, ANN output equals to loss function itself
type IdentityMetric <: mx.AbstractEvalMetric
total_sum :: Float64
n_sample :: Int
name :: Symbol
IdentityMetric(;name=:Identity) = new(0.0, 0, name)
end
function mx._update_single_output(
metric :: IdentityMetric, label :: mx.NDArray, pred :: mx.NDArray)
pred = copy(pred)
n_sample = size(pred)[end]
metric.n_sample += n_sample
metric.total_sum += sum(pred)
end
function mx.get(metric :: IdentityMetric)
return [(metric.name, metric.total_sum / metric.n_sample)]
end
function mx.reset!(metric :: IdentityMetric)
metric.total_sum = 0.0
metric.n_sample = 0
end
# We need also custom binary cross entropy metric
type CEMetric <: mx.AbstractEvalMetric
total_sum :: Float64
n_sample :: Int
CEMetric() = new(0.0, 0)
end
function mx._update_single_output(
metric :: CEMetric, label :: mx.NDArray, pred :: mx.NDArray)
label = copy(label)
pred = copy(pred)
n_sample = size(pred)[end]
metric.n_sample += n_sample
metric.total_sum += sum(-label .* log(pred) - (1 - label) .* log(1 - pred))
end
function mx.get(metric :: CEMetric)
return [(:CE, metric.total_sum / metric.n_sample)]
end
function mx.reset!(metric :: CEMetric)
metric.total_sum = 0.0
metric.n_sample = 0
end
#####################################
# Data preparation
# Since this is just proof of concept, train data is used also in evaluation
srand(2016)
n = 400
feats = 784
x = rand(Normal(), feats, n)
y = reshape(sample(0:1, n), (1, 400))
n_epoch = 10000
######################################
# NN and train function
batch_size = n
base_net = @mx.chain mx.Variable(:data) =>
mx.FullyConnected(name=:fc1, num_hidden=1)
# =>
#mx.Activation(name=:act1, act_type=:sigmoid)
function train_and_predict(base_nn :: mx.SymbolicNode, nn :: mx.SymbolicNode,
x, y; label_name :: Symbol=:label, lr=0.01, batch_size=400,
eval_metric=mx.ACE(), n_epoch = 1000)
train_provider = mx.ArrayDataProvider(:data => x, label_name => y,
batch_size=batch_size)
model = mx.FeedForward(nn, context=mx.cpu())
prediction_model = mx.FeedForward(base_nn, context=mx.cpu())
mx.fit(model, mx.SGD(lr=lr), train_provider, n_epoch=n_epoch,
initializer=mx.NormalInitializer(), eval_metric=eval_metric)
# Not sure that it is correct way to transfer weights
prediction_model.arg_params = model.arg_params
prediction_model.aux_params = model.aux_params
mx.predict(prediction_model, train_provider)
end
##########################################
# Evaluation
label = mx.Variable(:label)
lgr_nn = mx.LogisticRegressionOutput(data=base_net, label=label)
ce_loss = mx.Activation(data=base_net, name=:act1, act_type=:sigmoid)
ce_loss = mx.MakeLoss(
-label .* mx.log(ce_loss) - (1 - label) .* mx.log(1 - ce_loss))
# ce_loss = mx.MakeLoss(
# -label .* mx.log(base_net) - (1 - label) .* mx.log(1 - base_net))
lgr_pred = train_and_predict(base_net, lgr_nn, x, y, eval_metric=CEMetric())
ce_pred = train_and_predict(base_net, ce_loss, x, y, eval_metric=IdentityMetric(name=:CE))
# Since prediction taken from FC layer, they should be transformed
lgr_pred = 1./(1 + exp(-lgr_pred))
ce_pred = 1./(1 + exp(-ce_pred))
# Results
println("Original data: $(y[1:10])")
println("LogRegressionOutput prediction: $([Int(x > 0.5) for x in lgr_pred[1:10]])")
println("Custom CrossEntropy prediction: $([Int(x > 0.5) for x in ce_pred[1:10]])")
println("LogReg Cross Entropy Loss: $(mean(-y .* log(lgr_pred) - (1 - y) .* log(1 - lgr_pred)))")
println("Custom Cross Entropy Loss: $(mean(-y .* log(ce_pred) - (1 - y) .* log(1 - ce_pred)))")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment