Skip to content

Instantly share code, notes, and snippets.

@insaneyilin
Created February 15, 2023 13:02
Show Gist options
  • Save insaneyilin/e83a23834afe4901796ecae835ecb79d to your computer and use it in GitHub Desktop.
Save insaneyilin/e83a23834afe4901796ecae835ecb79d to your computer and use it in GitHub Desktop.
chamfer distance
import numpy as np
from scipy.spatial import KDTree
def chamfer_distance(s1, s2, direction='s2_to_s1'):
"""Chamfer distance between two point sets.
Args:
s1 (np.ndarray): [n_points_s1, n_dims]
s2 (np.ndarray): [n_points_s2, n_dims]
direction (str): direction of Chamfer distance.
's2_to_s1': computes average minimal distance from every point in s2 to s1
's1_to_s2': computes average minimal distance from every point in s1 to s2
'bi': bi-directional computation.
Returns:
chamfer_dist: float
"""
if direction == 's2_to_s1':
s1_kdtree = KDTree(s1)
min_s2_to_s1_dist = s1_kdtree.query(s2)[0]
chamfer_dist = np.mean(min_s2_to_s1_dist)
elif direction == 's1_to_s2':
s2_kdtree = KDTree(s2)
min_s1_to_s2_dist = s2_kdtree.query(s1)[0]
chamfer_dist = np.mean(min_s1_to_s2_dist)
elif direction == 'bi':
s1_kdtree = KDTree(s1)
min_s2_to_s1_dist = s1_kdtree.query(s2)[0]
s2_kdtree = KDTree(s2)
min_s1_to_s2_dist = s2_kdtree.query(s1)[0]
chamfer_dist = np.mean(min_s2_to_s1_dist) + np.mean(min_s1_to_s2_dist)
else:
raise ValueError(
"Invalid direction type. Supported types: \'s2_to_s1\', \'s1_to_s2\', \'bi\'")
return chamfer_dist
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment