Created
February 26, 2022 15:27
-
-
Save el-hult/327d93397d70c3fa6a404319e3ae2021 to your computer and use it in GitHub Desktop.
A code sample on how to create a interactive visualization in matplotlib with the matplotlib widgets. In this case, it shows how a quadratic with unknown location can be upper bounded.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# %% set up the plot | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from matplotlib.widgets import Slider | |
LOC_MAX = 2 | |
# The parametrized function to be plotted | |
def f1(t, amplitude, location): | |
return amplitude * (t - location) ** 2 | |
def f2(t, amplitude, location): | |
return amplitude * (np.abs(t) + LOC_MAX) ** 2 | |
t = np.linspace(-LOC_MAX - 1, LOC_MAX + 1, 1000) | |
# Define initial parameters | |
init_amplitude = 1 | |
init_location = 0 | |
# Create the figure and the line that we will manipulate | |
fig, ax = plt.subplots() | |
(line1,) = plt.plot(t, f1(t, init_amplitude, init_location), lw=2, label="True") | |
(line2,) = plt.plot(t, f2(t, init_amplitude, init_location), lw=2, label="Bound") | |
ax.set_xlabel("Time [s]") | |
slider_bg_color = "lightgoldenrodyellow" | |
ax.margins(x=0) | |
# adjust the main plot to make room for the sliders | |
plt.subplots_adjust(left=0.25, bottom=0.25) | |
# Make a horizontal slider to control the frequency. | |
axfreq = plt.axes([0.25, 0.1, 0.65, 0.03], facecolor=slider_bg_color) | |
location_slider = Slider( | |
ax=axfreq, | |
label="Minimum location", | |
valmin=-LOC_MAX, | |
valmax=LOC_MAX, | |
valinit=init_location, | |
) | |
# Make a vertically oriented slider to control the amplitude | |
axamp = plt.axes([0.1, 0.25, 0.0225, 0.63], facecolor=slider_bg_color) | |
amp_slider = Slider( | |
ax=axamp, | |
label="Amplitude", | |
valmin=0, | |
valmax=10, | |
valinit=init_amplitude, | |
orientation="vertical", | |
) | |
# The function to be called anytime a slider's value changes | |
def update(val): | |
line1.set_ydata(f1(t, amp_slider.val, location_slider.val)) | |
line2.set_ydata(f2(t, amp_slider.val, location_slider.val)) | |
fig.canvas.draw_idle() | |
# register the update function with each slider | |
location_slider.on_changed(update) | |
amp_slider.on_changed(update) | |
def reset(event): | |
location_slider.reset() | |
amp_slider.reset() | |
plt.show() | |
# %% |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment