Last active
January 17, 2017 11:47
-
-
Save Arkoniak/5b9b33f399daae726bfe581cbd37fee1 to your computer and use it in GitHub Desktop.
Custom loss function in Julia MXNet, linear regression
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using MXNet | |
import MXNet.mx: _update_single_output, reset!, get | |
using Distributions | |
##################################### | |
# Custom evaluation metric | |
# It just summarize predictions, because in the case of custom | |
# loss layer, ANN output equals to loss function itself | |
type IdentityMetric <: mx.AbstractEvalMetric | |
total_sum :: Float64 | |
n_sample :: Int | |
name :: Symbol | |
IdentityMetric(;name=:Identity) = new(0.0, 0, name) | |
end | |
function mx._update_single_output( | |
metric :: IdentityMetric, label :: mx.NDArray, pred :: mx.NDArray) | |
pred = copy(pred) | |
n_sample = size(pred)[end] | |
metric.n_sample += n_sample | |
metric.total_sum += sum(pred) | |
end | |
function mx.get(metric :: IdentityMetric) | |
return [(metric.name, metric.total_sum / metric.n_sample)] | |
end | |
function mx.reset!(metric :: IdentityMetric) | |
metric.total_sum = 0.0 | |
metric.n_sample = 0 | |
end | |
##################################### | |
# Data preparation | |
# Since this is just proof of concept, train data is used also in evaluation | |
srand(2016) | |
n = 30 | |
noise = rand(Normal(), n) | |
x = collect(linspace(0, 5, n)) | |
y = 1.5 * x + 2.0 + noise | |
###################################### | |
# NN and train function | |
# In this example we solve simple linear regression, so no hidden layers needed | |
batch_size = 10 | |
base_net = @mx.chain mx.Variable(:data) => | |
mx.FullyConnected(name=:fc1, num_hidden=1) | |
function train_and_predict(base_nn :: mx.SymbolicNode, nn :: mx.SymbolicNode, | |
x, y; label_name :: Symbol=:label, lr=0.01, batch_size=10, | |
eval_metric=mx.MSE(), n_epoch = 500) | |
train_provider = mx.ArrayDataProvider(:data => x, label_name => y, | |
batch_size=batch_size) | |
model = mx.FeedForward(nn, context=mx.cpu()) | |
prediction_model = mx.FeedForward(base_nn, context=mx.cpu()) | |
mx.fit(model, mx.SGD(lr=lr), train_provider, n_epoch=n_epoch, | |
initializer=mx.NormalInitializer(), eval_metric=eval_metric) | |
# Not sure that it is correct way to transfer weights | |
prediction_model.arg_params = model.arg_params | |
prediction_model.aux_params = model.aux_params | |
mx.predict(prediction_model, train_provider) | |
end | |
########################################## | |
# Evaluation | |
lro_nn = mx.LinearRegressionOutput(base_net, name=:lro) | |
label = mx.Variable(:label) | |
se_loss = mx.MakeLoss( | |
mx.square(base_net - mx.Reshape(label, shape=(1, batch_size)))) | |
abs_loss = mx.MakeLoss( | |
mx.abs(base_net - mx.Reshape(label, shape=(1, batch_size)))) | |
qe_loss = mx.MakeLoss( | |
mx.square(mx.square(base_net - mx.Reshape(label, shape=(1, batch_size))))) | |
lro_pred = train_and_predict(base_net, lro_nn, x, y, label_name=:lro_label) | |
se_pred = train_and_predict(base_net, se_loss, x, y, batch_size=batch_size, | |
eval_metric=IdentityMetric(name=:SE)) | |
abs_pred = train_and_predict(base_net, abs_loss, x, y, batch_size=batch_size, | |
eval_metric=IdentityMetric(name=:ABS)) | |
qe_pred = train_and_predict(base_net, qe_loss, x, y, batch_size=batch_size, | |
eval_metric=IdentityMetric(name=:QE), lr=0.0002) | |
println("LRO: $(sum((y - lro_pred[:]).^2)/n)") | |
println("SE: $(sum((y - se_pred[:]).^2)/n)") | |
println("ABS: $(sum(abs(y - abs_pred[:]))/n)") | |
println("QE: $(sum((y - qe_pred[:]).^4)/n)") | |
######################################## | |
# Plots | |
using Plots | |
gr() | |
scatter(x, y, markercolor=:black, markersize=1, label="data") | |
plot!(x, lro_pred[:], linecolor=:red, label = "LRO") | |
plot!(x, se_pred[:], linecolor = :blue, label = "SE") | |
plot!(x, abs_pred[:], linecolor = :green, label = "ABS") | |
plot!(x, qe_pred[:], linecolor = :violet, label = "QE") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment