import numpy as np from jutility import plotting, util, cli, units n = 10 r = np.linspace(0, 2, 200) eps = 0.2 a = np.linspace(-3, 3, n).reshape(n, 1) l_clip = np.minimum(r * a, np.clip(r, 1-eps, 1+eps) * a) cp = plotting.ColourPicker(len(a), cyclic=False) mp = plotting.MultiPlot( plotting.Subplot( *[ plotting.Line(r, l_clip[i], c=cp.next()) for i in range(n) ], plotting.VLine(1-eps, c="k", ls="--"), plotting.VLine(1+eps, c="k", ls="--"), title="$L_{clip}$", xlabel="r", grid=False, xticks=[0, 1-eps, 1, 1+eps, 2], ), cp.get_colourbar(a.min(), a.max(), label="A"), width_ratios=[1, 0.1], figsize=[8, 6], ) mp.save("PPO loss", ".")