Skip to content

Instantly share code, notes, and snippets.

@briochemc
Last active July 6, 2020 10:42
Show Gist options
  • Save briochemc/5dfcf6de436bc7ba9545af6821a82969 to your computer and use it in GitHub Desktop.
Save briochemc/5dfcf6de436bc7ba9545af6821a82969 to your computer and use it in GitHub Desktop.
# An minimal example for counting the number of times a function gets called using Cassette
# Simple recursive Fibonacci function
fibo(x) = x < 3 ? 1 : fibo(x - 2) + fibo(x - 1)
# So how many times is `fibo` called when evaluating `fibo(n)` for some `n`?
n = 3
# We can use Cassette to tell us, as below:
# Load package
using Cassette
# Create type alias for the context.
# Let me try to explain the "context" terminology:
# I don't want to simply invoke `fibo(n)`.
# I want to invoke `fibo(n)` but, at the same time, do some extra stuff (count the number of calls here).
# So here I want `fibo(n)` to be executed in the "context" of doing this extra stuff.
Cassette.@context CountFibo
# Create a "prehook" to do the extra stuff before each call to `fibo`.
# Here, "prehook" means that it will execute some code *before* every call.
# The prehook requires the context `ctx` that will store the counter,
# the type of function, `typeof(fibo)`, and its arguments `args...`.
# Here we use a `Ref` for the counter.
# A `Ref` is like a pointer, pointing to a single scalar variable (the counter here).
# (Note that `Ref`s are "accessed" via `[]`.)
Cassette.prehook(ctx::CountFibo, ::typeof(fibo), args...) = ctx.metadata[] += 1
# Initiate counter at 0 by creating an instance of the counter as a `Ref`
counter = CountFibo(metadata=Ref(0))
# Run the overdubbed `fibo` with argument `n`
Cassette.@overdub(counter, fibo(n))
# Show the number of calls
counter.metadata.x
# Here I store more data as I go,
# which I could use to plot figures comparing methods
# the convergence speed of different methods
# I use the Rosenbrock classic example and
# Optim.jl as an iterative process which will
# call f multiple times to find its minimum
# and I plot the convergence vs time
using Optim, Cassette
# 2D Rosenbrock function
f(x) = (1.0 - x[1])^2 + 100.0 * (x[2] - x[1]^2)^2
# Create context
Cassette.@context BenchmarkData
# Add prehook for storing number of calls of f
Cassette.prehook(ctx::BenchmarkData, ::typeof(f), args...) = ctx.metadata.fcalls[] += 1
# Add posthook for storing values of f and time at which it was computed
function Cassette.posthook(ctx::BenchmarkData, output, ::typeof(f), args...)
push!(ctx.metadata.fcounter, ctx.metadata.fcalls[])
push!(ctx.metadata.fvalues, output)
push!(ctx.metadata.ftimer, time())
end
# It is a bit redundant here to store all of these but this is for the purpose
# of learning how I can use Cassette so it's ok, right? :)
# Initiate BenchmarkData by creating an instance of the `tape`
struct ProfileCtx
fcalls::Ref{Int64}
fcounter::Vector{Int64}
fvalues::Vector{Float64}
ftimer::Vector{Float64}
end
tape = BenchmarkData(metadata=ProfileCtx(Ref(0), [], [], []))
# Run the overdubbed optimization
x0 = [0.0, 0.0]
Cassette.@overdub(tape, optimize(f, x0))
# Run it twice because you are timing it now
# and do not want to time the compiling part!
x0 = [0.0, 0.0]
tape = BenchmarkData(metadata=ProfileCtx(Ref(0), [], [], []))
Cassette.@overdub(tape, optimize(f, x0))
# Now plot the results using `Plots.jl`
using Plots
# Convergence vs time
timer = tape.metadata.ftimer .- tape.metadata.ftimer[1]
y = tape.metadata.fvalues
p1 = plot(timer, y, yaxis = :log)
xlabel!("computing time (s)")
ylabel!("f")
# Convergence vs number of calls
counter = tape.metadata.fcounter
p2 = plot(counter, y, yaxis = :log)
xlabel!("number of f calls")
ylabel!("f")
# Combine subplotds into single figure
plot(p1, p2, layout = (2, 1))
# Below is extra stuff for later maybe (all commented as not used for now)
# from Lyndon White for recording the time spent, copied here for safeguarding until I make it work on my MWE here
# Not finished yet
#function Cassette.overdub(ctx::BenchmarkData, ::typeof(f), args...)
# local start = time()
# try
# return f(args...)
# finally
# timing = time() - start
# push!(ctx.metadata.ftimer, timing)
# end
#end
# Without-using-Cassette suggestion from Kristoffer Carlsson on slack... To test also (saving for later)
#function run()
# times = []
# function objective_function(x)
# push!(times, time())
# return x*5
# end
# optimize(objective_function, 0.5)
#end
# `objective_function` is a closure (closing over `times`).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment