Skip to content

Instantly share code, notes, and snippets.

@SomeoneSerge
Last active November 8, 2020 23:47
Show Gist options
  • Save SomeoneSerge/861b0dc3d9c9e43e82df6615bcfc485d to your computer and use it in GitHub Desktop.
Save SomeoneSerge/861b0dc3d9c9e43e82df6615bcfc485d to your computer and use it in GitHub Desktop.
import torch
from functools import partial
from dataclasses import dataclass, astuple, asdict
from typing import List, Optional, Tuple, Union
from warnings import warn
from .numutils import unit
from .torch_fix import (
Matrices3x4,
Matrices4x4,
Matrix3x3,
Matrix4x4,
prepare_matvec,
prepend_like,
Scalars,
Vec2,
Vecs2,
Vecs3,
Points2,
Points3,
assert_broadcastable,
)
class Rays:
data: torch.Tensor
def __repr__(self):
return f"Rays[{self.rays_shape}]"
def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
args = [x.data if isinstance(x, Rays) else x for x in args]
out = func(*args, **kwargs)
if isinstance(out, torch.Tensor):
assert out.shape[-1] == self.shape[-1]
return Rays(out)
assert type(out) in [list, tuple]
assert all(
isinstance(o, torch.Tensor) and o.shape[-1] == self.shape[-1]
for o in out
)
return [Rays(o) for o in out]
def __init__(self, concatenated):
self.data = concatenated
def at(self, t) -> Tuple[Points3, Vecs3]:
assert_broadcastable(t, self.data)
d = self.directions
return self.origins + d * t, d
def unsqueeze(self, dim) -> "Rays":
return Rays(self.data.unsqueeze(dim))
def squeeze(self, dim) -> "Rays":
return Rays(self.data.squeeze(dim))
@property
def origins(self) -> Points3:
return self.data[..., :3]
@property
def directions(self) -> Vecs3:
return self.data[..., 3:6]
@property
def far(self) -> Scalars:
return self.data[..., 6:] # NOTE: keeps the dummy dim
@staticmethod
def make(origins, directions, far) -> "Rays":
rays = torch.cat((origins, directions, far), dim=-1)
return Rays(rays)
@property
def batch_shape(self):
return self.data.shape[:-1]
@property
def shape(self):
return self.data.shape
def reshape(self, *shape):
return Rays(self.data.reshape(*shape))
def expand(self, *shape):
return Rays(self.data.expand(*shape))
def view(self, *shape):
return Rays(self.data.view(*shape))
@staticmethod
def from_camera(centre, z_dir, near, far, directions, MIN_DIVISOR=1e-6):
batch = directions.shape[:-1]
dirs_proj_z = (
(directions * z_dir).sum(-1, keepdim=True).clamp(MIN_DIVISOR)
)
near = torch.ones(*batch, 1).to(centre).mul(near)
far = torch.ones(*batch, 1).to(centre).mul(far)
origins = centre + ((directions / dirs_proj_z) * near)
far = far / dirs_proj_z
return Rays.make(origins=origins, directions=directions, far=far)
@property
def rays_shape(self):
return self.data.shape[:-1]
@property
def device(self):
return self.data.device
def __len__(self):
return len(self.data)
def to(self, device):
return Rays(self.data.to(device))
def rays_reshape(self, *shape):
return Rays(self.data.reshape(*shape, self.data.shape[-1]))
def __getitem__(self, *args):
# TODO:
return Rays(self.data.__getitem__(*args))
@classmethod
def stack(cls, rays: List["Rays"], dim: int = 0):
# TODO: can be much simpler
return cls(torch.stack([r.data for r in rays]))
def transform_vectors(
mtxs: torch.Tensor, vectors: torch.Tensor
) -> torch.Tensor:
mtxs, vectors = prepare_matvec(mtxs, vectors)
# dimension of vector space (2 or 3)
# for extracting the rotation part of the transform
D = vectors.size(-1)
mtxs, vectors = prepare_matvec(mtxs, vectors)
result = torch.einsum("...ij,...j->...i", mtxs[..., :-1, :D], vectors)
return result
def transform_points(mtxs: torch.Tensor, points: torch.Tensor) -> torch.Tensor:
mtxs, points = prepare_matvec(mtxs, points)
D = points.size(-1)
assert mtxs.size(-1) == D + 1, (
f"Matrix acting on {D}-points from the left"
" should have {D+1} columns, not {mtxs.size(-1)}"
)
mtxs, points = prepare_matvec(mtxs, points)
# - NeRF transform matrices assume column-vector convention
# - comments instead of names because garbage collection
# rotate
result = torch.einsum("...ij,...j->...i", mtxs[..., :-1, :D], points)
# translate
result = result + mtxs[..., :-1, D]
# w-divide
result = result / (
mtxs[..., -1, :D].mul(points).sum(-1, keepdim=True)
+ mtxs[..., [-1], [D]]
)
return result
def transform_rays(mtxs: torch.Tensor, rays: Rays) -> Rays:
return Rays.make(
transform_points(mtxs, rays.origins),
transform_vectors(mtxs, rays.directions),
# TODO: should bounds be transformed?
# not if matrices are rototranslations
rays.far,
)
def pixels_raster_space(H: int, W: int, dtype=torch.float32) -> torch.Tensor:
"""Return [H, W, 2]-array of pixel indices (LU corners' coords).
Specifically, :code:`res[i, j] = (i, j)`.
>>> assert pixels_raster_space(2, 3).shape == (2, 3, 2)
The UPPER-LEFT image corner -- the origin, :math:`(0, 0)`, of rasterspace:
>>> assert pixels_raster_space(2, 3)[0, 0].norm() < 1e-9
The LOWER-RIGHT image corner -- the :code:`[H-1, W-1]`th pixel --
has raster space coordinates :math:`(W-1, H-1)`.
>>> T = torch.tensor
>>> assert pixels_raster_space(2, 3)[1, 2].sub(T((2., 1.))).norm() < 1e-9
"""
x = torch.linspace(0, W - 1, W, dtype=dtype)
y = torch.linspace(0, H - 1, H, dtype=dtype)
y, x = torch.meshgrid(y, x)
return torch.stack((x, y), dim=-1)
@dataclass
class SimplePinhole:
H: int # all cameras have same resolution
W: int
K: torch.Tensor # [..., 3, 3]
nearfar: Vecs2 # [..., 2]
camera_to_world: Matrix4x4
def to(self, device: torch.device):
data = astuple(self)
data = [
x.to(device) if isinstance(x, torch.Tensor) else x for x in data
]
return SimplePinhole(*data)
@property
def batch_shape(self):
return self.camera_to_world.shape[:-2]
@property
def batch_dims(self):
return tuple(1 for _ in self.batch_shape)
@property
def device(self):
return self.camera_to_world.device
@property
def dtype(self):
return self.camera_to_world.dtype
@property
def camera_centre(self):
return self.camera_to_world[..., :3, -1]
@property
def principal_point(self):
return torch.stack((self.K[..., 0, 2], self.K[..., 1, 2]), dim=-1)
@property
def px_per_m(self):
return torch.stack((self.K[..., 0, 0], self.K[..., 1, 1]), dim=-1)
@property
def K_mask(self):
return (
torch.tensor(
[
# fmt: off
[1, 0, 1],
[0, 1, 1],
[0, 0, 1.]
# fmt: on
]
)
.to(self.dtype)
.to(self.device)
)
@property
def near(self):
return self.nearfar[..., 0:1]
@property
def far(self):
return self.nearfar[..., 1:2]
@property
def all_pixels(self) -> torch.LongTensor:
return (
pixels_raster_space(H=self.H, W=self.W)
.reshape(self.H, self.W, *self.batch_dims, 2)
.expand(self.H, self.W, *self.batch_shape, 2)
.to(self.device)
)
def get_rays(
self, pixels: Union[None, torch.FloatTensor, torch.LongTensor] = None
) -> Rays:
r"""Generate world-space rays from pixel coordinates.
If pixels are None: generate :code:`(H, W)` (world-space) rays,
one for each pixel
"""
if pixels is None:
pixels = self.all_pixels.to(self.device)
if not pixels.dtype.is_floating_point:
pixels = pixels.add(0.5)
directions = self.apply_raster_to_cam(pixels)
c2w, directions = prepare_matvec(self.camera_to_world, directions)
directions = transform_vectors(c2w, directions)
directions = unit(directions)
centre = self.camera_centre.expand(*directions.shape)
cam_basis_z = c2w[..., :3, 2]
return Rays.from_camera(
centre=centre,
z_dir=cam_basis_z,
near=self.near,
far=self.far,
directions=directions,
)
@property
def calibration_matrix_4x4(self):
"""Convert (screenspace) meters to pixels.
Calibration matrix converts :math:`x, y` coordinates of
3-points projected onto :math:`z=1` plane
into pixel coordinates (raster space).
Raster space has orientation :code:`[right, down]`.
Raster space has origin in the upper-left image corner
(the :code:`[0, 0]` pixel).
"""
K = self.K * self.K_mask
col3 = (
torch.tensor([0, 0, 0])
.reshape(*self.batch_dims, 3, 1)
.expand(*self.batch_shape, 3, 1)
.to(K)
)
row4 = (
torch.tensor([0, 0, 0, 1])
.reshape(*self.batch_dims, 1, 4)
.expand(*self.batch_shape, 1, 4)
.to(K)
)
K = torch.cat((K, col3), dim=-1)
K = torch.cat((K, row4), dim=-2)
return K
@property
def calibration_matrix(self):
return self.K.mul(self.K_mask)
@property
def inv_calibration_matrix_4x4(self):
"""Transform from pixels to meters.
Inverse calibration matrix acts in raster space and converts pixel
coordinates into :math:`x,~y` coordinates on the :math:`z=1` plane.
Inverse calibration matrix includes shifting the space origin from
top-left image corner to the principal point. Principal point is
specified in pixel coordinates
"""
return torch.inverse(self.calibration_matrix_4x4)
@property
def inv_calibration_matrix(self):
return self.inv_calibration_matrix_4x4[..., :3, :3]
def rescale(self, *, resx: int, resy: int) -> "SimplePinhole":
"""Make the :class:`SimplePinhole` with new resolution."""
resx, resy = int(resx), int(resy)
assert min(resx, resy) >= 1, f"Invalid resolution: {(resx, resy)}"
K = self.K
new_wh = torch.tensor([resx, resy]).reshape(*self.batch_dims, 2).to(K)
wh = torch.tensor([self.W, self.H]).to(K)
px_per_m = self.px_per_m
# idk which order of multiplication is more stable
new_px_per_m = new_wh * (px_per_m / wh)
new_principal = new_wh * (self.principal_point / wh)
K = (
torch.eye(2)
.reshape(*self.batch_dims, 2, 2)
.to(K)
.mul(new_px_per_m)
)
K = torch.cat((K, new_principal.unsqueeze(-1)), dim=-1)
K = torch.cat(
(K, torch.tensor([0, 0, 1]).reshape(*self.batch_dims, 1, 3)),
dim=-2,
)
return SimplePinhole(**{**asdict(self), "H": resy, "W": resx, "K": K,})
@property
def flip_xy(self) -> Matrix4x4:
switch_orientation = (
torch.tensor(
[
# fmt: off
[-1, 0, 0, 0],
[0, -1, 0, 0],
[0, 0., 1, 0],
[0, 0., 0, 1]
# fmt: on
]
)
.reshape(*self.batch_dims, 4, 4)
.to(self.camera_to_world)
)
return switch_orientation
@property
def projection_matrix(self) -> Matrix4x4:
return (
torch.tensor(
[
# fmt: off
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 0, 1],
[0, 0, 1, 0.]
# fmt: on
]
)
.reshape(*self.batch_dims, 4, 4)
.to(self.camera_to_world)
)
@property
def raster_to_screen(self) -> Matrix4x4:
"""Transform from pixels to image plane.
Composes calibration matrix (converts pixels to meters) and makes a
switch from :code:`[right, down]` raster-space orientation to screen
space :code:`[left, up]`.
This matrix acts on 2-points and produces 2-points.
"""
r2s: Matrix4x4 = torch.matmul(
# 2. then we switch back the signs
self.flip_xy,
# 1. inv_K centers the principal point
self.inv_calibration_matrix_4x4,
)
return r2s
@property
def screen_to_raster(self) -> Matrix3x3:
"""Transform from meters (screenspace) to pixels.
Composes calibration matrix (converts pixels to meters) and makes a
switch from :code:`[right, down]` raster-space orientation to screen
space :code:`[left, up]`.
This matrix (projectively) acts on 2-points and produces 2-points.
"""
s2r = torch.matmul(
# Convert meters to pixels
self.calibration_matrix_4x4,
# Switch [left, up] orientation to [right, down]
self.flip_xy,
)
return s2r
def apply_raster_to_cam(self, pixels: Points2) -> Points3:
"""Apply :meth:`raster_to_screen` matrix to :obj:`pixels`."""
if isinstance(pixels, torch.LongTensor):
pixels = pixels.to(self.dtype).add(0.5)
assert pixels.shape[-1] == 2
r2s = self.raster_to_screen[..., :3, :3].to(pixels)
pixels = transform_points(r2s, pixels)
z = torch.ones(*pixels.shape[:-1], 1).to(pixels) # z=1
pixels = torch.cat((pixels, z), dim=-1)
return pixels
def apply_ndc_to_cam(self, points: Points3) -> Points3:
warn(
"z should be treated as disparity"
" but instead we linearly interpolate between near/far planes"
)
zrange = self.far - self.near
T = (
torch.eye(2)
.to(self.K)
.mul(torch.tensor([self.W, self.H]).to(self.K))
.reshape(*self.batch_dims, 2, 2)
)
T = torch.cat(
(T, torch.zeros(*self.batch_dims, 2, 2).to(self.K)), dim=-1
)
T = torch.cat(
(
T,
torch.cat(
(
torch.zeros(*self.batch_dims, 2).to(self.K),
zrange,
self.near,
),
dim=-1,
).unsqueeze(-2),
),
dim=-2,
)
T = torch.cat(
(
T,
torch.tensor([0, 0, 0, 1])
.reshape(*self.batch_dims, 1, 4)
.to(T),
),
dim=-2,
)
# # fmt: off
# points = transform_points(torch.tensor([
# [self.W, 0, 0, 0],
# [0, self.H, 0, 0],
# [0, 0, zrange, self.near],
# [0, 0, 0, 1]
# ]).to(points), points)
# # fmt: on
points = transform_points(T, points)
z = points[..., -1:]
xy = points[..., :2]
xy = transform_points(self.inv_calibration_matrix.to(xy), xy)
xy = -xy * z
points = torch.cat((xy, z), dim=-1)
return points
def apply_cam_to_ndc(self, points: Points3) -> Points3:
assert points.shape[-1] == 3, points.shape
z = points[..., 2:]
xy = points[..., :2] / z
xy = transform_points(self.screen_to_raster.to(xy)[..., :3, :3], xy)
T = (
torch.diag(torch.tensor([1 / self.W, 1 / self.H, 1]))
.reshape(*self.batch_dims, 3, 3)
.to(points)
)
xy = transform_points(T, xy)
zrange = self.far - self.near
# TODO: prepare general broadcasting
z = (z - self.near) / zrange
points = torch.cat((xy, z), dim=-1)
return points
def __iter__(self):
for i in range(len(self)):
yield self[i]
def __len__(self):
return len(self.camera_to_world)
def __getitem__(self, *i):
return SimplePinhole(
H=self.H,
W=self.W,
K=self.K.__getitem__(*i),
nearfar=self.nearfar.__getitem__(*i),
camera_to_world=self.camera_to_world.__getitem__(*i),
)
@staticmethod
def from_3x4(
H: int,
W: int,
px_per_m: Union[float, Vec2],
t_near: float,
t_far: float,
poses: Matrices3x4,
pre_rot: Optional[Matrix4x4] = None,
) -> List["SimplePinhole"]:
"""Import a 3x4-formatted extrinsic (as per original NeRF)."""
if not isinstance(poses, torch.Tensor):
warn(
"Deprecated: `poses` must be a **tensor** of camera-to-world matrices,"
f" not {type(poses)}."
" Implicitly stack-ing"
)
poses = torch.stack(poses)
batch_shape = poses.shape[:-2]
batch_dims = [1 for _ in batch_shape]
if not isinstance(poses, torch.Tensor):
warn(
"from_3x4: poses should be a tensor of shape [batch, 3, 4]"
f", got {type(poses)}. Assuming a list of matrices and stacking"
)
poses = torch.stack(poses, dim=0)
assert len(poses.shape) == 3, "poses: batch -> 3x4 matrix"
if not poses.dtype.is_floating_point:
warn(
f"Got poses of {poses.dtype} dtype."
" Implicitly converting to f32"
)
poses = poses.to(torch.float32)
poses = poses_3x4_to_4x4(poses)
H, W = int(H), int(W)
assert min(H, W) > 0, (H, W)
if isinstance(px_per_m, float):
focal = float(px_per_m)
px_per_m = (
torch.tensor([focal, focal])
.reshape(*batch_dims, 2)
.expand(*batch_shape, 2)
.to(poses)
)
elif not isinstance(px_per_m, torch.Tensor):
# FIXME:
px_per_m = torch.tensor(px_per_m).expand(*batch_shape, 2).to(poses)
else:
px_per_m = px_per_m.expand(*batch_shape, 2).to(poses)
if pre_rot is not None:
poses = torch.matmul(poses, pre_rot.to(poses))
principal_point = (
torch.tensor([0.5 * W, 0.5 * H])
.reshape(*batch_dims, 2)
.expand(*batch_shape, 2)
.to(poses)
)
K = (
torch.eye(2)
.reshape(*batch_dims, 2, 2)
.expand(*batch_shape, 2, 2)
.mul(px_per_m.unsqueeze(-1))
)
K = torch.cat((K, principal_point.unsqueeze(-1)), dim=-1)
K = torch.cat(
(
K,
torch.tensor([0, 0, 1])
.reshape(*batch_dims, 1, 3)
.expand(*batch_shape, 1, 3)
.to(K),
),
dim=-2,
).to(poses)
if not isinstance(t_near, torch.Tensor):
t_near = (
torch.tensor([t_near])
.reshape(*batch_dims)
.expand(*batch_dims)
.to(poses)
)
if not isinstance(t_far, torch.Tensor):
t_far = (
torch.tensor([t_far])
.reshape(*batch_dims)
.expand(*batch_dims)
.to(poses)
)
nearfar = torch.cat((t_near, t_far), dim=-1)
nearfar = nearfar.expand(*batch_shape, 2)
return SimplePinhole(
H=H, W=W, K=K, nearfar=nearfar, camera_to_world=poses,
)
@classmethod
def stack(cls, batch: List["SimplePinhole"], dim=0) -> "SimplePinhole":
if len(batch) == 0:
return cls(
H=0,
W=0,
K=torch.empty(()),
nearfar=torch.empty(),
camera_to_world=torch.empty(),
)
stack = partial(torch.stack, dim=dim)
H, W = batch[0].H, batch[0].W
assert all((c.H, c.W) == (H, W) for c in batch), [
(c.H, c.W) for c in batch
]
return cls(
H=H,
W=W,
K=stack([c.K for c in batch]),
nearfar=stack([c.nearfar for c in batch]),
camera_to_world=stack([c.camera_to_world for c in batch]),
)
@staticmethod
def from_nerf(
H, W, focal: float, t_near: float, t_far: float, poses: Matrices3x4
):
return SimplePinhole.from_3x4(
H=H,
W=W,
px_per_m=focal,
t_near=t_near,
t_far=t_far,
poses=poses,
pre_rot=rot_nerf_to_canon(),
)
def poses_3x4_to_4x4(poses_3x4: Matrices3x4) -> Matrices4x4:
assert poses_3x4.shape[-2:] == (
3,
4,
), f"expected (..., 3, 4) got {poses_3x4.shape}"
poses = torch.tensor([0, 0, 0, 1])
poses = prepend_like(poses, poses_3x4)
poses = poses.expand(*poses_3x4.shape[:-2], 1, 4)
poses = torch.cat((poses_3x4, poses), dim=-2)
poses = poses.reshape(*poses_3x4.shape[:-2], 4, 4)
return poses
def rot_llff_to_canon() -> Matrix4x4:
"""Converts :code:`(up, left, backwards)` LLFF's frame
into canonical :code:`(left, up, forward)` frame
(like in Zisserman or pbr-book)"""
# fmt: off
m = torch.tensor([
[0, 1, .0, 0],
[1, 0, .0, 0],
[0, 0, -1, 0],
[0, 0, .0, 1],
])
# fmt: on
return m
def rot_nerf_to_canon() -> Matrix4x4:
"""Converts :code:`(right, up, backwards)` original NeRF's frame
into canonical :code:`(left, up, forward)` frame
(like in Zisserman or pbr-book)"""
# fmt: off
flip_xz = torch.tensor([
[-1, 0, 0, .0],
[0, 1, 0, 0],
[0, 0, -1, 0],
[0, 0, 0, 1.]]
)
# fmt: on
return flip_xz
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment