Recently, I tried to rewrite a MATLAB program in Julia. The program solves a PDE derived from a continuous-time economic model. I got the same result as the MATLAB program, but it was much slower. Then, I reviewed the Performance Tips of Julia and realized that the problem lied in using global variables. Typically, there are a lot of parameters in an economic model and they are typically directly defined as global variables. Then, for convenience, I wrote several functions to calculate some formulae which use these parameters. Since those functions were frequently called in a long loop, the performance is low.
To guide future programming practice, here I experiment several ways to avoid this problem.
Before digging into various ways to avoid this problem, let's first check how slow using global variables can be. To compare the computational time of different approaches, I use the BenchmarkTools
package:
using BenchmarkTools
Consider the following code
μ = 1.0
σ = 0.8
a = 0.7
f(x) = (x .+ a) ./ (μ + σ^2) .* (1-a)/a + (μ+0.5*σ^2) ./ (x.+μ.+0.5*σ^2) - log.(x.^2 .+ μ/σ*(1-a)/a)
function repeval()
for i in 1:10000
res = f(0.5)
end
end
Here, I define there parameters μ
, σ
, and a
as well as a function with argument x
and the three variables. Then, I evaluate the function in a large loop. In order to use BenchmarkTools
to obtain the time spent, I wrap the loop in a function. Then, I use the @btime
macro to calculate the time:
@btime repeval()
The result is
22.630 ms (390000 allocations: 9.61 MiB)
Next, consider the following code without using global variables, where the parameters are explicitly passed as arguments to the function:
μ = 1.0
σ = 0.8
a = 0.7
f(x, a, μ, σ) = (x .+ a) ./ (μ + σ^2) .* (1-a)/a + (μ+0.5*σ^2) ./ (x.+μ.+0.5*σ^2) - log.(x.^2 .+ μ/σ*(1-a)/a)
function repeval()
for i in 1:10000
res = f(0.5, 0.7, 1.0, 0.8)
end
end
@btime repeval()
The benchmark result is
38.799 μs (0 allocations: 0 bytes)
To avoid using global variables and directly passing them as arguments, the code becomes over 500 times faster. However, when our economic model has many parameters, it is troublesome to pass all of these parameters to each function in the code. How can we achieve high performance while at the same time avoiding directly passing parameters to each function?
There are several candidates to achieve this
- Define global constants.
- Wrap the code in a function.
- Use NamedTuple or the
Parameters
package. - Use Closure.
Instead of defining global variables, we can define global constants by prefixing the definition with the keyword const
:
const μ = 1.0
const σ = 0.8
const a = 0.7
f(x) = (x .+ a) ./ (μ + σ^2) .* (1-a)/a + (μ+0.5*σ^2) ./ (x.+μ.+0.5*σ^2) - log.(x.^2 .+ μ/σ*(1-a)/a)
function repeval()
for i in 1:10000
res = f(0.5)
end
end
@btime repeval()
The result is
38.799 μs (0 allocations: 0 bytes)
The performance is similar to passing all parameters as arguments. However, there is a drawback for this solution: Global constants cannot be redefined. Suppose I am unhappy with the value μ=1.0
and want to try μ=0.9
. If I directly change the definition to const μ = 0.9
, I will receive a warning
WARNING: redefinition of constant μ. This may fail, cause incorrect answers, or produce other errors.
and the value of μ
is still 1.0
. In order to change the value, I have to restart the Julia session, which means using
all packages once again that may take a not-very-short time.
The second solution is to wrap the code in a function, including the parameter definitions. This is like a C/C++ type program where the entry point of the code is the main()
function:
function main()
μ = 1.0
σ = 0.8
a = 0.7
f(x) = (x .+ a) ./ (μ + σ^2) .* (1-a)/a + (μ+0.5*σ^2) ./ (x.+μ.+0.5*σ^2) - log.(x.^2 .+ μ/σ*(1-a)/a)
for i in 1:10000
res = f(0.5)
end
end
# Check time to evaluate the function
@btime main()
The result is also similar to the previous solution:
38.799 μs (0 allocations: 0 bytes)
It seems that from the perspective of the f
function, the parameters are still "global". However, from Julia's perspective, as long as they are wrapped in a function, they are not global variables. The drawback of this solution is possible difficulty in debugging if you use Jupyter Notebook, because you cannot easily run the code in the function block-by-block. Yet, this is not a problem for Juno, where you can debug by stepping into the function and/or setting breakpoints.
Another slightly more laborious way is to wrap all parameters in a NamedTuple or a struct
and pass it to the function. There is also a package available called Parameters
that facilitates the process.
Using NamedTuple, we have the following code
params = (μ = 1.0, σ = 0.8, a = 0.7)
function f(x, params)
a, μ, σ = params
return (x .+ a) ./ (μ + σ^2) .* (1-a)/a + (μ+0.5*σ^2) ./ (x.+μ.+0.5*σ^2) - log.(x.^2 .+ μ/σ*(1-a)/a)
end
function repeval()
for i in 1:10000
res = f(0.5, params)
end
end
@btime repeval()
The output is
316.599 μs (10000 allocations: 156.25 KiB)
Although it is 10 times slower than avoid using global variables, it is still a bit improvement.
We have similar result using the Parameters
package:
@with_kw struct Params
μ::Float64
σ::Float64
a::Float64
end
params = Params(μ = 1.0, σ = 0.8, a = 0.7)
function f(x, params)
@unpack a, μ, σ = params
return (x .+ a) ./ (μ + σ^2) .* (1-a)/a + (μ+0.5*σ^2) ./ (x.+μ.+0.5*σ^2) - log.(x.^2 .+ μ/σ*(1-a)/a)
end
function repeval()
for i in 1:10000
res = f(0.5, params)
end
end
@btime repeval()
The output is
309.800 μs (10000 allocations: 156.25 KiB)
which is slightly faster compared with using the NamedTuple.
However, we can see the complication. Previously, I only need one line to define f
, but now I need to have an extra line to unpack parameters. There is also an extra variable to pass to the function.
The last (but not least) solution I consider is to use closure. We can define a function getfunc
that receives parameters and return a function with parameters enclosed. If we need multiple functions like f
here, we can also give getfunc
another parameter to select which function to return and use if
statement in it to return the correct function. Essentially, the function getfunc
becomes a "closure-generator" that returns the desired function with parameters enclosed (closure).
Consider the following code:
params = (μ = 1.0, σ = 0.8, a = 0.7)
function getfunc(params)
a, μ, σ = params
f(x) = (x .+ a) ./ (μ + σ^2) .* (1-a)/a + (μ+0.5*σ^2) ./ (x.+μ.+0.5*σ^2) - log.(x.^2 .+ μ/σ*(1-a)/a)
return f
end
function repeval()
f = getfunc(params)
for i in 1:10000
res = f(0.5)
end
end
@btime repeval()
The output is
295.699 μs (10001 allocations: 156.28 KiB)
which is slightly better than using NamedTuple. The drawback is that we have to define an additional function to return functions. It could also be a benefit to include functions used in the code within a single function, which makes it easier to manage.
The experiments above can be summarized as follows:
- Slowest: use global variables.
- ~75x faster: use NamedTuple,
Parameters
package, or closure. - ~500x faster: use global constants, pass parameters to functions, or wrap the code in a function
The cons of these solutions are summarized as follows:
- Use global constants:
- Cannot change parameters without restarting the session.
- Pass parameters to functions:
- Troublesome if there are a lot of functions and a lot of parameters.
- Wrap the code in a function:
- Possibly difficult to debug in Notebook.
- Use NamedTuple or
Parameters
package- Still need to pass an additional argument to the function and need extra code to unpack parameters.
- Use closure
- Need an extra function to return closures.
Which one would you choose?
I came across this post. Just in case, someone bumps into it, I just wanted to add some remarks. There are two issues:
main()
as a benchmark should be compared differentlyABOUT 1)
Compare the following code
But in
repeval
you're not includingparams
as an argument of the function. Consequently, it's as ifparams
was a global variable. So, this is how it should be writtenwhich reduces the time. Let's compare it with
main
now.** ABOUT 2)**
The function
main()
you use as benchmark already includesx=0.5
within the function. Look at the times whenx
is an argument of the function.which is practically the same as the time using NamedTuples.
MISCELANEOUS
When you have large broadcasting operations, consider two improvements. One is just purely notation
some tips to write the function in a cleaner way
The second is the use of the macro
@turbo
fromLoopVectorization.jl
, which speeds up broadcasting operations.Also, if you have a computer with multiple cores, you can use
@Threads.threads
Hope it helps!!!