Skip to content

Instantly share code, notes, and snippets.

@YuigaWada
Created July 23, 2022 12:59
Show Gist options
  • Select an option

  • Save YuigaWada/381844002490eed0e069cfce9c4489c2 to your computer and use it in GitHub Desktop.

Select an option

Save YuigaWada/381844002490eed0e069cfce9c4489c2 to your computer and use it in GitHub Desktop.
Structure from Motion using OpenCV and Open3D
##################################################################################
# Structure from Motion using OpenCV and Open3D
# Author: Yuiga Wada (Keio University)
# Date: 2022/07/24
#
##################################################################################
# License
# This work is licensed under the Creative Commons Attribution-NonCommercial 4.0
# International License. To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
##################################################################################
import glob
import os
import cv2
import numpy as np
import open3d as o3d
from tqdm import tqdm
from dataclasses import dataclass
from typing import Any, List, Tuple
@dataclass
class FLANNparams:
index: dict
search: dict
@dataclass
class View:
keypoints: np.ndarray
descriptors: np.ndarray
Rt: np.ndarray
id: int
@dataclass
class Pair:
keyIdxs: List[np.ndarray]
points: List[np.ndarray]
inliers: List[np.ndarray]
F: np.ndarray
class SFM:
def __init__(self, imgs: List[np.ndarray], K: np.ndarray):
L = len(imgs)
detector = cv2.SIFT_create()
# database
viewdb = [View(*[None] * 3, i) for i in range(L)]
pairdb = {(i, j): Pair(*[None] * 4) for i in range(L) for j in range(i + 1, L)}
print("Compute view features... ")
for i, img in tqdm(enumerate(imgs), total=len(imgs)):
kpts, desc = detector.detectAndCompute(img, None)
viewdb[i].keypoints = kpts
viewdb[i].descriptors = desc
viewdb[i].id = i
viewdb[i].Rt = np.hstack((np.eye(3), np.zeros((3, 1))))
print("Compute pair features... ")
for i in tqdm(range(L)):
for j in range(i + 1, L):
id = (i, j)
kpts_pair = (viewdb[i].keypoints, viewdb[j].keypoints)
desc_pair = (viewdb[i].descriptors, viewdb[j].descriptors)
pairdb[id] = self._compute_matcher(kpts_pair, desc_pair)
self.viewdb = viewdb
self.pairdb = pairdb
self.L = L
self.K = K
self.points_3D = np.zeros((0, 3))
self.point_counter = 0
self.point_map = {}
self.used_descriptors = []
def build(self):
print("Reconstructing... ")
pair = self.pairdb[(0, 1)] # viewdb[0]に対応するカメラ座標を世界座標系の原点として以後計算
K, F = self.K, pair.F
E = K.T @ F @ K
R, _, t = cv2.decomposeEssentialMat(E) # Essential Matrixから[R|t]を計算
self.viewdb[1].Rt = np.hstack((R, t))
self._triangulate(0, 1)
self.used_descriptors.extend([self.viewdb[0].descriptors, self.viewdb[1].descriptors])
for i in tqdm(range(2, self.L)):
self.viewdb[i].Rt = self._solvePnP(i)
for j in range(i):
pair = self.pairdb[(j, i)]
self._triangulate(j, i)
self.used_descriptors.append(self.viewdb[i].descriptors)
pcd = self._get_pcd()
return pcd
@staticmethod
def remove_outlier(pcd) -> o3d.geometry.PointCloud: # 点群単位での外れ値を除去
pcd, _ = pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=1.0)
return pcd
def _compute_matcher(self, kpts_pair: Tuple[Any, Any], desc_pair: Tuple[Any, Any]) -> Pair:
kpts1, kpts2 = kpts_pair
desc1, desc2 = desc_pair
idx1, idx2 = [], []
params = FLANNparams(dict(algorithm=1, trees=5), dict(checks=50))
flann = cv2.FlannBasedMatcher(params.index, params.search)
matches = flann.knnMatch(desc1, desc2, k=2)
for (m, _) in matches:
idx1.append(m.queryIdx)
idx2.append(m.trainIdx)
pts1 = np.array([kp.pt for kp in kpts1])
pts2 = np.array([kp.pt for kp in kpts2])
F, mask = cv2.findFundamentalMat(pts1[idx1], pts2[idx2], method=cv2.FM_RANSAC, ransacReprojThreshold=0.9, confidence=0.99)
inliers1 = np.array(idx1)[mask[:, 0] == 1]
inliers2 = np.array(idx2)[mask[:, 0] == 1]
keyIdxs = [idx1, idx2]
points = [pts1, pts2]
inliers = [inliers1, inliers2]
return Pair(keyIdxs, points, inliers, F)
def _triangulate(self, id1: int, id2: int):
K_inv = np.linalg.inv(self.K)
view1, view2 = self.viewdb[id1], self.viewdb[id2]
pair = self.pairdb[(id1, id2)]
pts1, pts2 = pair.points
pts1, pts2 = pts1[pair.inliers[0]], pts2[pair.inliers[1]]
pts1 = cv2.convertPointsToHomogeneous(pts1)[:, 0, :].T
pts2 = cv2.convertPointsToHomogeneous(pts2)[:, 0, :].T
x, y = K_inv.dot(pts1), K_inv.dot(pts2) # カメラ座標系に直すためKの逆行列を掛ける (s\bm{u} = K[R|t]\bm{x})
X = cv2.triangulatePoints(view1.Rt, view2.Rt, x[:2, :], y[:2, :])
points_3D = X[:3] / X[3]
points_3D = points_3D.T
self.points_3D = np.concatenate((self.points_3D, points_3D), axis=0)
for i in range(pts1.shape[-1]):
self.point_map[(id1, pair.inliers[0][i])] = self.point_counter
self.point_map[(id2, pair.inliers[1][i])] = self.point_counter
self.point_counter += 1
def _solvePnP(self, id: int) -> np.ndarray: # PnPを解いて[R|t]を計算
matcher = cv2.BFMatcher(cv2.NORM_L2, crossCheck=False)
matcher.add(self.used_descriptors)
matcher.train()
matches = matcher.match(queryDescriptors=self.viewdb[id].descriptors)
points_3D, points_2D = np.zeros((0, 3)), np.zeros((0, 2))
for match in matches: # used_descriptorsとマッチングしたものについて, 3D点と2D点の対応を計算→PnPに落とし込む
_id, kpIdx1, kpIdx2 = match.imgIdx, match.queryIdx, match.trainIdx
if (_id, kpIdx2) not in self.point_map:
continue
point_2D = np.array(self.viewdb[id].keypoints[kpIdx1].pt).T.reshape((1, 2))
points_2D = np.concatenate((points_2D, point_2D), axis=0)
point_3D = self.points_3D[self.point_map[(_id, kpIdx2)], :].T.reshape((1, 3))
points_3D = np.concatenate((points_3D, point_3D), axis=0)
_, Rvec, tvec, _ = cv2.solvePnPRansac(points_3D[:, np.newaxis], points_2D[:, np.newaxis], self.K, None, flags=cv2.SOLVEPNP_SQPNP) # SQPnP[Terzakis+, ECCV20]
R, _ = cv2.Rodrigues(Rvec)
return np.hstack((R, tvec))
def _get_pcd(self) -> o3d.geometry.PointCloud:
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(self.points_3D)
return pcd
if __name__ == '__main__':
# prepare
K = np.loadtxt("./images/K.txt")
paths = sorted(glob.glob(os.path.join("./images/*.jpg")))
imgs = list(map(cv2.imread, paths))
print(paths)
# Structure from Motion
sfm = SFM(imgs, K)
pcd = sfm.build()
# clean up
pcd = SFM.remove_outlier(pcd)
# visualize
o3d.visualization.draw_geometries([pcd])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment