Created
February 15, 2023 13:02
-
-
Save insaneyilin/e83a23834afe4901796ecae835ecb79d to your computer and use it in GitHub Desktop.
chamfer distance
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import numpy as np | |
| from scipy.spatial import KDTree | |
| def chamfer_distance(s1, s2, direction='s2_to_s1'): | |
| """Chamfer distance between two point sets. | |
| Args: | |
| s1 (np.ndarray): [n_points_s1, n_dims] | |
| s2 (np.ndarray): [n_points_s2, n_dims] | |
| direction (str): direction of Chamfer distance. | |
| 's2_to_s1': computes average minimal distance from every point in s2 to s1 | |
| 's1_to_s2': computes average minimal distance from every point in s1 to s2 | |
| 'bi': bi-directional computation. | |
| Returns: | |
| chamfer_dist: float | |
| """ | |
| if direction == 's2_to_s1': | |
| s1_kdtree = KDTree(s1) | |
| min_s2_to_s1_dist = s1_kdtree.query(s2)[0] | |
| chamfer_dist = np.mean(min_s2_to_s1_dist) | |
| elif direction == 's1_to_s2': | |
| s2_kdtree = KDTree(s2) | |
| min_s1_to_s2_dist = s2_kdtree.query(s1)[0] | |
| chamfer_dist = np.mean(min_s1_to_s2_dist) | |
| elif direction == 'bi': | |
| s1_kdtree = KDTree(s1) | |
| min_s2_to_s1_dist = s1_kdtree.query(s2)[0] | |
| s2_kdtree = KDTree(s2) | |
| min_s1_to_s2_dist = s2_kdtree.query(s1)[0] | |
| chamfer_dist = np.mean(min_s2_to_s1_dist) + np.mean(min_s1_to_s2_dist) | |
| else: | |
| raise ValueError( | |
| "Invalid direction type. Supported types: \'s2_to_s1\', \'s1_to_s2\', \'bi\'") | |
| return chamfer_dist |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment