Skip to content

Instantly share code, notes, and snippets.

@srush
Created September 23, 2024 01:31
Show Gist options
  • Save srush/d6477c44bf6f6b2645bfe0a9626b944a to your computer and use it in GitHub Desktop.
Save srush/d6477c44bf6f6b2645bfe0a9626b944a to your computer and use it in GitHub Desktop.
import numpy as np
from numpy import ndarray
from chalk import *
from jaxtyping import Float
from chalk.transform import Batched
import numpy.random
def create_t_shaped_data(num_points: int) -> Float[ndarray, "num_points 2"]:
"""Create random 2D data points shaped like a T."""
t_shape = np.zeros((num_points, 2))
# Vertical part of T
t_shape[: num_points // 2, 0] = np.random.uniform(-0.5, 0.5, num_points // 2)
t_shape[: num_points // 2, 1] = np.random.uniform(-2, 2, num_points // 2)
# Horizontal part of T
t_shape[num_points // 2 :, 0] = np.random.uniform(
-2, 2, num_points - num_points // 2
)
t_shape[num_points // 2 :, 1] = np.random.uniform(
1.5, 2.5, num_points - num_points // 2
)
return t_shape
def find_closest_points(
set1: Float[ndarray, "n1 2"], set2: Float[ndarray, "n2 2"]
) -> tuple[Float[ndarray, "n1 2"], Float[ndarray, "n1"]]:
"""Find closest points and distances between two sets of points."""
diff = set1[:, np.newaxis, :] - set2[np.newaxis, :, :]
distances = np.linalg.norm(diff, axis=2)
closest_indices = np.argmin(distances, axis=1)
closest_points = set2[closest_indices]
min_distances = np.min(distances, axis=1)
return closest_points, min_distances
def draw_data_points(points: Float[ndarray, "*batch 2"]) -> Batched[Diagram, "*batch"]:
"""Draw data points as circles."""
return circle(0.1).translate(points[..., 0], points[..., 1])
def draw_connections(
points1: Float[ndarray, "*batch 2"], points2: Float[ndarray, "*batch 2"]
) -> Batched[Diagram, "*batch"]:
"""Draw lines connecting corresponding points."""
start_points = P2(points1[..., 0], points1[..., 1])
end_points = P2(points2[..., 0], points2[..., 1]) + 1e-5
return Path.from_pairs([(start_points, end_points)]).stroke()
def random_rotate_translate(
points: Float[ndarray, "*batch 2"],
) -> Float[ndarray, "*batch 2"]:
"""Randomly rotate and translate data points."""
angle = np.random.uniform(0, np.pi / 4) # Limit rotation to 45 degrees
translation = np.random.uniform(-1, 1, 2)
rotation_matrix = np.array(
[[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]]
)
rotated_points = np.einsum("...ij,...j->...i", rotation_matrix, points)
transformed_points = rotated_points + translation
return transformed_points
def iterative_procrustes_icp(
source: Float[ndarray, "n 2"],
target: Float[ndarray, "m 2"],
num_iterations: int = 10,
) -> tuple[
Float[ndarray, "num_iterations n 2"],
Float[ndarray, "num_iterations n 2"],
Float[ndarray, "num_iterations"],
]:
"""Perform iterative Procrustes ICP."""
current_source = source.copy()
all_transformed = np.zeros((num_iterations, *source.shape))
all_nearest = np.zeros((num_iterations, *source.shape))
total_distances = np.zeros(num_iterations)
for i in range(num_iterations):
nearest_points, distances = find_closest_points(current_source, target)
total_distances[i] = np.sum(distances)
centroid_source = np.mean(current_source, axis=0)
centroid_target = np.mean(nearest_points, axis=0)
centered_source = current_source - centroid_source
centered_target = nearest_points - centroid_target
H = centered_source.T @ centered_target
U, _, Vt = np.linalg.svd(H)
R = Vt.T @ U.T
current_source = (R @ centered_source.T).T + centroid_target
all_transformed[i] = current_source
all_nearest[i] = nearest_points
return all_transformed, all_nearest, total_distances
# Example usage
num_points = 200
t_data = create_t_shaped_data(num_points)
rotated_data = random_rotate_translate(t_data)
original_diagram = draw_data_points(t_data)
rotated_diagram = draw_data_points(rotated_data)
transformed_points, nearest_points, total_distances = iterative_procrustes_icp(
t_data, rotated_data, num_iterations=200
)
# Keep every 5th iteration
transformed_points_subset = transformed_points[::1]
nearest_points_subset = nearest_points[::1]
total_distances = total_distances[::1]
num_iterations_subset = transformed_points_subset.shape[0]
# Draw points at each iteration
transformed_diagrams = draw_data_points(transformed_points_subset)
nearest_diagrams = draw_data_points(nearest_points_subset)
rotated_diagrams = draw_data_points(rotated_data) # Add rotated points
# Create connections between corresponding points
connection_diagrams = draw_connections(transformed_points_subset, nearest_points_subset)
# Combine all diagrams
# Concatenate points first
transformed_concat = transformed_diagrams.concat().fill_color(
np.stack(
[
np.interp(
total_distances, [total_distances.min(), total_distances.max()], [1, 0]
),
np.zeros_like(total_distances),
np.interp(
total_distances, [total_distances.min(), total_distances.max()], [0, 1]
),
],
axis=-1,
)
)
nearest_concat = nearest_diagrams.concat()
rotated_concat = rotated_diagrams.concat().fill_color("red")
# Combine all diagrams
all_diagrams = rotated_concat + connection_diagrams.concat() + transformed_concat
# Render the animation
all_diagrams.animate("icp_animation.gif", height=400)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment