Created
May 20, 2020 21:17
-
-
Save caryan/f4c9afb2dfe4921ca72950bb90b5cfd0 to your computer and use it in GitHub Desktop.
Demonstrate how manual CSE can speed up ModelingToolkit ODE functions
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Explore Common Subexpression Elimination\n", | |
"\n", | |
"Many of the linear drive terms have the form of $A(t)H_{drive}$ where we have a time dependent drive amplitude $A$ multiplying the drive Hamiltonion. The same $A(t)$ will appear for every non zero element of $H_{drive}$ and so we would like calculate it once, bind the result to a variable, and then use the variable everywhere. The compiler is supposed to do this via common subexpression elimination. Let's check whether it actually works." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"using BenchmarkTools\n", | |
"using MacroTools: postwalk\n", | |
"using ModelingToolkit, OrdinaryDiffEq" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"3-element Array{Equation,1}:\n", | |
" Equation(derivative(ψ₁(t), t), cos(ω * t) * ψ₁(t))\n", | |
" Equation(derivative(ψ₂(t), t), cos(ω * t) * ψ₂(t))\n", | |
" Equation(derivative(ψ₃(t), t), cos(ω * t) * ψ₃(t))" | |
] | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# create trivial diagonal system of N equations\n", | |
"\n", | |
"N = 3\n", | |
"@parameters ω, t\n", | |
"@variables ψ[1:N](t)\n", | |
"@derivatives D'~t\n", | |
"\n", | |
"drive_amplitude = cos(ω*t)\n", | |
"\n", | |
"eqs = D.(ψ) .~ drive_amplitude.*ψ" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
":((var\"##MTIIPVar#258\", var\"##MTKArg#254\", var\"##MTKArg#255\", var\"##MTKArg#256\")->begin\n", | |
" @inbounds begin\n", | |
" let (ψ₁, ψ₂, ψ₃, ω, t) = (var\"##MTKArg#254\"[1], var\"##MTKArg#254\"[2], var\"##MTKArg#254\"[3], var\"##MTKArg#255\"[1], var\"##MTKArg#256\")\n", | |
" var\"##MTIIPVar#258\"[1] = cos(ω * t) * ψ₁\n", | |
" var\"##MTIIPVar#258\"[2] = cos(ω * t) * ψ₂\n", | |
" var\"##MTIIPVar#258\"[3] = cos(ω * t) * ψ₃\n", | |
" end\n", | |
" end\n", | |
" nothing\n", | |
" end)" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# let's look at the in-place ODE function\n", | |
"de = ODESystem(eqs)\n", | |
"generate_function(de)[2]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" 91.733 ns (0 allocations: 0 bytes)\n" | |
] | |
} | |
], | |
"source": [ | |
"# and time its performance\n", | |
"u = collect(range(0,1; length=N)); du = similar(u);\n", | |
"f = eval(generate_function(de)[2])\n", | |
"@btime f($du, $u, 5e9, 2.0)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
":((var\"##MTIIPVar#282\", var\"##MTKArg#278\", var\"##MTKArg#279\", var\"##MTKArg#280\")->begin\n", | |
" @inbounds begin\n", | |
" let (ψ₁, ψ₂, ψ₃, ω, t) = (var\"##MTKArg#278\"[1], var\"##MTKArg#278\"[2], var\"##MTKArg#278\"[3], var\"##MTKArg#279\"[1], var\"##MTKArg#280\")\n", | |
" drive_amplitude = cos(ω * t)\n", | |
" var\"##MTIIPVar#282\"[1] = drive_amplitude * ψ₁\n", | |
" var\"##MTIIPVar#282\"[2] = drive_amplitude * ψ₂\n", | |
" var\"##MTIIPVar#282\"[3] = drive_amplitude * ψ₃\n", | |
" end\n", | |
" end\n", | |
" nothing\n", | |
" end)" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# now let's do some manual CSE\n", | |
"ex = generate_function(de)[2]\n", | |
"ex = postwalk(x -> x == :(cos(ω*t)) ? :(drive_amplitude) : x, ex)\n", | |
"pushfirst!(ex.args[2].args[1].args[3].args[1].args[2].args, :(drive_amplitude = cos(ω*t)))\n", | |
"ex" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
" 46.231 ns (0 allocations: 0 bytes)\n" | |
] | |
} | |
], | |
"source": [ | |
"# and time that\n", | |
"u = collect(range(0,1; length=N)); du = similar(u);\n", | |
"f = eval(ex)\n", | |
"@btime f($du, $u, 2π*5, 2.0)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"3-element Array{Float64,1}:\n", | |
" 0.0\n", | |
" 0.5\n", | |
" 1.0" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# check that it calculated something\n", | |
"f(du, u, 2π*5, 2.0)\n", | |
"du" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Without manual CSE\n", | |
" 3.385 μs (0 allocations: 0 bytes)\n", | |
"\n", | |
" With manual CSE\n", | |
" 80.211 ns (0 allocations: 0 bytes)\n" | |
] | |
} | |
], | |
"source": [ | |
"# repeat with 100 terms to see a much bigger gap\n", | |
"\n", | |
"N = 100\n", | |
"@parameters ω, t\n", | |
"@variables ψ[1:N](t)\n", | |
"@derivatives D'~t\n", | |
"\n", | |
"drive_amplitude = cos(ω*t)\n", | |
"\n", | |
"eqs = D.(ψ) .~ drive_amplitude.*ψ\n", | |
"\n", | |
"de = ODESystem(eqs)\n", | |
"\n", | |
"println(\"Without manual CSE\")\n", | |
"u = collect(range(0,1; length=N)); du = similar(u);\n", | |
"f = eval(generate_function(de)[2])\n", | |
"@btime f($du, $u, 5e9, 2.0)\n", | |
"\n", | |
"println(\"\\n With manual CSE\")\n", | |
"ex = generate_function(de)[2]\n", | |
"ex = postwalk(x -> x == :(cos(ω*t)) ? :(drive_amplitude) : x, ex)\n", | |
"pushfirst!(ex.args[2].args[1].args[3].args[1].args[2].args, :(drive_amplitude = cos(ω*t)))\n", | |
"f = eval(ex)\n", | |
"@btime f($du, $u, 2π*5, 2.0)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Julia 1.4.0", | |
"language": "julia", | |
"name": "julia-1.4" | |
}, | |
"language_info": { | |
"file_extension": ".jl", | |
"mimetype": "application/julia", | |
"name": "julia", | |
"version": "1.4.0" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment