Last active
July 6, 2020 10:42
-
-
Save briochemc/5dfcf6de436bc7ba9545af6821a82969 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# An minimal example for counting the number of times a function gets called using Cassette | |
# Simple recursive Fibonacci function | |
fibo(x) = x < 3 ? 1 : fibo(x - 2) + fibo(x - 1) | |
# So how many times is `fibo` called when evaluating `fibo(n)` for some `n`? | |
n = 3 | |
# We can use Cassette to tell us, as below: | |
# Load package | |
using Cassette | |
# Create type alias for the context. | |
# Let me try to explain the "context" terminology: | |
# I don't want to simply invoke `fibo(n)`. | |
# I want to invoke `fibo(n)` but, at the same time, do some extra stuff (count the number of calls here). | |
# So here I want `fibo(n)` to be executed in the "context" of doing this extra stuff. | |
Cassette.@context CountFibo | |
# Create a "prehook" to do the extra stuff before each call to `fibo`. | |
# Here, "prehook" means that it will execute some code *before* every call. | |
# The prehook requires the context `ctx` that will store the counter, | |
# the type of function, `typeof(fibo)`, and its arguments `args...`. | |
# Here we use a `Ref` for the counter. | |
# A `Ref` is like a pointer, pointing to a single scalar variable (the counter here). | |
# (Note that `Ref`s are "accessed" via `[]`.) | |
Cassette.prehook(ctx::CountFibo, ::typeof(fibo), args...) = ctx.metadata[] += 1 | |
# Initiate counter at 0 by creating an instance of the counter as a `Ref` | |
counter = CountFibo(metadata=Ref(0)) | |
# Run the overdubbed `fibo` with argument `n` | |
Cassette.@overdub(counter, fibo(n)) | |
# Show the number of calls | |
counter.metadata.x |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Here I store more data as I go, | |
# which I could use to plot figures comparing methods | |
# the convergence speed of different methods | |
# I use the Rosenbrock classic example and | |
# Optim.jl as an iterative process which will | |
# call f multiple times to find its minimum | |
# and I plot the convergence vs time | |
using Optim, Cassette | |
# 2D Rosenbrock function | |
f(x) = (1.0 - x[1])^2 + 100.0 * (x[2] - x[1]^2)^2 | |
# Create context | |
Cassette.@context BenchmarkData | |
# Add prehook for storing number of calls of f | |
Cassette.prehook(ctx::BenchmarkData, ::typeof(f), args...) = ctx.metadata.fcalls[] += 1 | |
# Add posthook for storing values of f and time at which it was computed | |
function Cassette.posthook(ctx::BenchmarkData, output, ::typeof(f), args...) | |
push!(ctx.metadata.fcounter, ctx.metadata.fcalls[]) | |
push!(ctx.metadata.fvalues, output) | |
push!(ctx.metadata.ftimer, time()) | |
end | |
# It is a bit redundant here to store all of these but this is for the purpose | |
# of learning how I can use Cassette so it's ok, right? :) | |
# Initiate BenchmarkData by creating an instance of the `tape` | |
struct ProfileCtx | |
fcalls::Ref{Int64} | |
fcounter::Vector{Int64} | |
fvalues::Vector{Float64} | |
ftimer::Vector{Float64} | |
end | |
tape = BenchmarkData(metadata=ProfileCtx(Ref(0), [], [], [])) | |
# Run the overdubbed optimization | |
x0 = [0.0, 0.0] | |
Cassette.@overdub(tape, optimize(f, x0)) | |
# Run it twice because you are timing it now | |
# and do not want to time the compiling part! | |
x0 = [0.0, 0.0] | |
tape = BenchmarkData(metadata=ProfileCtx(Ref(0), [], [], [])) | |
Cassette.@overdub(tape, optimize(f, x0)) | |
# Now plot the results using `Plots.jl` | |
using Plots | |
# Convergence vs time | |
timer = tape.metadata.ftimer .- tape.metadata.ftimer[1] | |
y = tape.metadata.fvalues | |
p1 = plot(timer, y, yaxis = :log) | |
xlabel!("computing time (s)") | |
ylabel!("f") | |
# Convergence vs number of calls | |
counter = tape.metadata.fcounter | |
p2 = plot(counter, y, yaxis = :log) | |
xlabel!("number of f calls") | |
ylabel!("f") | |
# Combine subplotds into single figure | |
plot(p1, p2, layout = (2, 1)) | |
# Below is extra stuff for later maybe (all commented as not used for now) | |
# from Lyndon White for recording the time spent, copied here for safeguarding until I make it work on my MWE here | |
# Not finished yet | |
#function Cassette.overdub(ctx::BenchmarkData, ::typeof(f), args...) | |
# local start = time() | |
# try | |
# return f(args...) | |
# finally | |
# timing = time() - start | |
# push!(ctx.metadata.ftimer, timing) | |
# end | |
#end | |
# Without-using-Cassette suggestion from Kristoffer Carlsson on slack... To test also (saving for later) | |
#function run() | |
# times = [] | |
# function objective_function(x) | |
# push!(times, time()) | |
# return x*5 | |
# end | |
# optimize(objective_function, 0.5) | |
#end | |
# `objective_function` is a closure (closing over `times`). |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment