Created
January 12, 2017 12:10
-
-
Save Arkoniak/5402ddf4d272d2c32cc74343d5ce1793 to your computer and use it in GitHub Desktop.
Julia MXNet with custom loss and exponential activation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# https://github.com/dmlc/MXNet.jl/issues/167 | |
using MXNet | |
# Custom eval metric | |
import MXNet.mx: get, reset!, _update_single_output | |
type CustomMetric <: mx.AbstractEvalMetric | |
loss::Float64 | |
n::Int | |
CustomMetric() = new(0.0, 0) | |
end | |
function mx.reset!(metric::CustomMetric) | |
metric.loss = 0.0 | |
metric.n = 0 | |
end | |
function mx.get(metric::CustomMetric) | |
[(:CustomMetric, metric.loss / metric.n)] | |
end | |
function mx._update_single_output(metric::CustomMetric, label::mx.NDArray, pred::mx.NDArray) | |
pred = copy(pred) | |
n_sample = size(pred)[end] | |
metric.n += n_sample | |
metric.loss += sum(pred) | |
end | |
# Custom initializer | |
import MXNet.mx: _init_weight | |
immutable CustomInitializer <: mx.AbstractInitializer | |
scale :: AbstractFloat | |
end | |
CustomInitializer() = CustomInitializer(0.07) | |
function _init_weight(self :: CustomInitializer, name :: Base.Symbol, array :: mx.NDArray) | |
if string(name) == "output1_weight" | |
input_size = size(array)[1] | |
array[:] = reshape(collect(1:input_size), size(array)) | |
else | |
mx.rand!(-self.scale, self.scale, array) | |
end | |
end | |
# Base net | |
label = mx.Variable(:label) | |
net = @mx.chain mx.Variable(:data) => | |
mx.FullyConnected(name = :fc1_in, num_hidden = 10) => | |
mx.Activation(name = :fc1_out, act_type = :softrelu) => | |
mx.FullyConnected(name = :fc2_in, num_hidden = 4) | |
netexp = @mx.chain net => | |
mx.exp(name = :fc2_out) => | |
mx.FullyConnected(name = :output1, num_hidden = 1, attrs = Dict(:grad => "freeze")) | |
# Network with loss layer | |
netloss = mx.sqrt(mx.abs(netexp - label)) | |
netloss = mx.MakeLoss(netloss) | |
# data | |
x = rand(Float32, 1, 8) # 8 observations of 1 variable | |
y = exp(x) + 2.0 * exp(0.5 * x) + 3.0 * exp(0.3 * x) + 4.0 * exp(0.25 * x) | |
# Connect net, data and hyperparameters | |
batch_size = 4 | |
train_prov = mx.ArrayDataProvider(:data => x, :label => y; batch_size = batch_size) | |
eval_prov = mx.ArrayDataProvider(:data => x, :label => y; batch_size = batch_size) | |
# train | |
mdl = mx.FeedForward(netloss, context = mx.cpu()) | |
# opt = mx.SGD(lr = 0.1, momentum = 0.9, weight_decay = 0.00001) | |
opt = mx.ADAM() # Optimizing algorithm | |
mx.fit(mdl, opt, train_prov, initializer = CustomInitializer(), | |
n_epoch = 2000, eval_data = eval_prov, eval_metric = CustomMetric()) | |
# Label prediction requires network without loss layer | |
model_exp = mx.FeedForward(netexp) | |
# transfer weights from trained model | |
model_exp.arg_params = mdl.arg_params | |
model_exp.aux_params = mdl.aux_params | |
# prediction | |
y_exp = mx.predict(model_exp, eval_prov) | |
println(y) | |
println(y_exp) | |
# On my machine result is the following: | |
# Float32[12.1982 10.528 11.5016 11.3951 14.8478 10.4443 13.5633 13.3432] | |
# Float32[12.2024 10.5181 11.5073 11.4001 14.7479 10.4321 13.5332 13.3215] | |
# compare coefficients inside exp | |
model_coeff = mx.FeedForward(net) | |
model_coeff.arg_params = Dict(k => v for (k, v) in mdl.arg_params) | |
delete!(model_coeff.arg_params, :output1_weight) | |
delete!(model_coeff.arg_params, :output1_bias) | |
model_coeff.aux_params = mdl.aux_params | |
y_coeff = mx.predict(model_coeff, eval_prov) | |
println(y_coeff) | |
println([x; 0.5*x; 0.3*x; 0.25*x]) | |
# Float32[0.125768 0.0166951 0.0826378 0.0757601 0.265383 0.0106833 0.202016 0.190394; 0.228201 0.0920547 0.174112 0.165517 0.405438 0.0846132 0.324604 0.309848; 0.0941143 -0.0491133 0.0378486 0.0288325 0.273273 -0.0571001 0.192491 0.177583; 0.268321 0.104899 0.203436 0.193121 0.480636 0.0959561 0.383852 0.366177] | |
# Float32[0.492818 0.130827 0.35046 0.327667 0.947673 0.110677 0.742261 0.704412; 0.246409 0.0654135 0.17523 0.163834 0.473836 0.0553387 0.37113 0.352206; 0.147845 0.0392481 0.105138 0.0983002 0.284302 0.0332032 0.222678 0.211324; 0.123204 0.0327068 0.087615 0.0819168 0.236918 0.0276693 0.185565 0.176103] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment