The Jax developers optimized a differential equation benchmark in this issue which used DiffEqFlux.jl as a performance baseline. The Julia code from there was updated to include some standard performance tricks and is the benchmark code here. Thus both codes have been optimized by the library developers.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -------------------------------------------------------------------------------- | |
# Ewing model | |
# translation by: [email protected] (July 2020) | |
# -------------------------------------------------------------------------------- | |
using DifferentialEquations | |
using Plots | |
using LabelledArrays | |
using StaticArrays |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using OrdinaryDiffEq | |
using Plots | |
using Flux, DiffEqFlux, Optim | |
function lotka_volterra(du,u,p,t) | |
x, y = u | |
α, β, δ, γ = p | |
du[1] = dx = α*x - β*x*y | |
du[2] = dy = -δ*y + γ*x*y | |
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
retcode: Success | |
Interpolation: 3rd order Hermite | |
t: [0.0, 0.5, 1.0] | |
u: Vector{Num}[[x0, y0], [x0 + 0.08333333333333333((1.5x0) + (1.5(x0 + (0.5((1.5(x0 + (0.25((1.5(x0 + (0.25((1.5x0) - (x0*y0))))) - ((x0 + (0.25((1.5x0) - (x0*y0))))*(y0 + (0.25((x0*y0) - (3y0))))))))) - ((x0 + (0.25((1.5(x0 + (0.25((1.5x0) - (x0*y0))))) - ((x0 + (0.25((1.5x0) - (x0*y0))))*(y0 + (0.25((x0*y0) - (3y0))))))))*(y0 + (0.25(((x0 + (0.25((1.5x0) - (x0*y0))))*(y0 + (0.25((x0*y0) - (3y0))))) - (3(y0 + (0.25((x0*y0) - (3y0))))))))))))) + (2((1.5(x0 + (0.25((1.5x0) - (x0*y0))))) + (1.5(x0 + (0.25((1.5(x0 + (0.25((1.5x0) - (x0*y0))))) - ((x0 + (0.25((1.5x0) - (x0*y0))))*(y0 + (0.25((x0*y0) - (3y0))))))))) - ((x0 + (0.25((1.5x0) - (x0*y0))))*(y0 + (0.25((x0*y0) - (3y0))))) - ((x0 + (0.25((1.5(x0 + (0.25((1.5x0) - (x0*y0))))) - ((x0 + (0.25((1.5x0) - (x0*y0))))*(y0 + (0.25((x0*y0) - (3y0))))))))*(y0 + (0.25(((x0 + (0.25((1.5x0) - (x0*y0))))*(y0 + (0.25((x0*y0) - (3y0))))) - (3(y0 + (0.25((x0*y0) - (3y0))))))))))) - (x0*y0) - ((x0 + (0.5( |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Cassette, DiffRules | |
using Core: CodeInfo, SlotNumber, SSAValue, ReturnNode, GotoIfNot | |
const printbranch = true | |
Cassette.@context HasBranchingCtx | |
function Cassette.overdub(ctx::HasBranchingCtx, f, args...) | |
if Cassette.canrecurse(ctx, f, args...) | |
return Cassette.recurse(ctx, f, args...) |
Brusselator Stiff Partial Differential Equation Benchmark: Julia DifferentialEquations.jl vs Python SciPy
Tested is DifferentialEquations.jl vs Python's SciPy ODE solvers. Notes:
- Stiff ODE solvers are used since they are required to solve this problem effectively.
- The Python code is vectorized with for maximum performance
- All of the performance features are tried: automatic sparsity detection, preconditioners, etc.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<!DOCTYPE html> | |
<HTML lang = "en"> | |
<HEAD> | |
<meta charset="UTF-8"/> | |
<meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=yes"> | |
<title>Forward and Reverse Automatic Differentiation In A Nutshell</title> | |
<script type="text/x-mathjax-config"> | |
MathJax.Hub.Config({ |
This file has been truncated, but you can view the full file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
function (ˍ₋out, ˍ₋arg1, ˍ₋arg2, t) | |
#= C:\Users\accou\.julia\packages\SymbolicUtils\v2ZkM\src\code.jl:349 =# | |
#= C:\Users\accou\.julia\packages\SymbolicUtils\v2ZkM\src\code.jl:350 =# | |
#= C:\Users\accou\.julia\packages\SymbolicUtils\v2ZkM\src\code.jl:351 =# | |
begin | |
begin | |
#= C:\Users\accou\.julia\packages\Symbolics\vQXbU\src\build_function.jl:452 =# | |
#= C:\Users\accou\.julia\packages\SymbolicUtils\v2ZkM\src\code.jl:398 =# @inbounds begin | |
#= C:\Users\accou\.julia\packages\SymbolicUtils\v2ZkM\src\code.jl:394 =# |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
ERROR: MethodError: Cannot `convert` an object of type Float32 to an object of type Vector{Float32} | |
Closest candidates are: | |
convert(::Type{Array{T, N}}, ::SizedArray{S, T, N, N, Array{T, N}}) where {S, T, N} at C:\Users\accou\.julia\packages\StaticArrays\0T5rI\src\SizedArray.jl:121 | |
convert(::Type{Array{T, N}}, ::SizedArray{S, T, N, M, TData} where {M, TData<:AbstractArray{T, M}}) where {T, S, N} at C:\Users\accou\.julia\packages\StaticArrays\0T5rI\src\SizedArray.jl:115 | |
convert(::Type{<:Array}, ::LabelledArrays.LArray) at C:\Users\accou\.julia\packages\LabelledArrays\lfn1b\src\larray.jl:133 | |
... | |
Stacktrace: | |
[1] setproperty!(x::OrdinaryDiffEq.ODEIntegrator{Tsit5, false, Vector{Float32}, Float32}, f::Symbol, v::Float32) | |
@ Base .\Base.jl:43 | |
[2] initialize!(integrator::OrdinaryDiffEq.ODEIntegrator{Tsit5, false, Vector{Float32}, Float32}) |
OlderNewer