Skip to content

Instantly share code, notes, and snippets.

@larsoner
Created October 5, 2023 13:26
Show Gist options
  • Save larsoner/2506793c12f650e5f6bafe9616351c10 to your computer and use it in GitHub Desktop.
Save larsoner/2506793c12f650e5f6bafe9616351c10 to your computer and use it in GitHub Desktop.
"""Attempt to build a mne.viz.Brain-like GUI using magicgui."""
from magicgui import widgets, use_app
import numpy as np
import re
import mne
from matplotlib.figure import Figure
import pyvista
import pyvista.plotting
# native stuff
backend = mne.viz.get_3d_backend()
assert backend in ("pyvistaqt", "notebook"), backend
if backend == "notebook":
use_app("ipynb")
columns = [widgets.Container(layout="vertical", labels=False) for _ in range(2)]
columns[0].extend([
widgets.Label(value="Time"),
widgets.FloatSlider(min=-0.2, max=0.5, value=0.1, step=0.01),
widgets.ComboBox(choices=["lateral", "medial"]),
widgets.Label(value="fmin/fmid/fmax"),
widgets.FloatSlider(min=1, max=3, value=1, step=0.01),
widgets.FloatSlider(min=1, max=3, value=2, step=0.01),
widgets.FloatSlider(min=1, max=3, value=3, step=0.01),
widgets.Container(
widgets=[
widgets.PushButton(text="↺"),
widgets.PushButton(text="+"),
widgets.PushButton(text="-"),
],
layout="horizontal",
labels=False,
),
widgets.Label(value="Annotation"),
widgets.ComboBox(choices=["None", "aparc"]),
widgets.Label(value="Extract mode"),
widgets.ComboBox(choices=["mean", "max"]),
])
# %% Create main window
if backend == "pyvistaqt":
window = widgets.MainWindow(widgets=columns, layout="horizontal")
else:
from ipywidgets import AppLayout
window = AppLayout(
left_sidebar=columns[0].native,
center=columns[1].native,
)
for col in columns:
col.native.layout.height = "100%"
columns[0].native.layout.width = "200px"
for widget in columns[0]:
widget.native.layout.width = "190px"
window.layout.align_items = "center"
window.layout.display = "flex"
window.layout.width = "810px" # 10px buffer
window.layout.height = "610px" # 10px buffer
window.layout.flex_flow = "stretch"
window.layout.border = "1px solid black"
# columns[0].native.layout.border = "1px solid black"
# columns[1].native.layout.border = "1px solid red"
# native stuff
backend = mne.viz.get_3d_backend()
assert backend in ("pyvistaqt", "notebook"), backend
if backend == "pyvistaqt":
columns[0].native.layout().addStretch(1)
# %% PyVista renderer
if backend == "pyvistaqt":
from pyvistaqt import QtInteractor
from qtpy.QtWidgets import QSizePolicy
plotter = QtInteractor(parent=window.native, auto_update=False)
plotter.resize(600, 400)
plotter.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding)
columns[1].native.layout().addWidget(plotter, stretch=2)
else:
plotter = pyvista.plotting.Plotter(
off_screen=True,
notebook=True,
window_size=(600, 400),
)
from pyvista.trame import show_trame
viewer = show_trame(
plotter,
add_menu=False,
# width="100%",
# height="100%",
)
viewer.value = re.sub(
r" style=[\"'](.+)[\"']></iframe>",
# value taken from matplotlib's widget
r" style='\1; border: none;' scrolling='no'></iframe>", # noqa: E501
viewer.value,
)
viewer.layout.height = "400px"
viewer.layout.min_height = "400px"
viewer.layout.width = "600px"
viewer.layout.min_width = "600px"
columns[1].native.children += (viewer,)
grid = pyvista.ImageData(dimensions=(5, 5, 5), origin=(-2, -2, -2)).explode(0.2)
plotter.add_mesh(grid)
# %% matplotlib
dpi = 92
fig = Figure(figsize=(600 / dpi, dpi / 80), dpi=dpi, constrained_layout=True)
ax = fig.subplots()
ax.plot(np.random.RandomState(0).randn(1000))
if backend == "pyvistaqt":
from matplotlib.backends.backend_qt5agg import FigureCanvas
canvas = FigureCanvas(fig)
columns[1].native.layout().addWidget(canvas, stretch=1)
window.show()
else:
import ipympl.backend_nbagg
canvas = ipympl.backend_nbagg.Canvas(fig)
manager = ipympl.backend_nbagg.FigureManager(canvas, 0)
canvas.toolbar_visible = False
canvas.header_visible = False
canvas.footer_visible = False
canvas.layout.height = "200px"
canvas.layout.min_height = "200px"
fig.set_figheight(200 / dpi) # shouldn't be necessary but whatever
fig.set_figwidth(200 / dpi) # shouldn't be necessary but whatever
canvas.layout.min_width = "600px"
canvas.layout.width = "600px"
canvas.resizable = False
columns[1].native.children += (canvas,)
for widget in (window, columns[0].native, columns[1].native, viewer, canvas):
widget.layout.padding = "0 0 0 0px"
widget.layout.margin = "0 0 0 0px"
display(window)
plotter.camera.up = (0., 0., 1.)
plotter.camera.position = (0., -15, 0)
plotter.camera.focal_point = (0., 0., 0.)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment