Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save samuela/6e0ff7220f7aeab60aed32656b4faf72 to your computer and use it in GitHub Desktop.
Save samuela/6e0ff7220f7aeab60aed32656b4faf72 to your computer and use it in GitHub Desktop.
import time
import control
import matplotlib.pyplot as plt
from jax import random
from jax import jit
from jax import value_and_grad
from jax import lax
from jax import vmap
import jax.numpy as jp
from jax.experimental import stax
from jax.experimental import ode
from jax.experimental import optimizers
from jax.experimental.stax import Dense
from jax.experimental.stax import Relu
from jax.experimental.stax import Tanh
from research.utils import make_optimizer
from research.utils import DenseNoBias
from research.utils import random_psd
from research import blt
def fixed_env(n):
A = -1 * jp.eye(n)
# A = jp.diag(jp.array([-1.0, 1.0]))
B = jp.eye(n)
Q = jp.eye(n)
R = jp.eye(n)
N = jp.zeros((n, n))
return A, B, Q, R, N
def random_env(rng):
rngA, rngB, rngQ, rngR = random.split(rng, 4)
A = -1 * random_psd(rngA, 2)
B = random.normal(rngB, (2, 2))
Q = random_psd(rngQ, 2) + 0.1 * jp.eye(2)
R = random_psd(rngR, 2) + 0.1 * jp.eye(2)
N = jp.zeros((2, 2))
return A, B, Q, R, N
def policy_integrate_cost(dynamics_fn, cost_fn, gamma):
# Specialize to the environment.
def eval_policy(policy):
# Specialize to the policy.
def ofunc(y, t, policy_params):
x = y[1:]
u = policy(policy_params, x)
return jp.concatenate((jp.expand_dims((gamma**t) * cost_fn(x, u), axis=0), dynamics_fn(x, u)))
def eval_from_x0(policy_params, x0, total_time):
# Zero is necessary for some reason...
t = jp.array([0.0, total_time])
y0 = jp.concatenate((jp.zeros((1, )), x0))
yT = ode.odeint(ofunc, y0, t, policy_params, rtol=1e-3, mxstep=1e6)
# yT = ode.odeint(ofunc, y0, t, policy_params)
return yT[1, 0]
return eval_from_x0
return eval_policy
def main():
blt.plot()
return
total_time = 10.0
gamma = 1.0
x_dim = 2
rng = random.PRNGKey(0)
x0 = jp.array([2.0, 1.0])
# rng_x0, rng = random.split(rng)
# x0 = random.normal(rng_x0, shape=(x_dim, ))
### Set up the problem/environment
# xdot = Ax + Bu
# u = - Kx
# cost = xQx + uRu + 2xNu
A, B, Q, R, N = fixed_env(x_dim)
dynamics_fn = lambda x, u: A @ x + B @ u
cost_fn = lambda x, u: x.T @ Q @ x + u.T @ R @ u + 2 * x.T @ N @ u
policy_loss = policy_integrate_cost(dynamics_fn, cost_fn, gamma)
### Solve the Riccatti equation to get the infinite-horizon optimal solution.
K, _, _ = control.lqr(A, B, Q, R, N)
K = jp.array(K)
t0 = time.time()
opt_cost = policy_loss(lambda _, x: -K @ x)(None, x0, total_time)
print(f"opt_cost = {opt_cost} in {time.time() - t0}s")
### Set up the learned policy model.
policy_init, policy = stax.serial(
Dense(64),
Relu,
Dense(64),
Relu,
Dense(x_dim),
)
# policy_init, policy = DenseNoBias(2)
rng_init_params, rng = random.split(rng)
_, init_policy_params = policy_init(rng_init_params, (x_dim, ))
opt = make_optimizer(optimizers.adam(1e-3))(init_policy_params)
cost_and_grad = jit(value_and_grad(policy_loss(policy)))
### Main optimization loop.
costs = []
for i in range(100):
t0 = time.time()
cost, g = cost_and_grad(opt.value, x0, total_time)
opt = opt.update(g)
print(f"Episode {i}: excess cost = {cost - opt_cost}, elapsed = {time.time() - t0}")
costs.append(float(cost))
print(f"Opt solution cost from starting point: {opt_cost}")
# print(f"Gradient at opt solution: {opt_g}")
# Print the identified and optimal policy. Note that layers multiply multipy
# on the right instead of the left so we need a transpose.
# print(f"Est solution parameters: {opt.value}")
# print(f"Opt solution parameters: {-K.T}")
### Plot performance per iteration, incl. average optimal policy performance.
plt.figure()
plt.plot(costs)
plt.axhline(opt_cost, linestyle="--", color="gray")
plt.yscale("log")
plt.xlabel("Iteration")
plt.ylabel(f"Cost")
plt.legend(["Learned policy", "Direct LQR solution"])
plt.title(f"ODE control of LQR problem\n(T = {total_time}s, gamma = {gamma}, A = {A}")
### Example rollout plots (learned policy vs optimal policy).
# framerate = 30
# timesteps = jp.linspace(0, total_time, num=int(total_time * framerate))
# est_policy_rollout_states = ode.odeint(lambda x, _: dynamics_fn(x, policy(opt.value, x)),
# y0=x0,
# t=timesteps)
# est_policy_rollout_controls = vmap(lambda x: policy(opt.value, x))(est_policy_rollout_states)
# opt_policy_rollout_states = ode.odeint(lambda x, _: dynamics_fn(x, -K @ x), y0=x0, t=timesteps)
# opt_policy_rollout_controls = vmap(lambda x: -K @ x)(opt_policy_rollout_states)
# plt.figure()
# plt.plot(est_policy_rollout_states[:, 0], est_policy_rollout_states[:, 1], marker='.')
# plt.plot(opt_policy_rollout_states[:, 0], opt_policy_rollout_states[:, 1], marker='.')
# plt.xlabel("x_1")
# plt.ylabel("x_2")
# plt.legend(["Learned policy", "Direct LQR solution"])
# plt.title("Phase space trajectory")
# plt.figure()
# plt.plot(timesteps, jp.sqrt(jp.sum(est_policy_rollout_controls**2, axis=-1)))
# plt.plot(timesteps, jp.sqrt(jp.sum(opt_policy_rollout_controls**2, axis=-1)))
# plt.xlabel("time")
# plt.ylabel("control input (L2 norm)")
# plt.legend(["Learned policy", "Direct LQR solution"])
# plt.title("Policy control over time")
### Plot quiver field showing dynamics under learned policy.
# plot_policy_dynamics(dynamics_fn, cost_fn, lambda x: policy(opt.value, x))
plt.show()
def plot_policy_dynamics(dynamics_fn, cost_fn, policy):
t0 = time.time()
plt.figure()
x1s = jp.linspace(-1, 1, num=50)
x2s = jp.linspace(-1, 1, num=50)
flatmesh = jp.array([[x1, x2] for x1 in x1s for x2 in x2s])
uv = vmap(lambda x: dynamics_fn(x, policy(x)))(flatmesh)
uv_grid = jp.reshape(uv, (len(x1s), len(x2s), 2))
color = vmap(lambda x: cost_fn(x, policy(x)))(flatmesh)
color_grid = jp.reshape(color, (len(x1s), len(x2s)))
plt.quiver(x1s, x2s, uv_grid[:, :, 0], uv_grid[:, :, 1], color_grid)
plt.axis("equal")
plt.xlabel("x_1")
plt.ylabel("x_2")
plt.title("Dynamics under policy")
print(f"[timing] Plotting control dynamics took {time.time() - t0}s")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment