Created
July 5, 2019 16:12
-
-
Save oxinabox/f10edd1bda4fb9ebd8e06979fe37b960 to your computer and use it in GitHub Desktop.
Note: this does get a bit carried way on using closures. Cleaner is really to just set the logger.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"using Pkg\n", | |
"pkg\"activate .\"\n", | |
"#pkg\"add TensorBoardLogger Optim Flux\"" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"using TensorBoardLogger\n", | |
"using Logging" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"using Flux, Flux.Data.MNIST, Statistics\n", | |
"using Flux: onehotbatch, onecold, crossentropy, throttle\n", | |
"using Base.Iterators: repeated" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Chain(Dense(784, 32, NNlib.relu), Dense(32, 10), NNlib.softmax)" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Classify MNIST digits with a simple multi-layer-perceptron\n", | |
"\n", | |
"imgs = MNIST.images()\n", | |
"# Stack images into one large batch\n", | |
"X = hcat(float.(reshape.(imgs, :))...) |> gpu\n", | |
"\n", | |
"labels = MNIST.labels()\n", | |
"# One-hot-encode the labels\n", | |
"Y = onehotbatch(labels, 0:9) |> gpu\n", | |
"\n", | |
"m = Chain(\n", | |
" Dense(28^2, 32, relu),\n", | |
" Dense(32, 10),\n", | |
" softmax)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"accuracy (generic function with 1 method)" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"loss(x, y) = crossentropy(m(x), y)\n", | |
"accuracy(x, y) = mean(onecold(m(x)) .== onecold(y))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(::getfield(Main, Symbol(\"#callback#8\")){TBLogger,Tracker.Params}) (generic function with 1 method)" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"flat(xs) = reduce(vcat, vec.(Flux.data.(xs)))\n", | |
"function make_tensorboardlogger_callback_flux(dir=mkpath(\"logs/flux\"))\n", | |
" logger = TBLogger(dir)\n", | |
" network_parameters = params(m)\n", | |
" function callback()\n", | |
" with_logger(logger) do\n", | |
" @info \"\" loss=Flux.data(loss(X,Y)) acc=Flux.data(accuracy(X,Y)) \n", | |
" @info \"\" network_parameters=flat(network_parameters)\n", | |
"\n", | |
" end\n", | |
" end \n", | |
" return callback\n", | |
"end\n", | |
"evalcb = make_tensorboardlogger_callback_flux()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.9236" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"opt = ADAM()\n", | |
"dataset = repeated((X, Y), 200)\n", | |
"Flux.train!(loss, params(m), dataset, opt, cb = throttle(evalcb, 5))\n", | |
"\n", | |
"accuracy(X, Y)\n", | |
"\n", | |
"# Test set accuracy\n", | |
"tX = hcat(float.(reshape.(MNIST.images(:test), :))...)\n", | |
"tY = onehotbatch(MNIST.labels(:test), 0:9)\n", | |
"\n", | |
"accuracy(tX, tY)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"Flux.step!(mdl) do \n", | |
" loss = loss()\n", | |
" @info \"\" loss" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"global_logger(TensorBoardLogger())\n", | |
"@info()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"using Optim" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"make_tensorboardlogger_callback (generic function with 2 methods)" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Returns a closure over a logger, that takes a Optim trace as input\n", | |
"function make_tensorboardlogger_callback(dir=mkpath(\"logs/optim\"))\n", | |
" logger = TBLogger(dir)\n", | |
"\n", | |
" function callback(opt_state = Optim.OptimizationState)\n", | |
" with_logger(logger) do\n", | |
" @info \"\" opt_step = opt_state.iteration function_value=opt_state.value gradient_norm=opt_state.g_norm\n", | |
" end\n", | |
" return false # do not terminate optimisation\n", | |
" end \n", | |
" callback(trace::Optim.OptimizationTrace) = callback(last(trace))\n", | |
" return callback\n", | |
"end\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Results of Optimization Algorithm\n", | |
" * Algorithm: BFGS\n", | |
" * Starting Point: [1000.0,1000.0]\n", | |
" * Minimizer: [0.9999999926657492,0.999999985331392]\n", | |
" * Minimum: 5.379124e-17\n", | |
" * Iterations: 90\n", | |
" * Convergence: true\n", | |
" * |x - x'| ≤ 0.0e+00: false \n", | |
" |x - x'| = 7.00e-08 \n", | |
" * |f(x) - f(x')| ≤ 0.0e+00 |f(x)|: false\n", | |
" |f(x) - f(x')| = 2.09e+01 |f(x)|\n", | |
" * |g(x)| ≤ 1.0e-08: true \n", | |
" |g(x)| = 4.15e-11 \n", | |
" * Stopped by an increasing objective: false\n", | |
" * Reached Maximum Number of Iterations: false\n", | |
" * Objective Calls: 281\n", | |
" * Gradient Calls: 281" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"rosenbrock(x) = (1.0 - x[1])^2 + 100.0 * (x[2] - x[1]^2)^2\n", | |
"result = optimize(\n", | |
" rosenbrock,\n", | |
" 1000ones(2), \n", | |
" BFGS(),\n", | |
" Optim.Options(callback=make_tensorboardlogger_callback())\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Julia 1.1.0", | |
"language": "julia", | |
"name": "julia-1.1" | |
}, | |
"language_info": { | |
"file_extension": ".jl", | |
"mimetype": "application/julia", | |
"name": "julia", | |
"version": "1.1.0" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment