Skip to content

Instantly share code, notes, and snippets.

@akaszynski
Created April 29, 2025 14:19
Show Gist options
  • Save akaszynski/b21258fea6d0b6838b17f7b337bfade4 to your computer and use it in GitHub Desktop.
Save akaszynski/b21258fea6d0b6838b17f7b337bfade4 to your computer and use it in GitHub Desktop.
PyVista - Hover on point and click to place affine widget
import pyvista as pv
import numpy as np
import vtk
import femorph
###############################################################################
from typing import cast, Sequence
import numpy as np
import pyvista
from pyvista.core.errors import VTKVersionError
from pyvista.core.utilities.misc import try_callback
from pyvista.plotting import _vtk
DARK_YELLOW = (0.9647058823529412, 0.7450980392156863, 0)
GLOBAL_AXES = np.eye(3)
def _validate_axes(axes):
"""Validate and normalize input axes.
Axes are expected to follow the right-hand rule (e.g. third axis is the
cross product of the first two.
Parameters
----------
axes : sequence
The axes to be validated and normalized. Should be of shape (3, 3).
Returns
-------
dict
The validated and normalized axes.
"""
axes = np.array(axes)
if axes.shape != (3, 3):
msg = "`axes` must be a (3, 3) array."
raise ValueError(msg)
axes = axes / np.linalg.norm(axes, axis=1, keepdims=True)
if not np.allclose(np.cross(axes[0], axes[1]), axes[2]):
msg = "`axes` do not follow the right hand rule."
raise ValueError(msg)
return axes
def _check_callable(func, name="callback"):
"""Check if a variable is callable."""
if func and not callable(func):
msg = f"`{name}` must be a callable, not {type(func)}."
raise TypeError(msg)
return func
def _make_quarter_arc():
"""Make a quarter circle centered at the origin."""
circ = pyvista.Circle(resolution=100)
circ.faces = np.empty(0, dtype=int)
circ.lines = np.hstack(([26], np.arange(0, 26)))
return circ
def get_angle(v1, v2):
"""Compute the angle between two vectors in degrees.
Parameters
----------
v1 : numpy.ndarray
First input vector.
v2 : numpy.ndarray
Second input vector.
Returns
-------
float
Angle between vectors in degrees.
"""
return np.rad2deg(np.arccos(np.clip(np.dot(v1, v2), -1.0, 1.0)))
def ray_plane_intersection(start_point, direction, plane_point, normal):
"""Compute the intersection between a ray and a plane.
Parameters
----------
start_point : ndarray
Starting point of the ray.
direction : ndarray
Direction of the ray.
plane_point : ndarray
A point on the plane.
normal : ndarray
Normal to the plane.
Returns
-------
ndarray
Intersection point.
"""
t_value = np.dot(normal, (plane_point - start_point)) / np.dot(normal, direction)
return start_point + t_value * direction
class AffineWidget3D:
"""3D affine transform widget.
This widget allows interactive transformations including translation and
rotation using the left mouse button.
Parameters
----------
plotter : pyvista.Plotter
The plotter object.
actor : pyvista.Actor
The actor to which the widget is attached to.
origin : sequence[float], optional
Origin of the widget. Default is the center of the main actor.
start : bool, default: True
If True, start the widget immediately.
scale : float, default: 0.15
Scale factor for the widget relative to the length of the actor.
line_radius : float, default: 0.02
Relative radius of the lines composing the widget.
always_visible : bool, default: True
Make the widget always visible. Setting this to ``False`` will cause
the widget geometry to be hidden by other actors in the plotter.
axes_colors : tuple[ColorLike], optional
Uses the theme by default. Configure the individual axis colors by
modifying either the theme with ``pyvista.global_theme.axes.x_color =
<COLOR>`` or setting this with a ``tuple`` as in ``('r', 'g', 'b')``.
axes : numpy.ndarray, optional
``(3, 3)`` Numpy array defining the X, Y, and Z axes. By default this
matches the default coordinate system.
release_callback : callable, optional
Call this method when releasing the left mouse button. It is passed the
``user_matrix`` of the actor.
interact_callback : callable, optional
Call this method when moving the mouse with the left mouse button
pressed down and a valid movement actor selected. It is passed the
``user_matrix`` of the actor.
Notes
-----
After interacting with the actor, the transform will be stored within
:attr:`pyvista.Prop3D.user_matrix` but will not be applied to the
dataset. Use this matrix in conjunction with
:func:`pyvista.DataObjectFilters.transform` to transform the dataset.
Requires VTK >= v9.2
Examples
--------
Create the affine widget outside of the plotter and add it.
>>> import pyvista as pv
>>> pl = pv.Plotter()
>>> actor = pl.add_mesh(pv.Sphere())
>>> widget = pv.AffineWidget3D(pl, actor)
>>> pl.show()
Access the transform from the actor.
>>> actor.user_matrix
array([[1., 0., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 1., 0.],
[0., 0., 0., 1.]])
"""
def __init__(
self,
plotter,
origin: Sequence[float],
length: float,
start: bool = True,
scale=0.15,
line_radius=0.05,
always_visible: bool = True,
axes_colors=None,
axes=None,
release_callback=None,
interact_callback=None,
):
"""Initialize the widget."""
# needs VTK v9.2.0 due to the hardware picker
if pyvista.vtk_version_info < (9, 2):
msg = "AfflineWidget3D requires VTK v9.2.0 or newer."
raise VTKVersionError(msg)
self._axes = np.eye(4)
self._axes_inv = np.eye(4)
self._pl = plotter
self._selected_actor: pyvista.Actor | None = None
self._init_position = None
self._mouse_move_observer = None
self._left_press_observer = None
self._left_release_observer = None
self._origin = np.array(origin)
self._actor_length = length
self._user_matrix = np.eye(4)
self._cached_matrix = np.eye(4)
self._arrows = [] # type: ignore[var-annotated]
self._circles = [] # type: ignore[var-annotated]
self._pressing_down = False
# origin = origin if origin else actor.center
self._origin = np.array(origin)
if axes_colors is None:
axes_colors = (
pyvista.global_theme.axes.x_color,
pyvista.global_theme.axes.y_color,
pyvista.global_theme.axes.z_color,
)
self._axes_colors = axes_colors
self._circ = _make_quarter_arc()
# self._actor_length = self._main_actor.GetLength()
self._line_radius = line_radius
self._user_interact_callback = _check_callable(interact_callback)
self._user_release_callback = _check_callable(release_callback)
self._init_actors(scale, always_visible)
# axes must be set after initializing actors
if axes is not None:
try:
_validate_axes(axes)
except ValueError:
for actor in self._arrows + self._circles:
self._pl.remove_actor(actor)
raise
self.axes = axes
if start:
self.enable()
def show(self):
for actor in self._arrows:
actor.visibility = True
self._pl.render()
def hide(self):
for actor in self._arrows:
actor.visibility = False
self._pl.render()
def _init_actors(self, scale, always_visible):
"""Initialize the widget's actors."""
for ii, color in enumerate(self._axes_colors):
arrow = pyvista.Arrow(
(0, 0, 0),
direction=GLOBAL_AXES[ii],
scale=self._actor_length * scale * 1.15,
tip_radius=0.1,
shaft_radius=self._line_radius,
)
self._arrows.append(
self._pl.add_mesh(arrow, color=color, lighting=False, render=False)
)
# update origin and assign a default user_matrix
for actor in self._arrows + self._circles:
matrix = np.eye(4)
matrix[:3, -1] = self._origin
actor.user_matrix = matrix
if always_visible:
for actor in self._arrows + self._circles:
actor.mapper.SetResolveCoincidentTopologyToPolygonOffset()
actor.mapper.SetRelativeCoincidentTopologyPolygonOffsetParameters(
0, -20000
)
def _get_world_coord_rot(self, interactor):
"""Get the world coordinates given an interactor.
Unlike ``_get_world_coord_trans``, these coordinates are physically
accurate, but sensitive to the position of the camera. Rotation is zoom
independent.
"""
x, y = interactor.GetEventPosition()
coordinate = _vtk.vtkCoordinate()
coordinate.SetCoordinateSystemToDisplay()
coordinate.SetValue(x, y, 0)
ren = interactor.GetRenderWindow().GetRenderers().GetFirstRenderer()
point = np.array(coordinate.GetComputedWorldValue(ren))
if self._selected_actor:
index = self._circles.index(self._selected_actor)
to_widget = np.array(ren.camera.position - self._origin)
point = ray_plane_intersection(
point, to_widget, self._origin, self.axes[index]
)
return point
def _get_world_coord_trans(self, interactor):
"""Get the world coordinates given an interactor.
This uses a modified scaled approach to get the world coordinates that
are not physically accurate, but ignores zoom and works for
translation.
"""
x, y = interactor.GetEventPosition()
ren = interactor.GetRenderWindow().GetRenderers().GetFirstRenderer()
# Get normalized view coordinates (-1, 1)
width, height = ren.GetSize()
ndc_x = 2 * (x / width) - 1
ndc_y = 2 * (y / height) - 1
ndc_z = 1
# convert camera coordinates to world coordinates
camera = ren.GetActiveCamera()
projection_matrix = pyvista.array_from_vtkmatrix(
camera.GetProjectionTransformMatrix(ren.GetTiledAspectRatio(), 0, 1),
)
inverse_projection_matrix = np.linalg.inv(projection_matrix)
camera_coords = np.dot(inverse_projection_matrix, [ndc_x, ndc_y, ndc_z, 1])
modelview_matrix = pyvista.array_from_vtkmatrix(
camera.GetModelViewTransformMatrix()
)
inverse_modelview_matrix = np.linalg.inv(modelview_matrix)
world_coords = np.dot(inverse_modelview_matrix, camera_coords)
# Scale by twice actor length (experimentally determined for good UX)
return world_coords[:3] * self._actor_length * 2
def _move_callback(self, interactor, _event):
"""Process actions for the move mouse event."""
click_x, click_y = interactor.GetEventPosition()
click_z = 0
picker = interactor.GetPicker()
renderer = (
interactor.GetInteractorStyle()._parent()._plotter.iren.get_poked_renderer()
)
picker.Pick(click_x, click_y, click_z, renderer)
actor = picker.GetActor()
if self._pressing_down:
if self._selected_actor in self._arrows:
current_pos = self._get_world_coord_trans(interactor)
index = self._arrows.index(self._selected_actor)
diff = current_pos - self.init_position
trans_matrix = np.eye(4)
trans_matrix[:3, -1] = self.axes[index] * np.dot(diff, self.axes[index])
matrix = trans_matrix @ self._cached_matrix
elif self._selected_actor in self._circles:
current_pos = self._get_world_coord_rot(interactor)
index = self._circles.index(self._selected_actor)
vec_current = current_pos - self._origin
vec_init = self.init_position - self._origin
normal = self.axes[index]
vec_current = vec_current - np.dot(vec_current, normal) * normal
vec_init = vec_init - np.dot(vec_init, normal) * normal
vec_current /= np.linalg.norm(vec_current)
vec_init /= np.linalg.norm(vec_init)
angle = get_angle(vec_init, vec_current)
cross = np.cross(vec_init, vec_current)
if cross[index] < 0:
angle = -angle
trans = _vtk.vtkTransform()
trans.Translate(self._origin) # type: ignore[call-overload]
trans.RotateWXYZ(
angle,
self._axes[index][0],
self._axes[index][1],
self._axes[index][2],
)
trans.Translate(-self._origin) # type: ignore[call-overload]
trans.Update()
rot_matrix = pyvista.array_from_vtkmatrix(trans.GetMatrix())
matrix = rot_matrix @ self._cached_matrix
if self._user_interact_callback:
try_callback(self._user_interact_callback)
# self._main_actor.user_matrix = matrix
elif self._selected_actor and self._selected_actor is not actor:
# Return the color of the currently selected actor to normal and
# deselect it
if self._selected_actor in self._arrows:
index = self._arrows.index(self._selected_actor)
elif self._selected_actor in self._circles:
index = self._circles.index(self._selected_actor)
self._selected_actor.prop.color = self._axes_colors[index]
self._selected_actor = None
# Highlight the actor if there is no selected actor
if actor and not self._selected_actor:
if actor in self._arrows:
index = self._arrows.index(actor)
self._arrows[index].prop.color = DARK_YELLOW
actor.prop.color = DARK_YELLOW
self._selected_actor = actor
elif actor in self._circles:
index = self._circles.index(actor)
self._circles[index].prop.color = DARK_YELLOW
actor.prop.color = DARK_YELLOW
self._selected_actor = actor
self._pl.render()
def _press_callback(self, interactor, _event):
"""Process actions for the mouse button press event."""
if self._selected_actor:
self._pl.enable_trackball_actor_style()
self._pressing_down = True
if self._selected_actor in self._circles:
self.init_position = self._get_world_coord_rot(interactor)
else:
self.init_position = self._get_world_coord_trans(interactor)
def _release_callback(self, _interactor, _event):
"""Process actions for the mouse button release event."""
self._pl.enable_trackball_style()
self._pressing_down = False
# self._cached_matrix = self._main_actor.user_matrix
if self._user_release_callback:
try_callback(self._user_release_callback)
def _reset(self):
"""Reset the actor and cached transform."""
self._main_actor.user_matrix = np.eye(4)
self._cached_matrix = np.eye(4)
@property
def axes(self):
"""Return or set the axes of the widget.
The axes will be checked for orthogonality. Non-orthogonal axes will
raise a ``ValueError``
Returns
-------
numpy.ndarray
``(3, 3)`` array of axes.
"""
return self._axes[:3, :3]
@axes.setter
def axes(self, axes):
mat = np.eye(4)
mat[:3, :3] = _validate_axes(axes)
mat[:3, -1] = self.origin
self._axes = mat
self._axes_inv = np.linalg.inv(self._axes)
for actor in self._arrows + self._circles:
matrix = actor.user_matrix
# Be sure to use the inverse here
matrix[:3, :3] = self._axes_inv[:3, :3]
actor.user_matrix = matrix
@property
def origin(self) -> tuple[float, float, float]:
"""Origin of the widget.
This is where the origin of the widget will be located and where the
actor will be rotated about.
Returns
-------
tuple
Widget origin.
"""
return cast("tuple[float, float, float]", tuple(self._origin))
@origin.setter
def origin(self, value):
value = np.array(value)
diff = value - self._origin
for actor in self._circles + self._arrows:
if actor.user_matrix is None:
actor.user_matrix = np.eye(4)
matrix = actor.user_matrix
matrix[:3, -1] += diff
actor.user_matrix = matrix
self._origin = value
def enable(self):
"""Enable the widget."""
if not self._pl._picker_in_use:
self._pl.enable_mesh_picking(
show_message=False, show=False, picker="hardware"
)
self._mouse_move_observer = self._pl.iren.add_observer(
"MouseMoveEvent",
self._move_callback,
)
self._left_press_observer = self._pl.iren.add_observer(
"LeftButtonPressEvent",
self._press_callback,
interactor_style_fallback=False,
)
self._left_release_observer = self._pl.iren.add_observer(
"LeftButtonReleaseEvent",
self._release_callback,
interactor_style_fallback=False,
)
def disable(self):
"""Disable the widget."""
self._pl.disable_picking()
if self._mouse_move_observer:
self._pl.iren.remove_observer(self._mouse_move_observer)
if self._left_press_observer:
self._pl.iren.remove_observer(self._left_press_observer)
if self._left_release_observer:
self._pl.iren.remove_observer(self._left_release_observer)
def remove(self):
"""Disable and delete all actors of this widget."""
self.disable()
for actor in self._circles + self._arrows:
self._pl.remove_actor(actor)
self._circles = []
self._arrows = []
def update_origin(self, origin):
self.origin = origin
self._user_matrix = np.eye(4)
self._cached_matrix = np.eye(4)
@property
def selected(self) -> bool:
return self._selected_actor is not None
###############################################################################
mesh = pv.Cube().triangulate().subdivide(2)
mesh_points = pv.PolyData(mesh.points)
state = {
"mouse-xy": None,
"points": mesh_points,
"selected-point": None,
"selected-actor": None,
"widget": None,
}
def project_points_to_display_coords(plotter):
renderer = plotter.renderer
camera = renderer.GetActiveCamera()
points = state["points"].points
coords = []
for p in points:
world = vtk.vtkCoordinate()
world.SetCoordinateSystemToWorld()
world.SetValue(p)
display = world.GetComputedDisplayValue(renderer)
coords.append(display[:2])
pp = np.array(coords, dtype=np.float64)
state["display-lookup"] = femorph.kdtree.KdTree(pp)
def print_cursor_world_position(plotter):
def callback(interactor, event):
if "projected_points" not in state:
project_points_to_display_coords(plotter)
event_pos = np.array([interactor.GetEventPosition()], dtype=np.float64)
dist, index = state["display-lookup"].query(event_pos)
max_dist = float(np.mean(plotter.window_size)) / 50
if len(index) and dist < max_dist:
if state["selected-point"] != index:
pts_to_add = np.array([mesh_points.points[index[0]]])
plotter.remove_actor(state["selected-actor"])
state["selected-actor"] = plotter.add_points(
pts_to_add, render_points_as_spheres=True, color="r", point_size=10
)
state["selected-point"] = index
else:
plotter.remove_actor(state["selected-actor"])
state["selected-point"] = None
plotter.iren.add_observer("MouseMoveEvent", callback)
def add_camera_move_callback(plotter, callback):
def _internal_caller(caller, event):
# camera = caller.GetActiveCamera()
callback(plotter)
plotter.iren.add_observer("EndInteractionEvent", _internal_caller)
def on_camera_move(plotter):
project_points_to_display_coords(plotter)
import pyvista as pv
pl = pv.Plotter()
pl.add_mesh(mesh, color="w")
points_actor = pl.add_points(
state["points"],
point_size=10, # render_points_as_spheres=True
)
print_cursor_world_position(pl)
add_camera_move_callback(pl, on_camera_move)
widget = AffineWidget3D(
pl, state["points"].points[97], points_actor.GetLength() / 4, always_visible=False
)
widget.enable()
state["widget"] = widget
state["widget"].hide()
# def on_left_click_show_widget(plotter):
def on_release(interactor, event):
press_pos = state["mouse-xy"]
release_pos = pl.iren.get_event_position()
if press_pos is None:
return
dx = release_pos[0] - press_pos[0]
dy = release_pos[1] - press_pos[1]
drag_dist = dx * dx + dy * dy
if drag_dist > 4: # threshold in pixels
return
if state["selected-point"] is not None:
idx = state["selected-point"][0]
state["widget"].update_origin(state["points"].points[idx])
state["widget"].show()
elif not state["widget"].selected or state["widget"]._pressing_down:
state["widget"].hide()
def on_press(interactor, event):
state["mouse-xy"] = interactor.GetEventPosition()
pl.iren.add_observer("LeftButtonPressEvent", on_press)
pl.iren.add_observer("LeftButtonReleaseEvent", on_release)
pl.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment