Created
April 29, 2025 14:19
-
-
Save akaszynski/b21258fea6d0b6838b17f7b337bfade4 to your computer and use it in GitHub Desktop.
PyVista - Hover on point and click to place affine widget
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pyvista as pv | |
import numpy as np | |
import vtk | |
import femorph | |
############################################################################### | |
from typing import cast, Sequence | |
import numpy as np | |
import pyvista | |
from pyvista.core.errors import VTKVersionError | |
from pyvista.core.utilities.misc import try_callback | |
from pyvista.plotting import _vtk | |
DARK_YELLOW = (0.9647058823529412, 0.7450980392156863, 0) | |
GLOBAL_AXES = np.eye(3) | |
def _validate_axes(axes): | |
"""Validate and normalize input axes. | |
Axes are expected to follow the right-hand rule (e.g. third axis is the | |
cross product of the first two. | |
Parameters | |
---------- | |
axes : sequence | |
The axes to be validated and normalized. Should be of shape (3, 3). | |
Returns | |
------- | |
dict | |
The validated and normalized axes. | |
""" | |
axes = np.array(axes) | |
if axes.shape != (3, 3): | |
msg = "`axes` must be a (3, 3) array." | |
raise ValueError(msg) | |
axes = axes / np.linalg.norm(axes, axis=1, keepdims=True) | |
if not np.allclose(np.cross(axes[0], axes[1]), axes[2]): | |
msg = "`axes` do not follow the right hand rule." | |
raise ValueError(msg) | |
return axes | |
def _check_callable(func, name="callback"): | |
"""Check if a variable is callable.""" | |
if func and not callable(func): | |
msg = f"`{name}` must be a callable, not {type(func)}." | |
raise TypeError(msg) | |
return func | |
def _make_quarter_arc(): | |
"""Make a quarter circle centered at the origin.""" | |
circ = pyvista.Circle(resolution=100) | |
circ.faces = np.empty(0, dtype=int) | |
circ.lines = np.hstack(([26], np.arange(0, 26))) | |
return circ | |
def get_angle(v1, v2): | |
"""Compute the angle between two vectors in degrees. | |
Parameters | |
---------- | |
v1 : numpy.ndarray | |
First input vector. | |
v2 : numpy.ndarray | |
Second input vector. | |
Returns | |
------- | |
float | |
Angle between vectors in degrees. | |
""" | |
return np.rad2deg(np.arccos(np.clip(np.dot(v1, v2), -1.0, 1.0))) | |
def ray_plane_intersection(start_point, direction, plane_point, normal): | |
"""Compute the intersection between a ray and a plane. | |
Parameters | |
---------- | |
start_point : ndarray | |
Starting point of the ray. | |
direction : ndarray | |
Direction of the ray. | |
plane_point : ndarray | |
A point on the plane. | |
normal : ndarray | |
Normal to the plane. | |
Returns | |
------- | |
ndarray | |
Intersection point. | |
""" | |
t_value = np.dot(normal, (plane_point - start_point)) / np.dot(normal, direction) | |
return start_point + t_value * direction | |
class AffineWidget3D: | |
"""3D affine transform widget. | |
This widget allows interactive transformations including translation and | |
rotation using the left mouse button. | |
Parameters | |
---------- | |
plotter : pyvista.Plotter | |
The plotter object. | |
actor : pyvista.Actor | |
The actor to which the widget is attached to. | |
origin : sequence[float], optional | |
Origin of the widget. Default is the center of the main actor. | |
start : bool, default: True | |
If True, start the widget immediately. | |
scale : float, default: 0.15 | |
Scale factor for the widget relative to the length of the actor. | |
line_radius : float, default: 0.02 | |
Relative radius of the lines composing the widget. | |
always_visible : bool, default: True | |
Make the widget always visible. Setting this to ``False`` will cause | |
the widget geometry to be hidden by other actors in the plotter. | |
axes_colors : tuple[ColorLike], optional | |
Uses the theme by default. Configure the individual axis colors by | |
modifying either the theme with ``pyvista.global_theme.axes.x_color = | |
<COLOR>`` or setting this with a ``tuple`` as in ``('r', 'g', 'b')``. | |
axes : numpy.ndarray, optional | |
``(3, 3)`` Numpy array defining the X, Y, and Z axes. By default this | |
matches the default coordinate system. | |
release_callback : callable, optional | |
Call this method when releasing the left mouse button. It is passed the | |
``user_matrix`` of the actor. | |
interact_callback : callable, optional | |
Call this method when moving the mouse with the left mouse button | |
pressed down and a valid movement actor selected. It is passed the | |
``user_matrix`` of the actor. | |
Notes | |
----- | |
After interacting with the actor, the transform will be stored within | |
:attr:`pyvista.Prop3D.user_matrix` but will not be applied to the | |
dataset. Use this matrix in conjunction with | |
:func:`pyvista.DataObjectFilters.transform` to transform the dataset. | |
Requires VTK >= v9.2 | |
Examples | |
-------- | |
Create the affine widget outside of the plotter and add it. | |
>>> import pyvista as pv | |
>>> pl = pv.Plotter() | |
>>> actor = pl.add_mesh(pv.Sphere()) | |
>>> widget = pv.AffineWidget3D(pl, actor) | |
>>> pl.show() | |
Access the transform from the actor. | |
>>> actor.user_matrix | |
array([[1., 0., 0., 0.], | |
[0., 1., 0., 0.], | |
[0., 0., 1., 0.], | |
[0., 0., 0., 1.]]) | |
""" | |
def __init__( | |
self, | |
plotter, | |
origin: Sequence[float], | |
length: float, | |
start: bool = True, | |
scale=0.15, | |
line_radius=0.05, | |
always_visible: bool = True, | |
axes_colors=None, | |
axes=None, | |
release_callback=None, | |
interact_callback=None, | |
): | |
"""Initialize the widget.""" | |
# needs VTK v9.2.0 due to the hardware picker | |
if pyvista.vtk_version_info < (9, 2): | |
msg = "AfflineWidget3D requires VTK v9.2.0 or newer." | |
raise VTKVersionError(msg) | |
self._axes = np.eye(4) | |
self._axes_inv = np.eye(4) | |
self._pl = plotter | |
self._selected_actor: pyvista.Actor | None = None | |
self._init_position = None | |
self._mouse_move_observer = None | |
self._left_press_observer = None | |
self._left_release_observer = None | |
self._origin = np.array(origin) | |
self._actor_length = length | |
self._user_matrix = np.eye(4) | |
self._cached_matrix = np.eye(4) | |
self._arrows = [] # type: ignore[var-annotated] | |
self._circles = [] # type: ignore[var-annotated] | |
self._pressing_down = False | |
# origin = origin if origin else actor.center | |
self._origin = np.array(origin) | |
if axes_colors is None: | |
axes_colors = ( | |
pyvista.global_theme.axes.x_color, | |
pyvista.global_theme.axes.y_color, | |
pyvista.global_theme.axes.z_color, | |
) | |
self._axes_colors = axes_colors | |
self._circ = _make_quarter_arc() | |
# self._actor_length = self._main_actor.GetLength() | |
self._line_radius = line_radius | |
self._user_interact_callback = _check_callable(interact_callback) | |
self._user_release_callback = _check_callable(release_callback) | |
self._init_actors(scale, always_visible) | |
# axes must be set after initializing actors | |
if axes is not None: | |
try: | |
_validate_axes(axes) | |
except ValueError: | |
for actor in self._arrows + self._circles: | |
self._pl.remove_actor(actor) | |
raise | |
self.axes = axes | |
if start: | |
self.enable() | |
def show(self): | |
for actor in self._arrows: | |
actor.visibility = True | |
self._pl.render() | |
def hide(self): | |
for actor in self._arrows: | |
actor.visibility = False | |
self._pl.render() | |
def _init_actors(self, scale, always_visible): | |
"""Initialize the widget's actors.""" | |
for ii, color in enumerate(self._axes_colors): | |
arrow = pyvista.Arrow( | |
(0, 0, 0), | |
direction=GLOBAL_AXES[ii], | |
scale=self._actor_length * scale * 1.15, | |
tip_radius=0.1, | |
shaft_radius=self._line_radius, | |
) | |
self._arrows.append( | |
self._pl.add_mesh(arrow, color=color, lighting=False, render=False) | |
) | |
# update origin and assign a default user_matrix | |
for actor in self._arrows + self._circles: | |
matrix = np.eye(4) | |
matrix[:3, -1] = self._origin | |
actor.user_matrix = matrix | |
if always_visible: | |
for actor in self._arrows + self._circles: | |
actor.mapper.SetResolveCoincidentTopologyToPolygonOffset() | |
actor.mapper.SetRelativeCoincidentTopologyPolygonOffsetParameters( | |
0, -20000 | |
) | |
def _get_world_coord_rot(self, interactor): | |
"""Get the world coordinates given an interactor. | |
Unlike ``_get_world_coord_trans``, these coordinates are physically | |
accurate, but sensitive to the position of the camera. Rotation is zoom | |
independent. | |
""" | |
x, y = interactor.GetEventPosition() | |
coordinate = _vtk.vtkCoordinate() | |
coordinate.SetCoordinateSystemToDisplay() | |
coordinate.SetValue(x, y, 0) | |
ren = interactor.GetRenderWindow().GetRenderers().GetFirstRenderer() | |
point = np.array(coordinate.GetComputedWorldValue(ren)) | |
if self._selected_actor: | |
index = self._circles.index(self._selected_actor) | |
to_widget = np.array(ren.camera.position - self._origin) | |
point = ray_plane_intersection( | |
point, to_widget, self._origin, self.axes[index] | |
) | |
return point | |
def _get_world_coord_trans(self, interactor): | |
"""Get the world coordinates given an interactor. | |
This uses a modified scaled approach to get the world coordinates that | |
are not physically accurate, but ignores zoom and works for | |
translation. | |
""" | |
x, y = interactor.GetEventPosition() | |
ren = interactor.GetRenderWindow().GetRenderers().GetFirstRenderer() | |
# Get normalized view coordinates (-1, 1) | |
width, height = ren.GetSize() | |
ndc_x = 2 * (x / width) - 1 | |
ndc_y = 2 * (y / height) - 1 | |
ndc_z = 1 | |
# convert camera coordinates to world coordinates | |
camera = ren.GetActiveCamera() | |
projection_matrix = pyvista.array_from_vtkmatrix( | |
camera.GetProjectionTransformMatrix(ren.GetTiledAspectRatio(), 0, 1), | |
) | |
inverse_projection_matrix = np.linalg.inv(projection_matrix) | |
camera_coords = np.dot(inverse_projection_matrix, [ndc_x, ndc_y, ndc_z, 1]) | |
modelview_matrix = pyvista.array_from_vtkmatrix( | |
camera.GetModelViewTransformMatrix() | |
) | |
inverse_modelview_matrix = np.linalg.inv(modelview_matrix) | |
world_coords = np.dot(inverse_modelview_matrix, camera_coords) | |
# Scale by twice actor length (experimentally determined for good UX) | |
return world_coords[:3] * self._actor_length * 2 | |
def _move_callback(self, interactor, _event): | |
"""Process actions for the move mouse event.""" | |
click_x, click_y = interactor.GetEventPosition() | |
click_z = 0 | |
picker = interactor.GetPicker() | |
renderer = ( | |
interactor.GetInteractorStyle()._parent()._plotter.iren.get_poked_renderer() | |
) | |
picker.Pick(click_x, click_y, click_z, renderer) | |
actor = picker.GetActor() | |
if self._pressing_down: | |
if self._selected_actor in self._arrows: | |
current_pos = self._get_world_coord_trans(interactor) | |
index = self._arrows.index(self._selected_actor) | |
diff = current_pos - self.init_position | |
trans_matrix = np.eye(4) | |
trans_matrix[:3, -1] = self.axes[index] * np.dot(diff, self.axes[index]) | |
matrix = trans_matrix @ self._cached_matrix | |
elif self._selected_actor in self._circles: | |
current_pos = self._get_world_coord_rot(interactor) | |
index = self._circles.index(self._selected_actor) | |
vec_current = current_pos - self._origin | |
vec_init = self.init_position - self._origin | |
normal = self.axes[index] | |
vec_current = vec_current - np.dot(vec_current, normal) * normal | |
vec_init = vec_init - np.dot(vec_init, normal) * normal | |
vec_current /= np.linalg.norm(vec_current) | |
vec_init /= np.linalg.norm(vec_init) | |
angle = get_angle(vec_init, vec_current) | |
cross = np.cross(vec_init, vec_current) | |
if cross[index] < 0: | |
angle = -angle | |
trans = _vtk.vtkTransform() | |
trans.Translate(self._origin) # type: ignore[call-overload] | |
trans.RotateWXYZ( | |
angle, | |
self._axes[index][0], | |
self._axes[index][1], | |
self._axes[index][2], | |
) | |
trans.Translate(-self._origin) # type: ignore[call-overload] | |
trans.Update() | |
rot_matrix = pyvista.array_from_vtkmatrix(trans.GetMatrix()) | |
matrix = rot_matrix @ self._cached_matrix | |
if self._user_interact_callback: | |
try_callback(self._user_interact_callback) | |
# self._main_actor.user_matrix = matrix | |
elif self._selected_actor and self._selected_actor is not actor: | |
# Return the color of the currently selected actor to normal and | |
# deselect it | |
if self._selected_actor in self._arrows: | |
index = self._arrows.index(self._selected_actor) | |
elif self._selected_actor in self._circles: | |
index = self._circles.index(self._selected_actor) | |
self._selected_actor.prop.color = self._axes_colors[index] | |
self._selected_actor = None | |
# Highlight the actor if there is no selected actor | |
if actor and not self._selected_actor: | |
if actor in self._arrows: | |
index = self._arrows.index(actor) | |
self._arrows[index].prop.color = DARK_YELLOW | |
actor.prop.color = DARK_YELLOW | |
self._selected_actor = actor | |
elif actor in self._circles: | |
index = self._circles.index(actor) | |
self._circles[index].prop.color = DARK_YELLOW | |
actor.prop.color = DARK_YELLOW | |
self._selected_actor = actor | |
self._pl.render() | |
def _press_callback(self, interactor, _event): | |
"""Process actions for the mouse button press event.""" | |
if self._selected_actor: | |
self._pl.enable_trackball_actor_style() | |
self._pressing_down = True | |
if self._selected_actor in self._circles: | |
self.init_position = self._get_world_coord_rot(interactor) | |
else: | |
self.init_position = self._get_world_coord_trans(interactor) | |
def _release_callback(self, _interactor, _event): | |
"""Process actions for the mouse button release event.""" | |
self._pl.enable_trackball_style() | |
self._pressing_down = False | |
# self._cached_matrix = self._main_actor.user_matrix | |
if self._user_release_callback: | |
try_callback(self._user_release_callback) | |
def _reset(self): | |
"""Reset the actor and cached transform.""" | |
self._main_actor.user_matrix = np.eye(4) | |
self._cached_matrix = np.eye(4) | |
@property | |
def axes(self): | |
"""Return or set the axes of the widget. | |
The axes will be checked for orthogonality. Non-orthogonal axes will | |
raise a ``ValueError`` | |
Returns | |
------- | |
numpy.ndarray | |
``(3, 3)`` array of axes. | |
""" | |
return self._axes[:3, :3] | |
@axes.setter | |
def axes(self, axes): | |
mat = np.eye(4) | |
mat[:3, :3] = _validate_axes(axes) | |
mat[:3, -1] = self.origin | |
self._axes = mat | |
self._axes_inv = np.linalg.inv(self._axes) | |
for actor in self._arrows + self._circles: | |
matrix = actor.user_matrix | |
# Be sure to use the inverse here | |
matrix[:3, :3] = self._axes_inv[:3, :3] | |
actor.user_matrix = matrix | |
@property | |
def origin(self) -> tuple[float, float, float]: | |
"""Origin of the widget. | |
This is where the origin of the widget will be located and where the | |
actor will be rotated about. | |
Returns | |
------- | |
tuple | |
Widget origin. | |
""" | |
return cast("tuple[float, float, float]", tuple(self._origin)) | |
@origin.setter | |
def origin(self, value): | |
value = np.array(value) | |
diff = value - self._origin | |
for actor in self._circles + self._arrows: | |
if actor.user_matrix is None: | |
actor.user_matrix = np.eye(4) | |
matrix = actor.user_matrix | |
matrix[:3, -1] += diff | |
actor.user_matrix = matrix | |
self._origin = value | |
def enable(self): | |
"""Enable the widget.""" | |
if not self._pl._picker_in_use: | |
self._pl.enable_mesh_picking( | |
show_message=False, show=False, picker="hardware" | |
) | |
self._mouse_move_observer = self._pl.iren.add_observer( | |
"MouseMoveEvent", | |
self._move_callback, | |
) | |
self._left_press_observer = self._pl.iren.add_observer( | |
"LeftButtonPressEvent", | |
self._press_callback, | |
interactor_style_fallback=False, | |
) | |
self._left_release_observer = self._pl.iren.add_observer( | |
"LeftButtonReleaseEvent", | |
self._release_callback, | |
interactor_style_fallback=False, | |
) | |
def disable(self): | |
"""Disable the widget.""" | |
self._pl.disable_picking() | |
if self._mouse_move_observer: | |
self._pl.iren.remove_observer(self._mouse_move_observer) | |
if self._left_press_observer: | |
self._pl.iren.remove_observer(self._left_press_observer) | |
if self._left_release_observer: | |
self._pl.iren.remove_observer(self._left_release_observer) | |
def remove(self): | |
"""Disable and delete all actors of this widget.""" | |
self.disable() | |
for actor in self._circles + self._arrows: | |
self._pl.remove_actor(actor) | |
self._circles = [] | |
self._arrows = [] | |
def update_origin(self, origin): | |
self.origin = origin | |
self._user_matrix = np.eye(4) | |
self._cached_matrix = np.eye(4) | |
@property | |
def selected(self) -> bool: | |
return self._selected_actor is not None | |
############################################################################### | |
mesh = pv.Cube().triangulate().subdivide(2) | |
mesh_points = pv.PolyData(mesh.points) | |
state = { | |
"mouse-xy": None, | |
"points": mesh_points, | |
"selected-point": None, | |
"selected-actor": None, | |
"widget": None, | |
} | |
def project_points_to_display_coords(plotter): | |
renderer = plotter.renderer | |
camera = renderer.GetActiveCamera() | |
points = state["points"].points | |
coords = [] | |
for p in points: | |
world = vtk.vtkCoordinate() | |
world.SetCoordinateSystemToWorld() | |
world.SetValue(p) | |
display = world.GetComputedDisplayValue(renderer) | |
coords.append(display[:2]) | |
pp = np.array(coords, dtype=np.float64) | |
state["display-lookup"] = femorph.kdtree.KdTree(pp) | |
def print_cursor_world_position(plotter): | |
def callback(interactor, event): | |
if "projected_points" not in state: | |
project_points_to_display_coords(plotter) | |
event_pos = np.array([interactor.GetEventPosition()], dtype=np.float64) | |
dist, index = state["display-lookup"].query(event_pos) | |
max_dist = float(np.mean(plotter.window_size)) / 50 | |
if len(index) and dist < max_dist: | |
if state["selected-point"] != index: | |
pts_to_add = np.array([mesh_points.points[index[0]]]) | |
plotter.remove_actor(state["selected-actor"]) | |
state["selected-actor"] = plotter.add_points( | |
pts_to_add, render_points_as_spheres=True, color="r", point_size=10 | |
) | |
state["selected-point"] = index | |
else: | |
plotter.remove_actor(state["selected-actor"]) | |
state["selected-point"] = None | |
plotter.iren.add_observer("MouseMoveEvent", callback) | |
def add_camera_move_callback(plotter, callback): | |
def _internal_caller(caller, event): | |
# camera = caller.GetActiveCamera() | |
callback(plotter) | |
plotter.iren.add_observer("EndInteractionEvent", _internal_caller) | |
def on_camera_move(plotter): | |
project_points_to_display_coords(plotter) | |
import pyvista as pv | |
pl = pv.Plotter() | |
pl.add_mesh(mesh, color="w") | |
points_actor = pl.add_points( | |
state["points"], | |
point_size=10, # render_points_as_spheres=True | |
) | |
print_cursor_world_position(pl) | |
add_camera_move_callback(pl, on_camera_move) | |
widget = AffineWidget3D( | |
pl, state["points"].points[97], points_actor.GetLength() / 4, always_visible=False | |
) | |
widget.enable() | |
state["widget"] = widget | |
state["widget"].hide() | |
# def on_left_click_show_widget(plotter): | |
def on_release(interactor, event): | |
press_pos = state["mouse-xy"] | |
release_pos = pl.iren.get_event_position() | |
if press_pos is None: | |
return | |
dx = release_pos[0] - press_pos[0] | |
dy = release_pos[1] - press_pos[1] | |
drag_dist = dx * dx + dy * dy | |
if drag_dist > 4: # threshold in pixels | |
return | |
if state["selected-point"] is not None: | |
idx = state["selected-point"][0] | |
state["widget"].update_origin(state["points"].points[idx]) | |
state["widget"].show() | |
elif not state["widget"].selected or state["widget"]._pressing_down: | |
state["widget"].hide() | |
def on_press(interactor, event): | |
state["mouse-xy"] = interactor.GetEventPosition() | |
pl.iren.add_observer("LeftButtonPressEvent", on_press) | |
pl.iren.add_observer("LeftButtonReleaseEvent", on_release) | |
pl.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment