Skip to content

Instantly share code, notes, and snippets.

@staticfloat
Last active April 20, 2022 17:10
Show Gist options
  • Save staticfloat/eba2ed3b533b9c012a06228342a225f9 to your computer and use it in GitHub Desktop.
Save staticfloat/eba2ed3b533b9c012a06228342a225f9 to your computer and use it in GitHub Desktop.
$ julia-master --project --threads=auto flux_multithreaded_training.jl
[ Info: Warming up Flux.train!()
ERROR: LoadError: AssertionError: i == j
Stacktrace:
[1] _try_static
@ ~/.julia/packages/SimpleChains/mTIEp/src/simple_chain.jl:213 [inlined]
[2] _try_static
@ ~/.julia/packages/SimpleChains/mTIEp/src/simple_chain.jl:231 [inlined]
[3] maybe_static_size_arg
@ ~/.julia/packages/SimpleChains/mTIEp/src/simple_chain.jl:286 [inlined]
[4] train_batched!(g::Vector{Float32}, p::Vector{Float32}, _chn::SimpleChain{4, Tuple{Int64}, Tuple{TurboDense{true, Static.StaticInt{8}, typeof(σ)}, TurboDense{true, Static.StaticInt{8}, typeof(σ)}, TurboDense{true, Static.StaticInt{8}, typeof(σ)}, TurboDense{true, Static.StaticInt{1}, typeof(σ)}}}, X::Vector{Matrix{Float64}}, opt::SimpleChains.ADAM, iters::Int64; batchsize::Nothing)
@ SimpleChains ~/.julia/packages/SimpleChains/mTIEp/src/optimize.jl:419
[5] train_batched!(g::Vector{Float32}, p::Vector{Float32}, _chn::SimpleChain{4, Tuple{Int64}, Tuple{TurboDense{true, Static.StaticInt{8}, typeof(σ)}, TurboDense{true, Static.StaticInt{8}, typeof(σ)}, TurboDense{true, Static.StaticInt{8}, typeof(σ)}, TurboDense{true, Static.StaticInt{1}, typeof(σ)}}}, X::Vector{Matrix{Float64}}, opt::SimpleChains.ADAM, iters::Int64)
@ SimpleChains ~/.julia/packages/SimpleChains/mTIEp/src/optimize.jl:412
[6] macro expansion
@ ~/src/surrogate_testing/flux_multithreaded_training.jl:41 [inlined]
[7] macro expansion
@ ./timing.jl:440 [inlined]
[8] top-level scope
@ ~/src/surrogate_testing/flux_multithreaded_training.jl:39
using Flux, Printf, Statistics, SimpleChains
function gen_model(ys)
model = SimpleChain(8,
TurboDense(σ, 8),
TurboDense(σ, 8),
TurboDense(σ, 8),
TurboDense(σ, 1),
)
SimpleChains.add_loss(model, SimpleChains.SquaredLoss(ys))
return model
end
function gen_dataset(batch_size = 128, num_minibatches = 256)
# Return an array of (x, y) tuples
xs = [randn(8, batch_size) for _ in 1:num_minibatches]
ys = [randn(1, batch_size) for _ in 1:num_minibatches]
return xs, ys
end
function info_stats(msg, stats, num_epochs)
@info(
msg,
time=stats.time,
time_per_epoch=stats.time/num_epochs,
gc=@sprintf("%.1f%%", stats.gctime*100.0/stats.time),
allocated=Base.format_bytes(stats.bytes),
)
end
@info("Warming up Flux.train!()")
begin
xs, ys = gen_dataset()
model = gen_model(ys)
opt = SimpleChains.ADAM(1e-4)
num_epochs = 20
p = SimpleChains.init_params(model)
g = similar(p)
stats = @timed begin
for idx in 1:num_epochs
SimpleChains.train_batched!(g, p, model, xs, opt, 1)
end
end
info_stats("First warm-up completed", stats, num_epochs)
stats = @timed begin
for idx in 1:num_epochs
SimpleChains.train_batched!(g, p, model, xs, opt, 1)
end
end
info_stats("Second warm-up completed", stats, num_epochs)
end
num_models = 32
datasets = [gen_dataset() for _ in 1:num_models]
models = [gen_model(datasets[idx][2]) for idx in 1:num_models]
training_stats = [Any[] for _ in 1:num_models]
@warn("Beginning training with $(Threads.nthreads()) threads")
Threads.@threads for model_idx in 1:num_models
model = models[model_idx]
xs, ys = datasets[model_idx]
opt = SimpleChains.ADAM(1e-4)
for idx in 1:(model_idx*num_epochs)
push!(training_stats[model_idx], @timed begin
SimpleChains.train_batched!(g, p, model, xs, opt, 1)
end)
end
# Calculate mean statistics
tail_stats = training_stats[model_idx][end-num_epochs+1:end]
mean_stats = (;
time = mean(s.time for s in tail_stats),
gctime = mean(s.gctime for s in tail_stats),
bytes = mean(s.bytes for s in tail_stats),
)
info_stats("Finished model $(model_idx)", mean_stats, model_idx*num_epochs)
end
@info("Finished training")
using CairoMakie
fig = Figure()
ax = Axis(fig[1,1])
for model_idx in 10:num_models
filt_len = 2*num_epochs
gctimes = [s.gctime for s in training_stats[model_idx]]
gctimes_filt = conv(gctimes[:,:,:], ones(filt_len,1,1)./filt_len)[:]
times = [s.time for s in training_stats[model_idx]]
times_filt = conv(times[:,:,:], ones(filt_len,1,1)./filt_len)[:]
lines!(ax, gctimes_filt.*100.0./times_filt)
#lines!(ax, times_filt .- gctimes_filt)
end
ax.title = "average GC time % by epoch"
save("training_stats.png", fig)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment