Created
August 20, 2020 10:32
-
-
Save natema/a5fe8da406e9fe1e51127bceec4740c4 to your computer and use it in GitHub Desktop.
Resnet18 on CIFAR10 in Julia
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Flux | |
using Flux: @functor | |
function conv3x3(in_planes, out_planes; stride=1) | |
Conv((3,3), in_planes => out_planes; pad=1, stride=stride) | |
end | |
function conv1x1(in_planes, out_planes; stride=1) | |
Conv((1,1), in_planes => out_planes; pad=0, stride=stride) | |
end | |
struct BasicBlock | |
conv1::Conv | |
bn1::BatchNorm | |
conv2::Conv | |
bn2::BatchNorm | |
end | |
@functor BasicBlock | |
function BasicBlock(inplanes, planes; stride=1, base_width=64, dilation=1) | |
BasicBlock( | |
conv3x3(inplanes, planes; stride=stride), | |
BatchNorm(planes, relu), | |
conv3x3(planes, planes), | |
BatchNorm(planes) | |
) | |
end | |
function (m::BasicBlock)(x) | |
y = x |> m.conv1 |> m.bn1 |> m.conv2 |> m.bn2 | |
relu.(x + y) | |
end | |
struct BasicBlockDownsample | |
conv1::Conv | |
bn1::BatchNorm | |
conv2::Conv | |
bn2::BatchNorm | |
downsample::Chain | |
end | |
@functor BasicBlockDownsample | |
function BasicBlockDownsample(inplanes, planes; stride=1, base_width=64, dilation=1, downsample) | |
BasicBlockDownsample( | |
conv3x3(inplanes, planes; stride=stride), | |
BatchNorm(planes, relu), | |
conv3x3(planes, planes), | |
BatchNorm(planes), | |
downsample | |
) | |
end | |
function (m::BasicBlockDownsample)(x) | |
y = x |> m.conv1 |> m.bn1 |> m.conv2 |> m.bn2 | |
relu.(m.downsample(x) + y) | |
end | |
# ignore Bottlenect module | |
function make_layer(inplanes::Int, planes::Int, blocks::Int; base_width=64, stride=1) | |
downsample=nothing | |
layers = Any[] | |
if stride > 1 | |
downsample = Chain(conv1x1(inplanes, planes; stride=stride), | |
BatchNorm(planes)) | |
push!(layers, BasicBlockDownsample(inplanes, planes; stride=stride, base_width=base_width, downsample=downsample)) | |
else | |
push!(layers, BasicBlock(inplanes, planes; stride=stride, base_width=base_width)) | |
end | |
for i in 2:blocks | |
push!(layers, BasicBlock(planes, planes; base_width=base_width)) | |
end | |
Chain(layers...) | |
end | |
Base.@kwdef struct ResNet | |
conv1::Conv | |
bn1::BatchNorm | |
maxpool::MaxPool | |
layer1::Chain | |
layer2::Chain | |
layer3::Chain | |
layer4::Chain | |
avgpool::GlobalMeanPool | |
fc::Dense | |
end | |
@functor ResNet | |
function ResNet(layers::Vector{Int}, num_classes::Int) | |
ResNet(; | |
conv1 = Conv((7,7), 3 => 64; stride=2, pad=3), | |
bn1 = BatchNorm(64, relu), | |
maxpool = MaxPool((3, 3); stride=2, pad=1), | |
layer1 = make_layer(64, 64, layers[1]), | |
layer2 = make_layer(64, 128, layers[2]; stride=2), | |
layer3 = make_layer(128, 256, layers[3]; stride=2), | |
layer4 = make_layer(256, 512, layers[4]; stride=2), | |
avgpool = GlobalMeanPool(), | |
fc = Dense(512, num_classes) | |
) | |
end | |
function(m::ResNet)(x) | |
x |> m.conv1 |> m.bn1 |> m.maxpool |> m.layer1 |> m.layer2 |> m.layer3 |> m.layer4 |> | |
m.avgpool |> flatten |> m.fc | |
end | |
function Flux.testmode!(m::ResNet, mode = true) | |
for m in [m.bn1, m.layer1, m.layer2, m.layer3, m.layer4] | |
Flux.testmode!(m, mode) | |
end | |
end | |
function Flux.testmode!(m::BasicBlock, mode = true) | |
Flux.testmode!(m.bn1, mode) | |
Flux.testmode!(m.bn2, mode) | |
end | |
function Flux.testmode!(m::BasicBlockDownsample, mode = true) | |
Flux.testmode!(m.bn1, mode) | |
Flux.testmode!(m.bn2, mode) | |
Flux.testmode!(m.downsample, mode) # testmode! for a BN layer which is located on Chain | |
# downsample_bn = m.downsample.layers[2] | |
# Flux.testmode!(downsample_bn) | |
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@info "Loading libraries" | |
@info " Loading Flux" | |
using Flux | |
using Statistics | |
using Flux: onehotbatch, crossentropy, Momentum, update!, onecold | |
@info " Loading MLDatasets" | |
using MLDatasets: CIFAR10 | |
using Base.Iterators: partition | |
batchsize = 1000 | |
trainsize = 50000 - batchsize | |
@info "Loading training data" | |
trainimgs = CIFAR10.traintensor(Float32); | |
trainlabels = onehotbatch(CIFAR10.trainlabels(Float32) .+ 1, 1:10); | |
@info "Building the trainset" | |
trainset = [(trainimgs[:,:,:,i], trainlabels[:,i]) for i in partition(1:trainsize, batchsize)]; | |
batchnum = size(trainset)[1] | |
@info "Loading validation data" | |
valset = (trainsize+1):(trainsize+batchsize) | |
valX = trainimgs[:,:,:,valset] |> gpu; | |
valY = trainlabels[:, valset] |> gpu; | |
loss(x, y) = sum(crossentropy(m(x), y)) | |
opt = Momentum(0.01) | |
max_pred(x) = [findmax(m(x[:,:,:,i:i]))[2][1] for i in 1:(size(x)[4])] |> gpu | |
max_lab(y) = [findmax(y[:,i])[2] for i in 1:(size(y)[2])] |> gpu | |
accuracy(x, y) = mean(max_pred(x) .== max_lab(y)) |> gpu | |
@info "Loading the model" | |
include("yiyu-resnet.jl") | |
m = ResNet([2,2,2,2], 10) |> gpu; #ResNet18 | |
epochs = 10 | |
for epoch = 1:epochs | |
@info "epoch" epoch | |
for i in 1:batchnum | |
batch = trainset[i] |> gpu | |
gs = gradient(params(m)) do | |
l = loss(batch...) | |
end | |
@info "batch fraction" i/batchnum | |
update!(opt, params(m), gs) | |
end | |
@show accuracy(valX, valY) | |
end | |
@info "Loading test data" | |
testimgs = CIFAR10.testtensor(Float32); | |
testlabels = onehotbatch(CIFAR10.testlabels(Float32) .+ 1, 1:10); | |
testset = [(testimgs[:,:,:,i], testlabels[:,i]) for i in partition(1:10000, batchsize)] |> gpu; | |
class_correct = zeros(10) | |
class_total = zeros(10) | |
for i in 1:(10000/batchsize) | |
@info "Evaluating testset batch " i | |
preds = m(testset[i][1]) | |
lab = testset[i][2] | |
for j = 1:batchsize | |
pred_class = findmax(preds[:, j])[2] | |
actual_class = findmax(lab[:, j])[2] | |
if pred_class == actual_class | |
class_correct[pred_class] += 1 | |
end | |
class_total[actual_class] += 1 | |
end | |
end | |
class_correct ./ class_total |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
yiyu-resnet.jl is copied from @yiyuezhuo's resnet.jl.