Last active
February 3, 2018 12:03
-
-
Save antimon2/32f7d9951865f5748e7a9afbfbf556a5 to your computer and use it in GitHub Desktop.
Cifar10TrainSample.jl.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# cifar10.jl | |
module CIFAR10 | |
export CIFAR10Record, getlabel, getdata, getlabelastext | |
## Define Record Type | |
# 24584 == 3073 * 8 | |
primitive type CIFAR10Record 24584 end | |
function Base.read(stream::IO, ::Type{CIFAR10Record}) | |
bytes = read(stream, UInt8, 3073) | |
reinterpret(CIFAR10Record, bytes)[1] | |
end | |
## Show `CIFAR10Record` as Simple String representation | |
function Base.show(io::IO, record::CIFAR10Record) | |
bytes = reinterpret(UInt8, [record]) | |
## print(io, "CIFAR10Record($(repr(bytes[1])), $(repr(hash(bytes[2:end]))))") | |
print(io, "CIFAR10Record(") | |
# show 1st byte(=label) | |
show(io, bytes[1]) | |
print(io, ", ") | |
# show hashcode of the rest of bytes(=image) | |
show(io, hash(bytes[2:end])) | |
print(io, ')') | |
end | |
## Show `CIFAR10Record` as Image | |
### prepare1: CRC32 | |
const CRC32_TABLE = let poly::UInt32=0xedb88320 | |
tab = zeros(UInt32, 256) | |
for i in 0:255 | |
crc = UInt32(i) | |
for _ in 1:8 | |
if (crc & 1) == 1 | |
crc = (crc >> 1) ⊻ poly | |
else | |
crc >>= 1 | |
end | |
end | |
tab[i+1] = crc | |
end | |
tab | |
end; | |
function crc32(data::Vector{UInt8}, crc::UInt32=zero(UInt32)) | |
crc = ~crc | |
for b in data | |
crc = CRC32_TABLE[(UInt8(crc & 0xff) ⊻ b) + 1] ⊻ (crc >> 8) | |
end | |
~crc | |
end | |
### prepare2: Adler32 | |
const MOD_ADLER = UInt32(65521) | |
function adler32(data::Vector{UInt8}) | |
a = one(UInt32) | |
b = zero(UInt32) | |
l = length(data) | |
for i in 1:5550:l | |
e = min(i + 5549, l) | |
for v in data[i:e] | |
a += v | |
b += a | |
end | |
a %= MOD_ADLER | |
b %= MOD_ADLER | |
end | |
(b << 16) | a | |
end | |
### prepare3: PNG format | |
#### PNG Signature (8 bytes) | |
const PNG_SIGNATURE = b"\x89PNG\r\n\x1a\n"; | |
write_png_signature(io::IO) = write(io, PNG_SIGNATURE) | |
#### IHDR Chunk (25 bytes) | |
const IHDR_00 = b"\0\0\0\rIHDR"; | |
# True color (24bit-depth) RGB | |
const IHDR_10 = b"\b\x02\0\0\0"; | |
function write_png_ihdr(io::IO, width::Int, height::Int) | |
# write IHDR_00, width, height, IHDR_10, crc to IO | |
c = write(io, IHDR_00) | |
crc = crc32(IHDR_00[5:end]) | |
ihdrw = reinterpret(UInt8, [hton(width % UInt32)]) | |
ihdrh = reinterpret(UInt8, [hton(height % UInt32)]) | |
c += write(io, ihdrw) | |
crc = crc32(ihdrw, crc) | |
c += write(io, ihdrh) | |
crc = crc32(ihdrh, crc) | |
c += write(io, IHDR_10) | |
crc = crc32(IHDR_10, crc) | |
c += write(io, hton(crc)) | |
c | |
end | |
#### IDAT Chunk | |
const IDAT_04 = b"IDAT"; | |
# 圧縮方式+フラグ(Deflate, 圧縮レベル0) | |
const CMF_FLG = b"\b\x1d"; | |
# Deflate ブロックヘッダ(最終ブロック、無圧縮) | |
const BH = b"\x01"; | |
function write_png_idat(io::IO, img_src::AbstractArray{UInt8,3}) | |
depth, width, height = size(img_src) | |
# @assert depth == 3 | |
# write length, IDAT_04, CM_FLG, BH, LEN, NLEN, DAT, ADL, crc to IO | |
l = height * (1 + width * depth) | |
c = write(io, hton((l + 11) % UInt32)) | |
c += write(io, IDAT_04) | |
crc = crc32(IDAT_04) | |
c += write(io, CMF_FLG) | |
crc = crc32(CMF_FLG, crc) | |
c += write(io, BH) | |
crc = crc32(BH, crc) | |
LEN = htol(l % UInt16) | |
c += write(io, LEN) | |
crc = crc32(reinterpret(UInt8, [LEN]), crc) | |
NLEN = ~LEN | |
c += write(io, NLEN) | |
crc = crc32(reinterpret(UInt8, [NLEN]), crc) | |
IDAT_DAT = vec([zeros(UInt8, 1, height);reshape(img_src, :, height)]) | |
c += write(io, IDAT_DAT) | |
crc = crc32(IDAT_DAT, crc) | |
ADL = hton(adler32(IDAT_DAT)) | |
c += write(io, ADL) | |
crc = crc32(reinterpret(UInt8, [ADL]), crc) | |
c += write(io, hton(crc)) | |
c | |
end | |
#### IEND Chunk (12 bytes) | |
const IEND = b"\0\0\0\0IEND\xaeB`\x82"; | |
write_png_iend(io::IO) = write(io, IEND) | |
#### format to PNG | |
function write_png(io::IO, record::CIFAR10Record) | |
img_src = permutedims(reshape(reinterpret(UInt8, [record])[2:end], (32, 32, 3)), (3, 1, 2)) | |
c = write_png_signature(io) | |
c += write_png_ihdr(io, 32, 32) | |
c += write_png_idat(io, img_src) | |
c += write_png_iend(io) | |
c | |
end | |
### Show MIME | |
Base.mimewritable(::MIME"image/png", ::CIFAR10Record) = true | |
function Base.show(io::IO, ::MIME"image/png", record::CIFAR10Record) | |
write_png(io, record) | |
end | |
Base.mimewritable(::MIME"text/html", ::CIFAR10Record) = true | |
function Base.show(io::IO, ::MIME"text/html", record::CIFAR10Record) | |
print(io, "<img src=\"data:image/png;base64,") | |
iobuf = IOBuffer() | |
b64pipe = Base64EncodePipe(iobuf) | |
write_png(b64pipe, record) | |
write(io, read(seekstart(iobuf))) | |
print(io, "\">") | |
end | |
Base.mimewritable(::MIME"text/html", ::AbstractArray{CIFAR10Record}) = true | |
function Base.show(io::IO, mime::MIME"text/html", records::AbstractArray{CIFAR10Record}) | |
print(io, "<table>") | |
for record in records | |
print(io, "<tr><td>") | |
show(io, mime, record) | |
print(io, "</td></tr>") | |
end | |
print(io, "</table>") | |
end | |
## getter | |
getlabel(record::CIFAR10Record)::Int = Int(reinterpret(UInt8, [record])[1]) | |
getdata(record::CIFAR10Record)::Vector{UInt8} = reinterpret(UInt8, [record])[2:end] | |
const labels = String["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"] | |
getlabelastext(record::CIFAR10Record)::String = labels[getlabel(record) + 1] | |
end # module |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# cifar10_test.jl | |
using Base.Test | |
include("./cifar10.jl") | |
using CIFAR10 | |
@test isbits(CIFAR10Record) | |
@test sizeof(CIFAR10Record) == 3073 | |
# Must download and extract `cifar-10-binary.tar.gz`. | |
record0 = open("cifar-10-batches-bin/test_batch.bin", "r") do f | |
return read(f, CIFAR10Record) | |
end; | |
@test typeof(record0) == CIFAR10Record | |
@test string(record0) == "CIFAR10Record(0x03, 0xd0b45b812aae12b1)" | |
@test getlabel(record0) == 3 | |
@test getlabelastext(record0) == "cat" | |
data0 = getdata(record0) | |
@test length(data0) == 3072 | |
@test typeof(data0) == Vector{UInt8} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "## Preparation" | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-02-01T20:31:54.504000+09:00", | |
"start_time": "2018-02-01T11:31:54.238Z" | |
}, | |
"scrolled": true, | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "include(\"./cifar10.jl\")", | |
"execution_count": 1, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "CIFAR10" | |
}, | |
"execution_count": 1, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-02-01T20:31:55.096000+09:00", | |
"start_time": "2018-02-01T11:31:55.083Z" | |
}, | |
"collapsed": true, | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "using CIFAR10", | |
"execution_count": 2, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-02-01T20:31:56.088000+09:00", | |
"start_time": "2018-02-01T11:31:55.710Z" | |
}, | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "datadir = \"./cifar-10-batches-bin\"\n# train data: `data_batch_$(i).bin`\n# test data: `test_batch.bin`", | |
"execution_count": 3, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "\"./cifar-10-batches-bin\"" | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-02-01T20:31:51.865000+09:00", | |
"start_time": "2018-02-01T11:31:50.248Z" | |
}, | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "include(\"./layers.jl\")\ninclude(\"./cnnutil.jl\")\ninclude(\"./optimizer.jl\")", | |
"execution_count": 4, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "update (generic function with 2 methods)" | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "## Train Batch Reader" | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-02-01T20:31:58.310000+09:00", | |
"start_time": "2018-02-01T11:31:58.220Z" | |
}, | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "function readtrain(channel::Channel{Tuple{CIFAR10Record}}, fid::Int, datadir::String=datadir)\n if isopen(channel)\n filepath = joinpath(datadir, \"data_batch_$(fid).bin\")\n open(filepath, \"r\") do f\n while isopen(channel)\n _randidxs = shuffle(0:9999)\n for idx in _randidxs\n seek(f, idx * 3073)\n record = read(f, CIFAR10Record)\n # label = getlabel(record)\n try\n put!(channel, (record,))\n sleep(0.001) # yield to others\n catch ex\n if isa(ex, InvalidStateException)\n return\n else\n rethrow(ex)\n end\n end\n end\n end\n end\n end\nend", | |
"execution_count": 5, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "readtrain (generic function with 2 methods)" | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-02-01T20:31:59.104000+09:00", | |
"start_time": "2018-02-01T11:31:59.014Z" | |
}, | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "function train_batch_produce(channel::Channel{Tuple{CIFAR10Record}}, datadir=datadir)\n train_channels = [Channel{Tuple{CIFAR10Record}}(32) for _=1:5]\n for fid in 1:5\n @schedule readtrain(train_channels[fid], fid, datadir)\n end\n while true\n try\n put!(channel, take!(rand(train_channels)))\n catch ex\n if isa(ex, InvalidStateException)\n for ch in train_channels\n close(ch)\n end\n return\n else\n rethrow(ex)\n end\n end\n end\nend", | |
"execution_count": 6, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "train_batch_produce (generic function with 2 methods)" | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-02-01T20:32:05.753000+09:00", | |
"start_time": "2018-02-01T11:32:03.428Z" | |
}, | |
"collapsed": true, | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "struct CF10Batch\n channel::Channel{Tuple{CIFAR10Record}}\n batchsize::Int\nend\n\nfunction (f::CF10Batch)()\n buf = reshape(reinterpret(UInt8, collect(Iterators.take(f.channel, f.batchsize))), (:, f.batchsize))\n return (buf[2:3073, :], buf[1, :]) # Data and Labels\nend", | |
"execution_count": 7, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-02-01T20:31:57.519000+09:00", | |
"start_time": "2018-02-01T11:31:57.347Z" | |
}, | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "train_channel = Channel{Tuple{CIFAR10Record}}(32)\n@schedule train_batch_produce(train_channel)\ntrainbatch = CF10Batch(train_channel, 128)", | |
"execution_count": 8, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "CF10Batch(Channel{Tuple{CIFAR10.CIFAR10Record}}(sz_max:32,sz_curr:7), 128)" | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-02-01T20:32:07.296000+09:00", | |
"start_time": "2018-02-01T11:32:05.380Z" | |
}, | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "data0, labels0 = trainbatch()", | |
"execution_count": 9, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "(UInt8[0xd1 0x78 … 0xa9 0xa6; 0xd5 0x71 … 0xab 0xa5; … ; 0x34 0x4d … 0x42 0x78; 0x33 0x54 … 0x39 0x46], UInt8[0x07, 0x06, 0x03, 0x06, 0x02, 0x00, 0x02, 0x08, 0x09, 0x01 … 0x07, 0x03, 0x01, 0x00, 0x00, 0x01, 0x09, 0x07, 0x03, 0x01])" | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-02-01T20:32:07.623000+09:00", | |
"start_time": "2018-02-01T11:32:07.022Z" | |
}, | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "size(data0)", | |
"execution_count": 10, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "(3072, 128)" | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-02-01T20:32:09.227000+09:00", | |
"start_time": "2018-02-01T11:32:08.992Z" | |
}, | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "size(labels0)", | |
"execution_count": 11, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "(128,)" | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "## Inference" | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-02-01T20:32:10.633000+09:00", | |
"start_time": "2018-02-01T11:32:10.496Z" | |
}, | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "T = Float32\ninput_h = 32\ninput_w = 32\ninput_c = 3", | |
"execution_count": 12, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "3" | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "### Conv1" | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-02-01T20:32:13.006000+09:00", | |
"start_time": "2018-02-01T11:32:11.324Z" | |
}, | |
"collapsed": true, | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "fh1 = 5\nfw1 = 5\noutput_c1 = 64\ncstride1 = 1\ncpad1 = 2\nweight_init_std1 = 5f-2\n\ncW1 = weight_init_std1 .* randn(T, (fh1, fw1, input_c, output_c1))\ncb1 = zeros(T, output_c1)\nconv1lyr = Convolution(cW1, cb1, cstride1, cpad1);", | |
"execution_count": 13, | |
"outputs": [] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "### Relu1" | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-02-01T20:32:13.010000+09:00", | |
"start_time": "2018-02-01T11:32:12.388Z" | |
}, | |
"collapsed": true, | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "relu1lyr = ReluLayer{T}();", | |
"execution_count": 14, | |
"outputs": [] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "### Pool1" | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-02-01T20:32:14.112000+09:00", | |
"start_time": "2018-02-01T11:32:14.103Z" | |
}, | |
"collapsed": true, | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "ph1 = 3\npw1 = 3\npstride1 = 2\nppad1 = 1\npool1lyr = Pooling{T}(ph1, pw1, pstride1, ppad1);", | |
"execution_count": 15, | |
"outputs": [] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "### Conv2" | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-02-01T20:32:15.220000+09:00", | |
"start_time": "2018-02-01T11:32:15.212Z" | |
}, | |
"collapsed": true, | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "fh2 = 5\nfw2 = 5\ninput_c2 = output_c1\noutput_c2 = 64\ncstride2 = 1\ncpad2 = 2\nweight_init_std2 = 5f-2\n\ncW2 = weight_init_std2 .* randn(T, (fh2, fw2, input_c2, output_c2))\n# cb2 = zeros(T, output_c2)\ncb2 = fill(0.1f0, output_c2)\nconv2lyr = Convolution(cW2, cb2, cstride2, cpad2);", | |
"execution_count": 16, | |
"outputs": [] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "### Relu2" | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-02-01T20:32:16.406000+09:00", | |
"start_time": "2018-02-01T11:32:16.404Z" | |
}, | |
"collapsed": true, | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "relu2lyr = ReluLayer{T}();", | |
"execution_count": 17, | |
"outputs": [] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "### Pool2" | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-02-01T20:32:17.396000+09:00", | |
"start_time": "2018-02-01T11:32:17.393Z" | |
}, | |
"collapsed": true, | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "ph2 = 3\npw2 = 3\npstride2 = 2\nppad2 = 1\npool2lyr = Pooling{T}(ph2, pw2, pstride2, ppad2);", | |
"execution_count": 18, | |
"outputs": [] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "### FC3~5" | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-02-01T20:32:18.779000+09:00", | |
"start_time": "2018-02-01T11:32:18.571Z" | |
}, | |
"collapsed": true, | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "input_size = 8 * 8 * 64\nhidden1_size = 384\nhidden2_size = 192\noutput_size = 10\nweight_init_std = 0.04f0\nW3 = weight_init_std .* randn(T, hidden1_size, input_size)\nb3 = fill(0.1f0, hidden1_size)\nW4 = weight_init_std .* randn(T, hidden2_size, hidden1_size)\nb4 = fill(0.1f0, hidden2_size)\nW5 = (1f0/192) .* randn(T, output_size, hidden2_size)\nb5 = zeros(T, output_size);", | |
"execution_count": 19, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-02-01T20:32:19.812000+09:00", | |
"start_time": "2018-02-01T11:32:19.807Z" | |
}, | |
"collapsed": true, | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "a3lyr = AffineLayer(W3, b3)\nrelu3lyr = ReluLayer{T}();", | |
"execution_count": 20, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-02-01T20:32:20.296000+09:00", | |
"start_time": "2018-02-01T11:32:20.293Z" | |
}, | |
"collapsed": true, | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "a4lyr = AffineLayer(W4, b4)\nrelu4lyr = ReluLayer{T}();", | |
"execution_count": 21, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-02-01T20:32:20.901000+09:00", | |
"start_time": "2018-02-01T11:32:20.895Z" | |
}, | |
"collapsed": true, | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "a5lyr = AffineLayer(W5, b5)\nsoftmaxlyr = SoftmaxWithLossLayer{T,2}();", | |
"execution_count": 22, | |
"outputs": [] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "### Optimizer (Momentum)" | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-02-01T20:32:21.696000+09:00", | |
"start_time": "2018-02-01T11:32:21.344Z" | |
}, | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "opt = Momentum{T}()", | |
"execution_count": 23, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "Momentum{Float32}(0.01f0, 0.9f0)" | |
}, | |
"execution_count": 23, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "## Train" | |
}, | |
{ | |
"metadata": { | |
"collapsed": true, | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "mutable struct TrainParams\n pcW1::AbstractOptimizerParam\n pcb1::AbstractOptimizerParam\n pcW2::AbstractOptimizerParam\n pcb2::AbstractOptimizerParam\n pW3::AbstractOptimizerParam\n pb3::AbstractOptimizerParam\n pW4::AbstractOptimizerParam\n pb4::AbstractOptimizerParam\n pW5::AbstractOptimizerParam\n pb5::AbstractOptimizerParam\nend", | |
"execution_count": 24, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-02-01T20:32:30.610000+09:00", | |
"start_time": "2018-02-01T11:32:30.603Z" | |
}, | |
"collapsed": true, | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "mutable struct Trainer\n conv1lyr::Convolution\n relu1lyr::ReluLayer\n pool1lyr::Pooling\n conv2lyr::Convolution\n relu2lyr::ReluLayer\n pool2lyr::Pooling\n a3lyr::AffineLayer\n relu3lyr::ReluLayer\n a4lyr::AffineLayer\n relu4lyr::ReluLayer\n a5lyr::AffineLayer\n softmaxlyr::SoftmaxWithLossLayer\n opt::AbstractOptimizer\n params::TrainParams\n function Trainer(\n conv1lyr::Convolution,\n relu1lyr::ReluLayer,\n pool1lyr::Pooling,\n conv2lyr::Convolution,\n relu2lyr::ReluLayer,\n pool2lyr::Pooling,\n a3lyr::AffineLayer,\n relu3lyr::ReluLayer,\n a4lyr::AffineLayer,\n relu4lyr::ReluLayer,\n a5lyr::AffineLayer,\n softmaxlyr::SoftmaxWithLossLayer,\n opt::AbstractOptimizer\n )\n new(\n conv1lyr,\n relu1lyr,\n pool1lyr,\n conv2lyr,\n relu2lyr,\n pool2lyr,\n a3lyr,\n relu3lyr,\n a4lyr,\n relu4lyr,\n a5lyr,\n softmaxlyr,\n opt,\n TrainParams(\n initializeparam(opt, conv1lyr.W), # pcW1\n initializeparam(opt, conv1lyr.b), # pcb1\n initializeparam(opt, conv2lyr.W), # pcW2\n initializeparam(opt, conv2lyr.b), # pcb2\n initializeparam(opt, a3lyr.W), # pW3\n initializeparam(opt, a3lyr.b), # pb3\n initializeparam(opt, a4lyr.W), # pW4\n initializeparam(opt, a4lyr.b), # pb4\n initializeparam(opt, a5lyr.W), # pW5\n initializeparam(opt, a5lyr.b) # pb5\n )\n )\n end\nend", | |
"execution_count": 25, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-02-01T20:32:34.107000+09:00", | |
"start_time": "2018-02-01T11:32:34.016Z" | |
}, | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "function forward(trainer::Trainer, x1::AbstractArray{T,4}) where {T}\n a1 = forward(trainer.conv1lyr, x1)\n z1 = forward(trainer.relu1lyr, a1)\n p1 = forward(trainer.pool1lyr, z1)\n a2 = forward(trainer.conv2lyr, p1)\n z2 = forward(trainer.relu2lyr, a2)\n p2 = forward(trainer.pool2lyr, z2)\n a3 = forward(trainer.a3lyr, reshape(p2, (8*8*64, :)))\n z3 = forward(trainer.relu3lyr, a3)\n a4 = forward(trainer.a4lyr, z3)\n z4 = forward(trainer.relu4lyr, a4)\n a5 = forward(trainer.a5lyr, z4)\n a5\nend", | |
"execution_count": 26, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "forward (generic function with 8 methods)" | |
}, | |
"execution_count": 26, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-02-01T20:32:34.919000+09:00", | |
"start_time": "2018-02-01T11:32:34.831Z" | |
}, | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "function loss(trainer::Trainer, x1::AbstractArray{T,4}, t::AbstractArray{T,2}) where {T}\n y = forward(trainer, x1)\n forward(trainer.softmaxlyr, y, t)\nend", | |
"execution_count": 27, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "loss (generic function with 1 method)" | |
}, | |
"execution_count": 27, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-02-01T20:32:36.399000+09:00", | |
"start_time": "2018-02-01T11:32:36.312Z" | |
}, | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "function computegradient(trainer::Trainer)\n dout = one(T)\n da5 = backward(trainer.softmaxlyr, dout)\n dz4 = backward(trainer.a5lyr, da5)\n da4 = backward(trainer.relu4lyr, dz4)\n dz3 = backward(trainer.a4lyr, da4)\n da3 = backward(trainer.relu3lyr, dz3)\n dp2 = backward(trainer.a3lyr, da3)\n dz2 = backward(trainer.pool2lyr, reshape(dp2, (8, 8, 64, :)))\n da2 = backward(trainer.relu2lyr, dz2)\n dp1 = backward(trainer.conv2lyr, da2)\n dz1 = backward(trainer.pool1lyr, dp1)\n da1 = backward(trainer.relu1lyr, dz1)\n _dx = backward(trainer.conv1lyr, da1)\nend", | |
"execution_count": 28, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "computegradient (generic function with 1 method)" | |
}, | |
"execution_count": 28, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-02-01T20:32:37.390000+09:00", | |
"start_time": "2018-02-01T11:32:37.287Z" | |
}, | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "function applygradient(trainer::Trainer)\n params = trainer.params\n trainer.conv1lyr.W, params.pcW1 = update(trainer.opt, trainer.conv1lyr.W, trainer.conv1lyr.dW, params.pcW1)\n trainer.conv1lyr.b, params.pcb1 = update(trainer.opt, trainer.conv1lyr.b, trainer.conv1lyr.db, params.pcb1)\n trainer.conv2lyr.W, params.pcW2 = update(trainer.opt, trainer.conv2lyr.W, trainer.conv2lyr.dW, params.pcW2)\n trainer.conv2lyr.b, params.pcb2 = update(trainer.opt, trainer.conv2lyr.b, trainer.conv2lyr.db, params.pcb2)\n trainer.a3lyr.W, params.pW3 = update(trainer.opt, trainer.a3lyr.W, trainer.a3lyr.dW, params.pW3)\n trainer.a3lyr.b, params.pb3 = update(trainer.opt, trainer.a3lyr.b, trainer.a3lyr.db, params.pb3)\n trainer.a4lyr.W, params.pW4 = update(trainer.opt, trainer.a4lyr.W, trainer.a4lyr.dW, params.pW4)\n trainer.a4lyr.b, params.pb4 = update(trainer.opt, trainer.a4lyr.b, trainer.a4lyr.db, params.pb4)\n trainer.a5lyr.W, params.pW5 = update(trainer.opt, trainer.a5lyr.W, trainer.a5lyr.dW, params.pW5)\n trainer.a5lyr.b, params.pb5 = update(trainer.opt, trainer.a5lyr.b, trainer.a5lyr.db, params.pb5)\nend", | |
"execution_count": 29, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "applygradient (generic function with 1 method)" | |
}, | |
"execution_count": 29, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "### Evaluation" | |
}, | |
{ | |
"metadata": { | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "validate_channel = Channel{Tuple{CIFAR10Record}}(32)\n@schedule train_batch_produce(validate_channel)\nvalidatebatch = CF10Batch(validate_channel, 100)", | |
"execution_count": 30, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "CF10Batch(Channel{Tuple{CIFAR10.CIFAR10Record}}(sz_max:32,sz_curr:5), 100)" | |
}, | |
"execution_count": 30, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-02-01T20:31:58.310000+09:00", | |
"start_time": "2018-02-01T11:31:58.220Z" | |
}, | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "function readtest(channel::Channel{Tuple{CIFAR10Record}}, datadir::String=datadir)\n if isopen(channel)\n filepath = joinpath(datadir, \"test_batch.bin\")\n open(filepath, \"r\") do f\n while isopen(channel)\n seek(f, 0)\n for _ in 1:10000\n record = read(f, CIFAR10Record)\n # label = getlabel(record)\n try\n put!(channel, (record,))\n sleep(0.001) # yield to others\n catch ex\n if isa(ex, InvalidStateException)\n return\n else\n rethrow(ex)\n end\n end\n end\n end\n end\n end\nend", | |
"execution_count": 31, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "readtest (generic function with 2 methods)" | |
}, | |
"execution_count": 31, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "test_channel = Channel{Tuple{CIFAR10Record}}(32)\n@schedule readtest(test_channel)\ntestbatch = CF10Batch(test_channel, 100)", | |
"execution_count": 32, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "CF10Batch(Channel{Tuple{CIFAR10.CIFAR10Record}}(sz_max:32,sz_curr:1), 100)" | |
}, | |
"execution_count": 32, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "function calcaccuracy(trainer::Trainer, batch::CF10Batch, numsteps::Int=10)\n if numsteps < 1\n numsteps = 1\n end\n total_count = batch.batchsize * numsteps\n true_count = 0\n for step in 1:numsteps\n data, labels = batch()\n x1 = reshape(data ./ 255f0, (32, 32, 3, :))\n y = vec(mapslices(indmax, forward(trainer, x1), 1)) .- 1\n true_count += sum(y .== labels)\n end\n return true_count / total_count\nend", | |
"execution_count": 33, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "calcaccuracy (generic function with 2 methods)" | |
}, | |
"execution_count": 33, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "### Execute train" | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-02-01T20:32:38.520000+09:00", | |
"start_time": "2018-02-01T11:32:38.419Z" | |
}, | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "function train(trainer::Trainer, trainbatch::CF10Batch, numsteps::Int, \n validatebatch::CF10Batch=nothing, testbatch::CF10Batch=nothing, startstep::Int=1)\n batch_size = trainbatch.batchsize\n endstep = startstep + numsteps - 1\n for step = startstep:endstep\n data, labels = trainbatch()\n x1 = reshape(data, (32, 32, 3, :)) / 255f0\n t = zeros(T, (10, batch_size))\n for (i, labelidx) in enumerate(labels)\n t[labelidx + 1, i] = 1\n end\n ### calcurate loss (forward)\n _loss = loss(trainer, x1, t)\n if step % 10 == 0\n println(\"$(step): $(_loss)\")\n end\n ### calcurate gradient (backward)\n computegradient(trainer)\n ### apply gradient\n applygradient(trainer)\n ### periodic evaluation (check accuracy)\n if step % 100 == 0\n if validatebatch !== nothing\n train_acc = calcaccuracy(trainer, validatebatch)\n println(\"train_acc: $(train_acc)\")\n end\n if testbatch !== nothing\n test_acc = calcaccuracy(trainer, testbatch)\n println(\"test_acc: $(test_acc)\")\n end\n end\n end\nend", | |
"execution_count": 34, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "train (generic function with 4 methods)" | |
}, | |
"execution_count": 34, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-02-01T20:32:41.521000+09:00", | |
"start_time": "2018-02-01T11:32:40.105Z" | |
}, | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "trainer = Trainer(\n conv1lyr,\n relu1lyr,\n pool1lyr,\n conv2lyr,\n relu2lyr,\n pool2lyr,\n a3lyr,\n relu3lyr,\n a4lyr,\n relu4lyr,\n a5lyr,\n softmaxlyr,\n opt\n);", | |
"execution_count": 35, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-02-01T20:33:01.980000+09:00", | |
"start_time": "2018-02-01T11:32:47.506Z" | |
}, | |
"scrolled": true, | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "@time train(trainer, trainbatch, 100, validatebatch, testbatch)", | |
"execution_count": 36, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": "10: 2.3280575\n20: 2.291\n30: 2.290505\n40: 2.227149\n50: 2.172277\n60: 2.0068963\n70: 2.1987274\n80: 2.0112386\n90: 1.9883025\n100: 1.8631837\ntrain_acc: 0.271\ntest_acc: 0.281\n451.029012 seconds (11.82 M allocations: 329.287 GiB, 28.30% gc time)\n" | |
} | |
] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "#### Save the model (weights)" | |
}, | |
{ | |
"metadata": { | |
"collapsed": true, | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "# Pkg.add(\"JLD\")\nusing JLD, FileIO", | |
"execution_count": 37, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "@time train_acc = calcaccuracy(trainer, validatebatch, 100)", | |
"execution_count": 40, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": "157.848018 seconds (1.12 M allocations: 98.626 GiB, 25.84% gc time)\n" | |
}, | |
{ | |
"data": { | |
"text/plain": "0.2655" | |
}, | |
"execution_count": 40, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "@time test_acc = calcaccuracy(trainer, testbatch, 100)", | |
"execution_count": 41, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": "177.764737 seconds (568.86 k allocations: 98.600 GiB, 23.55% gc time)\n" | |
}, | |
{ | |
"data": { | |
"text/plain": "0.2805" | |
}, | |
"execution_count": 41, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"collapsed": true, | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "cW1, cb1 = conv1lyr.W, conv1lyr.b\ncW2, cb2 = conv2lyr.W, conv2lyr.b\nW3, b3 = a3lyr.W, a3lyr.b\nW4, b4 = a4lyr.W, a4lyr.b\nW5, b5 = a5lyr.W, a5lyr.b\n@save \"ckpt_sample_201802021840_0000100.jld\" cW1 cb1 cW2 cb2 W3 b3 W4 b4 W5 b5", | |
"execution_count": 38, | |
"outputs": [] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "#### (re)train" | |
}, | |
{ | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2018-02-01T20:33:01.980000+09:00", | |
"start_time": "2018-02-01T11:32:47.506Z" | |
}, | |
"scrolled": true, | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "@time train(trainer, trainbatch, 900, validatebatch, testbatch, 101)", | |
"execution_count": 43, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": "110: 1.926867\n120: 1.8587545\n130: 1.8940685\n140: 1.7620988\n150: 1.8591789\n160: 1.9014171\n170: 1.7010934\n180: 1.7665153\n190: 1.7036247\n200: 1.6929762\ntrain_acc: 0.411\ntest_acc: 0.403\n210: 1.6664252\n220: 1.7507999\n230: 1.568876\n240: 1.6144001\n250: 1.5455773\n260: 1.6112003\n270: 1.4552757\n280: 1.5788555\n290: 1.6090884\n300: 1.4770257\ntrain_acc: 0.469\ntest_acc: 0.443\n310: 1.5068269\n320: 1.4742835\n330: 1.4423931\n340: 1.4298854\n350: 1.5663092\n360: 1.4549472\n370: 1.4903312\n380: 1.4007056\n390: 1.3008981\n400: 1.4737234\ntrain_acc: 0.461\ntest_acc: 0.502\n410: 1.4489878\n420: 1.4234551\n430: 1.3521829\n440: 1.5339124\n450: 1.4638265\n460: 1.3159835\n470: 1.285465\n480: 1.3323457\n490: 1.1709098\n500: 1.4797423\ntrain_acc: 0.518\ntest_acc: 0.53\n510: 1.429904\n520: 1.4692785\n530: 1.3512571\n540: 1.2711961\n550: 1.2339783\n560: 1.3167245\n570: 1.2011247\n580: 1.2615013\n590: 1.294272\n600: 1.18332\ntrain_acc: 0.514\ntest_acc: 0.506\n610: 1.2178104\n620: 1.4519489\n630: 1.3791533\n640: 1.1602271\n650: 1.2617515\n660: 1.178378\n670: 1.2598214\n680: 1.3636239\n690: 1.139215\n700: 1.1494153\ntrain_acc: 0.57\ntest_acc: 0.531\n710: 1.2209207\n720: 1.2084265\n730: 1.1454576\n740: 1.3233155\n750: 1.120363\n760: 1.1156836\n770: 1.3812096\n780: 1.0658343\n790: 1.241617\n800: 1.1230841\ntrain_acc: 0.604\ntest_acc: 0.575\n810: 1.1207128\n820: 1.2319574\n830: 1.1797975\n840: 1.1806254\n850: 1.157985\n860: 1.0172669\n870: 1.0593016\n880: 1.0236645\n890: 1.1836785\n900: 1.2015376\ntrain_acc: 0.603\ntest_acc: 0.57\n910: 1.0798593\n920: 0.9859744\n930: 0.9686902\n940: 1.0859641\n950: 1.1518943\n960: 0.965214\n970: 1.1084701\n980: 1.0611815\n990: 1.0680001\n1000: 1.1381717\ntrain_acc: 0.592\ntest_acc: 0.554\n4058.505781 seconds (17.18 M allocations: 2.912 TiB, 31.23% gc time)\n" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "@time train_acc = calcaccuracy(trainer, validatebatch, 100)", | |
"execution_count": 44, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": "156.976842 seconds (562.53 k allocations: 99.860 GiB, 26.48% gc time)\n" | |
}, | |
{ | |
"data": { | |
"text/plain": "0.5981" | |
}, | |
"execution_count": 44, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"scrolled": true, | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "@time test_acc = calcaccuracy(trainer, testbatch, 100)", | |
"execution_count": 45, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": "175.079569 seconds (569.14 k allocations: 99.863 GiB, 24.50% gc time)\n" | |
}, | |
{ | |
"data": { | |
"text/plain": "0.5604" | |
}, | |
"execution_count": 45, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "#### Save the model (weights)" | |
}, | |
{ | |
"metadata": { | |
"collapsed": true, | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "cW1, cb1 = conv1lyr.W, conv1lyr.b\ncW2, cb2 = conv2lyr.W, conv2lyr.b\nW3, b3 = a3lyr.W, a3lyr.b\nW4, b4 = a4lyr.W, a4lyr.b\nW5, b5 = a5lyr.W, a5lyr.b\n@save \"ckpt_sample_201802022222_0001000.jld\" cW1 cb1 cW2 cb2 W3 b3 W4 b4 W5 b5", | |
"execution_count": 46, | |
"outputs": [] | |
}, | |
{ | |
"metadata": {}, | |
"cell_type": "markdown", | |
"source": "## Prediction" | |
}, | |
{ | |
"metadata": { | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "function predict_proba(trainer::Trainer, record::CIFAR10Record)\n res = forward(trainer, reshape(getdata(record) ./ 255f0, (32, 32, 3, 1)))\n softmax(vec(res))\nend", | |
"execution_count": 47, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "predict_proba (generic function with 1 method)" | |
}, | |
"execution_count": 47, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "sample_record = open(joinpath(datadir, \"test_batch.bin\")) do f\n read(f, CIFAR10Record)\nend", | |
"execution_count": 48, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "CIFAR10Record(0x03, 0xd0b45b812aae12b1)" | |
}, | |
"execution_count": 48, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "sample_softmax = predict_proba(trainer, sample_record)", | |
"execution_count": 49, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "10-element Array{Float32,1}:\n 0.0430505\n 0.0326076\n 0.0940189\n 0.285728 \n 0.0198259\n 0.247169 \n 0.0868483\n 0.0154235\n 0.158206 \n 0.0171223" | |
}, | |
"execution_count": 49, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "getlabel(sample_record)", | |
"execution_count": 50, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "3" | |
}, | |
"execution_count": 50, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "CIFAR10.labels[indmax(sample_softmax)]", | |
"execution_count": 51, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "\"cat\"" | |
}, | |
"execution_count": 51, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "getlabelastext(sample_record)", | |
"execution_count": 52, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "\"cat\"" | |
}, | |
"execution_count": 52, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"collapsed": true, | |
"trusted": false | |
}, | |
"cell_type": "code", | |
"source": "", | |
"execution_count": null, | |
"outputs": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"name": "julia-0.6", | |
"display_name": "Julia 0.6.2", | |
"language": "julia" | |
}, | |
"language_info": { | |
"file_extension": ".jl", | |
"name": "julia", | |
"mimetype": "application/julia", | |
"version": "0.6.2" | |
}, | |
"gist": { | |
"id": "", | |
"data": { | |
"description": "Cifar10TrainSample.jl.ipynb", | |
"public": true | |
} | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# cnnutil.jl | |
# require `layers.jl` | |
function zeropad(a::AbstractArray{T,N}, pad_width::NTuple{N,Tuple{Int,Int}}) where {T,N} | |
sizes = [b+p1+p2 for (b,(p1,p2))=zip(size(a),pad_width)] | |
r = zeros(T, sizes...) | |
ranges = [p1+1:p1+b for (b,(p1,_))=zip(size(a),pad_width)] | |
r[ranges...] = a | |
r | |
end | |
@inline zeropad(a::AbstractArray{T,N}, pad_width::Tuple{Int,Int}...) where {T,N} = zeropad(a, pad_width) | |
function im2col(input_data::AbstractArray{T,4}, filter_w::Int, filter_h::Int, stride::Int=1, pad::Int=0) where {T} | |
W, H, C, N = size(input_data) | |
out_h = (H + 2pad - filter_h) ÷ stride + 1 | |
out_w = (W + 2pad - filter_w) ÷ stride + 1 | |
img = pad==0 ? input_data : zeropad(input_data, (pad, pad), (pad, pad), (0, 0), (0, 0)) | |
col = zeros(T, (out_w, out_h, filter_w, filter_h, C, N)) | |
for y = 1:filter_h | |
y_max = y + stride*out_h - 1 | |
for x = 1:filter_w | |
x_max = x + stride*out_w - 1 | |
col[:, :, x, y, :, :] = img[x:stride:x_max, y:stride:y_max, :, :] | |
end | |
end | |
reshape(permutedims(col, (3, 4, 5, 1, 2, 6)), filter_w*filter_h*C, out_w*out_h*N) | |
end | |
function col2im(col::AbstractArray{T,2}, input_shape::NTuple{4,Int}, filter_h::Int, filter_w::Int, stride::Int=1, pad::Int=0) where {T} | |
W, H, C, N = input_shape | |
out_h = (H + 2pad - filter_h) ÷ stride + 1 | |
out_w = (W + 2pad - filter_w) ÷ stride + 1 | |
_col = permutedims(reshape(col, filter_w, filter_h, C, out_w, out_h, N), (4, 5, 1, 2, 3, 6)) | |
img = zeros(T, (W + 2*pad + stride - 1, H + 2*pad + stride - 1, C, N)) | |
for y = 1:filter_h | |
y_max = y + stride*out_h - 1 | |
for x = 1:filter_w | |
x_max = x + stride*out_w - 1 | |
img[x:stride:x_max, y:stride:y_max, :, :] += _col[:, :, x, y, :, :] | |
end | |
end | |
return img[pad+1:pad+W, pad+1:pad+H, :, :] | |
end | |
mutable struct Convolution{T<:AbstractFloat} <: AbstractLayer{T} | |
W::Array{T,4} | |
b::Array{T,1} | |
stride::Int | |
pad::Int | |
x::Array{T,4} | |
col::Array{T,2} | |
col_w::Array{T,2} | |
dW::Array{T,4} | |
db::Array{T,1} | |
(::Type{Convolution})( | |
W::Array{T,4}, | |
b::Array{T,1}, | |
stride::Int=1, | |
pad::Int=0) where {T} = new{T}(W, b, stride, pad) | |
end | |
function forward(self::Convolution{T}, x::AbstractArray{T,4}) where {T<:AbstractFloat} | |
FW, FH, C0, FN = size(self.W) | |
W, H, C, N = size(x) | |
@assert C0 == C | |
out_h = 1 + (H + 2*self.pad - FH) ÷ self.stride | |
out_w = 1 + (W + 2*self.pad - FW) ÷ self.stride | |
col = im2col(x, FH, FW, self.stride, self.pad) | |
col_w = reshape(self.W, (:, FN))' | |
out_ = col_w * col .+ self.b | |
out = permutedims(reshape(out_, (:, out_w, out_h, N)), (2, 3, 1, 4)) | |
self.x = x | |
self.col = col | |
self.col_w = col_w | |
return out | |
end | |
function backward(self::Convolution{T}, dout::AbstractArray{T,4}) where {T<:AbstractFloat} | |
FW, FH, C, FN = size(self.W) | |
dout_ = reshape(permutedims(dout, (3, 1, 2, 4)), (FN, :)) | |
self.db = vec(mapslices(sum, dout_, 2)) | |
dW_ = dout_ * self.col' | |
self.dW = reshape(dW_', (FW, FH, C, FN)) | |
dcol = self.col_w' * dout_ | |
dx = col2im(dcol, size(self.x), FH, FW, self.stride, self.pad) | |
return dx | |
end | |
mutable struct Pooling{T<:AbstractFloat} <: AbstractLayer{T} | |
pool_h::Int | |
pool_w::Int | |
stride::Int | |
pad::Int | |
x::Array{T,4} | |
argmax::Array{Int,1} | |
(::Type{Pooling{T}})(pool_h::Int, pool_w::Int, stride::Int=1, pad::Int=0) where {T<:AbstractFloat} = | |
new{T}(pool_h, pool_w, stride, pad) | |
end | |
function forward(self::Pooling{T}, x::AbstractArray{T,4}) where {T<:AbstractFloat} | |
W, H, C, N = size(x) | |
out_h = 1 + (H + 2*self.pad - self.pool_h) ÷ self.stride | |
out_w = 1 + (W + 2*self.pad - self.pool_w) ÷ self.stride | |
col_ = im2col(x, self.pool_h, self.pool_w, self.stride, self.pad) | |
col = reshape(col_, (self.pool_h*self.pool_w, :)) | |
self.x = x | |
out, _argmax = findmax(col, 1) | |
self.argmax = vec(_argmax) | |
return permutedims(reshape(out, (C, out_w, out_h, N)), (2, 3, 1, 4)) | |
end | |
function backward(self::Pooling{T}, dout::AbstractArray{T,4}) where {T<:AbstractFloat} | |
dout_ = permutedims(dout, (3, 1, 2, 4)) | |
pool_size = self.pool_h * self.pool_w | |
dmax = zeros(T, (pool_size, length(dout_))) | |
# dmax[argmax] .= vec(dout) | |
for (oidx, midx) in enumerate(self.argmax) | |
dmax[midx] = dout_[oidx] | |
end | |
dcol = reshape(dmax, (pool_size * size(dout_, 1), :)) | |
dx = col2im(dcol, size(self.x), self.pool_h, self.pool_w, self.stride, self.pad) | |
return dx | |
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# layers.jl | |
abstract type AbstractLayer{T<:AbstractFloat} end | |
## Relu | |
mutable struct ReluLayer{T<:AbstractFloat} <: AbstractLayer{T} | |
mask::AbstractArray{Bool} | |
(::Type{ReluLayer{T}})() where {T} = new{T}() | |
end | |
function forward(self::ReluLayer{T}, x::AbstractArray{T}) where {T<:AbstractFloat} | |
mask = self.mask = (x .<= 0) | |
out = copy(x) | |
out[mask] .= zero(T) | |
out | |
end | |
function backward(self::ReluLayer{T}, dout::AbstractArray{T}) where {T<:AbstractFloat} | |
dout[self.mask] .= zero(T) | |
dout | |
end | |
## Sigmoid | |
sigmoid(x::T) where {T<:AbstractFloat} = inv(one(T) + exp(-x)) | |
mutable struct SigmoidLayer{T<:AbstractFloat} <: AbstractLayer{T} | |
out::AbstractArray{T} | |
(::Type{SigmoidLayer{T}})() where {T} = new{T}() | |
end | |
function forward(self::SigmoidLayer{T}, x::A) where {T<:AbstractFloat, A<:AbstractArray{T}} | |
self.out = sigmoid.(x) | |
end | |
function backward(self::SigmoidLayer{T}, dout::A) where {T<:AbstractFloat, A<:AbstractArray{T}} | |
dout .* (one(T) .- self.out) .* self.out | |
end | |
### 5.6.2 バッチ版 Affine レイヤ | |
mutable struct AffineLayer{T<:AbstractFloat} <: AbstractLayer{T} | |
W::Matrix{T} | |
b::Vector{T} | |
x::AbstractArray{T} | |
dW::Matrix{T} | |
db::Vector{T} | |
function (::Type{AffineLayer})(W::Matrix{T}, b::Vector{T}) where {T} | |
new{T}(W, b) | |
end | |
end | |
function forward(self::AffineLayer{T}, x::A) where {T<:AbstractFloat, A<:AbstractArray{T}} | |
self.x = x | |
self.W * x .+ self.b | |
end | |
function backward(self::AffineLayer{T}, dout::A) where {T<:AbstractFloat, A<:AbstractArray{T}} | |
dx = self.W' * dout | |
self.dW = dout * self.x' | |
self.db = _sumvec(dout) | |
dx | |
end | |
@inline _sumvec(dout::AbstractVector{T}) where {T} = dout | |
@inline _sumvec(dout::AbstractMatrix{T}) where {T} = vec(mapslices(sum, dout, 2)) | |
@inline _sumvec(dout::AbstractArray{T,N}) where {T,N} = vec(mapslices(sum, dout, 2:N)) | |
### 5.6.3 Softmax-with-Loss レイヤ | |
function softmax(a::AbstractVector{T}) where {T<:AbstractFloat} | |
c = maximum(a) # オーバーフロー対策 | |
exp_a = exp.(a .- c) | |
exp_a ./ sum(exp_a) | |
end | |
function softmax(a::AbstractMatrix{T}) where {T<:AbstractFloat} | |
mapslices(softmax, a, 1) | |
end | |
function crossentropyerror(y::Vector{T}, t::Vector{T})::T where {T<:AbstractFloat} | |
δ = T(1f-7) # アンダーフロー対策 | |
# -sum(t .* log.(y .+ δ)) | |
-(t ⋅ log.(y .+ δ)) | |
end | |
function crossentropyerror(y::Matrix{T}, t::Matrix{T})::T where {T<:AbstractFloat} | |
batch_size = size(y, 2) | |
δ = T(1f-7) # アンダーフロー対策 | |
# -sum(t .* log(y .+ δ)) / batch_size | |
-vecdot(t, log.(y .+ δ)) / batch_size | |
end | |
function crossentropyerror(y::Matrix{T}, t::Vector{<:Integer})::T where {T<:AbstractFloat} | |
batch_size = size(y, 2) | |
δ = T(1f-7) # アンダーフロー対策 | |
-sum([log.(y[t[i]+1, i]) for i=1:batch_size] .+ δ) / batch_size | |
end | |
mutable struct SoftmaxWithLossLayer{T<:AbstractFloat,N} <: AbstractLayer{T} | |
loss::T | |
y::Array{T,N} | |
t::Array{T,N} | |
(::Type{SoftmaxWithLossLayer{T,N}})() where {T,N} = new{T,N}() | |
end | |
function forward(self::SoftmaxWithLossLayer{T,N}, x::AbstractArray{T,N}, t::AbstractArray{T,N}) where {T<:AbstractFloat,N} | |
self.t = t | |
y = self.y = softmax(x) | |
self.loss = crossentropyerror(y, t) | |
end | |
function backward(lyr::SoftmaxWithLossLayer{T}, dout::T=one(T)) where {T<:AbstractFloat} | |
dout .* _swlvec(lyr.y, lyr.t) | |
end | |
@inline _swlvec(y::AbstractArray{T}, t::AbstractVector{T}) where {T<:AbstractFloat} = y .- t | |
@inline _swlvec(y::AbstractArray{T}, t::AbstractMatrix{T}) where {T<:AbstractFloat} = (y .- t) / size(t)[2] | |
## Swish | |
#= ```https://arxiv.org/pdf/1710.05941.pdf | |
${\rm swish}(x) = x \cdot {\rm sigmod}(x)$``` | |
=# | |
mutable struct SwishLayer{T<:AbstractFloat} <: AbstractLayer{T} | |
out::AbstractArray{T} | |
ς::AbstractArray{T} # ← sigmoid | |
(::Type{SwishLayer{T}})() where {T} = new{T}() | |
end | |
function forward(self::SwishLayer{T}, x::A) where {T<:AbstractFloat, A<:AbstractArray{T}} | |
ς = self.ς = sigmoid.(x) | |
self.out = x .* ς | |
end | |
function backward(self::SwishLayer{T}, dout::A) where {T<:AbstractFloat, A<:AbstractArray{T}} | |
dout .* (self.out .+ self.ς .* (one(T) .- self.out)) | |
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# optimizer.jl | |
abstract type AbstractOptimizer{T<:AbstractFloat} end | |
abstract type AbstractOptimizerParam end | |
struct SGD{T<:AbstractFloat} <: AbstractOptimizer{T} | |
lr::T | |
(::Type{SGD{T}})(lr::T=T(0.01)) where {T<:AbstractFloat} = new{T}(lr) | |
end | |
@inline SGD(lr::T) where {T<:AbstractFloat} = SGD{T}(lr) | |
@inline (::Type{SGD{T}})(lr::AbstractFloat) where {T<:AbstractFloat} = SGD{T}(T(lr)) | |
function update(opt::SGD{T}, W::AbstractArray{T,N}, gW::AbstractArray{T,N}, param) where {T,N} | |
(W - opt.lr .* gW, param) | |
end | |
struct SGDParam <: AbstractOptimizerParam end | |
initializeparam(::AbstractOptimizer{T}, w::AbstractArray{T,N}) where {T,N} = SGDParam() | |
struct Momentum{T<:AbstractFloat} <: AbstractOptimizer{T} | |
lr::T | |
momentum::T | |
(::Type{Momentum{T}})(lr::T=T(0.01), momentum::T=T(0.9)) where {T<:AbstractFloat} = new{T}(lr, momentum) | |
end | |
@inline Momentum(lr::T, momentm::T=T(0.9)) where {T<:AbstractFloat} = Momentum{T}(lr, momentm) | |
@inline (::Type{Momentum{T}})(lr::AbstractFloat, momentum::AbstractFloat) where {T<:AbstractFloat} = Momentum{T}(T(lr), T(momentum)) | |
struct MomentumParam{T<:AbstractFloat,N} <: AbstractOptimizerParam | |
v::AbstractArray{T,N} | |
end | |
initializeparam(::Momentum{T}, w::AbstractArray{T,N}) where {T,N} = MomentumParam(zeros(w)) | |
function update(opt::Momentum{T}, W::AbstractArray{T,N}, gW::AbstractArray{T,N}, param::MomentumParam{T,N}) where {T,N} | |
new_v = opt.momentum .* param.v - opt.lr .* gW | |
(W + new_v, MomentumParam(new_v)) | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment