Skip to content

Instantly share code, notes, and snippets.

@antimon2
Last active July 12, 2019 12:10
Show Gist options
  • Save antimon2/2939e75f2620112ac0525c677beab41c to your computer and use it in GitHub Desktop.
Save antimon2/2939e75f2620112ac0525c677beab41c to your computer and use it in GitHub Desktop.
FashionMNIST_Sample
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-07-12T09:59:47.097Z",
"end_time": "2019-07-12T18:59:49.323000+09:00"
},
"trusted": true
},
"cell_type": "code",
"source": "# Julia Version\nVERSION",
"execution_count": 1,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 1,
"data": {
"text/plain": "v\"1.1.1\""
},
"metadata": {}
}
]
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-07-12T09:59:50.453Z",
"end_time": "2019-07-12T18:59:52.500000+09:00"
},
"trusted": true
},
"cell_type": "code",
"source": "]st Flux",
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"text": "\u001b[36m\u001b[1mProject \u001b[22m\u001b[39mOSCNagoyaDemoJL v0.1.0\n\u001b[32m\u001b[1m Status\u001b[22m\u001b[39m `~/Documents/ipynb_root/toybox/OSCNagoyaDemoJL/Project.toml`\n \u001b[90m [944b1d66]\u001b[39m\u001b[37m CodecZlib v0.5.2\u001b[39m\n \u001b[90m [587475ba]\u001b[39m\u001b[37m Flux v0.8.3\u001b[39m\n \u001b[90m [2913bbd2]\u001b[39m\u001b[37m StatsBase v0.31.0\u001b[39m\n \u001b[90m [9a3f8284]\u001b[39m\u001b[37m Random \u001b[39m\n",
"name": "stdout"
}
]
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-07-12T09:59:56.585Z",
"end_time": "2019-07-12T19:00:10.088000+09:00"
},
"trusted": true
},
"cell_type": "code",
"source": "using Base.Iterators: repeated\nusing CSV: write\nusing DataFrames: DataFrame\nusing Flux\nusing Random\n# using RDatasets\nusing StatsBase: sample\nRandom.seed!(123);",
"execution_count": 3,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-07-12T10:00:11.904Z",
"end_time": "2019-07-12T19:00:15.078000+09:00"
},
"trusted": true
},
"cell_type": "code",
"source": "using Flux.Data\n_xtrn = FashionMNIST.images(:train);\n_ytrn = FashionMNIST.labels(:train);\n_xtst = FashionMNIST.images(:test);\n_ytst = FashionMNIST.labels(:test);",
"execution_count": 4,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-07-12T10:00:18.648Z",
"end_time": "2019-07-12T19:00:20.167000+09:00"
},
"trusted": true
},
"cell_type": "code",
"source": "# Translate the xtrn and xtst to Float32 Array\nxtrn = reshape(Float32.(hcat(float.(_xtrn)...)), (28, 28, 1, :));\nxtst = reshape(Float32.(hcat(float.(_xtst)...)), (28, 28, 1, :));",
"execution_count": 5,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-07-12T10:00:21.154Z",
"end_time": "2019-07-12T19:00:21.234000+09:00"
},
"trusted": true
},
"cell_type": "code",
"source": "# Translate the ytrn and ytst to one-hot-vector encoding\nytrn = Flux.onehotbatch(_ytrn, 0:9);\nytst = Flux.onehotbatch(_ytst, 0:9);",
"execution_count": 6,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-07-12T10:00:25.316Z",
"end_time": "2019-07-12T19:00:25.780000+09:00"
},
"trusted": true
},
"cell_type": "code",
"source": "println.(summary.((xtrn,ytrn,xtst,ytst)));",
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"text": "28×28×1×60000 Array{Float32,4}\n10×60000 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}\n28×28×1×10000 Array{Float32,4}\n10×10000 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}\n",
"name": "stdout"
}
]
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-07-12T10:00:35.621Z",
"end_time": "2019-07-12T19:00:36.279000+09:00"
},
"trusted": true
},
"cell_type": "code",
"source": "# Define the batches of data for every updating of the parameter\nminibatches = Tuple{typeof(xtrn), typeof(ytrn)}[]\nbatch_size = 1000\nn_batch = 60\nrandom_idcs = randperm(size(xtrn, 4))\noffset = 1\nsegment = batch_size\nfor i in 1:n_batch\n idcs = random_idcs[offset:segment]\n push!(minibatches, (xtrn[:, :, :, idcs], ytrn[:, idcs]))\n offset += batch_size\n segment += batch_size\nend",
"execution_count": 8,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-07-12T10:01:06.581Z",
"end_time": "2019-07-12T19:01:06.605000+09:00"
},
"trusted": true
},
"cell_type": "code",
"source": "#=\nInitialization is from\nHe et al., 2015,\nDelving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification\nhttps://arxiv.org/abs/1502.01852\n=#\nkaiming(::Type{T}, h, w, i, o) where {T<:AbstractFloat} = T(sqrt(2 / (w * h * o))) .* randn(T, h, w, i, o)\nglorot_uniform(::Type{T}, dims...) where {T<:AbstractFloat} = (rand(T, dims...) .- T(0.5)) .* sqrt(T(24.0)/(sum(dims)))\n\n(::Type{Flux.Conv})(::Type{T}, k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;\n init = kaiming, stride = 1, pad = 0, dilation = 1) where {T<:AbstractFloat, N} =\n Flux.Conv(param(init(T, k..., ch...)), param(zeros(T, ch[2])), σ, stride=stride, pad=pad, dilation=dilation)\n\nfunction (::Type{Flux.Dense})(::Type{T}, in::Integer, out::Integer, σ = identity;\n initW = glorot_uniform, initb = zeros) where {T<:AbstractFloat}\n return Flux.Dense(param(initW(T, out, in)), param(initb(T, out)), σ)\nend",
"execution_count": 9,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-07-12T10:02:00.011Z",
"end_time": "2019-07-12T19:02:01.125000+09:00"
},
"trusted": true
},
"cell_type": "code",
"source": "# Specify the model\nmodel = Flux.Chain(\n Flux.Conv(Float32, (5, 5), 1=>20, relu),\n Flux.MaxPool((2, 2)),\n Flux.Conv(Float32, (5, 5), 20=>50, relu),\n Flux.MaxPool((2, 2)),\n x -> reshape(x, (800, :)), # Flatten\n Flux.Dense(Float32, 800, 500, relu),\n Flux.Dense(Float32, 500, 10),\n Flux.softmax\n);",
"execution_count": 10,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-07-12T10:02:08.067Z",
"end_time": "2019-07-12T19:02:08.071000+09:00"
},
"trusted": true
},
"cell_type": "code",
"source": "# Define loss function\nloss(x, y) = Flux.crossentropy(model(x), Float32.(y));",
"execution_count": 11,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-07-12T10:02:39.124Z",
"end_time": "2019-07-12T19:02:39.294000+09:00"
},
"trusted": true
},
"cell_type": "code",
"source": "# # Define the callback function that prints the loss every epoch\n# callback = () -> @show(loss(xtrn, ytrn));\nmutable struct Callback\n per::Int\n step::Int\n err\n acc\n Callback(per::Int) = new(per, 0)\nend\nfunction (c::Callback)()\n step = c.step + 1\n c.step = step\n if step % c.per == 0\n losstrn = Flux.Tracker.data(loss(xtrn, ytrn))\n losstst = Flux.Tracker.data(loss(xtst, ytst))\n @info \"loss\" step losstrn losstst\n c.err = vcat(c.err, hcat(losstrn, losstst))\n c.acc = vcat(c.acc, hcat(accuracy(xtrn, ytrn), accuracy(xtst, ytst)))\n end\nend\n# function init(c::Callback); c.step = 0; end",
"execution_count": 12,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-07-12T10:02:42.019Z",
"end_time": "2019-07-12T19:02:42.523000+09:00"
},
"trusted": true
},
"cell_type": "code",
"source": "# Train the model\nfunction accuracy(x, y)\n return sum(Flux.onecold(model(x)) .== Flux.onecold(y)) / size(y, 2)\nend",
"execution_count": 13,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 13,
"data": {
"text/plain": "accuracy (generic function with 1 method)"
},
"metadata": {}
}
]
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-07-12T10:05:01.694Z",
"end_time": "2019-07-12T19:15:50.697000+09:00"
},
"scrolled": false,
"trusted": true
},
"cell_type": "code",
"source": "@time begin\n callback = Callback(60)\n callback.err = hcat(Flux.Tracker.data(loss(xtrn, ytrn)), Flux.Tracker.data(loss(xtst, ytst)))\n callback.acc = hcat(accuracy(xtrn, ytrn), accuracy(xtst, ytst))\n Flux.@epochs 5 Flux.train!(loss, Flux.params(model), minibatches, Flux.ADAM(), cb=callback);\nend",
"execution_count": 14,
"outputs": [
{
"output_type": "stream",
"text": "┌ Info: Epoch 1\n└ @ Main /Users/antimon2/.julia/packages/Flux/qXNjB/src/optimise/train.jl:105\n┌ Info: loss\n│ step = 60\n│ losstrn = 0.5770615\n│ losstst = 0.6001126\n└ @ Main In[12]:16\n┌ Info: Epoch 2\n└ @ Main /Users/antimon2/.julia/packages/Flux/qXNjB/src/optimise/train.jl:105\n┌ Info: loss\n│ step = 120\n│ losstrn = 0.46836329\n│ losstst = 0.49278983\n└ @ Main In[12]:16\n┌ Info: Epoch 3\n└ @ Main /Users/antimon2/.julia/packages/Flux/qXNjB/src/optimise/train.jl:105\n┌ Info: loss\n│ step = 180\n│ losstrn = 0.4173441\n│ losstst = 0.44334346\n└ @ Main In[12]:16\n┌ Info: Epoch 4\n└ @ Main /Users/antimon2/.julia/packages/Flux/qXNjB/src/optimise/train.jl:105\n┌ Info: loss\n│ step = 240\n│ losstrn = 0.38111547\n│ losstst = 0.40958908\n└ @ Main In[12]:16\n┌ Info: Epoch 5\n└ @ Main /Users/antimon2/.julia/packages/Flux/qXNjB/src/optimise/train.jl:105\n┌ Info: loss\n│ step = 300\n│ losstrn = 0.35374358\n│ losstst = 0.38537923\n└ @ Main In[12]:16\n",
"name": "stderr"
},
{
"output_type": "stream",
"text": "647.178517 seconds (96.54 M allocations: 208.032 GiB, 15.43% gc time)\n",
"name": "stdout"
}
]
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-07-12T12:08:55.565Z",
"end_time": "2019-07-12T21:08:58.678000+09:00"
},
"trusted": true
},
"cell_type": "code",
"source": "callback.err",
"execution_count": 15,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 15,
"data": {
"text/plain": "6×2 Array{Float32,2}:\n 2.31074 2.31056 \n 0.577061 0.600113\n 0.468363 0.49279 \n 0.417344 0.443343\n 0.381115 0.409589\n 0.353744 0.385379"
},
"metadata": {}
}
]
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-07-12T12:08:59.640Z",
"end_time": "2019-07-12T21:09:03.008000+09:00"
},
"trusted": true
},
"cell_type": "code",
"source": "callback.acc",
"execution_count": 16,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 16,
"data": {
"text/plain": "6×2 Array{Float64,2}:\n 0.10095 0.101 \n 0.787633 0.7803\n 0.834583 0.8279\n 0.853567 0.8434\n 0.865633 0.8549\n 0.87435 0.8617"
},
"metadata": {}
}
]
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-07-12T12:09:04.753Z",
"end_time": "2019-07-12T21:09:07.362000+09:00"
},
"trusted": true
},
"cell_type": "code",
"source": "# Save loss and accuracy to csv for visualization\nwrite(\"fmnist-error-flux.csv\", DataFrame(callback.err, [:training, :testing]))\nwrite(\"fmnist-accuracy-flux.csv\", DataFrame(callback.acc, [:training, :testing]))",
"execution_count": 17,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 17,
"data": {
"text/plain": "\"fmnist-accuracy-flux.csv\""
},
"metadata": {}
}
]
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2019-07-11T14:58:17.832000+09:00",
"start_time": "2019-07-11T05:58:17.826Z"
},
"trusted": true
},
"cell_type": "code",
"source": "# Predict species \ntrn_yidx = Flux.onecold(model(xtrn'));\ntst_yidx = Flux.onecold(model(xtst'));",
"execution_count": null,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-07-12T12:09:09.505Z",
"end_time": "2019-07-12T21:09:38.418000+09:00"
},
"trusted": true
},
"cell_type": "code",
"source": "# Check accuracy of the model\naccuracy(xtrn, ytrn)",
"execution_count": 18,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 18,
"data": {
"text/plain": "0.87435"
},
"metadata": {}
}
]
},
{
"metadata": {
"ExecuteTime": {
"start_time": "2019-07-12T12:09:39.677Z",
"end_time": "2019-07-12T21:09:44.491000+09:00"
},
"trusted": true
},
"cell_type": "code",
"source": "accuracy(xtst, ytst)",
"execution_count": 19,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 19,
"data": {
"text/plain": "0.8617"
},
"metadata": {}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "",
"execution_count": null,
"outputs": []
}
],
"metadata": {
"kernelspec": {
"name": "julia-1.1",
"display_name": "Julia 1.1.1",
"language": "julia"
},
"language_info": {
"file_extension": ".jl",
"name": "julia",
"mimetype": "application/julia",
"version": "1.1.1"
},
"gist": {
"id": "",
"data": {
"description": "FashionMNIST_Sample",
"public": true
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment