Created
July 11, 2023 14:16
-
-
Save baggepinnen/eda5999fbe22232f445698546b5e282c to your computer and use it in GitHub Desktop.
FastDiff or NoDiff?
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using FastDifferentiation | |
# ============================================================================== | |
## Try simple things | |
# ============================================================================== | |
x = make_variables(:x, 2) | |
mu = make_variables(:mu, 2) | |
ex = sum(abs2, x .* mu) | |
hs = sparse_hessian(ex, x) | |
fun = make_function(hs, x, mu; in_place=false) | |
# Test | |
xv = randn(size(x)) | |
muv = randn(size(mu)) | |
fun(xv, muv) | |
# ============================================================================== | |
## Try on MPC generated functions | |
# The code below relies on `prob` being defined using one of our MPC examples | |
# ============================================================================== | |
using JuliaSimControl | |
using Serialization | |
## MPC controller | |
using HardwareAbstractions | |
import HardwareAbstractions as hw | |
using QuanserInterface | |
using JuliaSimControl | |
using JuliaSimControl.MPC | |
import JuliaSimControl.ControlDemoSystems as demo | |
using FiniteDiff | |
# cd("../../examples/quanser/") | |
# prob = deserialize("prob") | |
p = prob.p | |
vars = prob.vars | |
v = copy(vars.vars) | |
optfun = prob.optprob.f | |
# Test grad | |
out = similar(v) | |
optfun.grad(out, v, p) | |
@test out ≈ FiniteDiff.finite_difference_gradient(x->optfun.f(x, p), v) rtol = 1e-4 | |
# Test jac | |
nc = length(optfun.cons.constraint)*vars.n_robust + vars.nu*optfun.cons.robust_horizon*(vars.n_robust-1) | |
out = optfun.cons_jac_prototype | |
inner_out = zeros(nc) | |
optfun.cons_j(out, v, p) | |
@test out ≈ FiniteDiff.finite_difference_jacobian(function (x) | |
inner_out = zeros(nc) | |
optfun.cons(inner_out, x, p) | |
end, v) rtol = 1e-4 | |
@btime $optfun.cons_j($out, $v, $p) | |
# 63.701 μs (0 allocations: 0 bytes) | |
# Test lag hess | |
out = zeros(length(v), length(v)) | |
out = similar(optfun.lag_hess_prototype.nzval) | |
mu = randn(nc) | |
sig = 1 | |
optfun.lag_h(out, v, sig, mu, p) | |
## | |
vs = make_variables(:v, length(v)) | |
ps = make_variables(:p, length(p)) | |
mu = make_variables(:mu, nc) | |
# outs = similar(optfun.cons_jac_prototype, eltype(vs)) | |
# optfun.cons_j(outs, vs, ps) | |
consout = zeros(eltype(vs), nc) | |
# consout = Vector{Any}(undef, nc) | |
consex = optfun.cons(consout, copy(vs), ps) | |
fun = make_function(consex, vs, ps; in_place=false) | |
fun(v, p) | |
@btime fun(v, p) | |
J = sparse_jacobian(consex, vs) | |
vs = make_variables(:v, 2) | |
sparse_jacobian([transpose(vs)*vs; vs'vs], vs) | |
xs = make_variables(:x, 2) | |
# res = similar(xs) | |
res = zeros(eltype(xs), length(xs)) | |
function foo!(res, x) | |
res .= x.^2 | |
end | |
foo!(res, xs) | |
## | |
consout = zeros(nc); | |
consex = optfun.cons(consout, v, p) | |
# ============================================================================== | |
## Simple hessian example | |
# ============================================================================== | |
using Symbolics: variables | |
using Symbolics | |
function cartpole(x, u, p=0, t=0) | |
mc, mp, l, g = 1.0, 0.2, 0.5, 9.81 | |
q = x[SA[1, 2]] | |
qd = x[SA[3, 4]] | |
s = sin(q[2]) | |
c = cos(q[2]) | |
H = [mc+mp mp*l*c; mp*l*c mp*l^2] | |
C = [0 -mp*qd[2]*l*s; 0 0] | |
G = [0, mp * g * l * s] | |
B = [1, 0] | |
# qdd = (-H) \ (C * qd + G - B * u[1]) | |
den = (H[1, 1]*H[2, 2] - H[1, 2]*H[2, 1]) | |
xdd = [H[2,2] -H[1,2]; -H[2,1] H[1,1]] | |
qdd = xdd * (C * qd + G - B * u[1]) | |
return [qd; qdd] | |
end | |
x = variables(:x, 1:4) | |
u = variables(:u, 1:2) | |
xp = cartpole(x, u) | |
c = sum(abs2, xp) | |
vars = [x; u] | |
hs = Symbolics.sparsehessian(c, vars) | |
h = build_function(hs, x, u; expression=Val{false}, cse=true) | |
H = h[1](randn(4), randn(4)) # Show sparsity pattern | |
@btime h[1](randn(4), randn(2)) | |
import FastDifferentiation as fad | |
x = fad.make_variables(:x, 4) | |
u = fad.make_variables(:u, 2) | |
xp = cartpole(x, u) | |
c = sum(abs2, xp) | |
vars = [x; u] | |
hs = fad.sparse_hessian(c, vars) | |
h = fad.make_function(hs, vars) | |
@btime h(randn(6)) | |
# ============================================================================== | |
## One step more demanding | |
# ============================================================================== | |
function rk4(f::F, Ts0; supersample::Integer = 1) where {F} | |
supersample ≥ 1 || throw(ArgumentError("supersample must be positive.")) | |
# Runge-Kutta 4 method | |
Ts = Ts0 / supersample # to preserve type stability in case Ts0 is an integer | |
let Ts = Ts | |
function (x, u, p=0, t=0) | |
T = typeof(x) | |
f1 = f(x, u, p, t) | |
f2 = f(x + Ts / 2 * f1, u, p, t + Ts / 2) | |
f3 = f(x + Ts / 2 * f2, u, p, t + Ts / 2) | |
f4 = f(x + Ts * f3, u, p, t + Ts) | |
add = Ts / 6 * (f1 + 2 * f2 + 2 * f3 + f4) | |
# This gymnastics with changing the name to y is to ensure type stability when x + add is not the same type as x. The compiler is smart enough to figure out the type of y | |
y = x + add | |
for i in 2:supersample | |
f1 = f(y, u, p, t) | |
f2 = f(y + Ts / 2 * f1, u, p, t + Ts / 2) | |
f3 = f(y + Ts / 2 * f2, u, p, t + Ts / 2) | |
f4 = f(y + Ts * f3, u, p, t + Ts) | |
add = Ts / 6 * (f1 + 2 * f2 + 2 * f3 + f4) | |
y += add | |
end | |
return y | |
end | |
end | |
end | |
function rollout(f, x0::AbstractVector, u) | |
T = promote_type(eltype(x0), eltype(u)) | |
x = zeros(T, length(x0), size(u, 2)) | |
x[:, 1] .= x0 | |
rollout!(f, x, u) | |
end | |
function rollout!(f, x, u) | |
for i = 1:size(x, 2)-1 | |
x[:, i+1] = f(x[:, i], u[:, i]) # TODO: i * Ts | |
end | |
x, u | |
end | |
Ts = 0.01 | |
N = 2 | |
# u = variables(:u, 1:2, 1:N) | |
# x0 = variables(:x0, 1:4) | |
u = reshape(fad.make_variables(:u, 2*N), 2, N) | |
x0 = fad.make_variables(:x0, 4) | |
discrete_cartpole = rk4(cartpole, Ts) | |
x, _ = rollout(discrete_cartpole, x0, u) # Never finishes? | |
vars = [x0; vec(u)] | |
c = sum(abs2, x) + sum(abs2, u) | |
hs = Symbolics.sparsehessian(c, vars) | |
h = build_function(hs, x, u; expression=Val{false}, cse=true) | |
H = h[1](randn(4), randn(2, N)) # Show sparsity pattern | |
hs = fad.sparse_hessian(c, vars) | |
h = fad.make_function(hs, vars) | |
@btime h(randn(6)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment