Created
October 1, 2020 07:53
-
-
Save bougui505/e48db92b28bf314c00938591666afce1 to your computer and use it in GitHub Desktop.
Iterative Closest Point (ICP) implementation with least squares fit (lstsq) in Pytorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: UTF8 -*- | |
# Author: Guillaume Bouvier -- [email protected] | |
# https://research.pasteur.fr/en/member/guillaume-bouvier/ | |
# 2020-10-01 09:51:45 (UTC+0200) | |
import sys | |
import torch | |
def icp(coords, coords_ref, device, n_iter): | |
""" | |
Iterative Closest Point | |
""" | |
for t in range(n_iter): | |
cdist = torch.cdist(coords - coords.mean(axis=0), | |
coords_ref - coords_ref.mean(axis=0)) | |
mindists, argmins = torch.min(cdist, axis=1) | |
X, _ = torch.lstsq(coords_ref[argmins], coords) | |
coords = coords.mm(X[:3]) | |
rmsd = torch.sqrt((X[3:]**2).sum(axis=1).mean()) | |
print_progress(f'{t+1}/{n_iter}: {rmsd}') | |
return coords | |
def print_progress(instr): | |
sys.stdout.write(f'{instr}\r') | |
sys.stdout.flush() |
This implementation is not right. torch.lstsq is abandoned and you should use torch.linalg.lstq instead.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
thanks for sharing