Last active
July 12, 2019 12:10
-
-
Save antimon2/2939e75f2620112ac0525c677beab41c to your computer and use it in GitHub Desktop.
FashionMNIST_Sample
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "metadata": { | |
| "ExecuteTime": { | |
| "start_time": "2019-07-12T09:47:32.851Z", | |
| "end_time": "2019-07-12T18:47:34.972000+09:00" | |
| }, | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "# Julia Version\nVERSION", | |
| "execution_count": 1, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "execution_count": 1, | |
| "data": { | |
| "text/plain": "v\"1.1.1\"" | |
| }, | |
| "metadata": {} | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": { | |
| "ExecuteTime": { | |
| "start_time": "2019-07-12T09:47:37.241Z", | |
| "end_time": "2019-07-12T18:47:39.232000+09:00" | |
| }, | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "]st Knet", | |
| "execution_count": 2, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": "\u001b[36m\u001b[1mProject \u001b[22m\u001b[39mOSCNagoyaDemoJL v0.1.0\n\u001b[32m\u001b[1m Status\u001b[22m\u001b[39m `~/Documents/ipynb_root/toybox/OSCNagoyaDemoJL/Project.toml`\n \u001b[90m [1902f260]\u001b[39m\u001b[37m Knet v1.2.2\u001b[39m\n \u001b[90m [9a3f8284]\u001b[39m\u001b[37m Random \u001b[39m\n", | |
| "name": "stdout" | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": { | |
| "ExecuteTime": { | |
| "start_time": "2019-07-12T09:47:49.476Z", | |
| "end_time": "2019-07-12T18:47:57.859000+09:00" | |
| }, | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "using CSV: write\nusing DataFrames: DataFrame\nusing Knet\nusing Random\n# using RDatasets\nusing StatsBase: sample\nRandom.seed!(123);", | |
| "execution_count": 3, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "ExecuteTime": { | |
| "start_time": "2019-07-12T09:48:03.726Z", | |
| "end_time": "2019-07-12T18:48:10.247000+09:00" | |
| }, | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "include(Knet.dir(\"data\", \"fashion-mnist.jl\")) # Knet.dir constructs a path relative to Knet root\nxtrn,ytrn,xtst,ytst = fmnist() # mnist() loads MNIST data and converts into Julia arrays\nprintln.(summary.((xtrn,ytrn,xtst,ytst)));", | |
| "execution_count": 4, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": "┌ Info: Loading FMNIST...\n└ @ Main /Users/antimon2/.julia/packages/Knet/HwZrA/data/fashion-mnist.jl:37\n", | |
| "name": "stderr" | |
| }, | |
| { | |
| "output_type": "stream", | |
| "text": "28×28×1×60000 Array{Float32,4}\n60000-element Array{UInt8,1}\n28×28×1×10000 Array{Float32,4}\n10000-element Array{UInt8,1}\n", | |
| "name": "stdout" | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": { | |
| "ExecuteTime": { | |
| "start_time": "2019-07-12T09:48:11.799Z", | |
| "end_time": "2019-07-12T18:48:12.348000+09:00" | |
| }, | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "dtrn = minibatch(xtrn, ytrn, 1000);\ndtst = minibatch(xtst, ytst, 1000);", | |
| "execution_count": 5, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "ExecuteTime": { | |
| "start_time": "2019-07-12T09:48:16.782Z", | |
| "end_time": "2019-07-12T18:48:16.850000+09:00" | |
| }, | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "println.(summary.((dtrn,dtst)));", | |
| "execution_count": 6, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": "60-element Knet.Data{Tuple{Array{Float32,4},Array{UInt8,1}}}\n10-element Knet.Data{Tuple{Array{Float32,4},Array{UInt8,1}}}\n", | |
| "name": "stdout" | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": { | |
| "ExecuteTime": { | |
| "start_time": "2019-07-12T09:48:26.573Z", | |
| "end_time": "2019-07-12T18:48:27.045000+09:00" | |
| }, | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "# Define the convolutional layer\nstruct Conv; w; b; f; end\n(c::Conv)(x) = c.f.(conv4(c.w, x) .+ c.b)\nConv(w1::Int,w2::Int,cx::Int,cy::Int,f=relu) = Conv(param(w1,w2,cx,cy), param0(1,1,cy,1), f)", | |
| "execution_count": 7, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "execution_count": 7, | |
| "data": { | |
| "text/plain": "Conv" | |
| }, | |
| "metadata": {} | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": { | |
| "ExecuteTime": { | |
| "start_time": "2019-07-12T09:48:27.690Z", | |
| "end_time": "2019-07-12T18:48:27.802000+09:00" | |
| }, | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "# Define pooling layer\nstruct MaxPool; w; p; s; end\n(p::MaxPool)(x) = pool(x; window=p.w, padding=p.p, stride=p.s)\nMaxPool(w=2, p=0) = MaxPool(w, p, w)", | |
| "execution_count": 8, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "execution_count": 8, | |
| "data": { | |
| "text/plain": "MaxPool" | |
| }, | |
| "metadata": {} | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": { | |
| "ExecuteTime": { | |
| "start_time": "2019-07-12T04:11:20.189Z", | |
| "end_time": "2019-07-12T13:11:20.339000+09:00" | |
| }, | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "# Define dropout layer\nstruct Dropout; p; end\n(d::Dropout)(x) = dropout(x, d.p)\nDropout() = Dropout(0.5)", | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "ExecuteTime": { | |
| "start_time": "2019-07-12T09:48:39.288Z", | |
| "end_time": "2019-07-12T18:48:39.295000+09:00" | |
| }, | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "# Define the dense layer\nstruct Dense; w; b; f; end\nDense(i::Int, o::Int, f = relu) = Dense(param(o, i), param0(o), f); # constructor\n(d::Dense)(x) = d.f.(d.w * Knet.mat(x) .+ d.b); # define method for dense layer", | |
| "execution_count": 9, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "ExecuteTime": { | |
| "start_time": "2019-07-12T09:48:42.249Z", | |
| "end_time": "2019-07-12T18:48:42.255000+09:00" | |
| }, | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "# Define the chain layer\nstruct Chain\n layers\n Chain(layers...) = new(layers)\nend\n(c::Chain)(x) = (for l in c.layers; x = l(x); end; x); # define method for feed-forward\n(c::Chain)(x, y) = nll(c(x), y, dims = 1); # define method for negative-log likelihood loss function", | |
| "execution_count": 10, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": {}, | |
| "cell_type": "markdown", | |
| "source": "※↑ `nll()` 関数が softmax 相当の処理を内部に組み込んでいるため、↓モデル作成時に `softmax()` 組み込み不要" | |
| }, | |
| { | |
| "metadata": { | |
| "ExecuteTime": { | |
| "start_time": "2019-07-12T09:49:00.808Z", | |
| "end_time": "2019-07-12T18:49:01.924000+09:00" | |
| }, | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "# Specify the model\nmodel = Chain(\n Conv(5,5,1,20),\n MaxPool(),\n Conv(5,5,20,50),\n MaxPool(),\n Dense(800,500),\n Dense(500,10,identity)\n);", | |
| "execution_count": 11, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "ExecuteTime": { | |
| "start_time": "2019-07-12T09:49:15.302Z", | |
| "end_time": "2019-07-12T18:49:15.752000+09:00" | |
| }, | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "# Train the model for 100 epochs\nfunction accuracy(m::Chain, d::Knet.Data)\n _, yidx = findmax(m(d.x), dims = 1);\n yprd = [i[1] for i in yidx];\n\n return sum(yprd .== d.y) / length(d.y)\nend", | |
| "execution_count": 12, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "execution_count": 12, | |
| "data": { | |
| "text/plain": "accuracy (generic function with 1 method)" | |
| }, | |
| "metadata": {} | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": { | |
| "ExecuteTime": { | |
| "start_time": "2019-07-12T09:49:26.058Z", | |
| "end_time": "2019-07-12T18:49:26.092000+09:00" | |
| }, | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "(model::Chain)(x::AbstractMatrix) = model(reshape(x, (28, 28, 1, :)))", | |
| "execution_count": 13, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "ExecuteTime": { | |
| "start_time": "2019-07-12T09:49:39.217Z", | |
| "end_time": "2019-07-12T18:57:10.094000+09:00" | |
| }, | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "@time begin\n err = hcat(nll(model(dtrn.x), dtrn.y), nll(model(dtst.x), dtst.y))\n acc = hcat(accuracy(model, dtrn), accuracy(model, dtst))\n for (step, x) in enumerate(adam(model, repeat(dtrn, 5)))\n if step % 60 == 0\n losstrn = nll(model(dtrn.x), dtrn.y)\n losstst = nll(model(dtst.x), dtst.y)\n @info \"loss\" step losstrn losstst\n global err = vcat(err, hcat(losstrn, losstst))\n global acc = vcat(acc, hcat(accuracy(model, dtrn), accuracy(model, dtst)))\n end\n end\nend", | |
| "execution_count": 14, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "text": "┌ Info: loss\n│ step = 60\n│ losstrn = 0.5851358\n│ losstst = 0.60313046\n└ @ Main In[14]:8\n┌ Info: loss\n│ step = 120\n│ losstrn = 0.46183926\n│ losstst = 0.48129898\n└ @ Main In[14]:8\n┌ Info: loss\n│ step = 180\n│ losstrn = 0.40904018\n│ losstst = 0.42961675\n└ @ Main In[14]:8\n┌ Info: loss\n│ step = 240\n│ losstrn = 0.37732702\n│ losstst = 0.40059045\n└ @ Main In[14]:8\n┌ Info: loss\n│ step = 300\n│ losstrn = 0.34378177\n│ losstst = 0.3697316\n└ @ Main In[14]:8\n", | |
| "name": "stderr" | |
| }, | |
| { | |
| "output_type": "stream", | |
| "text": "449.968652 seconds (31.27 M allocations: 206.921 GiB, 14.73% gc time)\n", | |
| "name": "stdout" | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": { | |
| "ExecuteTime": { | |
| "start_time": "2019-07-12T09:57:15.836Z", | |
| "end_time": "2019-07-12T18:57:16.890000+09:00" | |
| }, | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "err", | |
| "execution_count": 15, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "execution_count": 15, | |
| "data": { | |
| "text/plain": "6×2 Array{Float32,2}:\n 2.30234 2.30247 \n 0.585136 0.60313 \n 0.461839 0.481299\n 0.40904 0.429617\n 0.377327 0.40059 \n 0.343782 0.369732" | |
| }, | |
| "metadata": {} | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": { | |
| "ExecuteTime": { | |
| "start_time": "2019-07-12T09:57:18.159Z", | |
| "end_time": "2019-07-12T18:57:19.024000+09:00" | |
| }, | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "acc", | |
| "execution_count": 16, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "execution_count": 16, | |
| "data": { | |
| "text/plain": "6×2 Array{Float64,2}:\n 0.0961667 0.0959\n 0.784717 0.7775\n 0.834067 0.8305\n 0.853683 0.8493\n 0.866033 0.8581\n 0.877017 0.8709" | |
| }, | |
| "metadata": {} | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": { | |
| "ExecuteTime": { | |
| "start_time": "2019-07-12T09:57:22.652Z", | |
| "end_time": "2019-07-12T18:57:24.611000+09:00" | |
| }, | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "# Save loss and accuracy to csv for visualization\nwrite(\"fmnist-error-knet.csv\", DataFrame(err, [:training, :testing]))\nwrite(\"fmnist-accuracy-knet.csv\", DataFrame(acc, [:training, :testing]))", | |
| "execution_count": 17, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "execution_count": 17, | |
| "data": { | |
| "text/plain": "\"fmnist-accuracy-knet.csv\"" | |
| }, | |
| "metadata": {} | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "# or even faster without saving the loss\n# adam!(model, repeat(dtrn, 100));", | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "ExecuteTime": { | |
| "end_time": "2019-07-11T14:52:41.909000+09:00", | |
| "start_time": "2019-07-11T05:52:41.758Z" | |
| }, | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "# Predict species\n_, trn_yidx = findmax(model(dtrn.x), dims = 1); # training set\ntrn_yprd = [i[1] for i in trn_yidx];\n\n_, tst_yidx = findmax(model(dtst.x), dims = 1); # testing set\ntst_yprd = [i[1] for i in tst_yidx];", | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "metadata": { | |
| "ExecuteTime": { | |
| "start_time": "2019-07-12T09:57:44.910Z", | |
| "end_time": "2019-07-12T18:58:00.678000+09:00" | |
| }, | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "# Check accuracy of the model\naccuracy(model, dtrn)", | |
| "execution_count": 18, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "execution_count": 18, | |
| "data": { | |
| "text/plain": "0.8770166666666667" | |
| }, | |
| "metadata": {} | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": { | |
| "ExecuteTime": { | |
| "start_time": "2019-07-12T09:58:15.555Z", | |
| "end_time": "2019-07-12T18:58:17.955000+09:00" | |
| }, | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "accuracy(model, dtst)", | |
| "execution_count": 19, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "execution_count": 19, | |
| "data": { | |
| "text/plain": "0.8709" | |
| }, | |
| "metadata": {} | |
| } | |
| ] | |
| }, | |
| { | |
| "metadata": { | |
| "trusted": true | |
| }, | |
| "cell_type": "code", | |
| "source": "", | |
| "execution_count": null, | |
| "outputs": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "name": "julia-1.1", | |
| "display_name": "Julia 1.1.1", | |
| "language": "julia" | |
| }, | |
| "language_info": { | |
| "file_extension": ".jl", | |
| "name": "julia", | |
| "mimetype": "application/julia", | |
| "version": "1.1.1" | |
| }, | |
| "gist": { | |
| "id": "", | |
| "data": { | |
| "description": "FashionMNIST_Sample", | |
| "public": true | |
| } | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment