Skip to content

Instantly share code, notes, and snippets.

@antimon2
Last active February 3, 2018 12:03
Show Gist options
  • Save antimon2/32f7d9951865f5748e7a9afbfbf556a5 to your computer and use it in GitHub Desktop.
Save antimon2/32f7d9951865f5748e7a9afbfbf556a5 to your computer and use it in GitHub Desktop.
Cifar10TrainSample.jl.ipynb
# cifar10.jl
module CIFAR10
export CIFAR10Record, getlabel, getdata, getlabelastext
## Define Record Type
# 24584 == 3073 * 8
primitive type CIFAR10Record 24584 end
function Base.read(stream::IO, ::Type{CIFAR10Record})
bytes = read(stream, UInt8, 3073)
reinterpret(CIFAR10Record, bytes)[1]
end
## Show `CIFAR10Record` as Simple String representation
function Base.show(io::IO, record::CIFAR10Record)
bytes = reinterpret(UInt8, [record])
## print(io, "CIFAR10Record($(repr(bytes[1])), $(repr(hash(bytes[2:end]))))")
print(io, "CIFAR10Record(")
# show 1st byte(=label)
show(io, bytes[1])
print(io, ", ")
# show hashcode of the rest of bytes(=image)
show(io, hash(bytes[2:end]))
print(io, ')')
end
## Show `CIFAR10Record` as Image
### prepare1: CRC32
const CRC32_TABLE = let poly::UInt32=0xedb88320
tab = zeros(UInt32, 256)
for i in 0:255
crc = UInt32(i)
for _ in 1:8
if (crc & 1) == 1
crc = (crc >> 1) ⊻ poly
else
crc >>= 1
end
end
tab[i+1] = crc
end
tab
end;
function crc32(data::Vector{UInt8}, crc::UInt32=zero(UInt32))
crc = ~crc
for b in data
crc = CRC32_TABLE[(UInt8(crc & 0xff) ⊻ b) + 1] ⊻ (crc >> 8)
end
~crc
end
### prepare2: Adler32
const MOD_ADLER = UInt32(65521)
function adler32(data::Vector{UInt8})
a = one(UInt32)
b = zero(UInt32)
l = length(data)
for i in 1:5550:l
e = min(i + 5549, l)
for v in data[i:e]
a += v
b += a
end
a %= MOD_ADLER
b %= MOD_ADLER
end
(b << 16) | a
end
### prepare3: PNG format
#### PNG Signature (8 bytes)
const PNG_SIGNATURE = b"\x89PNG\r\n\x1a\n";
write_png_signature(io::IO) = write(io, PNG_SIGNATURE)
#### IHDR Chunk (25 bytes)
const IHDR_00 = b"\0\0\0\rIHDR";
# True color (24bit-depth) RGB
const IHDR_10 = b"\b\x02\0\0\0";
function write_png_ihdr(io::IO, width::Int, height::Int)
# write IHDR_00, width, height, IHDR_10, crc to IO
c = write(io, IHDR_00)
crc = crc32(IHDR_00[5:end])
ihdrw = reinterpret(UInt8, [hton(width % UInt32)])
ihdrh = reinterpret(UInt8, [hton(height % UInt32)])
c += write(io, ihdrw)
crc = crc32(ihdrw, crc)
c += write(io, ihdrh)
crc = crc32(ihdrh, crc)
c += write(io, IHDR_10)
crc = crc32(IHDR_10, crc)
c += write(io, hton(crc))
c
end
#### IDAT Chunk
const IDAT_04 = b"IDAT";
# 圧縮方式+フラグ(Deflate, 圧縮レベル0)
const CMF_FLG = b"\b\x1d";
# Deflate ブロックヘッダ(最終ブロック、無圧縮)
const BH = b"\x01";
function write_png_idat(io::IO, img_src::AbstractArray{UInt8,3})
depth, width, height = size(img_src)
# @assert depth == 3
# write length, IDAT_04, CM_FLG, BH, LEN, NLEN, DAT, ADL, crc to IO
l = height * (1 + width * depth)
c = write(io, hton((l + 11) % UInt32))
c += write(io, IDAT_04)
crc = crc32(IDAT_04)
c += write(io, CMF_FLG)
crc = crc32(CMF_FLG, crc)
c += write(io, BH)
crc = crc32(BH, crc)
LEN = htol(l % UInt16)
c += write(io, LEN)
crc = crc32(reinterpret(UInt8, [LEN]), crc)
NLEN = ~LEN
c += write(io, NLEN)
crc = crc32(reinterpret(UInt8, [NLEN]), crc)
IDAT_DAT = vec([zeros(UInt8, 1, height);reshape(img_src, :, height)])
c += write(io, IDAT_DAT)
crc = crc32(IDAT_DAT, crc)
ADL = hton(adler32(IDAT_DAT))
c += write(io, ADL)
crc = crc32(reinterpret(UInt8, [ADL]), crc)
c += write(io, hton(crc))
c
end
#### IEND Chunk (12 bytes)
const IEND = b"\0\0\0\0IEND\xaeB`\x82";
write_png_iend(io::IO) = write(io, IEND)
#### format to PNG
function write_png(io::IO, record::CIFAR10Record)
img_src = permutedims(reshape(reinterpret(UInt8, [record])[2:end], (32, 32, 3)), (3, 1, 2))
c = write_png_signature(io)
c += write_png_ihdr(io, 32, 32)
c += write_png_idat(io, img_src)
c += write_png_iend(io)
c
end
### Show MIME
Base.mimewritable(::MIME"image/png", ::CIFAR10Record) = true
function Base.show(io::IO, ::MIME"image/png", record::CIFAR10Record)
write_png(io, record)
end
Base.mimewritable(::MIME"text/html", ::CIFAR10Record) = true
function Base.show(io::IO, ::MIME"text/html", record::CIFAR10Record)
print(io, "<img src=\"data:image/png;base64,")
iobuf = IOBuffer()
b64pipe = Base64EncodePipe(iobuf)
write_png(b64pipe, record)
write(io, read(seekstart(iobuf)))
print(io, "\">")
end
Base.mimewritable(::MIME"text/html", ::AbstractArray{CIFAR10Record}) = true
function Base.show(io::IO, mime::MIME"text/html", records::AbstractArray{CIFAR10Record})
print(io, "<table>")
for record in records
print(io, "<tr><td>")
show(io, mime, record)
print(io, "</td></tr>")
end
print(io, "</table>")
end
## getter
getlabel(record::CIFAR10Record)::Int = Int(reinterpret(UInt8, [record])[1])
getdata(record::CIFAR10Record)::Vector{UInt8} = reinterpret(UInt8, [record])[2:end]
const labels = String["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
getlabelastext(record::CIFAR10Record)::String = labels[getlabel(record) + 1]
end # module
# cifar10_test.jl
using Base.Test
include("./cifar10.jl")
using CIFAR10
@test isbits(CIFAR10Record)
@test sizeof(CIFAR10Record) == 3073
# Must download and extract `cifar-10-binary.tar.gz`.
record0 = open("cifar-10-batches-bin/test_batch.bin", "r") do f
return read(f, CIFAR10Record)
end;
@test typeof(record0) == CIFAR10Record
@test string(record0) == "CIFAR10Record(0x03, 0xd0b45b812aae12b1)"
@test getlabel(record0) == 3
@test getlabelastext(record0) == "cat"
data0 = getdata(record0)
@test length(data0) == 3072
@test typeof(data0) == Vector{UInt8}
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {},
"cell_type": "markdown",
"source": "## Preparation"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2018-02-01T20:31:54.504000+09:00",
"start_time": "2018-02-01T11:31:54.238Z"
},
"scrolled": true,
"trusted": false
},
"cell_type": "code",
"source": "include(\"./cifar10.jl\")",
"execution_count": 1,
"outputs": [
{
"data": {
"text/plain": "CIFAR10"
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2018-02-01T20:31:55.096000+09:00",
"start_time": "2018-02-01T11:31:55.083Z"
},
"collapsed": true,
"trusted": false
},
"cell_type": "code",
"source": "using CIFAR10",
"execution_count": 2,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2018-02-01T20:31:56.088000+09:00",
"start_time": "2018-02-01T11:31:55.710Z"
},
"trusted": false
},
"cell_type": "code",
"source": "datadir = \"./cifar-10-batches-bin\"\n# train data: `data_batch_$(i).bin`\n# test data: `test_batch.bin`",
"execution_count": 3,
"outputs": [
{
"data": {
"text/plain": "\"./cifar-10-batches-bin\""
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2018-02-01T20:31:51.865000+09:00",
"start_time": "2018-02-01T11:31:50.248Z"
},
"trusted": false
},
"cell_type": "code",
"source": "include(\"./layers.jl\")\ninclude(\"./cnnutil.jl\")\ninclude(\"./optimizer.jl\")",
"execution_count": 4,
"outputs": [
{
"data": {
"text/plain": "update (generic function with 2 methods)"
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## Train Batch Reader"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2018-02-01T20:31:58.310000+09:00",
"start_time": "2018-02-01T11:31:58.220Z"
},
"trusted": false
},
"cell_type": "code",
"source": "function readtrain(channel::Channel{Tuple{CIFAR10Record}}, fid::Int, datadir::String=datadir)\n if isopen(channel)\n filepath = joinpath(datadir, \"data_batch_$(fid).bin\")\n open(filepath, \"r\") do f\n while isopen(channel)\n _randidxs = shuffle(0:9999)\n for idx in _randidxs\n seek(f, idx * 3073)\n record = read(f, CIFAR10Record)\n # label = getlabel(record)\n try\n put!(channel, (record,))\n sleep(0.001) # yield to others\n catch ex\n if isa(ex, InvalidStateException)\n return\n else\n rethrow(ex)\n end\n end\n end\n end\n end\n end\nend",
"execution_count": 5,
"outputs": [
{
"data": {
"text/plain": "readtrain (generic function with 2 methods)"
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2018-02-01T20:31:59.104000+09:00",
"start_time": "2018-02-01T11:31:59.014Z"
},
"trusted": false
},
"cell_type": "code",
"source": "function train_batch_produce(channel::Channel{Tuple{CIFAR10Record}}, datadir=datadir)\n train_channels = [Channel{Tuple{CIFAR10Record}}(32) for _=1:5]\n for fid in 1:5\n @schedule readtrain(train_channels[fid], fid, datadir)\n end\n while true\n try\n put!(channel, take!(rand(train_channels)))\n catch ex\n if isa(ex, InvalidStateException)\n for ch in train_channels\n close(ch)\n end\n return\n else\n rethrow(ex)\n end\n end\n end\nend",
"execution_count": 6,
"outputs": [
{
"data": {
"text/plain": "train_batch_produce (generic function with 2 methods)"
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2018-02-01T20:32:05.753000+09:00",
"start_time": "2018-02-01T11:32:03.428Z"
},
"collapsed": true,
"trusted": false
},
"cell_type": "code",
"source": "struct CF10Batch\n channel::Channel{Tuple{CIFAR10Record}}\n batchsize::Int\nend\n\nfunction (f::CF10Batch)()\n buf = reshape(reinterpret(UInt8, collect(Iterators.take(f.channel, f.batchsize))), (:, f.batchsize))\n return (buf[2:3073, :], buf[1, :]) # Data and Labels\nend",
"execution_count": 7,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2018-02-01T20:31:57.519000+09:00",
"start_time": "2018-02-01T11:31:57.347Z"
},
"trusted": false
},
"cell_type": "code",
"source": "train_channel = Channel{Tuple{CIFAR10Record}}(32)\n@schedule train_batch_produce(train_channel)\ntrainbatch = CF10Batch(train_channel, 128)",
"execution_count": 8,
"outputs": [
{
"data": {
"text/plain": "CF10Batch(Channel{Tuple{CIFAR10.CIFAR10Record}}(sz_max:32,sz_curr:7), 128)"
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2018-02-01T20:32:07.296000+09:00",
"start_time": "2018-02-01T11:32:05.380Z"
},
"trusted": false
},
"cell_type": "code",
"source": "data0, labels0 = trainbatch()",
"execution_count": 9,
"outputs": [
{
"data": {
"text/plain": "(UInt8[0xd1 0x78 … 0xa9 0xa6; 0xd5 0x71 … 0xab 0xa5; … ; 0x34 0x4d … 0x42 0x78; 0x33 0x54 … 0x39 0x46], UInt8[0x07, 0x06, 0x03, 0x06, 0x02, 0x00, 0x02, 0x08, 0x09, 0x01 … 0x07, 0x03, 0x01, 0x00, 0x00, 0x01, 0x09, 0x07, 0x03, 0x01])"
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2018-02-01T20:32:07.623000+09:00",
"start_time": "2018-02-01T11:32:07.022Z"
},
"trusted": false
},
"cell_type": "code",
"source": "size(data0)",
"execution_count": 10,
"outputs": [
{
"data": {
"text/plain": "(3072, 128)"
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2018-02-01T20:32:09.227000+09:00",
"start_time": "2018-02-01T11:32:08.992Z"
},
"trusted": false
},
"cell_type": "code",
"source": "size(labels0)",
"execution_count": 11,
"outputs": [
{
"data": {
"text/plain": "(128,)"
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## Inference"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2018-02-01T20:32:10.633000+09:00",
"start_time": "2018-02-01T11:32:10.496Z"
},
"trusted": false
},
"cell_type": "code",
"source": "T = Float32\ninput_h = 32\ninput_w = 32\ninput_c = 3",
"execution_count": 12,
"outputs": [
{
"data": {
"text/plain": "3"
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### Conv1"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2018-02-01T20:32:13.006000+09:00",
"start_time": "2018-02-01T11:32:11.324Z"
},
"collapsed": true,
"trusted": false
},
"cell_type": "code",
"source": "fh1 = 5\nfw1 = 5\noutput_c1 = 64\ncstride1 = 1\ncpad1 = 2\nweight_init_std1 = 5f-2\n\ncW1 = weight_init_std1 .* randn(T, (fh1, fw1, input_c, output_c1))\ncb1 = zeros(T, output_c1)\nconv1lyr = Convolution(cW1, cb1, cstride1, cpad1);",
"execution_count": 13,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### Relu1"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2018-02-01T20:32:13.010000+09:00",
"start_time": "2018-02-01T11:32:12.388Z"
},
"collapsed": true,
"trusted": false
},
"cell_type": "code",
"source": "relu1lyr = ReluLayer{T}();",
"execution_count": 14,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### Pool1"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2018-02-01T20:32:14.112000+09:00",
"start_time": "2018-02-01T11:32:14.103Z"
},
"collapsed": true,
"trusted": false
},
"cell_type": "code",
"source": "ph1 = 3\npw1 = 3\npstride1 = 2\nppad1 = 1\npool1lyr = Pooling{T}(ph1, pw1, pstride1, ppad1);",
"execution_count": 15,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### Conv2"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2018-02-01T20:32:15.220000+09:00",
"start_time": "2018-02-01T11:32:15.212Z"
},
"collapsed": true,
"trusted": false
},
"cell_type": "code",
"source": "fh2 = 5\nfw2 = 5\ninput_c2 = output_c1\noutput_c2 = 64\ncstride2 = 1\ncpad2 = 2\nweight_init_std2 = 5f-2\n\ncW2 = weight_init_std2 .* randn(T, (fh2, fw2, input_c2, output_c2))\n# cb2 = zeros(T, output_c2)\ncb2 = fill(0.1f0, output_c2)\nconv2lyr = Convolution(cW2, cb2, cstride2, cpad2);",
"execution_count": 16,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### Relu2"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2018-02-01T20:32:16.406000+09:00",
"start_time": "2018-02-01T11:32:16.404Z"
},
"collapsed": true,
"trusted": false
},
"cell_type": "code",
"source": "relu2lyr = ReluLayer{T}();",
"execution_count": 17,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### Pool2"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2018-02-01T20:32:17.396000+09:00",
"start_time": "2018-02-01T11:32:17.393Z"
},
"collapsed": true,
"trusted": false
},
"cell_type": "code",
"source": "ph2 = 3\npw2 = 3\npstride2 = 2\nppad2 = 1\npool2lyr = Pooling{T}(ph2, pw2, pstride2, ppad2);",
"execution_count": 18,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### FC3~5"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2018-02-01T20:32:18.779000+09:00",
"start_time": "2018-02-01T11:32:18.571Z"
},
"collapsed": true,
"trusted": false
},
"cell_type": "code",
"source": "input_size = 8 * 8 * 64\nhidden1_size = 384\nhidden2_size = 192\noutput_size = 10\nweight_init_std = 0.04f0\nW3 = weight_init_std .* randn(T, hidden1_size, input_size)\nb3 = fill(0.1f0, hidden1_size)\nW4 = weight_init_std .* randn(T, hidden2_size, hidden1_size)\nb4 = fill(0.1f0, hidden2_size)\nW5 = (1f0/192) .* randn(T, output_size, hidden2_size)\nb5 = zeros(T, output_size);",
"execution_count": 19,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2018-02-01T20:32:19.812000+09:00",
"start_time": "2018-02-01T11:32:19.807Z"
},
"collapsed": true,
"trusted": false
},
"cell_type": "code",
"source": "a3lyr = AffineLayer(W3, b3)\nrelu3lyr = ReluLayer{T}();",
"execution_count": 20,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2018-02-01T20:32:20.296000+09:00",
"start_time": "2018-02-01T11:32:20.293Z"
},
"collapsed": true,
"trusted": false
},
"cell_type": "code",
"source": "a4lyr = AffineLayer(W4, b4)\nrelu4lyr = ReluLayer{T}();",
"execution_count": 21,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2018-02-01T20:32:20.901000+09:00",
"start_time": "2018-02-01T11:32:20.895Z"
},
"collapsed": true,
"trusted": false
},
"cell_type": "code",
"source": "a5lyr = AffineLayer(W5, b5)\nsoftmaxlyr = SoftmaxWithLossLayer{T,2}();",
"execution_count": 22,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### Optimizer (Momentum)"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2018-02-01T20:32:21.696000+09:00",
"start_time": "2018-02-01T11:32:21.344Z"
},
"trusted": false
},
"cell_type": "code",
"source": "opt = Momentum{T}()",
"execution_count": 23,
"outputs": [
{
"data": {
"text/plain": "Momentum{Float32}(0.01f0, 0.9f0)"
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## Train"
},
{
"metadata": {
"collapsed": true,
"trusted": false
},
"cell_type": "code",
"source": "mutable struct TrainParams\n pcW1::AbstractOptimizerParam\n pcb1::AbstractOptimizerParam\n pcW2::AbstractOptimizerParam\n pcb2::AbstractOptimizerParam\n pW3::AbstractOptimizerParam\n pb3::AbstractOptimizerParam\n pW4::AbstractOptimizerParam\n pb4::AbstractOptimizerParam\n pW5::AbstractOptimizerParam\n pb5::AbstractOptimizerParam\nend",
"execution_count": 24,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2018-02-01T20:32:30.610000+09:00",
"start_time": "2018-02-01T11:32:30.603Z"
},
"collapsed": true,
"trusted": false
},
"cell_type": "code",
"source": "mutable struct Trainer\n conv1lyr::Convolution\n relu1lyr::ReluLayer\n pool1lyr::Pooling\n conv2lyr::Convolution\n relu2lyr::ReluLayer\n pool2lyr::Pooling\n a3lyr::AffineLayer\n relu3lyr::ReluLayer\n a4lyr::AffineLayer\n relu4lyr::ReluLayer\n a5lyr::AffineLayer\n softmaxlyr::SoftmaxWithLossLayer\n opt::AbstractOptimizer\n params::TrainParams\n function Trainer(\n conv1lyr::Convolution,\n relu1lyr::ReluLayer,\n pool1lyr::Pooling,\n conv2lyr::Convolution,\n relu2lyr::ReluLayer,\n pool2lyr::Pooling,\n a3lyr::AffineLayer,\n relu3lyr::ReluLayer,\n a4lyr::AffineLayer,\n relu4lyr::ReluLayer,\n a5lyr::AffineLayer,\n softmaxlyr::SoftmaxWithLossLayer,\n opt::AbstractOptimizer\n )\n new(\n conv1lyr,\n relu1lyr,\n pool1lyr,\n conv2lyr,\n relu2lyr,\n pool2lyr,\n a3lyr,\n relu3lyr,\n a4lyr,\n relu4lyr,\n a5lyr,\n softmaxlyr,\n opt,\n TrainParams(\n initializeparam(opt, conv1lyr.W), # pcW1\n initializeparam(opt, conv1lyr.b), # pcb1\n initializeparam(opt, conv2lyr.W), # pcW2\n initializeparam(opt, conv2lyr.b), # pcb2\n initializeparam(opt, a3lyr.W), # pW3\n initializeparam(opt, a3lyr.b), # pb3\n initializeparam(opt, a4lyr.W), # pW4\n initializeparam(opt, a4lyr.b), # pb4\n initializeparam(opt, a5lyr.W), # pW5\n initializeparam(opt, a5lyr.b) # pb5\n )\n )\n end\nend",
"execution_count": 25,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2018-02-01T20:32:34.107000+09:00",
"start_time": "2018-02-01T11:32:34.016Z"
},
"trusted": false
},
"cell_type": "code",
"source": "function forward(trainer::Trainer, x1::AbstractArray{T,4}) where {T}\n a1 = forward(trainer.conv1lyr, x1)\n z1 = forward(trainer.relu1lyr, a1)\n p1 = forward(trainer.pool1lyr, z1)\n a2 = forward(trainer.conv2lyr, p1)\n z2 = forward(trainer.relu2lyr, a2)\n p2 = forward(trainer.pool2lyr, z2)\n a3 = forward(trainer.a3lyr, reshape(p2, (8*8*64, :)))\n z3 = forward(trainer.relu3lyr, a3)\n a4 = forward(trainer.a4lyr, z3)\n z4 = forward(trainer.relu4lyr, a4)\n a5 = forward(trainer.a5lyr, z4)\n a5\nend",
"execution_count": 26,
"outputs": [
{
"data": {
"text/plain": "forward (generic function with 8 methods)"
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2018-02-01T20:32:34.919000+09:00",
"start_time": "2018-02-01T11:32:34.831Z"
},
"trusted": false
},
"cell_type": "code",
"source": "function loss(trainer::Trainer, x1::AbstractArray{T,4}, t::AbstractArray{T,2}) where {T}\n y = forward(trainer, x1)\n forward(trainer.softmaxlyr, y, t)\nend",
"execution_count": 27,
"outputs": [
{
"data": {
"text/plain": "loss (generic function with 1 method)"
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2018-02-01T20:32:36.399000+09:00",
"start_time": "2018-02-01T11:32:36.312Z"
},
"trusted": false
},
"cell_type": "code",
"source": "function computegradient(trainer::Trainer)\n dout = one(T)\n da5 = backward(trainer.softmaxlyr, dout)\n dz4 = backward(trainer.a5lyr, da5)\n da4 = backward(trainer.relu4lyr, dz4)\n dz3 = backward(trainer.a4lyr, da4)\n da3 = backward(trainer.relu3lyr, dz3)\n dp2 = backward(trainer.a3lyr, da3)\n dz2 = backward(trainer.pool2lyr, reshape(dp2, (8, 8, 64, :)))\n da2 = backward(trainer.relu2lyr, dz2)\n dp1 = backward(trainer.conv2lyr, da2)\n dz1 = backward(trainer.pool1lyr, dp1)\n da1 = backward(trainer.relu1lyr, dz1)\n _dx = backward(trainer.conv1lyr, da1)\nend",
"execution_count": 28,
"outputs": [
{
"data": {
"text/plain": "computegradient (generic function with 1 method)"
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2018-02-01T20:32:37.390000+09:00",
"start_time": "2018-02-01T11:32:37.287Z"
},
"trusted": false
},
"cell_type": "code",
"source": "function applygradient(trainer::Trainer)\n params = trainer.params\n trainer.conv1lyr.W, params.pcW1 = update(trainer.opt, trainer.conv1lyr.W, trainer.conv1lyr.dW, params.pcW1)\n trainer.conv1lyr.b, params.pcb1 = update(trainer.opt, trainer.conv1lyr.b, trainer.conv1lyr.db, params.pcb1)\n trainer.conv2lyr.W, params.pcW2 = update(trainer.opt, trainer.conv2lyr.W, trainer.conv2lyr.dW, params.pcW2)\n trainer.conv2lyr.b, params.pcb2 = update(trainer.opt, trainer.conv2lyr.b, trainer.conv2lyr.db, params.pcb2)\n trainer.a3lyr.W, params.pW3 = update(trainer.opt, trainer.a3lyr.W, trainer.a3lyr.dW, params.pW3)\n trainer.a3lyr.b, params.pb3 = update(trainer.opt, trainer.a3lyr.b, trainer.a3lyr.db, params.pb3)\n trainer.a4lyr.W, params.pW4 = update(trainer.opt, trainer.a4lyr.W, trainer.a4lyr.dW, params.pW4)\n trainer.a4lyr.b, params.pb4 = update(trainer.opt, trainer.a4lyr.b, trainer.a4lyr.db, params.pb4)\n trainer.a5lyr.W, params.pW5 = update(trainer.opt, trainer.a5lyr.W, trainer.a5lyr.dW, params.pW5)\n trainer.a5lyr.b, params.pb5 = update(trainer.opt, trainer.a5lyr.b, trainer.a5lyr.db, params.pb5)\nend",
"execution_count": 29,
"outputs": [
{
"data": {
"text/plain": "applygradient (generic function with 1 method)"
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### Evaluation"
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "validate_channel = Channel{Tuple{CIFAR10Record}}(32)\n@schedule train_batch_produce(validate_channel)\nvalidatebatch = CF10Batch(validate_channel, 100)",
"execution_count": 30,
"outputs": [
{
"data": {
"text/plain": "CF10Batch(Channel{Tuple{CIFAR10.CIFAR10Record}}(sz_max:32,sz_curr:5), 100)"
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2018-02-01T20:31:58.310000+09:00",
"start_time": "2018-02-01T11:31:58.220Z"
},
"trusted": false
},
"cell_type": "code",
"source": "function readtest(channel::Channel{Tuple{CIFAR10Record}}, datadir::String=datadir)\n if isopen(channel)\n filepath = joinpath(datadir, \"test_batch.bin\")\n open(filepath, \"r\") do f\n while isopen(channel)\n seek(f, 0)\n for _ in 1:10000\n record = read(f, CIFAR10Record)\n # label = getlabel(record)\n try\n put!(channel, (record,))\n sleep(0.001) # yield to others\n catch ex\n if isa(ex, InvalidStateException)\n return\n else\n rethrow(ex)\n end\n end\n end\n end\n end\n end\nend",
"execution_count": 31,
"outputs": [
{
"data": {
"text/plain": "readtest (generic function with 2 methods)"
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "test_channel = Channel{Tuple{CIFAR10Record}}(32)\n@schedule readtest(test_channel)\ntestbatch = CF10Batch(test_channel, 100)",
"execution_count": 32,
"outputs": [
{
"data": {
"text/plain": "CF10Batch(Channel{Tuple{CIFAR10.CIFAR10Record}}(sz_max:32,sz_curr:1), 100)"
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "function calcaccuracy(trainer::Trainer, batch::CF10Batch, numsteps::Int=10)\n if numsteps < 1\n numsteps = 1\n end\n total_count = batch.batchsize * numsteps\n true_count = 0\n for step in 1:numsteps\n data, labels = batch()\n x1 = reshape(data ./ 255f0, (32, 32, 3, :))\n y = vec(mapslices(indmax, forward(trainer, x1), 1)) .- 1\n true_count += sum(y .== labels)\n end\n return true_count / total_count\nend",
"execution_count": 33,
"outputs": [
{
"data": {
"text/plain": "calcaccuracy (generic function with 2 methods)"
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### Execute train"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2018-02-01T20:32:38.520000+09:00",
"start_time": "2018-02-01T11:32:38.419Z"
},
"trusted": false
},
"cell_type": "code",
"source": "function train(trainer::Trainer, trainbatch::CF10Batch, numsteps::Int, \n validatebatch::CF10Batch=nothing, testbatch::CF10Batch=nothing, startstep::Int=1)\n batch_size = trainbatch.batchsize\n endstep = startstep + numsteps - 1\n for step = startstep:endstep\n data, labels = trainbatch()\n x1 = reshape(data, (32, 32, 3, :)) / 255f0\n t = zeros(T, (10, batch_size))\n for (i, labelidx) in enumerate(labels)\n t[labelidx + 1, i] = 1\n end\n ### calcurate loss (forward)\n _loss = loss(trainer, x1, t)\n if step % 10 == 0\n println(\"$(step): $(_loss)\")\n end\n ### calcurate gradient (backward)\n computegradient(trainer)\n ### apply gradient\n applygradient(trainer)\n ### periodic evaluation (check accuracy)\n if step % 100 == 0\n if validatebatch !== nothing\n train_acc = calcaccuracy(trainer, validatebatch)\n println(\"train_acc: $(train_acc)\")\n end\n if testbatch !== nothing\n test_acc = calcaccuracy(trainer, testbatch)\n println(\"test_acc: $(test_acc)\")\n end\n end\n end\nend",
"execution_count": 34,
"outputs": [
{
"data": {
"text/plain": "train (generic function with 4 methods)"
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2018-02-01T20:32:41.521000+09:00",
"start_time": "2018-02-01T11:32:40.105Z"
},
"trusted": false
},
"cell_type": "code",
"source": "trainer = Trainer(\n conv1lyr,\n relu1lyr,\n pool1lyr,\n conv2lyr,\n relu2lyr,\n pool2lyr,\n a3lyr,\n relu3lyr,\n a4lyr,\n relu4lyr,\n a5lyr,\n softmaxlyr,\n opt\n);",
"execution_count": 35,
"outputs": []
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2018-02-01T20:33:01.980000+09:00",
"start_time": "2018-02-01T11:32:47.506Z"
},
"scrolled": true,
"trusted": false
},
"cell_type": "code",
"source": "@time train(trainer, trainbatch, 100, validatebatch, testbatch)",
"execution_count": 36,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": "10: 2.3280575\n20: 2.291\n30: 2.290505\n40: 2.227149\n50: 2.172277\n60: 2.0068963\n70: 2.1987274\n80: 2.0112386\n90: 1.9883025\n100: 1.8631837\ntrain_acc: 0.271\ntest_acc: 0.281\n451.029012 seconds (11.82 M allocations: 329.287 GiB, 28.30% gc time)\n"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "#### Save the model (weights)"
},
{
"metadata": {
"collapsed": true,
"trusted": false
},
"cell_type": "code",
"source": "# Pkg.add(\"JLD\")\nusing JLD, FileIO",
"execution_count": 37,
"outputs": []
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "@time train_acc = calcaccuracy(trainer, validatebatch, 100)",
"execution_count": 40,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": "157.848018 seconds (1.12 M allocations: 98.626 GiB, 25.84% gc time)\n"
},
{
"data": {
"text/plain": "0.2655"
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "@time test_acc = calcaccuracy(trainer, testbatch, 100)",
"execution_count": 41,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": "177.764737 seconds (568.86 k allocations: 98.600 GiB, 23.55% gc time)\n"
},
{
"data": {
"text/plain": "0.2805"
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"collapsed": true,
"trusted": false
},
"cell_type": "code",
"source": "cW1, cb1 = conv1lyr.W, conv1lyr.b\ncW2, cb2 = conv2lyr.W, conv2lyr.b\nW3, b3 = a3lyr.W, a3lyr.b\nW4, b4 = a4lyr.W, a4lyr.b\nW5, b5 = a5lyr.W, a5lyr.b\n@save \"ckpt_sample_201802021840_0000100.jld\" cW1 cb1 cW2 cb2 W3 b3 W4 b4 W5 b5",
"execution_count": 38,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "#### (re)train"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2018-02-01T20:33:01.980000+09:00",
"start_time": "2018-02-01T11:32:47.506Z"
},
"scrolled": true,
"trusted": false
},
"cell_type": "code",
"source": "@time train(trainer, trainbatch, 900, validatebatch, testbatch, 101)",
"execution_count": 43,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": "110: 1.926867\n120: 1.8587545\n130: 1.8940685\n140: 1.7620988\n150: 1.8591789\n160: 1.9014171\n170: 1.7010934\n180: 1.7665153\n190: 1.7036247\n200: 1.6929762\ntrain_acc: 0.411\ntest_acc: 0.403\n210: 1.6664252\n220: 1.7507999\n230: 1.568876\n240: 1.6144001\n250: 1.5455773\n260: 1.6112003\n270: 1.4552757\n280: 1.5788555\n290: 1.6090884\n300: 1.4770257\ntrain_acc: 0.469\ntest_acc: 0.443\n310: 1.5068269\n320: 1.4742835\n330: 1.4423931\n340: 1.4298854\n350: 1.5663092\n360: 1.4549472\n370: 1.4903312\n380: 1.4007056\n390: 1.3008981\n400: 1.4737234\ntrain_acc: 0.461\ntest_acc: 0.502\n410: 1.4489878\n420: 1.4234551\n430: 1.3521829\n440: 1.5339124\n450: 1.4638265\n460: 1.3159835\n470: 1.285465\n480: 1.3323457\n490: 1.1709098\n500: 1.4797423\ntrain_acc: 0.518\ntest_acc: 0.53\n510: 1.429904\n520: 1.4692785\n530: 1.3512571\n540: 1.2711961\n550: 1.2339783\n560: 1.3167245\n570: 1.2011247\n580: 1.2615013\n590: 1.294272\n600: 1.18332\ntrain_acc: 0.514\ntest_acc: 0.506\n610: 1.2178104\n620: 1.4519489\n630: 1.3791533\n640: 1.1602271\n650: 1.2617515\n660: 1.178378\n670: 1.2598214\n680: 1.3636239\n690: 1.139215\n700: 1.1494153\ntrain_acc: 0.57\ntest_acc: 0.531\n710: 1.2209207\n720: 1.2084265\n730: 1.1454576\n740: 1.3233155\n750: 1.120363\n760: 1.1156836\n770: 1.3812096\n780: 1.0658343\n790: 1.241617\n800: 1.1230841\ntrain_acc: 0.604\ntest_acc: 0.575\n810: 1.1207128\n820: 1.2319574\n830: 1.1797975\n840: 1.1806254\n850: 1.157985\n860: 1.0172669\n870: 1.0593016\n880: 1.0236645\n890: 1.1836785\n900: 1.2015376\ntrain_acc: 0.603\ntest_acc: 0.57\n910: 1.0798593\n920: 0.9859744\n930: 0.9686902\n940: 1.0859641\n950: 1.1518943\n960: 0.965214\n970: 1.1084701\n980: 1.0611815\n990: 1.0680001\n1000: 1.1381717\ntrain_acc: 0.592\ntest_acc: 0.554\n4058.505781 seconds (17.18 M allocations: 2.912 TiB, 31.23% gc time)\n"
}
]
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "@time train_acc = calcaccuracy(trainer, validatebatch, 100)",
"execution_count": 44,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": "156.976842 seconds (562.53 k allocations: 99.860 GiB, 26.48% gc time)\n"
},
{
"data": {
"text/plain": "0.5981"
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"scrolled": true,
"trusted": false
},
"cell_type": "code",
"source": "@time test_acc = calcaccuracy(trainer, testbatch, 100)",
"execution_count": 45,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": "175.079569 seconds (569.14 k allocations: 99.863 GiB, 24.50% gc time)\n"
},
{
"data": {
"text/plain": "0.5604"
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "#### Save the model (weights)"
},
{
"metadata": {
"collapsed": true,
"trusted": false
},
"cell_type": "code",
"source": "cW1, cb1 = conv1lyr.W, conv1lyr.b\ncW2, cb2 = conv2lyr.W, conv2lyr.b\nW3, b3 = a3lyr.W, a3lyr.b\nW4, b4 = a4lyr.W, a4lyr.b\nW5, b5 = a5lyr.W, a5lyr.b\n@save \"ckpt_sample_201802022222_0001000.jld\" cW1 cb1 cW2 cb2 W3 b3 W4 b4 W5 b5",
"execution_count": 46,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## Prediction"
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "function predict_proba(trainer::Trainer, record::CIFAR10Record)\n res = forward(trainer, reshape(getdata(record) ./ 255f0, (32, 32, 3, 1)))\n softmax(vec(res))\nend",
"execution_count": 47,
"outputs": [
{
"data": {
"text/plain": "predict_proba (generic function with 1 method)"
},
"execution_count": 47,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "sample_record = open(joinpath(datadir, \"test_batch.bin\")) do f\n read(f, CIFAR10Record)\nend",
"execution_count": 48,
"outputs": [
{
"data": {
"text/plain": "CIFAR10Record(0x03, 0xd0b45b812aae12b1)"
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "sample_softmax = predict_proba(trainer, sample_record)",
"execution_count": 49,
"outputs": [
{
"data": {
"text/plain": "10-element Array{Float32,1}:\n 0.0430505\n 0.0326076\n 0.0940189\n 0.285728 \n 0.0198259\n 0.247169 \n 0.0868483\n 0.0154235\n 0.158206 \n 0.0171223"
},
"execution_count": 49,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "getlabel(sample_record)",
"execution_count": 50,
"outputs": [
{
"data": {
"text/plain": "3"
},
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "CIFAR10.labels[indmax(sample_softmax)]",
"execution_count": 51,
"outputs": [
{
"data": {
"text/plain": "\"cat\""
},
"execution_count": 51,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"trusted": false
},
"cell_type": "code",
"source": "getlabelastext(sample_record)",
"execution_count": 52,
"outputs": [
{
"data": {
"text/plain": "\"cat\""
},
"execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"metadata": {
"collapsed": true,
"trusted": false
},
"cell_type": "code",
"source": "",
"execution_count": null,
"outputs": []
}
],
"metadata": {
"kernelspec": {
"name": "julia-0.6",
"display_name": "Julia 0.6.2",
"language": "julia"
},
"language_info": {
"file_extension": ".jl",
"name": "julia",
"mimetype": "application/julia",
"version": "0.6.2"
},
"gist": {
"id": "",
"data": {
"description": "Cifar10TrainSample.jl.ipynb",
"public": true
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
# cnnutil.jl
# require `layers.jl`
function zeropad(a::AbstractArray{T,N}, pad_width::NTuple{N,Tuple{Int,Int}}) where {T,N}
sizes = [b+p1+p2 for (b,(p1,p2))=zip(size(a),pad_width)]
r = zeros(T, sizes...)
ranges = [p1+1:p1+b for (b,(p1,_))=zip(size(a),pad_width)]
r[ranges...] = a
r
end
@inline zeropad(a::AbstractArray{T,N}, pad_width::Tuple{Int,Int}...) where {T,N} = zeropad(a, pad_width)
function im2col(input_data::AbstractArray{T,4}, filter_w::Int, filter_h::Int, stride::Int=1, pad::Int=0) where {T}
W, H, C, N = size(input_data)
out_h = (H + 2pad - filter_h) ÷ stride + 1
out_w = (W + 2pad - filter_w) ÷ stride + 1
img = pad==0 ? input_data : zeropad(input_data, (pad, pad), (pad, pad), (0, 0), (0, 0))
col = zeros(T, (out_w, out_h, filter_w, filter_h, C, N))
for y = 1:filter_h
y_max = y + stride*out_h - 1
for x = 1:filter_w
x_max = x + stride*out_w - 1
col[:, :, x, y, :, :] = img[x:stride:x_max, y:stride:y_max, :, :]
end
end
reshape(permutedims(col, (3, 4, 5, 1, 2, 6)), filter_w*filter_h*C, out_w*out_h*N)
end
function col2im(col::AbstractArray{T,2}, input_shape::NTuple{4,Int}, filter_h::Int, filter_w::Int, stride::Int=1, pad::Int=0) where {T}
W, H, C, N = input_shape
out_h = (H + 2pad - filter_h) ÷ stride + 1
out_w = (W + 2pad - filter_w) ÷ stride + 1
_col = permutedims(reshape(col, filter_w, filter_h, C, out_w, out_h, N), (4, 5, 1, 2, 3, 6))
img = zeros(T, (W + 2*pad + stride - 1, H + 2*pad + stride - 1, C, N))
for y = 1:filter_h
y_max = y + stride*out_h - 1
for x = 1:filter_w
x_max = x + stride*out_w - 1
img[x:stride:x_max, y:stride:y_max, :, :] += _col[:, :, x, y, :, :]
end
end
return img[pad+1:pad+W, pad+1:pad+H, :, :]
end
mutable struct Convolution{T<:AbstractFloat} <: AbstractLayer{T}
W::Array{T,4}
b::Array{T,1}
stride::Int
pad::Int
x::Array{T,4}
col::Array{T,2}
col_w::Array{T,2}
dW::Array{T,4}
db::Array{T,1}
(::Type{Convolution})(
W::Array{T,4},
b::Array{T,1},
stride::Int=1,
pad::Int=0) where {T} = new{T}(W, b, stride, pad)
end
function forward(self::Convolution{T}, x::AbstractArray{T,4}) where {T<:AbstractFloat}
FW, FH, C0, FN = size(self.W)
W, H, C, N = size(x)
@assert C0 == C
out_h = 1 + (H + 2*self.pad - FH) ÷ self.stride
out_w = 1 + (W + 2*self.pad - FW) ÷ self.stride
col = im2col(x, FH, FW, self.stride, self.pad)
col_w = reshape(self.W, (:, FN))'
out_ = col_w * col .+ self.b
out = permutedims(reshape(out_, (:, out_w, out_h, N)), (2, 3, 1, 4))
self.x = x
self.col = col
self.col_w = col_w
return out
end
function backward(self::Convolution{T}, dout::AbstractArray{T,4}) where {T<:AbstractFloat}
FW, FH, C, FN = size(self.W)
dout_ = reshape(permutedims(dout, (3, 1, 2, 4)), (FN, :))
self.db = vec(mapslices(sum, dout_, 2))
dW_ = dout_ * self.col'
self.dW = reshape(dW_', (FW, FH, C, FN))
dcol = self.col_w' * dout_
dx = col2im(dcol, size(self.x), FH, FW, self.stride, self.pad)
return dx
end
mutable struct Pooling{T<:AbstractFloat} <: AbstractLayer{T}
pool_h::Int
pool_w::Int
stride::Int
pad::Int
x::Array{T,4}
argmax::Array{Int,1}
(::Type{Pooling{T}})(pool_h::Int, pool_w::Int, stride::Int=1, pad::Int=0) where {T<:AbstractFloat} =
new{T}(pool_h, pool_w, stride, pad)
end
function forward(self::Pooling{T}, x::AbstractArray{T,4}) where {T<:AbstractFloat}
W, H, C, N = size(x)
out_h = 1 + (H + 2*self.pad - self.pool_h) ÷ self.stride
out_w = 1 + (W + 2*self.pad - self.pool_w) ÷ self.stride
col_ = im2col(x, self.pool_h, self.pool_w, self.stride, self.pad)
col = reshape(col_, (self.pool_h*self.pool_w, :))
self.x = x
out, _argmax = findmax(col, 1)
self.argmax = vec(_argmax)
return permutedims(reshape(out, (C, out_w, out_h, N)), (2, 3, 1, 4))
end
function backward(self::Pooling{T}, dout::AbstractArray{T,4}) where {T<:AbstractFloat}
dout_ = permutedims(dout, (3, 1, 2, 4))
pool_size = self.pool_h * self.pool_w
dmax = zeros(T, (pool_size, length(dout_)))
# dmax[argmax] .= vec(dout)
for (oidx, midx) in enumerate(self.argmax)
dmax[midx] = dout_[oidx]
end
dcol = reshape(dmax, (pool_size * size(dout_, 1), :))
dx = col2im(dcol, size(self.x), self.pool_h, self.pool_w, self.stride, self.pad)
return dx
end
# layers.jl
abstract type AbstractLayer{T<:AbstractFloat} end
## Relu
mutable struct ReluLayer{T<:AbstractFloat} <: AbstractLayer{T}
mask::AbstractArray{Bool}
(::Type{ReluLayer{T}})() where {T} = new{T}()
end
function forward(self::ReluLayer{T}, x::AbstractArray{T}) where {T<:AbstractFloat}
mask = self.mask = (x .<= 0)
out = copy(x)
out[mask] .= zero(T)
out
end
function backward(self::ReluLayer{T}, dout::AbstractArray{T}) where {T<:AbstractFloat}
dout[self.mask] .= zero(T)
dout
end
## Sigmoid
sigmoid(x::T) where {T<:AbstractFloat} = inv(one(T) + exp(-x))
mutable struct SigmoidLayer{T<:AbstractFloat} <: AbstractLayer{T}
out::AbstractArray{T}
(::Type{SigmoidLayer{T}})() where {T} = new{T}()
end
function forward(self::SigmoidLayer{T}, x::A) where {T<:AbstractFloat, A<:AbstractArray{T}}
self.out = sigmoid.(x)
end
function backward(self::SigmoidLayer{T}, dout::A) where {T<:AbstractFloat, A<:AbstractArray{T}}
dout .* (one(T) .- self.out) .* self.out
end
### 5.6.2 バッチ版 Affine レイヤ
mutable struct AffineLayer{T<:AbstractFloat} <: AbstractLayer{T}
W::Matrix{T}
b::Vector{T}
x::AbstractArray{T}
dW::Matrix{T}
db::Vector{T}
function (::Type{AffineLayer})(W::Matrix{T}, b::Vector{T}) where {T}
new{T}(W, b)
end
end
function forward(self::AffineLayer{T}, x::A) where {T<:AbstractFloat, A<:AbstractArray{T}}
self.x = x
self.W * x .+ self.b
end
function backward(self::AffineLayer{T}, dout::A) where {T<:AbstractFloat, A<:AbstractArray{T}}
dx = self.W' * dout
self.dW = dout * self.x'
self.db = _sumvec(dout)
dx
end
@inline _sumvec(dout::AbstractVector{T}) where {T} = dout
@inline _sumvec(dout::AbstractMatrix{T}) where {T} = vec(mapslices(sum, dout, 2))
@inline _sumvec(dout::AbstractArray{T,N}) where {T,N} = vec(mapslices(sum, dout, 2:N))
### 5.6.3 Softmax-with-Loss レイヤ
function softmax(a::AbstractVector{T}) where {T<:AbstractFloat}
c = maximum(a) # オーバーフロー対策
exp_a = exp.(a .- c)
exp_a ./ sum(exp_a)
end
function softmax(a::AbstractMatrix{T}) where {T<:AbstractFloat}
mapslices(softmax, a, 1)
end
function crossentropyerror(y::Vector{T}, t::Vector{T})::T where {T<:AbstractFloat}
δ = T(1f-7) # アンダーフロー対策
# -sum(t .* log.(y .+ δ))
-(t ⋅ log.(y .+ δ))
end
function crossentropyerror(y::Matrix{T}, t::Matrix{T})::T where {T<:AbstractFloat}
batch_size = size(y, 2)
δ = T(1f-7) # アンダーフロー対策
# -sum(t .* log(y .+ δ)) / batch_size
-vecdot(t, log.(y .+ δ)) / batch_size
end
function crossentropyerror(y::Matrix{T}, t::Vector{<:Integer})::T where {T<:AbstractFloat}
batch_size = size(y, 2)
δ = T(1f-7) # アンダーフロー対策
-sum([log.(y[t[i]+1, i]) for i=1:batch_size] .+ δ) / batch_size
end
mutable struct SoftmaxWithLossLayer{T<:AbstractFloat,N} <: AbstractLayer{T}
loss::T
y::Array{T,N}
t::Array{T,N}
(::Type{SoftmaxWithLossLayer{T,N}})() where {T,N} = new{T,N}()
end
function forward(self::SoftmaxWithLossLayer{T,N}, x::AbstractArray{T,N}, t::AbstractArray{T,N}) where {T<:AbstractFloat,N}
self.t = t
y = self.y = softmax(x)
self.loss = crossentropyerror(y, t)
end
function backward(lyr::SoftmaxWithLossLayer{T}, dout::T=one(T)) where {T<:AbstractFloat}
dout .* _swlvec(lyr.y, lyr.t)
end
@inline _swlvec(y::AbstractArray{T}, t::AbstractVector{T}) where {T<:AbstractFloat} = y .- t
@inline _swlvec(y::AbstractArray{T}, t::AbstractMatrix{T}) where {T<:AbstractFloat} = (y .- t) / size(t)[2]
## Swish
#= ```https://arxiv.org/pdf/1710.05941.pdf
${\rm swish}(x) = x \cdot {\rm sigmod}(x)$```
=#
mutable struct SwishLayer{T<:AbstractFloat} <: AbstractLayer{T}
out::AbstractArray{T}
ς::AbstractArray{T} # ← sigmoid
(::Type{SwishLayer{T}})() where {T} = new{T}()
end
function forward(self::SwishLayer{T}, x::A) where {T<:AbstractFloat, A<:AbstractArray{T}}
ς = self.ς = sigmoid.(x)
self.out = x .* ς
end
function backward(self::SwishLayer{T}, dout::A) where {T<:AbstractFloat, A<:AbstractArray{T}}
dout .* (self.out .+ self.ς .* (one(T) .- self.out))
end
# optimizer.jl
abstract type AbstractOptimizer{T<:AbstractFloat} end
abstract type AbstractOptimizerParam end
struct SGD{T<:AbstractFloat} <: AbstractOptimizer{T}
lr::T
(::Type{SGD{T}})(lr::T=T(0.01)) where {T<:AbstractFloat} = new{T}(lr)
end
@inline SGD(lr::T) where {T<:AbstractFloat} = SGD{T}(lr)
@inline (::Type{SGD{T}})(lr::AbstractFloat) where {T<:AbstractFloat} = SGD{T}(T(lr))
function update(opt::SGD{T}, W::AbstractArray{T,N}, gW::AbstractArray{T,N}, param) where {T,N}
(W - opt.lr .* gW, param)
end
struct SGDParam <: AbstractOptimizerParam end
initializeparam(::AbstractOptimizer{T}, w::AbstractArray{T,N}) where {T,N} = SGDParam()
struct Momentum{T<:AbstractFloat} <: AbstractOptimizer{T}
lr::T
momentum::T
(::Type{Momentum{T}})(lr::T=T(0.01), momentum::T=T(0.9)) where {T<:AbstractFloat} = new{T}(lr, momentum)
end
@inline Momentum(lr::T, momentm::T=T(0.9)) where {T<:AbstractFloat} = Momentum{T}(lr, momentm)
@inline (::Type{Momentum{T}})(lr::AbstractFloat, momentum::AbstractFloat) where {T<:AbstractFloat} = Momentum{T}(T(lr), T(momentum))
struct MomentumParam{T<:AbstractFloat,N} <: AbstractOptimizerParam
v::AbstractArray{T,N}
end
initializeparam(::Momentum{T}, w::AbstractArray{T,N}) where {T,N} = MomentumParam(zeros(w))
function update(opt::Momentum{T}, W::AbstractArray{T,N}, gW::AbstractArray{T,N}, param::MomentumParam{T,N}) where {T,N}
new_v = opt.momentum .* param.v - opt.lr .* gW
(W + new_v, MomentumParam(new_v))
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment