import numpy as np
from jutility import plotting, util, cli, units

n = 10
r = np.linspace(0, 2, 200)
eps = 0.2
a = np.linspace(-3, 3, n).reshape(n, 1)
l_clip = np.minimum(r * a, np.clip(r, 1-eps, 1+eps) * a)

cp = plotting.ColourPicker(len(a), cyclic=False)

mp = plotting.MultiPlot(
    plotting.Subplot(
        *[
            plotting.Line(r, l_clip[i], c=cp.next())
            for i in range(n)
        ],
        plotting.VLine(1-eps, c="k", ls="--"),
        plotting.VLine(1+eps, c="k", ls="--"),
        title="$L_{clip}$",
        xlabel="r",
        grid=False,
        xticks=[0, 1-eps, 1, 1+eps, 2],
    ),
    cp.get_colourbar(a.min(), a.max(), label="A"),
    width_ratios=[1, 0.1],
    figsize=[8, 6],
)
mp.save("PPO loss", ".")