Skip to content

Instantly share code, notes, and snippets.

@jakelevi1996
Last active November 14, 2024 18:32
Show Gist options
  • Save jakelevi1996/1171a4dd5979eb4fc82a661fb916df09 to your computer and use it in GitHub Desktop.
Save jakelevi1996/1171a4dd5979eb4fc82a661fb916df09 to your computer and use it in GitHub Desktop.
Plot PPO loss
import numpy as np
from jutility import plotting, util, cli, units
n = 10
r = np.linspace(0, 2, 200)
eps = 0.2
a = np.linspace(-3, 3, n).reshape(n, 1)
l_clip = np.minimum(r * a, np.clip(r, 1-eps, 1+eps) * a)
cp = plotting.ColourPicker(len(a), cyclic=False)
mp = plotting.MultiPlot(
plotting.Subplot(
*[
plotting.Line(r, l_clip[i], c=cp.next())
for i in range(n)
],
plotting.VLine(1-eps, c="k", ls="--"),
plotting.VLine(1+eps, c="k", ls="--"),
title="$L_{clip}$",
xlabel="r",
grid=False,
xticks=[0, 1-eps, 1, 1+eps, 2],
),
cp.get_colourbar(a.min(), a.max(), label="A"),
width_ratios=[1, 0.1],
figsize=[8, 6],
)
mp.save("PPO loss", ".")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment