Skip to content

Instantly share code, notes, and snippets.

@bougui505
Created October 1, 2020 07:53
Show Gist options
  • Save bougui505/e48db92b28bf314c00938591666afce1 to your computer and use it in GitHub Desktop.
Save bougui505/e48db92b28bf314c00938591666afce1 to your computer and use it in GitHub Desktop.
Iterative Closest Point (ICP) implementation with least squares fit (lstsq) in Pytorch
#!/usr/bin/env python
# -*- coding: UTF8 -*-
# Author: Guillaume Bouvier -- [email protected]
# https://research.pasteur.fr/en/member/guillaume-bouvier/
# 2020-10-01 09:51:45 (UTC+0200)
import sys
import torch
def icp(coords, coords_ref, device, n_iter):
"""
Iterative Closest Point
"""
for t in range(n_iter):
cdist = torch.cdist(coords - coords.mean(axis=0),
coords_ref - coords_ref.mean(axis=0))
mindists, argmins = torch.min(cdist, axis=1)
X, _ = torch.lstsq(coords_ref[argmins], coords)
coords = coords.mm(X[:3])
rmsd = torch.sqrt((X[3:]**2).sum(axis=1).mean())
print_progress(f'{t+1}/{n_iter}: {rmsd}')
return coords
def print_progress(instr):
sys.stdout.write(f'{instr}\r')
sys.stdout.flush()
@smiles724
Copy link

This implementation is not right. torch.lstsq is abandoned and you should use torch.linalg.lstq instead.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment