Last active
November 8, 2020 23:47
-
-
Save SomeoneSerge/861b0dc3d9c9e43e82df6615bcfc485d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from functools import partial | |
from dataclasses import dataclass, astuple, asdict | |
from typing import List, Optional, Tuple, Union | |
from warnings import warn | |
from .numutils import unit | |
from .torch_fix import ( | |
Matrices3x4, | |
Matrices4x4, | |
Matrix3x3, | |
Matrix4x4, | |
prepare_matvec, | |
prepend_like, | |
Scalars, | |
Vec2, | |
Vecs2, | |
Vecs3, | |
Points2, | |
Points3, | |
assert_broadcastable, | |
) | |
class Rays: | |
data: torch.Tensor | |
def __repr__(self): | |
return f"Rays[{self.rays_shape}]" | |
def __torch_function__(self, func, types, args=(), kwargs=None): | |
if kwargs is None: | |
kwargs = {} | |
args = [x.data if isinstance(x, Rays) else x for x in args] | |
out = func(*args, **kwargs) | |
if isinstance(out, torch.Tensor): | |
assert out.shape[-1] == self.shape[-1] | |
return Rays(out) | |
assert type(out) in [list, tuple] | |
assert all( | |
isinstance(o, torch.Tensor) and o.shape[-1] == self.shape[-1] | |
for o in out | |
) | |
return [Rays(o) for o in out] | |
def __init__(self, concatenated): | |
self.data = concatenated | |
def at(self, t) -> Tuple[Points3, Vecs3]: | |
assert_broadcastable(t, self.data) | |
d = self.directions | |
return self.origins + d * t, d | |
def unsqueeze(self, dim) -> "Rays": | |
return Rays(self.data.unsqueeze(dim)) | |
def squeeze(self, dim) -> "Rays": | |
return Rays(self.data.squeeze(dim)) | |
@property | |
def origins(self) -> Points3: | |
return self.data[..., :3] | |
@property | |
def directions(self) -> Vecs3: | |
return self.data[..., 3:6] | |
@property | |
def far(self) -> Scalars: | |
return self.data[..., 6:] # NOTE: keeps the dummy dim | |
@staticmethod | |
def make(origins, directions, far) -> "Rays": | |
rays = torch.cat((origins, directions, far), dim=-1) | |
return Rays(rays) | |
@property | |
def batch_shape(self): | |
return self.data.shape[:-1] | |
@property | |
def shape(self): | |
return self.data.shape | |
def reshape(self, *shape): | |
return Rays(self.data.reshape(*shape)) | |
def expand(self, *shape): | |
return Rays(self.data.expand(*shape)) | |
def view(self, *shape): | |
return Rays(self.data.view(*shape)) | |
@staticmethod | |
def from_camera(centre, z_dir, near, far, directions, MIN_DIVISOR=1e-6): | |
batch = directions.shape[:-1] | |
dirs_proj_z = ( | |
(directions * z_dir).sum(-1, keepdim=True).clamp(MIN_DIVISOR) | |
) | |
near = torch.ones(*batch, 1).to(centre).mul(near) | |
far = torch.ones(*batch, 1).to(centre).mul(far) | |
origins = centre + ((directions / dirs_proj_z) * near) | |
far = far / dirs_proj_z | |
return Rays.make(origins=origins, directions=directions, far=far) | |
@property | |
def rays_shape(self): | |
return self.data.shape[:-1] | |
@property | |
def device(self): | |
return self.data.device | |
def __len__(self): | |
return len(self.data) | |
def to(self, device): | |
return Rays(self.data.to(device)) | |
def rays_reshape(self, *shape): | |
return Rays(self.data.reshape(*shape, self.data.shape[-1])) | |
def __getitem__(self, *args): | |
# TODO: | |
return Rays(self.data.__getitem__(*args)) | |
@classmethod | |
def stack(cls, rays: List["Rays"], dim: int = 0): | |
# TODO: can be much simpler | |
return cls(torch.stack([r.data for r in rays])) | |
def transform_vectors( | |
mtxs: torch.Tensor, vectors: torch.Tensor | |
) -> torch.Tensor: | |
mtxs, vectors = prepare_matvec(mtxs, vectors) | |
# dimension of vector space (2 or 3) | |
# for extracting the rotation part of the transform | |
D = vectors.size(-1) | |
mtxs, vectors = prepare_matvec(mtxs, vectors) | |
result = torch.einsum("...ij,...j->...i", mtxs[..., :-1, :D], vectors) | |
return result | |
def transform_points(mtxs: torch.Tensor, points: torch.Tensor) -> torch.Tensor: | |
mtxs, points = prepare_matvec(mtxs, points) | |
D = points.size(-1) | |
assert mtxs.size(-1) == D + 1, ( | |
f"Matrix acting on {D}-points from the left" | |
" should have {D+1} columns, not {mtxs.size(-1)}" | |
) | |
mtxs, points = prepare_matvec(mtxs, points) | |
# - NeRF transform matrices assume column-vector convention | |
# - comments instead of names because garbage collection | |
# rotate | |
result = torch.einsum("...ij,...j->...i", mtxs[..., :-1, :D], points) | |
# translate | |
result = result + mtxs[..., :-1, D] | |
# w-divide | |
result = result / ( | |
mtxs[..., -1, :D].mul(points).sum(-1, keepdim=True) | |
+ mtxs[..., [-1], [D]] | |
) | |
return result | |
def transform_rays(mtxs: torch.Tensor, rays: Rays) -> Rays: | |
return Rays.make( | |
transform_points(mtxs, rays.origins), | |
transform_vectors(mtxs, rays.directions), | |
# TODO: should bounds be transformed? | |
# not if matrices are rototranslations | |
rays.far, | |
) | |
def pixels_raster_space(H: int, W: int, dtype=torch.float32) -> torch.Tensor: | |
"""Return [H, W, 2]-array of pixel indices (LU corners' coords). | |
Specifically, :code:`res[i, j] = (i, j)`. | |
>>> assert pixels_raster_space(2, 3).shape == (2, 3, 2) | |
The UPPER-LEFT image corner -- the origin, :math:`(0, 0)`, of rasterspace: | |
>>> assert pixels_raster_space(2, 3)[0, 0].norm() < 1e-9 | |
The LOWER-RIGHT image corner -- the :code:`[H-1, W-1]`th pixel -- | |
has raster space coordinates :math:`(W-1, H-1)`. | |
>>> T = torch.tensor | |
>>> assert pixels_raster_space(2, 3)[1, 2].sub(T((2., 1.))).norm() < 1e-9 | |
""" | |
x = torch.linspace(0, W - 1, W, dtype=dtype) | |
y = torch.linspace(0, H - 1, H, dtype=dtype) | |
y, x = torch.meshgrid(y, x) | |
return torch.stack((x, y), dim=-1) | |
@dataclass | |
class SimplePinhole: | |
H: int # all cameras have same resolution | |
W: int | |
K: torch.Tensor # [..., 3, 3] | |
nearfar: Vecs2 # [..., 2] | |
camera_to_world: Matrix4x4 | |
def to(self, device: torch.device): | |
data = astuple(self) | |
data = [ | |
x.to(device) if isinstance(x, torch.Tensor) else x for x in data | |
] | |
return SimplePinhole(*data) | |
@property | |
def batch_shape(self): | |
return self.camera_to_world.shape[:-2] | |
@property | |
def batch_dims(self): | |
return tuple(1 for _ in self.batch_shape) | |
@property | |
def device(self): | |
return self.camera_to_world.device | |
@property | |
def dtype(self): | |
return self.camera_to_world.dtype | |
@property | |
def camera_centre(self): | |
return self.camera_to_world[..., :3, -1] | |
@property | |
def principal_point(self): | |
return torch.stack((self.K[..., 0, 2], self.K[..., 1, 2]), dim=-1) | |
@property | |
def px_per_m(self): | |
return torch.stack((self.K[..., 0, 0], self.K[..., 1, 1]), dim=-1) | |
@property | |
def K_mask(self): | |
return ( | |
torch.tensor( | |
[ | |
# fmt: off | |
[1, 0, 1], | |
[0, 1, 1], | |
[0, 0, 1.] | |
# fmt: on | |
] | |
) | |
.to(self.dtype) | |
.to(self.device) | |
) | |
@property | |
def near(self): | |
return self.nearfar[..., 0:1] | |
@property | |
def far(self): | |
return self.nearfar[..., 1:2] | |
@property | |
def all_pixels(self) -> torch.LongTensor: | |
return ( | |
pixels_raster_space(H=self.H, W=self.W) | |
.reshape(self.H, self.W, *self.batch_dims, 2) | |
.expand(self.H, self.W, *self.batch_shape, 2) | |
.to(self.device) | |
) | |
def get_rays( | |
self, pixels: Union[None, torch.FloatTensor, torch.LongTensor] = None | |
) -> Rays: | |
r"""Generate world-space rays from pixel coordinates. | |
If pixels are None: generate :code:`(H, W)` (world-space) rays, | |
one for each pixel | |
""" | |
if pixels is None: | |
pixels = self.all_pixels.to(self.device) | |
if not pixels.dtype.is_floating_point: | |
pixels = pixels.add(0.5) | |
directions = self.apply_raster_to_cam(pixels) | |
c2w, directions = prepare_matvec(self.camera_to_world, directions) | |
directions = transform_vectors(c2w, directions) | |
directions = unit(directions) | |
centre = self.camera_centre.expand(*directions.shape) | |
cam_basis_z = c2w[..., :3, 2] | |
return Rays.from_camera( | |
centre=centre, | |
z_dir=cam_basis_z, | |
near=self.near, | |
far=self.far, | |
directions=directions, | |
) | |
@property | |
def calibration_matrix_4x4(self): | |
"""Convert (screenspace) meters to pixels. | |
Calibration matrix converts :math:`x, y` coordinates of | |
3-points projected onto :math:`z=1` plane | |
into pixel coordinates (raster space). | |
Raster space has orientation :code:`[right, down]`. | |
Raster space has origin in the upper-left image corner | |
(the :code:`[0, 0]` pixel). | |
""" | |
K = self.K * self.K_mask | |
col3 = ( | |
torch.tensor([0, 0, 0]) | |
.reshape(*self.batch_dims, 3, 1) | |
.expand(*self.batch_shape, 3, 1) | |
.to(K) | |
) | |
row4 = ( | |
torch.tensor([0, 0, 0, 1]) | |
.reshape(*self.batch_dims, 1, 4) | |
.expand(*self.batch_shape, 1, 4) | |
.to(K) | |
) | |
K = torch.cat((K, col3), dim=-1) | |
K = torch.cat((K, row4), dim=-2) | |
return K | |
@property | |
def calibration_matrix(self): | |
return self.K.mul(self.K_mask) | |
@property | |
def inv_calibration_matrix_4x4(self): | |
"""Transform from pixels to meters. | |
Inverse calibration matrix acts in raster space and converts pixel | |
coordinates into :math:`x,~y` coordinates on the :math:`z=1` plane. | |
Inverse calibration matrix includes shifting the space origin from | |
top-left image corner to the principal point. Principal point is | |
specified in pixel coordinates | |
""" | |
return torch.inverse(self.calibration_matrix_4x4) | |
@property | |
def inv_calibration_matrix(self): | |
return self.inv_calibration_matrix_4x4[..., :3, :3] | |
def rescale(self, *, resx: int, resy: int) -> "SimplePinhole": | |
"""Make the :class:`SimplePinhole` with new resolution.""" | |
resx, resy = int(resx), int(resy) | |
assert min(resx, resy) >= 1, f"Invalid resolution: {(resx, resy)}" | |
K = self.K | |
new_wh = torch.tensor([resx, resy]).reshape(*self.batch_dims, 2).to(K) | |
wh = torch.tensor([self.W, self.H]).to(K) | |
px_per_m = self.px_per_m | |
# idk which order of multiplication is more stable | |
new_px_per_m = new_wh * (px_per_m / wh) | |
new_principal = new_wh * (self.principal_point / wh) | |
K = ( | |
torch.eye(2) | |
.reshape(*self.batch_dims, 2, 2) | |
.to(K) | |
.mul(new_px_per_m) | |
) | |
K = torch.cat((K, new_principal.unsqueeze(-1)), dim=-1) | |
K = torch.cat( | |
(K, torch.tensor([0, 0, 1]).reshape(*self.batch_dims, 1, 3)), | |
dim=-2, | |
) | |
return SimplePinhole(**{**asdict(self), "H": resy, "W": resx, "K": K,}) | |
@property | |
def flip_xy(self) -> Matrix4x4: | |
switch_orientation = ( | |
torch.tensor( | |
[ | |
# fmt: off | |
[-1, 0, 0, 0], | |
[0, -1, 0, 0], | |
[0, 0., 1, 0], | |
[0, 0., 0, 1] | |
# fmt: on | |
] | |
) | |
.reshape(*self.batch_dims, 4, 4) | |
.to(self.camera_to_world) | |
) | |
return switch_orientation | |
@property | |
def projection_matrix(self) -> Matrix4x4: | |
return ( | |
torch.tensor( | |
[ | |
# fmt: off | |
[1, 0, 0, 0], | |
[0, 1, 0, 0], | |
[0, 0, 0, 1], | |
[0, 0, 1, 0.] | |
# fmt: on | |
] | |
) | |
.reshape(*self.batch_dims, 4, 4) | |
.to(self.camera_to_world) | |
) | |
@property | |
def raster_to_screen(self) -> Matrix4x4: | |
"""Transform from pixels to image plane. | |
Composes calibration matrix (converts pixels to meters) and makes a | |
switch from :code:`[right, down]` raster-space orientation to screen | |
space :code:`[left, up]`. | |
This matrix acts on 2-points and produces 2-points. | |
""" | |
r2s: Matrix4x4 = torch.matmul( | |
# 2. then we switch back the signs | |
self.flip_xy, | |
# 1. inv_K centers the principal point | |
self.inv_calibration_matrix_4x4, | |
) | |
return r2s | |
@property | |
def screen_to_raster(self) -> Matrix3x3: | |
"""Transform from meters (screenspace) to pixels. | |
Composes calibration matrix (converts pixels to meters) and makes a | |
switch from :code:`[right, down]` raster-space orientation to screen | |
space :code:`[left, up]`. | |
This matrix (projectively) acts on 2-points and produces 2-points. | |
""" | |
s2r = torch.matmul( | |
# Convert meters to pixels | |
self.calibration_matrix_4x4, | |
# Switch [left, up] orientation to [right, down] | |
self.flip_xy, | |
) | |
return s2r | |
def apply_raster_to_cam(self, pixels: Points2) -> Points3: | |
"""Apply :meth:`raster_to_screen` matrix to :obj:`pixels`.""" | |
if isinstance(pixels, torch.LongTensor): | |
pixels = pixels.to(self.dtype).add(0.5) | |
assert pixels.shape[-1] == 2 | |
r2s = self.raster_to_screen[..., :3, :3].to(pixels) | |
pixels = transform_points(r2s, pixels) | |
z = torch.ones(*pixels.shape[:-1], 1).to(pixels) # z=1 | |
pixels = torch.cat((pixels, z), dim=-1) | |
return pixels | |
def apply_ndc_to_cam(self, points: Points3) -> Points3: | |
warn( | |
"z should be treated as disparity" | |
" but instead we linearly interpolate between near/far planes" | |
) | |
zrange = self.far - self.near | |
T = ( | |
torch.eye(2) | |
.to(self.K) | |
.mul(torch.tensor([self.W, self.H]).to(self.K)) | |
.reshape(*self.batch_dims, 2, 2) | |
) | |
T = torch.cat( | |
(T, torch.zeros(*self.batch_dims, 2, 2).to(self.K)), dim=-1 | |
) | |
T = torch.cat( | |
( | |
T, | |
torch.cat( | |
( | |
torch.zeros(*self.batch_dims, 2).to(self.K), | |
zrange, | |
self.near, | |
), | |
dim=-1, | |
).unsqueeze(-2), | |
), | |
dim=-2, | |
) | |
T = torch.cat( | |
( | |
T, | |
torch.tensor([0, 0, 0, 1]) | |
.reshape(*self.batch_dims, 1, 4) | |
.to(T), | |
), | |
dim=-2, | |
) | |
# # fmt: off | |
# points = transform_points(torch.tensor([ | |
# [self.W, 0, 0, 0], | |
# [0, self.H, 0, 0], | |
# [0, 0, zrange, self.near], | |
# [0, 0, 0, 1] | |
# ]).to(points), points) | |
# # fmt: on | |
points = transform_points(T, points) | |
z = points[..., -1:] | |
xy = points[..., :2] | |
xy = transform_points(self.inv_calibration_matrix.to(xy), xy) | |
xy = -xy * z | |
points = torch.cat((xy, z), dim=-1) | |
return points | |
def apply_cam_to_ndc(self, points: Points3) -> Points3: | |
assert points.shape[-1] == 3, points.shape | |
z = points[..., 2:] | |
xy = points[..., :2] / z | |
xy = transform_points(self.screen_to_raster.to(xy)[..., :3, :3], xy) | |
T = ( | |
torch.diag(torch.tensor([1 / self.W, 1 / self.H, 1])) | |
.reshape(*self.batch_dims, 3, 3) | |
.to(points) | |
) | |
xy = transform_points(T, xy) | |
zrange = self.far - self.near | |
# TODO: prepare general broadcasting | |
z = (z - self.near) / zrange | |
points = torch.cat((xy, z), dim=-1) | |
return points | |
def __iter__(self): | |
for i in range(len(self)): | |
yield self[i] | |
def __len__(self): | |
return len(self.camera_to_world) | |
def __getitem__(self, *i): | |
return SimplePinhole( | |
H=self.H, | |
W=self.W, | |
K=self.K.__getitem__(*i), | |
nearfar=self.nearfar.__getitem__(*i), | |
camera_to_world=self.camera_to_world.__getitem__(*i), | |
) | |
@staticmethod | |
def from_3x4( | |
H: int, | |
W: int, | |
px_per_m: Union[float, Vec2], | |
t_near: float, | |
t_far: float, | |
poses: Matrices3x4, | |
pre_rot: Optional[Matrix4x4] = None, | |
) -> List["SimplePinhole"]: | |
"""Import a 3x4-formatted extrinsic (as per original NeRF).""" | |
if not isinstance(poses, torch.Tensor): | |
warn( | |
"Deprecated: `poses` must be a **tensor** of camera-to-world matrices," | |
f" not {type(poses)}." | |
" Implicitly stack-ing" | |
) | |
poses = torch.stack(poses) | |
batch_shape = poses.shape[:-2] | |
batch_dims = [1 for _ in batch_shape] | |
if not isinstance(poses, torch.Tensor): | |
warn( | |
"from_3x4: poses should be a tensor of shape [batch, 3, 4]" | |
f", got {type(poses)}. Assuming a list of matrices and stacking" | |
) | |
poses = torch.stack(poses, dim=0) | |
assert len(poses.shape) == 3, "poses: batch -> 3x4 matrix" | |
if not poses.dtype.is_floating_point: | |
warn( | |
f"Got poses of {poses.dtype} dtype." | |
" Implicitly converting to f32" | |
) | |
poses = poses.to(torch.float32) | |
poses = poses_3x4_to_4x4(poses) | |
H, W = int(H), int(W) | |
assert min(H, W) > 0, (H, W) | |
if isinstance(px_per_m, float): | |
focal = float(px_per_m) | |
px_per_m = ( | |
torch.tensor([focal, focal]) | |
.reshape(*batch_dims, 2) | |
.expand(*batch_shape, 2) | |
.to(poses) | |
) | |
elif not isinstance(px_per_m, torch.Tensor): | |
# FIXME: | |
px_per_m = torch.tensor(px_per_m).expand(*batch_shape, 2).to(poses) | |
else: | |
px_per_m = px_per_m.expand(*batch_shape, 2).to(poses) | |
if pre_rot is not None: | |
poses = torch.matmul(poses, pre_rot.to(poses)) | |
principal_point = ( | |
torch.tensor([0.5 * W, 0.5 * H]) | |
.reshape(*batch_dims, 2) | |
.expand(*batch_shape, 2) | |
.to(poses) | |
) | |
K = ( | |
torch.eye(2) | |
.reshape(*batch_dims, 2, 2) | |
.expand(*batch_shape, 2, 2) | |
.mul(px_per_m.unsqueeze(-1)) | |
) | |
K = torch.cat((K, principal_point.unsqueeze(-1)), dim=-1) | |
K = torch.cat( | |
( | |
K, | |
torch.tensor([0, 0, 1]) | |
.reshape(*batch_dims, 1, 3) | |
.expand(*batch_shape, 1, 3) | |
.to(K), | |
), | |
dim=-2, | |
).to(poses) | |
if not isinstance(t_near, torch.Tensor): | |
t_near = ( | |
torch.tensor([t_near]) | |
.reshape(*batch_dims) | |
.expand(*batch_dims) | |
.to(poses) | |
) | |
if not isinstance(t_far, torch.Tensor): | |
t_far = ( | |
torch.tensor([t_far]) | |
.reshape(*batch_dims) | |
.expand(*batch_dims) | |
.to(poses) | |
) | |
nearfar = torch.cat((t_near, t_far), dim=-1) | |
nearfar = nearfar.expand(*batch_shape, 2) | |
return SimplePinhole( | |
H=H, W=W, K=K, nearfar=nearfar, camera_to_world=poses, | |
) | |
@classmethod | |
def stack(cls, batch: List["SimplePinhole"], dim=0) -> "SimplePinhole": | |
if len(batch) == 0: | |
return cls( | |
H=0, | |
W=0, | |
K=torch.empty(()), | |
nearfar=torch.empty(), | |
camera_to_world=torch.empty(), | |
) | |
stack = partial(torch.stack, dim=dim) | |
H, W = batch[0].H, batch[0].W | |
assert all((c.H, c.W) == (H, W) for c in batch), [ | |
(c.H, c.W) for c in batch | |
] | |
return cls( | |
H=H, | |
W=W, | |
K=stack([c.K for c in batch]), | |
nearfar=stack([c.nearfar for c in batch]), | |
camera_to_world=stack([c.camera_to_world for c in batch]), | |
) | |
@staticmethod | |
def from_nerf( | |
H, W, focal: float, t_near: float, t_far: float, poses: Matrices3x4 | |
): | |
return SimplePinhole.from_3x4( | |
H=H, | |
W=W, | |
px_per_m=focal, | |
t_near=t_near, | |
t_far=t_far, | |
poses=poses, | |
pre_rot=rot_nerf_to_canon(), | |
) | |
def poses_3x4_to_4x4(poses_3x4: Matrices3x4) -> Matrices4x4: | |
assert poses_3x4.shape[-2:] == ( | |
3, | |
4, | |
), f"expected (..., 3, 4) got {poses_3x4.shape}" | |
poses = torch.tensor([0, 0, 0, 1]) | |
poses = prepend_like(poses, poses_3x4) | |
poses = poses.expand(*poses_3x4.shape[:-2], 1, 4) | |
poses = torch.cat((poses_3x4, poses), dim=-2) | |
poses = poses.reshape(*poses_3x4.shape[:-2], 4, 4) | |
return poses | |
def rot_llff_to_canon() -> Matrix4x4: | |
"""Converts :code:`(up, left, backwards)` LLFF's frame | |
into canonical :code:`(left, up, forward)` frame | |
(like in Zisserman or pbr-book)""" | |
# fmt: off | |
m = torch.tensor([ | |
[0, 1, .0, 0], | |
[1, 0, .0, 0], | |
[0, 0, -1, 0], | |
[0, 0, .0, 1], | |
]) | |
# fmt: on | |
return m | |
def rot_nerf_to_canon() -> Matrix4x4: | |
"""Converts :code:`(right, up, backwards)` original NeRF's frame | |
into canonical :code:`(left, up, forward)` frame | |
(like in Zisserman or pbr-book)""" | |
# fmt: off | |
flip_xz = torch.tensor([ | |
[-1, 0, 0, .0], | |
[0, 1, 0, 0], | |
[0, 0, -1, 0], | |
[0, 0, 0, 1.]] | |
) | |
# fmt: on | |
return flip_xz |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment