Last active
October 2, 2017 16:36
-
-
Save kleinschmidt/87bd4aace452da66aec846f878d89e1a to your computer and use it in GitHub Desktop.
a Flux.jl example with a toy neural network that learns and, or, and xor
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Multi-task network in julia with Flux\n", | |
"\n", | |
"This is to illustrate how easy it is to build and train neural network models in Julia with [Flux.jl](https://github.com/FluxML/Flux.jl). Flux is a framework that gives you both a convenient, high-level API and the ability to mess with things at as low a level as you want.\n", | |
"\n", | |
"## Task and training data\n", | |
"\n", | |
"This network learns to do three different \"tasks\": binary and, or, and xor. The network gets two binary inputs plus a \"task\" signal, which is a one-hot encoding of the task variable. There's one hidden layer, and additionally the task input connect directly to the output as well." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"12-element Array{Int64,1}:\n", | |
" 0\n", | |
" 0\n", | |
" 0\n", | |
" 1\n", | |
" 0\n", | |
" 1\n", | |
" 1\n", | |
" 1\n", | |
" 0\n", | |
" 1\n", | |
" 1\n", | |
" 0" | |
] | |
}, | |
"execution_count": 1, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"using Flux\n", | |
"\n", | |
"# inputs for each task (columns are observations)\n", | |
"x1 = [0 0 1 1\n", | |
" 0 1 0 1]\n", | |
"\n", | |
"funcs = [&, |, xor]\n", | |
"\n", | |
"y = vcat((f.(x1[1,:], x1[2,:]) for f in funcs)...)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Here I'm using the julia broadcasting syntax `f.()` to apply each function `f` to the pairs of inputs (`x1[1,:]` is the first row of `x1`).\n", | |
"\n", | |
"We can use the `onehot` and `onehotbatch` functions to convert a vector of class labes (in this case, the task functions) into a boolean matrix using one-hot encoding:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"3×12 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}:\n", | |
" true true true true false … false false false false false\n", | |
" false false false false true true false false false false\n", | |
" false false false false false false true true true true" | |
] | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"x_task = Flux.onehotbatch(repeat(funcs, inner=size(x1, 2)), funcs)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"We'll put the task inputs first so it'll be easy to pick them out later:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"5×12 Array{Int64,2}:\n", | |
" 1 1 1 1 0 0 0 0 0 0 0 0\n", | |
" 0 0 0 0 1 1 1 1 0 0 0 0\n", | |
" 0 0 0 0 0 0 0 0 1 1 1 1\n", | |
" 0 0 1 1 0 0 1 1 0 0 1 1\n", | |
" 0 1 0 1 0 1 0 1 0 1 0 1" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"x = vcat(x_task, repmat(x1, 1, 3))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Specifying the network\n", | |
"\n", | |
"First, I'll show how to specify the network with as much as Flux's high-level API as possible. It's not just a bunch of stacked layers, so we have to do a little bit of manual fiddling (which goes to show how flexible this abstraction is!).\n", | |
"\n", | |
"In Flux, a network is trained by specifying a loss function, which is parametrized by Flux's special \"tracked\" arrays which support automatic differentiation and training. In practice, we can usually specify a bunch of Flux layers, which are just functions that take input and generate output but with Flux parameters. We can do it manually, too, which we'll see below.\n", | |
"\n", | |
"To start, we'll specify the two pathways. The pathway through the hidden layer has two layers of connections (which are all densely connected). There's an output nonlinearity after the first layer. (There will also be an output non-linearity as well but we need to add the output of the two pathways together before applying it.)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"loss (generic function with 1 method)" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"using Flux: Dense\n", | |
"\n", | |
"n_hidden = 10\n", | |
"hidden = Chain(\n", | |
" Dense(size(x, 1), n_hidden, σ),\n", | |
" Dense(n_hidden, 1)\n", | |
")\n", | |
"\n", | |
"## direct pathway from task units to output\n", | |
"direct = Dense(size(x_task, 1), 1)\n", | |
"\n", | |
"## overall output: add hidden and direct output and apply sigmoid (all elementwise)\n", | |
"m(x::Matrix) = σ.(hidden(x) .+ direct(x[1:size(x_task,1),:]))\n", | |
"\n", | |
"# define the loss function we'll optimize\n", | |
"loss(x,y) = Flux.mse(m(x), y')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"The initial loss before training, based on the randomly initialized weights:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Tracked 0-dimensional Array{Float64,0}:\n", | |
"0.249905" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"loss(x,y)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Notice that this is a special \"tracked\" array. This is how Flux supports automatic backprop.\n", | |
"\n", | |
"## Training the network\n", | |
"\n", | |
"Flux provides a number of conveniences for training models as well. _Optimizers_ create a special function that, when called, optimizes the parameters using backprop. (We can also do this manually, which we'll see below)." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(::#58) (generic function with 1 method)" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"ps = vcat(params(hidden), params(direct))\n", | |
"η = 1\n", | |
"opt = SGD(ps, η)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"`opt` is a function that will update the parameters we've extracted from the two layers. The `Flux.train!` function is a wrapper for the loop of calculating loss, backpropogation, and calling the optimizer (slightly annoying, the data needs to be a vector of x,y tuples, even if x is a matrix and y a vector of targets):" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"loss(x, y) = param(0.246506)\n" | |
] | |
} | |
], | |
"source": [ | |
"dataset = (x,y)\n", | |
"Flux.train!(loss, [dataset], opt, cb = () -> @show(loss(x,y)))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"`train!` runs through all the data we pass it once. We can put it in a loop to train until the tolerance reaches a certain threshold:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"loss(x, y) = param(0.243395)\n", | |
"loss(x, y) = param(1.61531e-5)\n", | |
" 15.496159 seconds (71.65 M allocations: 5.114 GiB, 4.17% gc time)\n" | |
] | |
} | |
], | |
"source": [ | |
"using Base.Iterators: repeated\n", | |
"\n", | |
"tol = 1e-5\n", | |
"stop() = loss(x,y).data[] < tol\n", | |
"\n", | |
"# show the value of the less function every 10 seconds\n", | |
"callback = Flux.throttle(() -> @show(loss(x,y)), 10)\n", | |
"\n", | |
"dataset = repeated((x,y), 500)\n", | |
"\n", | |
"@time begin\n", | |
" while !stop()\n", | |
" ## currently there's a performance problem in the train! function that makes\n", | |
" ## gc time dominate for small models like this, so we'll run the train loop manually\n", | |
" #Flux.train!(loss, dataset, opt, cb = Flux.throttle(callback, 10))\n", | |
" for d in dataset\n", | |
" Flux.back!(loss(d...))\n", | |
" opt()\n", | |
" callback()\n", | |
" end\n", | |
" end\n", | |
"end" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": { | |
"collapsed": false, | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Tracked 0-dimensional Array{Float64,0}:\n", | |
"9.98587e-6" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"loss(x,y)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"12×2 Array{Float64,2}:\n", | |
" 0.000630202 0.0\n", | |
" 0.00138896 0.0\n", | |
" 0.00138876 0.0\n", | |
" 0.997339 1.0\n", | |
" 0.00297521 0.0\n", | |
" 0.998626 1.0\n", | |
" 0.998627 1.0\n", | |
" 0.997396 1.0\n", | |
" 0.00236595 0.0\n", | |
" 0.995386 1.0\n", | |
" 0.995389 1.0\n", | |
" 0.00639834 0.0" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"hcat(m(x).data', y)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Low-level interface\n", | |
"\n", | |
"We can also specify the network using a low-level interface and still benefit from Flux's abstractions (tracked arrays, automatic backprop)." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": { | |
"collapsed": false, | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Tracked 10-element Array{Float64,1}:\n", | |
" 2.44412 \n", | |
" 0.606339\n", | |
" -3.49427 \n", | |
" -0.467204\n", | |
" 1.7908 \n", | |
" -2.58228 \n", | |
" -0.329082\n", | |
" -2.60199 \n", | |
" 0.448517\n", | |
" -3.82918 " | |
] | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"hidden[1].W\n", | |
"hidden[1].b" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Tracked 0-dimensional Array{Float64,0}:\n", | |
"3.15446" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"using Flux.Tracker: param, back!, data, grad\n", | |
"\n", | |
"W_hidden = param(randn(n_hidden, size(x,1)))\n", | |
"b_hidden = param(randn(n_hidden))\n", | |
"hidden2(x) = σ.(W_hidden*x .+ b_hidden)\n", | |
"\n", | |
"W_out = param(randn(1, n_hidden))\n", | |
"b_out = param(randn(1))\n", | |
"W_direct = param(randn(1, size(x_task, 1)))\n", | |
"predict(x::Matrix) = σ.(W_out*hidden2(x) .+ W_direct*x[1:3,:] .+ b_out)\n", | |
"\n", | |
"loss2(x,y) = sum((predict(x) .- y').^2)\n", | |
"loss2(x,y)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Note how the value of `loss` is also `Tracked`. Calling `back!` on the tracked value will calculate the gradients of the individual parameters with respect to the loss:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Tracked 0-dimensional Array{Float64,0}:\n", | |
"3.15446" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"l = loss2(x,y)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"10×5 Array{Float64,2}:\n", | |
" -0.00548446 -0.0199413 0.0175428 -0.0195804 -0.0199099 \n", | |
" 0.0214485 0.183165 -0.132873 0.122418 0.152889 \n", | |
" -0.00695408 -0.0751547 0.0498836 -0.0555422 -0.0568866 \n", | |
" 0.0316629 0.0468998 -0.0684927 0.0805919 0.116873 \n", | |
" -0.00050766 0.000505941 -0.00204762 0.00315389 0.00371962\n", | |
" 0.00411865 0.0195692 -0.0139061 0.0203315 0.0170381 \n", | |
" 0.00433668 0.0209703 -0.014808 0.0208767 0.0201817 \n", | |
" -0.00399932 0.0717888 -0.0267617 0.0879622 0.058855 \n", | |
" -0.0600394 -0.197018 0.0356492 -0.218219 -0.266648 \n", | |
" -0.00335029 -0.123585 0.119496 -0.0332793 -0.0536155 " | |
] | |
}, | |
"execution_count": 14, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"back!(loss2(x,y))\n", | |
"grad(W_hidden)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"...and we can update `W_hidden` (using the `.data` field to access the underlying values in place) according to the gradient " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Tracked 0-dimensional Array{Float64,0}:\n", | |
"3.11955" | |
] | |
}, | |
"execution_count": 15, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"W_hidden.data .-= 0.1*grad(W_hidden)\n", | |
"\n", | |
"loss2(x,y)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"We can define our own update function (which is essentially what `Flux.SGD` does) by looping over all the parameters and updating their `.data` based on the gradient of the parameters. Then, update until the loss is below tolerance (as above):" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"loss2(x, y) = param(2.82972)\n", | |
"loss2(x, y) = param(0.000383281)\n", | |
"loss2(x, y) = param(0.000171005)\n", | |
"loss2(x, y) = param(0.000109063)\n" | |
] | |
} | |
], | |
"source": [ | |
"function update!(ps, η)\n", | |
" back!(loss2(x,y))\n", | |
" for p in ps\n", | |
" ∇, dat = grad(p), data(p)\n", | |
" dat .-= η .* ∇\n", | |
" ∇ .= 0\n", | |
" end\n", | |
"end\n", | |
"\n", | |
"cb = Flux.throttle(() -> @show(loss2(x,y)), 5)\n", | |
"\n", | |
"ps = [W_hidden, b_hidden, W_out, b_out, W_direct]\n", | |
"tol = 1e-4\n", | |
"while loss2(x,y).data[] > tol\n", | |
" back!(loss2(x,y))\n", | |
" update!(ps, 0.1)\n", | |
" cb()\n", | |
"end" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": { | |
"collapsed": false, | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"12×2 Array{Float64,2}:\n", | |
" 0.0 0.0\n", | |
" 0.0 0.0\n", | |
" 0.0 0.0\n", | |
" 1.0 1.0\n", | |
" 0.0 0.0\n", | |
" 1.0 1.0\n", | |
" 1.0 1.0\n", | |
" 1.0 1.0\n", | |
" 0.0 0.0\n", | |
" 1.0 1.0\n", | |
" 1.0 1.0\n", | |
" 0.01 0.0" | |
] | |
}, | |
"execution_count": 17, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"hcat(round.(predict(x).data', 2), y)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Performance\n", | |
"\n", | |
"To compare to the matlab implementation, we need to do a few extra things. That implementation shuffles the input data on every iteration and does one backprop step per observation (instead of in batch), and it saves the MSE throughout training. Also, there are no bias weights learned in the Matlab version, but that's not going to speed things up here." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": { | |
"collapsed": false, | |
"scrolled": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"12-element Array{Tuple{Array{Float64,2},Float64},1}:\n", | |
" ([1.0; 0.0; … ; 0.0; 0.0], 0.0)\n", | |
" ([1.0; 0.0; … ; 0.0; 1.0], 0.0)\n", | |
" ([1.0; 0.0; … ; 1.0; 0.0], 0.0)\n", | |
" ([1.0; 0.0; … ; 1.0; 1.0], 1.0)\n", | |
" ([0.0; 1.0; … ; 0.0; 0.0], 0.0)\n", | |
" ([0.0; 1.0; … ; 0.0; 1.0], 1.0)\n", | |
" ([0.0; 1.0; … ; 1.0; 0.0], 1.0)\n", | |
" ([0.0; 1.0; … ; 1.0; 1.0], 1.0)\n", | |
" ([0.0; 0.0; … ; 0.0; 0.0], 0.0)\n", | |
" ([0.0; 0.0; … ; 0.0; 1.0], 1.0)\n", | |
" ([0.0; 0.0; … ; 1.0; 0.0], 1.0)\n", | |
" ([0.0; 0.0; … ; 1.0; 1.0], 0.0)" | |
] | |
}, | |
"execution_count": 18, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"## make a vector single-observation x,y tuples (with x as a matrix so we don't need new methods)\n", | |
"xys = [(float(x[:,i:i]), float(y[i])) for i in 1:length(y)]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": { | |
"collapsed": false, | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"12-element Array{TrackedArray{…,Array{Float64,0}},1}:\n", | |
" param(7.2204e-7) \n", | |
" param(1.89381e-6)\n", | |
" param(4.1702e-6) \n", | |
" param(1.28171e-5)\n", | |
" param(3.99636e-6)\n", | |
" param(2.52668e-9)\n", | |
" param(1.21483e-8)\n", | |
" param(6.62773e-6)\n", | |
" param(5.34491e-6)\n", | |
" param(1.47294e-5)\n", | |
" param(1.62699e-5)\n", | |
" param(3.34133e-5)" | |
] | |
}, | |
"execution_count": 19, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"[loss2(xy...) for xy in xys]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": { | |
"collapsed": false, | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"loss2(x, y) = param(0.000400447)\n", | |
" 3.513151 seconds (14.27 M allocations: 957.211 MiB, 3.84% gc time)\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"2000-element Array{Float64,1}:\n", | |
" 2.06891 \n", | |
" 1.83597 \n", | |
" 1.97047 \n", | |
" 1.81257 \n", | |
" 1.34066 \n", | |
" 1.29541 \n", | |
" 1.11231 \n", | |
" 1.05008 \n", | |
" 0.954588 \n", | |
" 0.939797 \n", | |
" 0.897048 \n", | |
" 0.726046 \n", | |
" 0.671461 \n", | |
" ⋮ \n", | |
" 0.000231743\n", | |
" 0.000231601\n", | |
" 0.000231493\n", | |
" 0.000231337\n", | |
" 0.00023122 \n", | |
" 0.000231087\n", | |
" 0.000230968\n", | |
" 0.000230809\n", | |
" 0.000230671\n", | |
" 0.000230542\n", | |
" 0.000230429\n", | |
" 0.000230311" | |
] | |
}, | |
"execution_count": 20, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"reset!(x) = randn!(x.data)\n", | |
"reset!.(ps)\n", | |
"\n", | |
"n_training = 2000\n", | |
"η = 0.3\n", | |
"\n", | |
"function train!(ps, xys, η, n_training, cb)\n", | |
"\n", | |
" mses = zeros(n_training)\n", | |
"\n", | |
" for iter in 1:n_training\n", | |
" shuffle!(xys)\n", | |
" for (x,y) in xys\n", | |
" l = loss2(x,y)\n", | |
" mses[iter] += l.data[]\n", | |
" back!(l)\n", | |
" update!(ps, 0.3)\n", | |
" end\n", | |
" cb()\n", | |
" end\n", | |
"\n", | |
" return mses\n", | |
"end\n", | |
"\n", | |
"@time train!(ps, xys, 0.3, 2000, cb)\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"I'm getting 3.2 seconds for 2,000 training iterations, vs. about 2.6 for the matlab code. Matlab does surprisingly well here! The tradeoff is that this code is more concise and more expressive, without sacrificing detailed control. I'd have to dig in a little and see whether it can be optimized at all." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Profiling" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"loss2(x, y) = param(2.31568)\n", | |
"1 ./event.jl:436; (::Base.##300#301{IJulia.#send_st...\n", | |
" 1 ...0.6/IJulia/src/stdio.jl:88; send_stdio(::String)\n", | |
" 1 ....6/IJulia/src/stdio.jl:130; send_stream(::String)\n", | |
" 1 ...v0.6/IJulia/src/msg.jl:48; send_ipython(::ZMQ.Socket, ::IJul...\n", | |
"3630 ./task.jl:335; (::IJulia.##14#17)()\n", | |
" 3630 ...Julia/src/eventloop.jl:8; eventloop(::ZMQ.Socket)\n", | |
" 3630 ...rc/execute_request.jl:154; execute_request(::ZMQ.Socket, ::...\n", | |
" 3630 ...Compat/src/Compat.jl:464; include_string(::Module, ::Stri...\n", | |
" 3630 ./<missing>:?; anonymous\n", | |
" 3630 ./profile.jl:23; macro expansion\n", | |
" 1 ./In[20]:13; train!(::Array{Flux.Tracker.Tr...\n", | |
" 1258 ./In[20]:14; train!(::Array{Flux.Tracker.Tr...\n", | |
" 1249 ./In[12]:12; loss2\n", | |
" 929 ./In[12]:10; predict(::Array{Float64,2})\n", | |
" 8 ./In[12]:5; hidden2\n", | |
" 5 .../src/tracker/lib.jl:63; *(::TrackedArray{…,Array{Fl...\n", | |
" 5 .../tracker/Tracker.jl:36; Type\n", | |
" 1 ./linalg/matmul.jl:367; gemm_wrapper!(::Array{Float...\n", | |
" 4 ...tracker/Tracker.jl:17; (::Flux.Tracker.Call{Base.#...\n", | |
" 1 ./linalg/matmul.jl:146; *\n", | |
" 1 ./linalg/matmul.jl:341; gemm_wrapper!(::Array{Float...\n", | |
" 2 ./linalg/matmul.jl:367; gemm_wrapper!(::Array{Float...\n", | |
" 2 ./linalg/blas.jl:1027; gemm!(::Char, ::Char, ::Fl...\n", | |
" 2 ./abstractarray.jl:882; getindex\n", | |
" 2 ./multidimensional.jl:438; _getindex\n", | |
" 2 ./multidimensional.jl:442; macro expansion\n", | |
" 1 ./multidimensional.jl:453; _unsafe_getindex(::IndexLin...\n", | |
" 1 ...ltidimensional.jl:460; macro expansion\n", | |
" 1 ...ltidimensional.jl:466; _unsafe_getindex!\n", | |
" 1 ...ltidimensional.jl:472; macro expansion\n", | |
" 1 ./cartesian.jl:64; macro expansion\n", | |
" 1 ...tidimensional.jl:474; macro expansion\n", | |
" 8 ./broadcast.jl:0; broadcast(::Function, ::Trac...\n", | |
" 850 ./broadcast.jl:434; broadcast(::Function, ::Trac...\n", | |
" 8 ./broadcast.jl:0; containertype(::TrackedArra...\n", | |
" 2 ./broadcast.jl:34; containertype(::TrackedArra...\n", | |
" 6 ...src/tracker/lib.jl:0; broadcast_c(::Function, ::T...\n", | |
" 314 ...src/tracker/lib.jl:129; broadcast_c(::Function, ::T...\n", | |
" 16 ...rc/tracker/lib.jl:97; tracked_broadcast(::Functio...\n", | |
" 16 ./tuple.jl:181; map\n", | |
" 7 ./array.jl:455; _collect(::Array{Float64,2}...\n", | |
" 6 ./array.jl:375; _similar_for(::Array{Float...\n", | |
" 2 ./abstractarray.jl:524; similar(::Array{Float64,2}...\n", | |
" 1 ./tuple.jl:0; map(::Flux.Tracker.##8#10{3...\n", | |
" 6 ./tuple.jl:178; map(::Flux.Tracker.##8#10{3...\n", | |
" 4 ./array.jl:455; _collect(::Array{Float64,2...\n", | |
" 3 ./array.jl:375; _similar_for(::Array{Float...\n", | |
" 1 ./abstractarray.jl:524; similar(::Array{Float64,2...\n", | |
" 1 ...rc/tracker/lib.jl:97; #8\n", | |
" 240 ...rc/tracker/lib.jl:100; tracked_broadcast(::Functio...\n", | |
" 236 ./broadcast.jl:434; broadcast\n", | |
" 166 ./broadcast.jl:310; broadcast_c\n", | |
" 3 ./reflection.jl:510; _methods_by_ftype(::Any, :...\n", | |
" 18 ./reflection.jl:521; _methods_by_ftype(::Any, :...\n", | |
" 1 ./broadcast.jl:312; broadcast_c\n", | |
" 69 ./broadcast.jl:314; broadcast_c\n", | |
" 9 ./broadcast.jl:0; broadcast_t(::Function, ::...\n", | |
" 32 ./broadcast.jl:266; broadcast_t(::Function, ::...\n", | |
" 2 ./boot.jl:317; Array{ForwardDiff.Dual{Voi...\n", | |
" 8 ./broadcast.jl:268; broadcast_t(::Function, ::...\n", | |
" 1 ./broadcast.jl:0; _broadcast!(::##11#12, ::A...\n", | |
" 4 ./broadcast.jl:139; _broadcast!(::##11#12, ::A...\n", | |
" 4 ./broadcast.jl:147; macro expansion\n", | |
" 1 ./simdloop.jl:68; macro expansion\n", | |
" 3 ./simdloop.jl:73; macro expansion\n", | |
" 3 ./broadcast.jl:153; macro expansion\n", | |
" 3 ./<missing>:0; (::##11#12)(::ForwardDif...\n", | |
" 3 ...c/activation.jl:1; σ(::ForwardDiff.Dual{Vo...\n", | |
" 2 ...ff/src/dual.jl:169; -\n", | |
" 1 ...ff/src/dual.jl:170; exp\n", | |
" 1 ...rc/partials.jl:82; *\n", | |
" 1 ...c/partials.jl:109; *\n", | |
" 1 ...c/partials.jl:199; scale_tuple\n", | |
" 1 ...c/partials.jl:155; macro expansion\n", | |
" 45 ...rc/tracker/lib.jl:101; tracked_broadcast(::Functio...\n", | |
" 1 ...tracker/Tracker.jl:34; Flux.Tracker.TrackedArray(:...\n", | |
" 11 ...src/tracker/lib.jl:91; (::Flux.Tracker.Broadcasted...\n", | |
" 11 ./array.jl:455; _collect(::Array{ForwardDi...\n", | |
" 6 ./array.jl:375; _similar_for(::Array{Forwa...\n", | |
" 4 ...src/tracker/lib.jl:0; tracked_broadcast(::Functio...\n", | |
" 41 ...src/tracker/lib.jl:97; tracked_broadcast(::Functio...\n", | |
" 41 ./tuple.jl:178; map(::Flux.Tracker.##8#10{2...\n", | |
" 39 ./array.jl:455; _collect(::Array{Float64,1}...\n", | |
" 4 ./array.jl:375; _similar_for(::Array{Float6...\n", | |
" 2 ./abstractarray.jl:524; similar(::Array{Float64,2}...\n", | |
" 1 ./array.jl:477; collect_to!(::Array{Forward...\n", | |
" 193 ...src/tracker/lib.jl:100; tracked_broadcast(::Functio...\n", | |
" 191 ./broadcast.jl:434; broadcast\n", | |
" 128 ./broadcast.jl:310; broadcast_c\n", | |
" 1 ./reflection.jl:0; _methods_by_ftype(::Any, :...\n", | |
" 1 ./reflection.jl:510; _methods_by_ftype(::Any, :...\n", | |
" 15 ./reflection.jl:521; _methods_by_ftype(::Any, :...\n", | |
" 1 ./broadcast.jl:311; broadcast_c\n", | |
" 1 ./broadcast.jl:53; broadcast_indices\n", | |
" 1 ./broadcast.jl:48; broadcast_indices\n", | |
" 1 ./broadcast.jl:52; broadcast_indices\n", | |
" 1 ./abstractarray.jl:64; indices\n", | |
" 62 ./broadcast.jl:314; broadcast_c\n", | |
" 11 ./broadcast.jl:0; broadcast_t(::Function, ::...\n", | |
" 30 ./broadcast.jl:266; broadcast_t(::Function, ::...\n", | |
" 3 ./boot.jl:317; Array{ForwardDiff.Dual{Voi...\n", | |
" 8 ./broadcast.jl:268; broadcast_t(::Function, ::...\n", | |
" 7 ./broadcast.jl:139; _broadcast!(::##9#10, ::Ar...\n", | |
" 7 ./broadcast.jl:147; macro expansion\n", | |
" 7 ./simdloop.jl:73; macro expansion\n", | |
" 7 ./broadcast.jl:153; macro expansion\n", | |
" 7 ./<missing>:0; (::##9#10)(::ForwardDiff....\n", | |
" 4 ...rc/activation.jl:1; σ(::ForwardDiff.Dual{Vo...\n", | |
" 4 ...iff/src/dual.jl:169; exp\n", | |
" 1 ./special/exp.jl:68; exp(::Float64)\n", | |
" 1 ./special/exp.jl:111; exp(::Float64)\n", | |
" 1 ./special/exp.jl:112; exp(::Float64)\n", | |
" 1 ./special/exp.jl:131; exp(::Float64)\n", | |
" 46 ...src/tracker/lib.jl:101; tracked_broadcast(::Functio...\n", | |
" 2 .../tracker/Tracker.jl:34; Flux.Tracker.TrackedArray(::F...\n", | |
" 5 .../src/tracker/lib.jl:91; (::Flux.Tracker.Broadcasted{A...\n", | |
" 1 ./array.jl:451; _collect(::Array{ForwardDif...\n", | |
" 4 ./array.jl:455; _collect(::Array{ForwardDif...\n", | |
" 2 ./array.jl:375; _similar_for(::Array{Forwar...\n", | |
" 1 ./abstractarray.jl:524; similar(::Array{ForwardDif...\n", | |
" 9 ...src/tracker/lib.jl:62; *(::TrackedArray{…,Array{F...\n", | |
" 8 .../tracker/Tracker.jl:36; Type\n", | |
" 8 .../tracker/Tracker.jl:17; (::Flux.Tracker.Call{Base.#*,...\n", | |
" 2 ./linalg/matmul.jl:146; *\n", | |
" 1 ./linalg/matmul.jl:348; gemm_wrapper!(::Array{Float...\n", | |
" 5 ./linalg/matmul.jl:367; gemm_wrapper!(::Array{Float...\n", | |
" 5 ./linalg/blas.jl:1027; gemm!(::Char, ::Char, ::Fl...\n", | |
" 5 ...src/tracker/lib.jl:63; *(::TrackedArray{…,Array{F...\n", | |
" 4 .../tracker/Tracker.jl:36; Type\n", | |
" 4 .../tracker/Tracker.jl:17; (::Flux.Tracker.Call{Base.#*,...\n", | |
" 1 ./linalg/matmul.jl:146; *\n", | |
" 3 ./linalg/matmul.jl:367; gemm_wrapper!(::Array{Float...\n", | |
" 1 ./linalg/blas.jl:0; gemm!(::Char, ::Char, ::Fl...\n", | |
" 2 ./linalg/blas.jl:1027; gemm!(::Char, ::Char, ::Fl...\n", | |
" 2 ./broadcast.jl:0; broadcast(::Function, ::Track...\n", | |
" 295 ./broadcast.jl:434; broadcast(::Function, ::Track...\n", | |
" 15 ...src/tracker/lib.jl:0; tracked_broadcast(::Function...\n", | |
" 10 ...src/tracker/lib.jl:97; tracked_broadcast(::Function...\n", | |
" 10 ./tuple.jl:178; map(::Flux.Tracker.##8#10{2}...\n", | |
" 1 ./array.jl:454; _collect(::Array{Float64,2},...\n", | |
" 1 ./generator.jl:44; next\n", | |
" 9 ./array.jl:455; _collect(::Array{Float64,2},...\n", | |
" 9 ./array.jl:375; _similar_for(::Array{Float6...\n", | |
" 2 ./abstractarray.jl:524; similar(::Array{Float64,2},...\n", | |
" 191 ...src/tracker/lib.jl:100; tracked_broadcast(::Function...\n", | |
" 186 ./broadcast.jl:434; broadcast\n", | |
" 128 ./broadcast.jl:310; broadcast_c\n", | |
" 1 ./reflection.jl:510; _methods_by_ftype(::Any, ::...\n", | |
" 3 ./reflection.jl:512; _methods_by_ftype(::Any, ::...\n", | |
" 16 ./reflection.jl:521; _methods_by_ftype(::Any, ::...\n", | |
" 58 ./broadcast.jl:314; broadcast_c\n", | |
" 8 ./broadcast.jl:0; broadcast_t(::Function, ::T...\n", | |
" 23 ./broadcast.jl:266; broadcast_t(::Function, ::T...\n", | |
" 1 ./boot.jl:317; Array{ForwardDiff.Dual{Void...\n", | |
" 2 ./broadcast.jl:268; broadcast_t(::Function, ::T...\n", | |
" 1 ...src/tracker/lib.jl:88; Flux.Tracker.Broadcasted(::...\n", | |
" 71 ...src/tracker/lib.jl:101; tracked_broadcast(::Function...\n", | |
" 1 .../tracker/Tracker.jl:0; Flux.Tracker.TrackedArray(::F...\n", | |
" 1 .../tracker/Tracker.jl:11; Flux.Tracker.Call{Flux.Tracke...\n", | |
" 26 .../src/tracker/lib.jl:91; (::Flux.Tracker.Broadcasted{A...\n", | |
" 26 ./array.jl:455; _collect(::Array{ForwardDif...\n", | |
" 26 ./array.jl:375; _similar_for(::Array{Forwar...\n", | |
" 23 ./abstractarray.jl:524; similar(::Array{ForwardDif...\n", | |
" 8 .../src/tracker/lib.jl:52; sum(::TrackedArray{…,Array{...\n", | |
" 4 ...x/src/tracker/lib.jl:4; toarray\n", | |
" 25 ./In[20]:15; train!(::Array{Flux.Tracker.Tr...\n", | |
" 1 ./abstractarray.jl:0; getindex(::Array{Float64,0})\n", | |
" 1 ./abstractarray.jl:882; getindex(::Array{Float64,0})\n", | |
" 388 ./In[20]:16; train!(::Array{Flux.Tracker.Tr...\n", | |
" 386 ...src/tracker/back.jl:43; back!(::TrackedArray{…,Array...\n", | |
" 45 ...src/tracker/back.jl:39; back!\n", | |
" 45 ...src/tracker/back.jl:8; scan(::TrackedArray{…,Array{...\n", | |
" 42 ./abstractarray.jl:1731; foreach(::Flux.Tracker.#sca...\n", | |
" 42 ...src/tracker/back.jl:8; scan(::TrackedArray{…,Array...\n", | |
" 35 ./abstractarray.jl:1731; foreach(::Flux.Tracker.#sc...\n", | |
" 29 ...rc/tracker/back.jl:8; scan(::TrackedArray{…,Arr...\n", | |
" 26 ./abstractarray.jl:1731; foreach(::Flux.Tracker.#s...\n", | |
" 1 ...c/tracker/back.jl:6; scan(::TrackedArray{…,Ar...\n", | |
" 24 ...c/tracker/back.jl:8; scan(::TrackedArray{…,Ar...\n", | |
" 20 ./abstractarray.jl:1731; foreach(::Flux.Tracker.#...\n", | |
" 1 .../tracker/back.jl:0; scan(::TrackedArray{…,A...\n", | |
" 1 .../tracker/back.jl:6; scan(::TrackedArray{…,A...\n", | |
" 17 .../tracker/back.jl:8; scan(::TrackedArray{…,A...\n", | |
" 7 ./abstractarray.jl:1731; foreach(::Flux.Tracker....\n", | |
" 7 .../tracker/back.jl:8; scan(::TrackedArray{…,...\n", | |
" 2 ...stractarray.jl:1731; foreach(::Flux.Tracker...\n", | |
" 1 ...tracker/back.jl:6; scan(::TrackedArray{…...\n", | |
" 341 ...src/tracker/back.jl:24; back(::TrackedArray{…,Array...\n", | |
" 337 ...src/tracker/back.jl:15; back(::Flux.Tracker.Call{Base...\n", | |
" 335 ...rc/tracker/back.jl:24; back(::TrackedArray{…,Arra...\n", | |
" 326 ...rc/tracker/back.jl:15; back(::Flux.Tracker.Call{Fl...\n", | |
" 319 ./abstractarray.jl:1732; foreach(::Function, ::Tupl...\n", | |
" 1 ./iterators.jl:183; next(::Base.Iterators.Zip2...\n", | |
" 40 ./iterators.jl:185; next(::Base.Iterators.Zip2...\n", | |
" 1 ./iterators.jl:0; zip(::Tuple{TrackedArray{…...\n", | |
" 247 ...rc/tracker/lib.jl:117; (::Flux.Tracker.##17#19)(:...\n", | |
" 247 ...c/tracker/back.jl:32; macro expansion\n", | |
" 246 .../tracker/back.jl:24; back(::TrackedArray{…,A...\n", | |
" 242 .../tracker/back.jl:15; back(::Flux.Tracker.Call...\n", | |
" 238 ./abstractarray.jl:1732; foreach(::Function, ::T...\n", | |
" 1 ./iterators.jl:0; done(::Base.Iterators.Z...\n", | |
" 20 ./iterators.jl:185; next(::Base.Iterators.Z...\n", | |
" 198 .../tracker/lib.jl:117; (::Flux.Tracker.##17#19...\n", | |
" 198 ...tracker/back.jl:32; macro expansion\n", | |
" 159 ...racker/back.jl:24; back(::TrackedArray{…...\n", | |
" 155 ...racker/back.jl:15; back(::Flux.Tracker.C...\n", | |
" 12 ...racker/lib.jl:71; back\n", | |
" 12 ...acker/back.jl:32; macro expansion\n", | |
" 1 ...alg/matmul.jl:0; A_mul_Bt(::Array{Flo...\n", | |
" 7 ...alg/matmul.jl:189; A_mul_Bt(::Array{Flo...\n", | |
" 5 ...alg/matmul.jl:191; A_mul_Bt!\n", | |
" 1 ...lg/matmul.jl:366; gemm_wrapper!(::Arr...\n", | |
" 3 ...lg/matmul.jl:367; gemm_wrapper!(::Arr...\n", | |
" 2 ...alg/blas.jl:1027; gemm!(::Char, ::Ch...\n", | |
" 1 ...alg/blas.jl:1036; gemm!(::Char, ::Ch...\n", | |
" 1 ...acker/back.jl:21; back(::TrackedArray{...\n", | |
" 1 ./broadcast.jl:204; broadcast!\n", | |
" 1 ./broadcast.jl:211; broadcast_c!\n", | |
" 1 ./broadcast.jl:139; _broadcast!\n", | |
" 1 ./broadcast.jl:147; macro expansion\n", | |
" 1 ./simdloop.jl:73; macro expansion\n", | |
" 1 ./broadcast.jl:153; macro expansion\n", | |
" 2 ...acker/back.jl:22; back(::TrackedArray{...\n", | |
" 5 ...racker/lib.jl:71; back(::Base.#*, ::Arr...\n", | |
" 5 ...racker/back.jl:32; macro expansion\n", | |
" 2 ...alg/matmul.jl:189; A_mul_Bt(::Array{Flo...\n", | |
" 1 ...alg/matmul.jl:191; A_mul_Bt!\n", | |
" 1 ...lg/matmul.jl:367; gemm_wrapper!(::Arr...\n", | |
" 1 ...alg/blas.jl:1027; gemm!(::Char, ::Ch...\n", | |
" 2 ...acker/back.jl:21; back(::TrackedArray{...\n", | |
" 2 ./broadcast.jl:204; broadcast!\n", | |
" 1 ./broadcast.jl:208; broadcast_c!\n", | |
" 1 ./broadcast.jl:90; check_broadcast_indices\n", | |
" 1 ./broadcast.jl:86; check_broadcast_in...\n", | |
" 1 ./broadcast.jl:48; broadcast_indices\n", | |
" 1 ./broadcast.jl:52; broadcast_indices\n", | |
" 1 ...actarray.jl:64; indices\n", | |
" 1 ./broadcast.jl:211; broadcast_c!\n", | |
" 1 ./broadcast.jl:139; _broadcast!\n", | |
" 1 ./broadcast.jl:147; macro expansion\n", | |
" 1 ./simdloop.jl:73; macro expansion\n", | |
" 1 ./broadcast.jl:153; macro expansion\n", | |
" 137 ...racker/lib.jl:72; back\n", | |
" 137 ...acker/back.jl:32; macro expansion\n", | |
" 4 ...lg/matmul.jl:182; At_mul_B(::Array{Fl...\n", | |
" 3 ...alg/matmul.jl:184; At_mul_B!\n", | |
" 3 ...lg/matmul.jl:367; gemm_wrapper!(::Arr...\n", | |
" 1 ...alg/blas.jl:0; gemm!(::Char, ::Ch...\n", | |
" 1 ...alg/blas.jl:1022; gemm!(::Char, ::Ch...\n", | |
" 1 ...alg/blas.jl:1027; gemm!(::Char, ::Ch...\n", | |
" 133 ...cker/back.jl:24; back(::TrackedArray...\n", | |
" 132 ...cker/back.jl:15; back(::Flux.Tracker...\n", | |
" 129 ...actarray.jl:1732; foreach(::Functio...\n", | |
" 34 ./iterators.jl:185; next(::Base.Iterat...\n", | |
" 83 ...cker/lib.jl:117; (::Flux.Tracker.##...\n", | |
" 83 ...ker/back.jl:32; macro expansion\n", | |
" 1 ...ker/back.jl:19; back(::TrackedArr...\n", | |
" 1 ...ker/back.jl:21; back(::TrackedArr...\n", | |
" 1 ./broadcast.jl:204; broadcast!\n", | |
" 1 ...oadcast.jl:211; broadcast_c!\n", | |
" 1 ...oadcast.jl:139; _broadcast!\n", | |
" 1 ...adcast.jl:147; macro expansion\n", | |
" 1 ...mdloop.jl:73; macro expansion\n", | |
" 1 ...adcast.jl:151; macro expansion\n", | |
" 1 ...ker/back.jl:22; back(::TrackedArr...\n", | |
" 7 ...ker/back.jl:24; back(::TrackedArr...\n", | |
" 4 ...ker/back.jl:15; back(::Flux.Track...\n", | |
" 4 ...cker/lib.jl:71; back(::Base.#*, :...\n", | |
" 4 ...er/back.jl:32; macro expansion\n", | |
" 4 ...matmul.jl:189; A_mul_Bt(::Arra...\n", | |
" 2 ...matmul.jl:191; A_mul_Bt!\n", | |
" 2 ...matmul.jl:367; gemm_wrapper!(...\n", | |
" 2 .../blas.jl:1027; gemm!(::Char...\n", | |
" 72 ...cker/lib.jl:106; unbroadcast(::Tra...\n", | |
" 44 ./array.jl:1819; filter(::Functi...\n", | |
" 4 ...tarray.jl:882; getindex\n", | |
" 4 ...nsional.jl:438; _getindex\n", | |
" 4 ...sional.jl:442; macro expansion\n", | |
" 4 ...sional.jl:453; _unsafe_getind...\n", | |
" 3 ...sional.jl:458; macro expansion\n", | |
" 1 ...sional.jl:460; macro expansion\n", | |
" 1 ...ional.jl:466; _unsafe_getindex!\n", | |
" 1 ...ional.jl:472; macro expansion\n", | |
" 1 ...esian.jl:62; macro expansion\n", | |
" 1 ...onal.jl:352; next\n", | |
" 30 ...tarray.jl:1865; map(::Function,...\n", | |
" 5 ./array.jl:455; _collect(::Unit...\n", | |
" 3 ./array.jl:184; reshape(::Array...\n", | |
" 22 ...ducedim.jl:572; sum\n", | |
" 22 ...ducedim.jl:570; sum\n", | |
" 22 ...ucedim.jl:241; mapreducedim\n", | |
" 2 ...ucedim.jl:210; mapreducedim!\n", | |
" 2 ...ucedim.jl:173; _mapreducedim!...\n", | |
" 1 ...ucedim.jl:0; check_reducedi...\n", | |
" 1 ...ucedim.jl:169; check_reducedi...\n", | |
" 20 ...ucedim.jl:73; reducedim_init...\n", | |
" 2 ./array.jl:0; fill!(::Array{...\n", | |
" 4 ...ucedim.jl:33; reduced_indice...\n", | |
" 1 ./array.jl:0; vect(::Base.On...\n", | |
" 3 ./array.jl:76; vect(::Base.On...\n", | |
" 14 ...ucedim.jl:43; reduced_indice...\n", | |
" 3 ...cker/lib.jl:116; back\n", | |
" 3 ...acker/lib.jl:116; (::Flux.Tracker.##...\n", | |
" 3 ./broadcast.jl:434; broadcast\n", | |
" 2 ./broadcast.jl:311; broadcast_c\n", | |
" 2 ./broadcast.jl:53; broadcast_indices\n", | |
" 2 ./tuple.jl:159; map\n", | |
" 2 ...oadcast.jl:48; broadcast_indices\n", | |
" 2 ...oadcast.jl:52; broadcast_indices\n", | |
" 2 ...tarray.jl:64; indices\n", | |
" 1 ./broadcast.jl:314; broadcast_c\n", | |
" 1 ./broadcast.jl:266; broadcast_t\n", | |
" 37 ...tracker/lib.jl:106; unbroadcast(::Tracked...\n", | |
" 23 ./array.jl:1819; filter(::Function, ::...\n", | |
" 5 ...tractarray.jl:882; getindex\n", | |
" 5 ...imensional.jl:438; _getindex\n", | |
" 5 ...imensional.jl:442; macro expansion\n", | |
" 5 ...imensional.jl:453; _unsafe_getindex(::I...\n", | |
" 4 ...mensional.jl:458; macro expansion\n", | |
" 1 ...mensional.jl:460; macro expansion\n", | |
" 1 ...mensional.jl:466; _unsafe_getindex!\n", | |
" 1 ...ensional.jl:472; macro expansion\n", | |
" 1 ./cartesian.jl:62; macro expansion\n", | |
" 1 ...ensional.jl:352; next\n", | |
" 11 ...tractarray.jl:1865; map(::Function, ::Un...\n", | |
" 2 ./array.jl:455; _collect(::UnitRange{...\n", | |
" 1 ./generator.jl:32; Base.Generator(::Flux...\n", | |
" 13 ./reducedim.jl:572; sum\n", | |
" 13 ./reducedim.jl:570; sum\n", | |
" 13 ./reducedim.jl:241; mapreducedim\n", | |
" 13 ./reducedim.jl:73; reducedim_initarray(...\n", | |
" 5 ./reducedim.jl:33; reduced_indices(::Tu...\n", | |
" 1 ./array.jl:0; vect(::Base.OneTo{In...\n", | |
" 4 ./array.jl:76; vect(::Base.OneTo{In...\n", | |
" 7 ./reducedim.jl:43; reduced_indices(::Tu...\n", | |
" 3 .../tracker/lib.jl:116; back\n", | |
" 1 .../tracker/lib.jl:0; (::Flux.Tracker.##16#18{...\n", | |
" 2 .../tracker/lib.jl:116; (::Flux.Tracker.##16#18{...\n", | |
" 2 ./broadcast.jl:434; broadcast\n", | |
" 1 ./broadcast.jl:311; broadcast_c\n", | |
" 1 ./broadcast.jl:53; broadcast_indices\n", | |
" 1 ./broadcast.jl:57; broadcast_shape\n", | |
" 1 ./broadcast.jl:57; broadcast_shape\n", | |
" 1 ./broadcast.jl:64; _bcs\n", | |
" 1 ./broadcast.jl:63; _bcs\n", | |
" 1 ./broadcast.jl:314; broadcast_c\n", | |
" 1 ./broadcast.jl:266; broadcast_t\n", | |
" 1 .../tracker/back.jl:26; back(::TrackedArray{…,A...\n", | |
" 6 ...rc/tracker/lib.jl:116; back\n", | |
" 5 ...rc/tracker/lib.jl:116; (::Flux.Tracker.##16#18{Flu...\n", | |
" 5 ./broadcast.jl:434; broadcast\n", | |
" 5 ./broadcast.jl:314; broadcast_c\n", | |
" 4 ./broadcast.jl:266; broadcast_t\n", | |
" 1 ./broadcast.jl:268; broadcast_t\n", | |
" 1 ./broadcast.jl:139; _broadcast!\n", | |
" 1 ./broadcast.jl:147; macro expansion\n", | |
" 1 ./simdloop.jl:73; macro expansion\n", | |
" 1 ./broadcast.jl:153; macro expansion\n", | |
" 1 ...src/tracker/lib.jl:55; back\n", | |
" 1953 ./In[20]:17; train!(::Array{Flux.Tracker.Tr...\n", | |
" 1879 ./In[16]:2; update!(::Array{Flux.Tracker.T...\n", | |
" 1378 ./In[12]:12; loss2\n", | |
" 1120 ./In[12]:10; predict(::Array{Int64,2})\n", | |
" 21 ./In[12]:5; hidden2\n", | |
" 18 ...src/tracker/lib.jl:63; *(::TrackedArray{…,Array{F...\n", | |
" 17 ...tracker/Tracker.jl:36; Type\n", | |
" 17 ...racker/Tracker.jl:17; (::Flux.Tracker.Call{Base.#...\n", | |
" 2 ./linalg/matmul.jl:146; *\n", | |
" 14 ./linalg/matmul.jl:483; generic_matmatmul!(::Array...\n", | |
" 1 ./linalg/matmul.jl:0; _generic_matmatmul!(::Arra...\n", | |
" 1 ./linalg/matmul.jl:490; _generic_matmatmul!(::Arra...\n", | |
" 1 ./linalg/matmul.jl:375; lapack_size(::Char, ::Arr...\n", | |
" 1 ./linalg/matmul.jl:508; _generic_matmatmul!(::Arra...\n", | |
" 1 ./linalg/matmul.jl:509; _generic_matmatmul!(::Arra...\n", | |
" 1 ./linalg/matmul.jl:515; _generic_matmatmul!(::Arra...\n", | |
" 1 ./linalg/matmul.jl:389; copy_transpose!(::Array{F...\n", | |
" 1 ...alg/transpose.jl:142; copy_transpose!(::Array{...\n", | |
" 1 ./abstractarray.jl:362; checkbounds\n", | |
" 2 ./linalg/matmul.jl:516; _generic_matmatmul!(::Arra...\n", | |
" 1 ./linalg/matmul.jl:379; copy!(::Array{Int64,2}, :...\n", | |
" 1 ./abstractarray.jl:716; copy!(::Array{Int64,2}, ...\n", | |
" 1 ./linalg/matmul.jl:522; _generic_matmatmul!(::Arra...\n", | |
" 6 ./linalg/matmul.jl:523; _generic_matmatmul!(::Arra...\n", | |
" 9 ./abstractarray.jl:882; getindex\n", | |
" 9 ./multidimensional.jl:438; _getindex\n", | |
" 9 ./multidimensional.jl:442; macro expansion\n", | |
" 9 ...ltidimensional.jl:453; _unsafe_getindex(::IndexLin...\n", | |
" 1 ...ltidimensional.jl:456; macro expansion\n", | |
" 1 ...ltidimensional.jl:457; macro expansion\n", | |
" 1 ...ltidimensional.jl:311; index_shape\n", | |
" 1 ./abstractarray.jl:64; indices\n", | |
" 6 ...ltidimensional.jl:458; macro expansion\n", | |
" 1 ...ltidimensional.jl:459; macro expansion\n", | |
" 1 ./tuple.jl:284; ==(::Tuple{Int64,Int64}, :...\n", | |
" 7 ./broadcast.jl:0; broadcast(::Function, ::Tra...\n", | |
" 1040 ./broadcast.jl:434; broadcast(::Function, ::Tra...\n", | |
" 7 ./broadcast.jl:0; containertype(::TrackedArra...\n", | |
" 14 ./broadcast.jl:34; containertype(::TrackedArra...\n", | |
" 11 ...rc/tracker/lib.jl:0; broadcast_c(::Function, ::T...\n", | |
" 364 ...rc/tracker/lib.jl:129; broadcast_c(::Function, ::T...\n", | |
" 1 ...rc/tracker/lib.jl:0; tracked_broadcast(::Functi...\n", | |
" 28 ...rc/tracker/lib.jl:97; tracked_broadcast(::Functi...\n", | |
" 28 ./tuple.jl:181; map\n", | |
" 1 ./array.jl:454; _collect(::Array{Float64,2...\n", | |
" 1 ./generator.jl:44; next\n", | |
" 17 ./array.jl:455; _collect(::Array{Float64,2...\n", | |
" 14 ./array.jl:375; _similar_for(::Array{Floa...\n", | |
" 7 ./abstractarray.jl:524; similar(::Array{Float64,2...\n", | |
" 1 ./array.jl:474; collect_to!(::Array{Forwa...\n", | |
" 10 ./tuple.jl:178; map(::Flux.Tracker.##8#10{...\n", | |
" 8 ./array.jl:455; _collect(::Array{Float64,2...\n", | |
" 3 ./array.jl:375; _similar_for(::Array{Floa...\n", | |
" 1 ./abstractarray.jl:524; similar(::Array{Float64,...\n", | |
" 2 ./array.jl:474; collect_to!(::Array{Forwa...\n", | |
" 1 ...c/tracker/lib.jl:94; (::Flux.Tracker.##6#7{Tup...\n", | |
" 2 ./array.jl:477; collect_to!(::Array{Forwa...\n", | |
" 2 ...rc/tracker/lib.jl:97; #8\n", | |
" 1 ...rc/tracker/lib.jl:94; dualify\n", | |
" 266 ...rc/tracker/lib.jl:100; tracked_broadcast(::Functi...\n", | |
" 263 ./broadcast.jl:434; broadcast\n", | |
" 165 ./broadcast.jl:310; broadcast_c\n", | |
" 5 ./reflection.jl:510; _methods_by_ftype(::Any, ...\n", | |
" 6 ./reflection.jl:512; _methods_by_ftype(::Any, ...\n", | |
" 20 ./reflection.jl:521; _methods_by_ftype(::Any, ...\n", | |
" 98 ./broadcast.jl:314; broadcast_c\n", | |
" 5 ./broadcast.jl:0; broadcast_t(::Function, :...\n", | |
" 46 ./broadcast.jl:266; broadcast_t(::Function, :...\n", | |
" 4 ./boot.jl:317; Array{ForwardDiff.Dual{Vo...\n", | |
" 24 ./broadcast.jl:268; broadcast_t(::Function, :...\n", | |
" 16 ./broadcast.jl:139; _broadcast!(::##11#12, :...\n", | |
" 16 ./broadcast.jl:147; macro expansion\n", | |
" 16 ./simdloop.jl:73; macro expansion\n", | |
" 15 ./broadcast.jl:153; macro expansion\n", | |
" 15 ./<missing>:0; (::##11#12)(::ForwardDi...\n", | |
" 1 ...ff/src/dual.jl:198; +\n", | |
" 1 ...rc/partials.jl:117; _mul_partials\n", | |
" 1 ...c/partials.jl:219; mul_tuples\n", | |
" 1 ...c/partials.jl:155; macro expansion\n", | |
" 13 .../activation.jl:1; σ(::ForwardDiff.Dual...\n", | |
" 2 ...f/src/dual.jl:169; -\n", | |
" 1 ...f/src/dual.jl:208; /\n", | |
" 1 ...rc/partials.jl:82; *\n", | |
" 1 ...c/partials.jl:109; *\n", | |
" 1 ...c/partials.jl:199; scale_tuple\n", | |
" 1 .../partials.jl:155; macro expansion\n", | |
" 10 ...f/src/dual.jl:169; exp\n", | |
" 1 ./special/exp.jl:0; exp(::Float64)\n", | |
" 1 ./special/exp.jl:89; exp(::Float64)\n", | |
" 8 ./special/exp.jl:112; exp(::Float64)\n", | |
" 1 ./broadcast.jl:154; macro expansion\n", | |
" 1 ...idimensional.jl:247; setindex!\n", | |
" 1 ...rc/tracker/lib.jl:88; Flux.Tracker.Broadcasted(:...\n", | |
" 49 ...rc/tracker/lib.jl:101; tracked_broadcast(::Functi...\n", | |
" 1 ...tracker/Tracker.jl:34; Flux.Tracker.TrackedArray(:...\n", | |
" 8 ...src/tracker/lib.jl:91; (::Flux.Tracker.Broadcasted...\n", | |
" 8 ./array.jl:455; _collect(::Array{ForwardDi...\n", | |
" 8 ./array.jl:375; _similar_for(::Array{Forwa...\n", | |
" 1 ./abstractarray.jl:524; similar(::Array{ForwardDi...\n", | |
" 8 ...rc/tracker/lib.jl:0; tracked_broadcast(::Functio...\n", | |
" 47 ...rc/tracker/lib.jl:97; tracked_broadcast(::Functio...\n", | |
" 1 ./tuple.jl:0; map(::Flux.Tracker.##8#10{2...\n", | |
" 46 ./tuple.jl:178; map(::Flux.Tracker.##8#10{2...\n", | |
" 19 ./array.jl:455; _collect(::Array{Float64,2...\n", | |
" 13 ./array.jl:375; _similar_for(::Array{Float...\n", | |
" 10 ./abstractarray.jl:524; similar(::Array{Float64,2...\n", | |
" 2 ./array.jl:473; collect_to!(::Array{Forwar...\n", | |
" 1 ./array.jl:474; collect_to!(::Array{Forwar...\n", | |
" 1 ...rc/tracker/lib.jl:97; #8\n", | |
" 1 ...rc/tracker/lib.jl:94; dualify\n", | |
" 296 ...rc/tracker/lib.jl:100; tracked_broadcast(::Functio...\n", | |
" 295 ./broadcast.jl:434; broadcast\n", | |
" 135 ./broadcast.jl:310; broadcast_c\n", | |
" 1 ./reflection.jl:0; _methods_by_ftype(::Any, :...\n", | |
" 2 ./reflection.jl:510; _methods_by_ftype(::Any, :...\n", | |
" 1 ./reflection.jl:512; _methods_by_ftype(::Any, :...\n", | |
" 12 ./reflection.jl:521; _methods_by_ftype(::Any, :...\n", | |
" 160 ./broadcast.jl:314; broadcast_c\n", | |
" 11 ./broadcast.jl:0; broadcast_t(::Function, :...\n", | |
" 32 ./broadcast.jl:266; broadcast_t(::Function, :...\n", | |
" 5 ./boot.jl:317; Array{ForwardDiff.Dual{Voi...\n", | |
" 107 ./broadcast.jl:268; broadcast_t(::Function, :...\n", | |
" 1 ./broadcast.jl:0; _broadcast!(::##9#10, ::...\n", | |
" 101 ./broadcast.jl:139; _broadcast!(::##9#10, ::...\n", | |
" 101 ./broadcast.jl:147; macro expansion\n", | |
" 1 ./simdloop.jl:72; macro expansion\n", | |
" 100 ./simdloop.jl:73; macro expansion\n", | |
" 4 ./broadcast.jl:151; macro expansion\n", | |
" 94 ./broadcast.jl:153; macro expansion\n", | |
" 92 ./<missing>:0; (::##9#10)(::ForwardDiff...\n", | |
" 2 ./special/exp.jl:119; exp(::Float64)\n", | |
" 2 ...ff/src/dual.jl:197; +\n", | |
" 1 ...ff/src/dual.jl:198; +\n", | |
" 1 ...rc/partials.jl:117; _mul_partials\n", | |
" 1 ...rc/partials.jl:219; mul_tuples\n", | |
" 1 ...c/partials.jl:155; macro expansion\n", | |
" 2 .../activation.jl:0; σ(::ForwardDiff.Dual{...\n", | |
" 74 .../activation.jl:1; σ(::ForwardDiff.Dual{...\n", | |
" 1 ...ff/src/dual.jl:207; +\n", | |
" 3 ...ff/src/dual.jl:169; -\n", | |
" 3 ...ff/src/dual.jl:207; /\n", | |
" 5 ...ff/src/dual.jl:208; /\n", | |
" 5 ...rc/partials.jl:82; *\n", | |
" 5 ...c/partials.jl:109; *\n", | |
" 5 ...c/partials.jl:199; scale_tuple\n", | |
" 5 ...c/partials.jl:155; macro expansion\n", | |
" 57 ...ff/src/dual.jl:169; exp\n", | |
" 1 ./special/exp.jl:0; exp(::Float64)\n", | |
" 1 ./special/exp.jl:69; exp(::Float64)\n", | |
" 2 ./special/exp.jl:74; exp(::Float64)\n", | |
" 2 ./special/exp.jl:85; exp(::Float64)\n", | |
" 1 ./special/exp.jl:94; exp(::Float64)\n", | |
" 1 ./special/exp.jl:105; exp(::Float64)\n", | |
" 2 ./special/exp.jl:110; exp(::Float64)\n", | |
" 19 ./special/exp.jl:111; exp(::Float64)\n", | |
" 11 ./special/exp.jl:112; exp(::Float64)\n", | |
" 4 ./special/exp.jl:117; exp(::Float64)\n", | |
" 3 ./special/exp.jl:118; exp(::Float64)\n", | |
" 6 ./special/exp.jl:119; exp(::Float64)\n", | |
" 4 ./special/exp.jl:132; exp(::Float64)\n", | |
" 2 ...ff/src/dual.jl:170; exp\n", | |
" 2 ...rc/partials.jl:82; *\n", | |
" 2 ...c/partials.jl:109; *\n", | |
" 2 ...c/partials.jl:199; scale_tuple\n", | |
" 2 ...c/partials.jl:155; macro expansion\n", | |
" 2 ./broadcast.jl:154; macro expansion\n", | |
" 2 ...idimensional.jl:247; setindex!\n", | |
" 50 ...rc/tracker/lib.jl:101; tracked_broadcast(::Functio...\n", | |
" 1 ...tracker/Tracker.jl:34; Flux.Tracker.TrackedArray(::...\n", | |
" 9 ...src/tracker/lib.jl:91; (::Flux.Tracker.Broadcasted{...\n", | |
" 9 ./array.jl:455; _collect(::Array{ForwardDif...\n", | |
" 6 ./array.jl:375; _similar_for(::Array{Forwa...\n", | |
" 2 ./abstractarray.jl:524; similar(::Array{ForwardDif...\n", | |
" 1 ./array.jl:473; collect_to!(::Array{Float6...\n", | |
" 1 ./array.jl:477; collect_to!(::Array{Float6...\n", | |
" 5 ...rc/tracker/lib.jl:62; *(::TrackedArray{…,Array{...\n", | |
" 3 .../tracker/Tracker.jl:36; Type\n", | |
" 1 ...tracker/Tracker.jl:0; (::Flux.Tracker.Call{Base.#*...\n", | |
" 2 ...tracker/Tracker.jl:17; (::Flux.Tracker.Call{Base.#*...\n", | |
" 2 ./linalg/matmul.jl:367; gemm_wrapper!(::Array{Float...\n", | |
" 2 ./linalg/blas.jl:1027; gemm!(::Char, ::Char, ::Fl...\n", | |
" 11 ...rc/tracker/lib.jl:63; *(::TrackedArray{…,Array{...\n", | |
" 10 ...tracker/Tracker.jl:36; Type\n", | |
" 10 ...tracker/Tracker.jl:17; (::Flux.Tracker.Call{Base.#...\n", | |
" 1 ./linalg/matmul.jl:146; *\n", | |
" 9 ./linalg/matmul.jl:483; generic_matmatmul!(::Array{...\n", | |
" 1 ./linalg/matmul.jl:494; _generic_matmatmul!(::Arra...\n", | |
" 1 ./linalg/matmul.jl:508; _generic_matmatmul!(::Arra...\n", | |
" 1 ./linalg/matmul.jl:509; _generic_matmatmul!(::Arra...\n", | |
" 1 ./linalg/matmul.jl:515; _generic_matmatmul!(::Arra...\n", | |
" 1 ./linalg/matmul.jl:389; copy_transpose!(::Array{Fl...\n", | |
" 1 ...alg/transpose.jl:143; copy_transpose!(::Array{F...\n", | |
" 1 ./abstractarray.jl:362; checkbounds\n", | |
" 4 ./linalg/matmul.jl:516; _generic_matmatmul!(::Arra...\n", | |
" 4 ./linalg/matmul.jl:379; copy!(::Array{Int64,2}, ::...\n", | |
" 1 ./abstractarray.jl:706; copy!(::Array{Int64,2}, :...\n", | |
" 1 ./range.jl:393; length\n", | |
" 1 ./checked.jl:221; checked_sub\n", | |
" 2 ./abstractarray.jl:715; copy!(::Array{Int64,2}, :...\n", | |
" 1 ./abstractarray.jl:716; copy!(::Array{Int64,2}, :...\n", | |
" 1 ./linalg/matmul.jl:581; _generic_matmatmul!(::Arra...\n", | |
" 239 ./broadcast.jl:434; broadcast(::Function, ::Tra...\n", | |
" 8 ...src/tracker/lib.jl:0; tracked_broadcast(::Functio...\n", | |
" 14 ...src/tracker/lib.jl:97; tracked_broadcast(::Functio...\n", | |
" 14 ./tuple.jl:178; map(::Flux.Tracker.##8#10{2...\n", | |
" 12 ./array.jl:455; _collect(::Array{Float64,2}...\n", | |
" 10 ./array.jl:375; _similar_for(::Array{Float...\n", | |
" 9 ./abstractarray.jl:524; similar(::Array{Float64,2}...\n", | |
" 2 ./array.jl:473; collect_to!(::Array{Forwar...\n", | |
" 1 ...rc/tracker/lib.jl:97; #8\n", | |
" 170 ...src/tracker/lib.jl:100; tracked_broadcast(::Functio...\n", | |
" 168 ./broadcast.jl:434; broadcast\n", | |
" 98 ./broadcast.jl:310; broadcast_c\n", | |
" 3 ./reflection.jl:510; _methods_by_ftype(::Any, :...\n", | |
" 1 ./reflection.jl:512; _methods_by_ftype(::Any, :...\n", | |
" 11 ./reflection.jl:521; _methods_by_ftype(::Any, :...\n", | |
" 2 ./broadcast.jl:311; broadcast_c\n", | |
" 2 ./broadcast.jl:53; broadcast_indices\n", | |
" 2 ./tuple.jl:158; map\n", | |
" 2 ./broadcast.jl:48; broadcast_indices\n", | |
" 2 ./broadcast.jl:52; broadcast_indices\n", | |
" 2 ...alg/rowvector.jl:112; indices\n", | |
" 2 ./abstractarray.jl:64; indices\n", | |
" 68 ./broadcast.jl:314; broadcast_c\n", | |
" 8 ./broadcast.jl:0; broadcast_t(::Function, ::...\n", | |
" 32 ./broadcast.jl:266; broadcast_t(::Function, ::...\n", | |
" 1 ./boot.jl:317; Array{ForwardDiff.Dual{Voi...\n", | |
" 9 ./broadcast.jl:268; broadcast_t(::Function, ::...\n", | |
" 2 ./broadcast.jl:139; _broadcast!(::##13#14, ::A...\n", | |
" 2 ./broadcast.jl:147; macro expansion\n", | |
" 2 ./simdloop.jl:73; macro expansion\n", | |
" 2 ./broadcast.jl:153; macro expansion\n", | |
" 2 ./<missing>:0; (::##13#14)(::ForwardDiff...\n", | |
" 2 ./intfuncs.jl:205; literal_pow\n", | |
" 1 ...iff/src/dual.jl:362; ^\n", | |
" 1 ...iff/src/dual.jl:363; ^\n", | |
" 1 ...src/partials.jl:82; *\n", | |
" 1 ...rc/partials.jl:109; *\n", | |
" 1 ...c/partials.jl:199; scale_tuple\n", | |
" 1 ...c/partials.jl:155; macro expansion\n", | |
" 39 ...src/tracker/lib.jl:101; tracked_broadcast(::Functio...\n", | |
" 1 .../tracker/Tracker.jl:0; Flux.Tracker.TrackedArray(::F...\n", | |
" 2 .../tracker/Tracker.jl:34; Flux.Tracker.TrackedArray(::F...\n", | |
" 7 .../src/tracker/lib.jl:91; (::Flux.Tracker.Broadcasted{A...\n", | |
" 7 ./array.jl:455; _collect(::Array{ForwardDif...\n", | |
" 6 ./array.jl:375; _similar_for(::Array{Forwar...\n", | |
" 1 ./abstractarray.jl:524; similar(::Array{ForwardDif...\n", | |
" 1 ./reduce.jl:276; _mapreduce(::Base.#identity,...\n", | |
" 497 ...rc/tracker/back.jl:43; back!(::TrackedArray{…,Arr...\n", | |
" 41 ...src/tracker/back.jl:39; back!\n", | |
" 41 ...src/tracker/back.jl:8; scan(::TrackedArray{…,Array...\n", | |
" 39 ./abstractarray.jl:1731; foreach(::Flux.Tracker.#sca...\n", | |
" 39 ...rc/tracker/back.jl:8; scan(::TrackedArray{…,Arra...\n", | |
" 33 ./abstractarray.jl:1731; foreach(::Flux.Tracker.#sc...\n", | |
" 32 ...c/tracker/back.jl:8; scan(::TrackedArray{…,Arr...\n", | |
" 31 ./abstractarray.jl:1731; foreach(::Flux.Tracker.#...\n", | |
" 1 ...c/tracker/back.jl:6; scan(::TrackedArray{…,Ar...\n", | |
" 30 ...c/tracker/back.jl:8; scan(::TrackedArray{…,Ar...\n", | |
" 22 ./abstractarray.jl:1731; foreach(::Flux.Tracker....\n", | |
" 22 .../tracker/back.jl:8; scan(::TrackedArray{…,...\n", | |
" 13 ...stractarray.jl:1731; foreach(::Flux.Tracker...\n", | |
" 1 ...tracker/back.jl:0; scan(::TrackedArray{…...\n", | |
" 1 ...tracker/back.jl:6; scan(::TrackedArray{…...\n", | |
" 10 ...tracker/back.jl:8; scan(::TrackedArray{…...\n", | |
" 3 ...stractarray.jl:1731; foreach(::Flux.Tracke...\n", | |
" 1 ...tracker/back.jl:8; scan(::TrackedArray{…...\n", | |
" 1 ...tracker/back.jl:12; scan(::TrackedArray{…...\n", | |
" 456 ...src/tracker/back.jl:24; back(::TrackedArray{…,Array...\n", | |
" 453 ...rc/tracker/back.jl:15; back(::Flux.Tracker.Call{Bas...\n", | |
" 446 ...rc/tracker/back.jl:24; back(::TrackedArray{…,Arr...\n", | |
" 435 ...c/tracker/back.jl:15; back(::Flux.Tracker.Call{Fl...\n", | |
" 410 ./abstractarray.jl:1732; foreach(::Function, ::Tup...\n", | |
" 26 ./iterators.jl:185; next(::Base.Iterators.Zip...\n", | |
" 1 ./iterators.jl:0; zip(::Tuple{TrackedArray{...\n", | |
" 361 ...c/tracker/lib.jl:117; (::Flux.Tracker.##17#19)(...\n", | |
" 361 .../tracker/back.jl:32; macro expansion\n", | |
" 1 ...tracker/back.jl:19; back(::TrackedArray{…,...\n", | |
" 1 ./refpointer.jl:121; setindex!\n", | |
" 359 ...tracker/back.jl:24; back(::TrackedArray{…,...\n", | |
" 357 ...tracker/back.jl:15; back(::Flux.Tracker.Call...\n", | |
" 345 ...stractarray.jl:1732; foreach(::Function, ::...\n", | |
" 32 ./iterators.jl:185; next(::Base.Iterators....\n", | |
" 294 ...tracker/lib.jl:117; (::Flux.Tracker.##17#1...\n", | |
" 294 ...racker/back.jl:32; macro expansion\n", | |
" 1 ...acker/back.jl:22; back(::TrackedArray{…...\n", | |
" 245 ...acker/back.jl:24; back(::TrackedArray{…...\n", | |
" 238 ...acker/back.jl:15; back(::Flux.Tracker.C...\n", | |
" 13 ...racker/lib.jl:71; back\n", | |
" 13 ...acker/back.jl:32; macro expansion\n", | |
" 1 ...alg/matmul.jl:0; A_mul_Bt(::Array{Flo...\n", | |
" 8 ...alg/matmul.jl:189; A_mul_Bt(::Array{Flo...\n", | |
" 8 ...lg/matmul.jl:191; A_mul_Bt!\n", | |
" 1 ...lg/matmul.jl:366; gemm_wrapper!(::Ar...\n", | |
" 7 ...lg/matmul.jl:367; gemm_wrapper!(::Ar...\n", | |
" 6 ...alg/blas.jl:1027; gemm!(::Char, ::C...\n", | |
" 2 ...acker/back.jl:21; back(::TrackedArray{...\n", | |
" 2 ./broadcast.jl:204; broadcast!\n", | |
" 1 ./broadcast.jl:208; broadcast_c!\n", | |
" 1 ./broadcast.jl:211; broadcast_c!\n", | |
" 1 ./broadcast.jl:139; _broadcast!\n", | |
" 1 ./broadcast.jl:147; macro expansion\n", | |
" 1 ./simdloop.jl:66; macro expansion\n", | |
" 2 ...acker/back.jl:22; back(::TrackedArray{...\n", | |
" 17 ...racker/lib.jl:71; back(::Base.#*, ::Ar...\n", | |
" 17 ...acker/back.jl:32; macro expansion\n", | |
" 13 ...lg/matmul.jl:483; generic_matmatmul!(...\n", | |
" 5 ...lg/matmul.jl:508; _generic_matmatmul!...\n", | |
" 3 ...lg/matmul.jl:509; _generic_matmatmul!...\n", | |
" 2 ...lg/matmul.jl:515; _generic_matmatmul!...\n", | |
" 2 ...lg/matmul.jl:389; copy_transpose!(::...\n", | |
" 2 ...ranspose.jl:148; copy_transpose!(::...\n", | |
" 3 ...lg/matmul.jl:516; _generic_matmatmul!...\n", | |
" 3 ...lg/matmul.jl:381; copy!(::Array{Int6...\n", | |
" 1 ...ranspose.jl:143; copy_transpose!(::...\n", | |
" 1 ...actarray.jl:362; checkbounds\n", | |
" 1 ...ranspose.jl:147; copy_transpose!(::...\n", | |
" 1 ...ranspose.jl:148; copy_transpose!(::...\n", | |
" 1 ...cker/back.jl:0; back(::TrackedArray...\n", | |
" 1 ...cker/back.jl:19; back(::TrackedArray...\n", | |
" 1 ...cker/back.jl:21; back(::TrackedArray...\n", | |
" 207 ...racker/lib.jl:72; back\n", | |
" 207 ...acker/back.jl:32; macro expansion\n", | |
" 4 ...lg/matmul.jl:182; At_mul_B(::Array{F...\n", | |
" 4 ...lg/matmul.jl:184; At_mul_B!\n", | |
" 4 ...lg/matmul.jl:367; gemm_wrapper!(::Ar...\n", | |
" 1 ...alg/blas.jl:0; gemm!(::Char, ::C...\n", | |
" 3 ...alg/blas.jl:1027; gemm!(::Char, ::C...\n", | |
" 203 ...cker/back.jl:24; back(::TrackedArra...\n", | |
" 199 ...cker/back.jl:15; back(::Flux.Tracke...\n", | |
" 178 ...actarray.jl:1732; foreach(::Functio...\n", | |
" 1 ./iterators.jl:187; done(::Base.Itera...\n", | |
" 22 ./iterators.jl:185; next(::Base.Itera...\n", | |
" 132 ...cker/lib.jl:117; (::Flux.Tracker.#...\n", | |
" 132 ...ker/back.jl:32; macro expansion\n", | |
" 1 ...er/back.jl:21; back(::TrackedAr...\n", | |
" 1 ...oadcast.jl:204; broadcast!\n", | |
" 1 ...oadcast.jl:211; broadcast_c!\n", | |
" 1 ...adcast.jl:139; _broadcast!\n", | |
" 1 ...adcast.jl:147; macro expansion\n", | |
" 1 ...mdloop.jl:73; macro expansion\n", | |
" 1 ...dcast.jl:151; macro expansion\n", | |
" 2 ...er/back.jl:22; back(::TrackedAr...\n", | |
" 48 ...er/back.jl:24; back(::TrackedAr...\n", | |
" 45 ...er/back.jl:15; back(::Flux.Trac...\n", | |
" 43 ...ker/lib.jl:71; back(::Base.#*,...\n", | |
" 43 ...r/back.jl:32; macro expansion\n", | |
" 1 ...matmul.jl:189; A_mul_Bt\n", | |
" 1 ...matmul.jl:473; generic_matmat...\n", | |
" 40 ...matmul.jl:483; generic_matmat...\n", | |
" 4 ...atmul.jl:509; _generic_matm...\n", | |
" 5 ...atmul.jl:515; _generic_matm...\n", | |
" 5 ...atmul.jl:389; copy_transpos...\n", | |
" 1 ...spose.jl:147; copy_transpo...\n", | |
" 4 ...spose.jl:148; copy_transpo...\n", | |
" 4 ...atmul.jl:516; _generic_matm...\n", | |
" 3 ...atmul.jl:381; copy!(::Array...\n", | |
" 1 ...spose.jl:134; copy_transpo...\n", | |
" 1 ./range.jl:393; length\n", | |
" 1 ...cked.jl:164; checked_add\n", | |
" 1 ...spose.jl:146; copy_transpo...\n", | |
" 1 ...spose.jl:147; copy_transpo...\n", | |
" 1 ...atmul.jl:519; _generic_matm...\n", | |
" 4 ...atmul.jl:522; _generic_matm...\n", | |
" 17 ...atmul.jl:523; _generic_matm...\n", | |
" 2 ...atmul.jl:525; _generic_matm...\n", | |
" 1 ...atmul.jl:581; _generic_matm...\n", | |
" 1 ...r/back.jl:21; back(::Tracked...\n", | |
" 1 ...adcast.jl:204; broadcast!\n", | |
" 1 ...dcast.jl:207; broadcast_c!\n", | |
" 1 ...array.jl:64; indices\n", | |
" 80 ...ker/lib.jl:106; unbroadcast(::Tr...\n", | |
" 49 ./array.jl:1819; filter(::Functi...\n", | |
" 2 ...tarray.jl:882; getindex\n", | |
" 2 ...sional.jl:438; _getindex\n", | |
" 2 ...sional.jl:442; macro expansion\n", | |
" 1 ...sional.jl:453; _unsafe_getind...\n", | |
" 1 ...ional.jl:458; macro expansion\n", | |
" 33 ...tarray.jl:1865; map(::Function...\n", | |
" 4 ./array.jl:455; _collect(::Unit...\n", | |
" 1 ./array.jl:0; reshape(::Array...\n", | |
" 1 ./array.jl:184; reshape(::Array...\n", | |
" 25 ...ucedim.jl:572; sum\n", | |
" 25 ...ucedim.jl:570; sum\n", | |
" 25 ...ucedim.jl:241; mapreducedim\n", | |
" 2 ...ucedim.jl:210; mapreducedim!\n", | |
" 2 ...ucedim.jl:202; _mapreducedim!...\n", | |
" 2 ...mdloop.jl:71; macro expansion\n", | |
" 23 ...ucedim.jl:73; reducedim_init...\n", | |
" 3 ...ucedim.jl:33; reduced_indice...\n", | |
" 1 ./array.jl:0; vect(::Base.On...\n", | |
" 1 ./array.jl:76; vect(::Base.On...\n", | |
" 16 ...ucedim.jl:43; reduced_indice...\n", | |
" 20 ...cker/lib.jl:116; back\n", | |
" 1 ...cker/lib.jl:0; (::Flux.Tracker.#...\n", | |
" 19 ...cker/lib.jl:116; (::Flux.Tracker.#...\n", | |
" 19 ./broadcast.jl:434; broadcast\n", | |
" 2 ...oadcast.jl:311; broadcast_c\n", | |
" 2 ./broadcast.jl:53; broadcast_indices\n", | |
" 1 ...oadcast.jl:48; broadcast_indices\n", | |
" 1 ...oadcast.jl:52; broadcast_indices\n", | |
" 1 ...tarray.jl:64; indices\n", | |
" 1 ...oadcast.jl:57; broadcast_shape\n", | |
" 1 ...oadcast.jl:57; broadcast_shape\n", | |
" 1 ...adcast.jl:63; _bcs\n", | |
" 1 ...adcast.jl:0; _bcs1(::Base.On...\n", | |
" 17 ...oadcast.jl:314; broadcast_c\n", | |
" 8 ...oadcast.jl:266; broadcast_t\n", | |
" 9 ...oadcast.jl:268; broadcast_t\n", | |
" 9 ...oadcast.jl:139; _broadcast!\n", | |
" 9 ...adcast.jl:147; macro expansion\n", | |
" 9 ...mdloop.jl:73; macro expansion\n", | |
" 8 ...adcast.jl:151; macro expansion\n", | |
" 1 ...adcast.jl:154; macro expansion\n", | |
" 1 ...ional.jl:247; setindex!\n", | |
" 47 ...racker/lib.jl:106; unbroadcast(::Tracked...\n", | |
" 23 ./array.jl:1819; filter(::Function, :...\n", | |
" 2 ...tractarray.jl:882; getindex\n", | |
" 2 ...imensional.jl:438; _getindex\n", | |
" 1 ...imensional.jl:441; macro expansion\n", | |
" 1 ...ractarray.jl:362; checkbounds\n", | |
" 1 ...imensional.jl:442; macro expansion\n", | |
" 1 ...mensional.jl:453; _unsafe_getindex(::...\n", | |
" 1 ...mensional.jl:458; macro expansion\n", | |
" 18 ...tractarray.jl:1865; map(::Function, ::Un...\n", | |
" 2 ./array.jl:455; _collect(::UnitRange...\n", | |
" 1 ./array.jl:184; reshape(::Array{Floa...\n", | |
" 22 ./reducedim.jl:572; sum\n", | |
" 22 ./reducedim.jl:570; sum\n", | |
" 22 ./reducedim.jl:241; mapreducedim\n", | |
" 2 ./reducedim.jl:210; mapreducedim!\n", | |
" 2 ./reducedim.jl:194; _mapreducedim!(::Ba...\n", | |
" 2 ./simdloop.jl:73; macro expansion\n", | |
" 2 ./reducedim.jl:195; macro expansion\n", | |
" 20 ./reducedim.jl:73; reducedim_initarray...\n", | |
" 1 ./reducedim.jl:0; reduced_indices(::T...\n", | |
" 1 ./reducedim.jl:33; reduced_indices(::T...\n", | |
" 1 ./array.jl:76; vect(::Base.OneTo{I...\n", | |
" 17 ./reducedim.jl:43; reduced_indices(::T...\n", | |
" 12 ...tracker/lib.jl:116; back\n", | |
" 11 .../tracker/lib.jl:116; (::Flux.Tracker.##16#18...\n", | |
" 11 ./broadcast.jl:434; broadcast\n", | |
" 1 ./broadcast.jl:311; broadcast_c\n", | |
" 1 ./broadcast.jl:53; broadcast_indices\n", | |
" 1 ./tuple.jl:159; map\n", | |
" 1 ./broadcast.jl:48; broadcast_indices\n", | |
" 1 ./broadcast.jl:52; broadcast_indices\n", | |
" 1 ...tractarray.jl:64; indices\n", | |
" 10 ./broadcast.jl:314; broadcast_c\n", | |
" 8 ./broadcast.jl:266; broadcast_t\n", | |
" 2 ./broadcast.jl:268; broadcast_t\n", | |
" 2 ./broadcast.jl:139; _broadcast!\n", | |
" 2 ./broadcast.jl:147; macro expansion\n", | |
" 2 ./simdloop.jl:73; macro expansion\n", | |
" 2 ./broadcast.jl:153; macro expansion\n", | |
" 1 .../tracker/lib.jl:106; unbroadcast(::TrackedArr...\n", | |
" 21 ...c/tracker/lib.jl:116; back\n", | |
" 1 ...rc/tracker/lib.jl:0; (::Flux.Tracker.##16#18{Fl...\n", | |
" 19 ...rc/tracker/lib.jl:116; (::Flux.Tracker.##16#18{Fl...\n", | |
" 19 ./broadcast.jl:434; broadcast\n", | |
" 19 ./broadcast.jl:314; broadcast_c\n", | |
" 17 ./broadcast.jl:266; broadcast_t\n", | |
" 2 ./broadcast.jl:268; broadcast_t\n", | |
" 2 ./broadcast.jl:139; _broadcast!\n", | |
" 2 ./broadcast.jl:147; macro expansion\n", | |
" 2 ./simdloop.jl:73; macro expansion\n", | |
" 2 ./broadcast.jl:153; macro expansion\n", | |
" 7 ...src/tracker/lib.jl:55; back\n", | |
" 1 ...src/tracker/lib.jl:52; sum(::TrackedArray{…,Array...\n", | |
" 1 ...x/src/tracker/lib.jl:4; toarray\n", | |
" 2 ./In[16]:3; update!(::Array{Flux.Tracker.T...\n", | |
" 6 ./In[16]:4; update!(::Array{Flux.Tracker.T...\n", | |
" 57 ./In[16]:5; update!(::Array{Flux.Tracker.T...\n", | |
" 29 ./broadcast.jl:204; broadcast!(::Function, ::Arra...\n", | |
" 1 ./broadcast.jl:208; broadcast_c!\n", | |
" 1 ./broadcast.jl:89; check_broadcast_indices\n", | |
" 1 ./broadcast.jl:84; check_broadcast_shape(::Tuple...\n", | |
" 1 ./broadcast.jl:0; check_broadcast_shape(::Tuple...\n", | |
" 26 ./broadcast.jl:211; broadcast_c!\n", | |
" 6 ./broadcast.jl:139; _broadcast!(::##15#16, ::Arra...\n", | |
" 2 ./broadcast.jl:144; macro expansion\n", | |
" 4 ./broadcast.jl:147; macro expansion\n", | |
" 1 ./simdloop.jl:66; macro expansion\n", | |
" 2 ./simdloop.jl:72; macro expansion\n", | |
" 1 ./simdloop.jl:73; macro expansion\n", | |
" 1 ./broadcast.jl:154; macro expansion\n", | |
" 1 ...ltidimensional.jl:247; setindex!\n", | |
" 8 ./In[16]:6; update!(::Array{Flux.Tracker.T...\n", | |
" 2 ./broadcast.jl:22; broadcast!(::Base.#identity, :...\n", | |
" 1 ./array.jl:227; fill!(::Array{Float64,2}, ::I...\n", | |
" 1 ./array.jl:228; fill!(::Array{Float64,1}, ::I...\n", | |
" 5 ./In[20]:19; train!(::Array{Flux.Tracker.Tr...\n", | |
" 4 ....6/Flux/src/utils.jl:37; (::Flux.##throttled#4#9{Bool,Bo...\n" | |
] | |
} | |
], | |
"source": [ | |
"reset!.(ps)\n", | |
"@profile train!(ps, xys, 0.3, 2000, cb)\n", | |
"Profile.print()" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Julia 0.6.0", | |
"language": "julia", | |
"name": "julia-0.6" | |
}, | |
"language_info": { | |
"file_extension": ".jl", | |
"mimetype": "application/julia", | |
"name": "julia", | |
"version": "0.6.0" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment