Created
November 18, 2021 15:10
-
-
Save baggepinnen/91c0f688fefa204e0f012d78d4d7d878 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using StaticArrays | |
using Statistics, LinearAlgebra | |
using ModelingToolkit, Symbolics | |
""" | |
rk4(f, l, Ts) | |
Discretize dynamics `f` and loss function `l`using RK4 with sample time `Ts`. | |
The returned function is on the form `(xₖ,uₖ,t)-> (xₖ₊₁, loss)`. | |
Both `f` and `l` take the arguments `(x, u, t)`. | |
""" | |
function rk4(f::F, l::LT, Ts) where {F, LT} | |
# Runge-Kutta 4 method | |
function (x, u, t) | |
f1, L1 = f(x, u, t), l(x, u, t) | |
f2, L2 = f(x + Ts / 2 * f1, u, t + Ts / 2), l(x + Ts / 2 * f1, u, t + Ts / 2) | |
f3, L3 = f(x + Ts / 2 * f2, u, t + Ts / 2), l(x + Ts / 2 * f2, u, t + Ts / 2) | |
f4, L4 = f(x + Ts * f3, u, t + Ts), l(x + Ts * f3, u, t + Ts) | |
x += Ts / 6 * (f1 + 2 * f2 + 2 * f3 + f4) | |
L = Ts / 6 * (L1 + 2 * L2 + 2 * L3 + L4) | |
return x, L | |
end | |
end | |
function cartpole(x, u, _) | |
mc, mp, l, g = 1.0, 0.2, 0.5, 9.81 | |
q = x[SA[1, 2]] | |
qd = x[SA[3, 4]] | |
s = sin(q[2]) | |
c = cos(q[2]) | |
H = @SMatrix [mc+mp mp*l*c; mp*l*c mp*l^2] | |
C = @SMatrix [0 -mp*qd[2]*l*s; 0 0] | |
G = @SVector [0, mp * g * l * s] | |
B = @SVector [1, 0] | |
qdd = -H \ (C * qd + G - B * u[1]) | |
return [qd; qdd] | |
end | |
function loss(x,u,t) | |
c = x'Q1*x + u'Q2*u | |
# if Q3 !== nothing | |
# Δu = u - uprev | |
# c += dot(Δu,Q3,Δu) | |
# end | |
c | |
end | |
function final_cost(x) | |
x'Q1*x # TODO: replace by Riccati solution | |
end | |
nu = 1 # number of controls | |
nx = 4 # number of states | |
Ts = 0.02 # sample time | |
N = 2 # Time horizon (set very small to not take too long time generating symbolic functions) realistic N are in the hundreds | |
x0 = zeros(nx) # Initial state | |
x0[1] = 3 # cart pos | |
x0[2] = pi*0.5 # pendulum angle | |
xr = zeros(nx) # reference state | |
Q1 = diagm(Float64[1, 1, 1, 1]) # state cost matrix | |
Q2 = Ts * diagm(ones(nu)) # control cost matrix | |
Q3 = nothing | |
# Control limits | |
umin = -10 * ones(nu) | |
umax = 10 * ones(nu) | |
# State limits (be careful with those, they may make the problem infeasible) | |
xmin = -50 * ones(nx) | |
xmax = 50 * ones(nx) | |
discrete = rk4(cartpole, loss, Ts) # discretize the loss integral and continupus dynamics | |
# xp, L = discrete(x, u, t) | |
## Build symbolic representation of optimal control problem. | |
w = [] # variables | |
w0 = [] # initial guess | |
lbw = [] # lower bound on w | |
ubw = [] # upper bound on w | |
g = [] # equality constraints | |
L = 0 | |
@variables x[1:nx](1) # initial value variable | |
x = collect(x) # Symbolic arrays are too buggy | |
append!(w, x) | |
append!(w0, x0) | |
append!(lbw, x) # Initial state is fixed | |
append!(ubw, x) | |
for n = 1:N # for whole time horizon N | |
global x, L | |
@variables u[1:nu](n) | |
u = collect(u) # Symbolic arrays are too buggy | |
append!(w, u) | |
append!(w0, 0) # TODO: add u0 | |
append!(lbw, umin) | |
append!(ubw, umax) | |
xp, l = discrete(x, u, n) | |
L += l | |
@variables x[1:4](n+1) # x in next time point | |
x = collect(x) # Symbolic arrays are too buggy | |
append!(w, x) | |
append!(w0, zeros(nx)) # TODO: add warmstart | |
append!(lbw, xmin) | |
append!(ubw, xmax) | |
append!(g, xp .- x) # propagated x is x in next time point | |
L += final_cost(x) | |
end | |
## | |
# J = Symbolics.jacobian(xp, x) | |
@time A = Symbolics.sparsejacobian(g, w); # This takes forever to print in full form | |
@time dw = Symbolics.gradient(L, w) | |
@time H = Symbolics.sparsehessian(L, w) | |
@time jacfun = build_function(A, w, expression = Val(true)); # takes forever | |
# 222.257406 seconds (963.52 M allocations: 40.937 GiB, 15.66% gc time, 0.41% compilation time) | |
@time hessfun = build_function(H, w, expression = Val(false)) | |
lbfun = build_function(lbw, w, expression = Val(false)) | |
ubfun = build_function(ubw, w, expression = Val(false)) | |
res = lbfun[1](w0) | |
@test res[1:nx] == w0[1:nx] | |
@time jacfun = build_function(A.nzval, w, expression = Val(false)); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment