Last active
April 20, 2022 17:10
-
-
Save staticfloat/eba2ed3b533b9c012a06228342a225f9 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
$ julia-master --project --threads=auto flux_multithreaded_training.jl | |
[ Info: Warming up Flux.train!() | |
ERROR: LoadError: AssertionError: i == j | |
Stacktrace: | |
[1] _try_static | |
@ ~/.julia/packages/SimpleChains/mTIEp/src/simple_chain.jl:213 [inlined] | |
[2] _try_static | |
@ ~/.julia/packages/SimpleChains/mTIEp/src/simple_chain.jl:231 [inlined] | |
[3] maybe_static_size_arg | |
@ ~/.julia/packages/SimpleChains/mTIEp/src/simple_chain.jl:286 [inlined] | |
[4] train_batched!(g::Vector{Float32}, p::Vector{Float32}, _chn::SimpleChain{4, Tuple{Int64}, Tuple{TurboDense{true, Static.StaticInt{8}, typeof(σ)}, TurboDense{true, Static.StaticInt{8}, typeof(σ)}, TurboDense{true, Static.StaticInt{8}, typeof(σ)}, TurboDense{true, Static.StaticInt{1}, typeof(σ)}}}, X::Vector{Matrix{Float64}}, opt::SimpleChains.ADAM, iters::Int64; batchsize::Nothing) | |
@ SimpleChains ~/.julia/packages/SimpleChains/mTIEp/src/optimize.jl:419 | |
[5] train_batched!(g::Vector{Float32}, p::Vector{Float32}, _chn::SimpleChain{4, Tuple{Int64}, Tuple{TurboDense{true, Static.StaticInt{8}, typeof(σ)}, TurboDense{true, Static.StaticInt{8}, typeof(σ)}, TurboDense{true, Static.StaticInt{8}, typeof(σ)}, TurboDense{true, Static.StaticInt{1}, typeof(σ)}}}, X::Vector{Matrix{Float64}}, opt::SimpleChains.ADAM, iters::Int64) | |
@ SimpleChains ~/.julia/packages/SimpleChains/mTIEp/src/optimize.jl:412 | |
[6] macro expansion | |
@ ~/src/surrogate_testing/flux_multithreaded_training.jl:41 [inlined] | |
[7] macro expansion | |
@ ./timing.jl:440 [inlined] | |
[8] top-level scope | |
@ ~/src/surrogate_testing/flux_multithreaded_training.jl:39 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Flux, Printf, Statistics, SimpleChains | |
function gen_model(ys) | |
model = SimpleChain(8, | |
TurboDense(σ, 8), | |
TurboDense(σ, 8), | |
TurboDense(σ, 8), | |
TurboDense(σ, 1), | |
) | |
SimpleChains.add_loss(model, SimpleChains.SquaredLoss(ys)) | |
return model | |
end | |
function gen_dataset(batch_size = 128, num_minibatches = 256) | |
# Return an array of (x, y) tuples | |
xs = [randn(8, batch_size) for _ in 1:num_minibatches] | |
ys = [randn(1, batch_size) for _ in 1:num_minibatches] | |
return xs, ys | |
end | |
function info_stats(msg, stats, num_epochs) | |
@info( | |
msg, | |
time=stats.time, | |
time_per_epoch=stats.time/num_epochs, | |
gc=@sprintf("%.1f%%", stats.gctime*100.0/stats.time), | |
allocated=Base.format_bytes(stats.bytes), | |
) | |
end | |
@info("Warming up Flux.train!()") | |
begin | |
xs, ys = gen_dataset() | |
model = gen_model(ys) | |
opt = SimpleChains.ADAM(1e-4) | |
num_epochs = 20 | |
p = SimpleChains.init_params(model) | |
g = similar(p) | |
stats = @timed begin | |
for idx in 1:num_epochs | |
SimpleChains.train_batched!(g, p, model, xs, opt, 1) | |
end | |
end | |
info_stats("First warm-up completed", stats, num_epochs) | |
stats = @timed begin | |
for idx in 1:num_epochs | |
SimpleChains.train_batched!(g, p, model, xs, opt, 1) | |
end | |
end | |
info_stats("Second warm-up completed", stats, num_epochs) | |
end | |
num_models = 32 | |
datasets = [gen_dataset() for _ in 1:num_models] | |
models = [gen_model(datasets[idx][2]) for idx in 1:num_models] | |
training_stats = [Any[] for _ in 1:num_models] | |
@warn("Beginning training with $(Threads.nthreads()) threads") | |
Threads.@threads for model_idx in 1:num_models | |
model = models[model_idx] | |
xs, ys = datasets[model_idx] | |
opt = SimpleChains.ADAM(1e-4) | |
for idx in 1:(model_idx*num_epochs) | |
push!(training_stats[model_idx], @timed begin | |
SimpleChains.train_batched!(g, p, model, xs, opt, 1) | |
end) | |
end | |
# Calculate mean statistics | |
tail_stats = training_stats[model_idx][end-num_epochs+1:end] | |
mean_stats = (; | |
time = mean(s.time for s in tail_stats), | |
gctime = mean(s.gctime for s in tail_stats), | |
bytes = mean(s.bytes for s in tail_stats), | |
) | |
info_stats("Finished model $(model_idx)", mean_stats, model_idx*num_epochs) | |
end | |
@info("Finished training") | |
using CairoMakie | |
fig = Figure() | |
ax = Axis(fig[1,1]) | |
for model_idx in 10:num_models | |
filt_len = 2*num_epochs | |
gctimes = [s.gctime for s in training_stats[model_idx]] | |
gctimes_filt = conv(gctimes[:,:,:], ones(filt_len,1,1)./filt_len)[:] | |
times = [s.time for s in training_stats[model_idx]] | |
times_filt = conv(times[:,:,:], ones(filt_len,1,1)./filt_len)[:] | |
lines!(ax, gctimes_filt.*100.0./times_filt) | |
#lines!(ax, times_filt .- gctimes_filt) | |
end | |
ax.title = "average GC time % by epoch" | |
save("training_stats.png", fig) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment