Last active
July 12, 2019 12:10
-
-
Save antimon2/2939e75f2620112ac0525c677beab41c to your computer and use it in GitHub Desktop.
FashionMNIST_Sample
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "metadata": { | |
| "ExecuteTime": { | |
| "start_time": "2019-07-12T09:59:47.097Z", | |
| "end_time": "2019-07-12T18:59:49.323000+09:00" | |
| }, | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "# Julia Version\nVERSION", | |
| "execution_count": 1, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "execution_count": 1, | |
| "data": { | |
| "text/plain": "v\"1.1.1\"" | |
| }, | |
| "metadata": {} | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": { | |
| "ExecuteTime": { | |
| "start_time": "2019-07-12T09:59:50.453Z", | |
| "end_time": "2019-07-12T18:59:52.500000+09:00" | |
| }, | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "]st Flux", | |
| "execution_count": 2, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": "\u001b[36m\u001b[1mProject \u001b[22m\u001b[39mOSCNagoyaDemoJL v0.1.0\n\u001b[32m\u001b[1m Status\u001b[22m\u001b[39m `~/Documents/ipynb_root/toybox/OSCNagoyaDemoJL/Project.toml`\n \u001b[90m [944b1d66]\u001b[39m\u001b[37m CodecZlib v0.5.2\u001b[39m\n \u001b[90m [587475ba]\u001b[39m\u001b[37m Flux v0.8.3\u001b[39m\n \u001b[90m [2913bbd2]\u001b[39m\u001b[37m StatsBase v0.31.0\u001b[39m\n \u001b[90m [9a3f8284]\u001b[39m\u001b[37m Random \u001b[39m\n", | |
| "name": "stdout" | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": { | |
| "ExecuteTime": { | |
| "start_time": "2019-07-12T09:59:56.585Z", | |
| "end_time": "2019-07-12T19:00:10.088000+09:00" | |
| }, | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "using Base.Iterators: repeated\nusing CSV: write\nusing DataFrames: DataFrame\nusing Flux\nusing Random\n# using RDatasets\nusing StatsBase: sample\nRandom.seed!(123);", | |
| "execution_count": 3, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "ExecuteTime": { | |
| "start_time": "2019-07-12T10:00:11.904Z", | |
| "end_time": "2019-07-12T19:00:15.078000+09:00" | |
| }, | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "using Flux.Data\n_xtrn = FashionMNIST.images(:train);\n_ytrn = FashionMNIST.labels(:train);\n_xtst = FashionMNIST.images(:test);\n_ytst = FashionMNIST.labels(:test);", | |
| "execution_count": 4, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "ExecuteTime": { | |
| "start_time": "2019-07-12T10:00:18.648Z", | |
| "end_time": "2019-07-12T19:00:20.167000+09:00" | |
| }, | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "# Translate the xtrn and xtst to Float32 Array\nxtrn = reshape(Float32.(hcat(float.(_xtrn)...)), (28, 28, 1, :));\nxtst = reshape(Float32.(hcat(float.(_xtst)...)), (28, 28, 1, :));", | |
| "execution_count": 5, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "ExecuteTime": { | |
| "start_time": "2019-07-12T10:00:21.154Z", | |
| "end_time": "2019-07-12T19:00:21.234000+09:00" | |
| }, | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "# Translate the ytrn and ytst to one-hot-vector encoding\nytrn = Flux.onehotbatch(_ytrn, 0:9);\nytst = Flux.onehotbatch(_ytst, 0:9);", | |
| "execution_count": 6, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "ExecuteTime": { | |
| "start_time": "2019-07-12T10:00:25.316Z", | |
| "end_time": "2019-07-12T19:00:25.780000+09:00" | |
| }, | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "println.(summary.((xtrn,ytrn,xtst,ytst)));", | |
| "execution_count": 7, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": "28×28×1×60000 Array{Float32,4}\n10×60000 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}\n28×28×1×10000 Array{Float32,4}\n10×10000 Flux.OneHotMatrix{Array{Flux.OneHotVector,1}}\n", | |
| "name": "stdout" | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": { | |
| "ExecuteTime": { | |
| "start_time": "2019-07-12T10:00:35.621Z", | |
| "end_time": "2019-07-12T19:00:36.279000+09:00" | |
| }, | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "# Define the batches of data for every updating of the parameter\nminibatches = Tuple{typeof(xtrn), typeof(ytrn)}[]\nbatch_size = 1000\nn_batch = 60\nrandom_idcs = randperm(size(xtrn, 4))\noffset = 1\nsegment = batch_size\nfor i in 1:n_batch\n idcs = random_idcs[offset:segment]\n push!(minibatches, (xtrn[:, :, :, idcs], ytrn[:, idcs]))\n offset += batch_size\n segment += batch_size\nend", | |
| "execution_count": 8, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "ExecuteTime": { | |
| "start_time": "2019-07-12T10:01:06.581Z", | |
| "end_time": "2019-07-12T19:01:06.605000+09:00" | |
| }, | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "#=\nInitialization is from\nHe et al., 2015,\nDelving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification\nhttps://arxiv.org/abs/1502.01852\n=#\nkaiming(::Type{T}, h, w, i, o) where {T<:AbstractFloat} = T(sqrt(2 / (w * h * o))) .* randn(T, h, w, i, o)\nglorot_uniform(::Type{T}, dims...) where {T<:AbstractFloat} = (rand(T, dims...) .- T(0.5)) .* sqrt(T(24.0)/(sum(dims)))\n\n(::Type{Flux.Conv})(::Type{T}, k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = identity;\n init = kaiming, stride = 1, pad = 0, dilation = 1) where {T<:AbstractFloat, N} =\n Flux.Conv(param(init(T, k..., ch...)), param(zeros(T, ch[2])), σ, stride=stride, pad=pad, dilation=dilation)\n\nfunction (::Type{Flux.Dense})(::Type{T}, in::Integer, out::Integer, σ = identity;\n initW = glorot_uniform, initb = zeros) where {T<:AbstractFloat}\n return Flux.Dense(param(initW(T, out, in)), param(initb(T, out)), σ)\nend", | |
| "execution_count": 9, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "ExecuteTime": { | |
| "start_time": "2019-07-12T10:02:00.011Z", | |
| "end_time": "2019-07-12T19:02:01.125000+09:00" | |
| }, | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "# Specify the model\nmodel = Flux.Chain(\n Flux.Conv(Float32, (5, 5), 1=>20, relu),\n Flux.MaxPool((2, 2)),\n Flux.Conv(Float32, (5, 5), 20=>50, relu),\n Flux.MaxPool((2, 2)),\n x -> reshape(x, (800, :)), # Flatten\n Flux.Dense(Float32, 800, 500, relu),\n Flux.Dense(Float32, 500, 10),\n Flux.softmax\n);", | |
| "execution_count": 10, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "ExecuteTime": { | |
| "start_time": "2019-07-12T10:02:08.067Z", | |
| "end_time": "2019-07-12T19:02:08.071000+09:00" | |
| }, | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "# Define loss function\nloss(x, y) = Flux.crossentropy(model(x), Float32.(y));", | |
| "execution_count": 11, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "ExecuteTime": { | |
| "start_time": "2019-07-12T10:02:39.124Z", | |
| "end_time": "2019-07-12T19:02:39.294000+09:00" | |
| }, | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "# # Define the callback function that prints the loss every epoch\n# callback = () -> @show(loss(xtrn, ytrn));\nmutable struct Callback\n per::Int\n step::Int\n err\n acc\n Callback(per::Int) = new(per, 0)\nend\nfunction (c::Callback)()\n step = c.step + 1\n c.step = step\n if step % c.per == 0\n losstrn = Flux.Tracker.data(loss(xtrn, ytrn))\n losstst = Flux.Tracker.data(loss(xtst, ytst))\n @info \"loss\" step losstrn losstst\n c.err = vcat(c.err, hcat(losstrn, losstst))\n c.acc = vcat(c.acc, hcat(accuracy(xtrn, ytrn), accuracy(xtst, ytst)))\n end\nend\n# function init(c::Callback); c.step = 0; end", | |
| "execution_count": 12, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "ExecuteTime": { | |
| "start_time": "2019-07-12T10:02:42.019Z", | |
| "end_time": "2019-07-12T19:02:42.523000+09:00" | |
| }, | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "# Train the model\nfunction accuracy(x, y)\n return sum(Flux.onecold(model(x)) .== Flux.onecold(y)) / size(y, 2)\nend", | |
| "execution_count": 13, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "execution_count": 13, | |
| "data": { | |
| "text/plain": "accuracy (generic function with 1 method)" | |
| }, | |
| "metadata": {} | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": { | |
| "ExecuteTime": { | |
| "start_time": "2019-07-12T10:05:01.694Z", | |
| "end_time": "2019-07-12T19:15:50.697000+09:00" | |
| }, | |
| "scrolled": false, | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "@time begin\n callback = Callback(60)\n callback.err = hcat(Flux.Tracker.data(loss(xtrn, ytrn)), Flux.Tracker.data(loss(xtst, ytst)))\n callback.acc = hcat(accuracy(xtrn, ytrn), accuracy(xtst, ytst))\n Flux.@epochs 5 Flux.train!(loss, Flux.params(model), minibatches, Flux.ADAM(), cb=callback);\nend", | |
| "execution_count": 14, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": "┌ Info: Epoch 1\n└ @ Main /Users/antimon2/.julia/packages/Flux/qXNjB/src/optimise/train.jl:105\n┌ Info: loss\n│ step = 60\n│ losstrn = 0.5770615\n│ losstst = 0.6001126\n└ @ Main In[12]:16\n┌ Info: Epoch 2\n└ @ Main /Users/antimon2/.julia/packages/Flux/qXNjB/src/optimise/train.jl:105\n┌ Info: loss\n│ step = 120\n│ losstrn = 0.46836329\n│ losstst = 0.49278983\n└ @ Main In[12]:16\n┌ Info: Epoch 3\n└ @ Main /Users/antimon2/.julia/packages/Flux/qXNjB/src/optimise/train.jl:105\n┌ Info: loss\n│ step = 180\n│ losstrn = 0.4173441\n│ losstst = 0.44334346\n└ @ Main In[12]:16\n┌ Info: Epoch 4\n└ @ Main /Users/antimon2/.julia/packages/Flux/qXNjB/src/optimise/train.jl:105\n┌ Info: loss\n│ step = 240\n│ losstrn = 0.38111547\n│ losstst = 0.40958908\n└ @ Main In[12]:16\n┌ Info: Epoch 5\n└ @ Main /Users/antimon2/.julia/packages/Flux/qXNjB/src/optimise/train.jl:105\n┌ Info: loss\n│ step = 300\n│ losstrn = 0.35374358\n│ losstst = 0.38537923\n└ @ Main In[12]:16\n", | |
| "name": "stderr" | |
| }, | |
| { | |
| "output_type": "stream", | |
| "text": "647.178517 seconds (96.54 M allocations: 208.032 GiB, 15.43% gc time)\n", | |
| "name": "stdout" | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": { | |
| "ExecuteTime": { | |
| "start_time": "2019-07-12T12:08:55.565Z", | |
| "end_time": "2019-07-12T21:08:58.678000+09:00" | |
| }, | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "callback.err", | |
| "execution_count": 15, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "execution_count": 15, | |
| "data": { | |
| "text/plain": "6×2 Array{Float32,2}:\n 2.31074 2.31056 \n 0.577061 0.600113\n 0.468363 0.49279 \n 0.417344 0.443343\n 0.381115 0.409589\n 0.353744 0.385379" | |
| }, | |
| "metadata": {} | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": { | |
| "ExecuteTime": { | |
| "start_time": "2019-07-12T12:08:59.640Z", | |
| "end_time": "2019-07-12T21:09:03.008000+09:00" | |
| }, | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "callback.acc", | |
| "execution_count": 16, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "execution_count": 16, | |
| "data": { | |
| "text/plain": "6×2 Array{Float64,2}:\n 0.10095 0.101 \n 0.787633 0.7803\n 0.834583 0.8279\n 0.853567 0.8434\n 0.865633 0.8549\n 0.87435 0.8617" | |
| }, | |
| "metadata": {} | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": { | |
| "ExecuteTime": { | |
| "start_time": "2019-07-12T12:09:04.753Z", | |
| "end_time": "2019-07-12T21:09:07.362000+09:00" | |
| }, | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "# Save loss and accuracy to csv for visualization\nwrite(\"fmnist-error-flux.csv\", DataFrame(callback.err, [:training, :testing]))\nwrite(\"fmnist-accuracy-flux.csv\", DataFrame(callback.acc, [:training, :testing]))", | |
| "execution_count": 17, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "execution_count": 17, | |
| "data": { | |
| "text/plain": "\"fmnist-accuracy-flux.csv\"" | |
| }, | |
| "metadata": {} | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2019-07-11T14:58:17.832000+09:00", | |
| "start_time": "2019-07-11T05:58:17.826Z" | |
| }, | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "# Predict species \ntrn_yidx = Flux.onecold(model(xtrn'));\ntst_yidx = Flux.onecold(model(xtst'));", | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "ExecuteTime": { | |
| "start_time": "2019-07-12T12:09:09.505Z", | |
| "end_time": "2019-07-12T21:09:38.418000+09:00" | |
| }, | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "# Check accuracy of the model\naccuracy(xtrn, ytrn)", | |
| "execution_count": 18, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "execution_count": 18, | |
| "data": { | |
| "text/plain": "0.87435" | |
| }, | |
| "metadata": {} | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": { | |
| "ExecuteTime": { | |
| "start_time": "2019-07-12T12:09:39.677Z", | |
| "end_time": "2019-07-12T21:09:44.491000+09:00" | |
| }, | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "accuracy(xtst, ytst)", | |
| "execution_count": 19, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "execution_count": 19, | |
| "data": { | |
| "text/plain": "0.8617" | |
| }, | |
| "metadata": {} | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "", | |
| "execution_count": null, | |
| "outputs": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "name": "julia-1.1", | |
| "display_name": "Julia 1.1.1", | |
| "language": "julia" | |
| }, | |
| "language_info": { | |
| "file_extension": ".jl", | |
| "name": "julia", | |
| "mimetype": "application/julia", | |
| "version": "1.1.1" | |
| }, | |
| "gist": { | |
| "id": "", | |
| "data": { | |
| "description": "FashionMNIST_Sample", | |
| "public": true | |
| } | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment