Last active
March 15, 2022 14:31
-
-
Save pekkavaa/7e2cc8f37b2446035c27cbc43a561ddc to your computer and use it in GitHub Desktop.
A dumb triangle rasterizer with PyTorch.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
from numpy import array | |
import torch | |
from torch import Tensor | |
""" | |
A dumb triangle rasterizer with PyTorch. | |
It evaluates the barycentrics for all image pixels for each triangle | |
and then picks the "colors" (just barycentrics again) for each pixel | |
based on which triangle rendered to that pixel last. | |
In other words it draws the triangles in order (painter's algorithm). | |
""" | |
width = 300 | |
height = 200 | |
# The triangle list | |
d = Tensor([ | |
( (0.9, 0.5), (0.5, 0.8), (0.1, 0.15) ), | |
( (0.5, 0.1), (0.3, 0.85), (0.8, 0.25) ), | |
( (0.3, 0.15), (0.1, 0.2), (0.15, 0.05) ), | |
]).cuda() | |
N = d.size()[0] | |
P = height*width # The number of pixels in the output image. | |
# Calculates the signed distance from the edge v0-v1 to point p | |
def edgefunc(v0, v1, p): | |
""" | |
let S = H*W | |
v0 and v1 have vertex positions for all N triangles. | |
Their shapes are [N x 2] | |
p is a list of sampling points as a [N x S x 2] tensor. | |
Each of the N triangles has an [S x 2] matrix of sampling points. | |
returns a [N x S] matrix | |
""" | |
S = p.size()[1] | |
# Take all the x and y coordinates of all the positions as a | |
# [N x S] tensor | |
px = p[:, :, 1].cuda() | |
py = p[:, :, 0].cuda() | |
# We need to manually broadcast the vector to cover all sample points | |
y01 = v0[:,0] - v1[:,0] # [N] | |
x10 = v1[:,1] - v0[:,1] # [N] | |
y01 = y01.unsqueeze(0).t().repeat((1, S)).cuda() # [N x S] | |
x10 = x10.unsqueeze(0).t().repeat((1, S)).cuda() # [N x S] | |
cross = v0[:,1]*v1[:,0] - v0[:,0]*v1[:,1] # [N] | |
cross = cross.unsqueeze(0).t().repeat((1,S)) # [N x S] | |
return y01*px + x10*py + cross | |
# Calculate the area of the parallelogram formed by the triangle | |
area = edgefunc(d[:, 2, :], d[:, 1, :], d[:, None, 0, :]) | |
# Create a grid of sampling positions | |
ys = np.linspace(0, 1, height, endpoint=False) | |
xs = np.linspace(0, 1, width, endpoint=False) | |
xmesh, ymesh = np.meshgrid(xs, ys) | |
# Reshape the sampling positions to a H x W x 2 tensor | |
gridcpu = np.moveaxis(array(list(zip(ymesh, xmesh))), 1, 2) | |
gridcpu = np.reshape(gridcpu, (height*width, 2)) | |
grid = Tensor(gridcpu) | |
grid = grid.unsqueeze(0).repeat((N, 1, 1)) # [N x P x 2] | |
# Evaluate the edge functions at every position. | |
# We should get a [N x P] vector out of each. | |
w0 = -edgefunc(d[:, 1, :], d[:, 2, :], grid) / area | |
w1 = -edgefunc(d[:, 2, :], d[:, 0, :], grid) / area | |
w2 = -edgefunc(d[:, 0, :], d[:, 1, :], grid) / area | |
ids = torch.tensor(range(0, N), dtype=torch.long) | |
# Only pixels inside the triangles will have color | |
# [N x P] | |
mask = (w0 > 0) & (w1 > 0) & (w2 > 0) | |
# Pack the images to a tensor for pixel-wise indexing. | |
# Each triangle will have its own image. | |
# | |
# Size will be [P x N x 3] which is kind of weird but we need it (maybe?) | |
# for indexing later. | |
imgs = torch.stack([w0,w1,w2], dim=1).transpose(0,2).transpose(1,2) | |
# Construct a vector of length P that will tell for each pixel from which | |
# image we should fetch the pixel. | |
ids = ids.unsqueeze(0).t().repeat((1,height*width)) | |
idmask = ids * mask.type(torch.LongTensor) | |
idmax = torch.max(idmask, dim=0)[0] | |
# Pick a rendered pixel from the correct image (specified by idmax) for each | |
# output pixel. | |
# "pixelrange" simply selects all pixels. For some reason just using | |
# imgs[:, idmax, :] wouldn't get rid of the middle dimension and produced | |
# a matrix with incorrect dimensions. | |
pixelrange = torch.arange(P, dtype=torch.long) | |
img2 = imgs[pixelrange, idmax, :].t() | |
# [3 x P] | |
buf = torch.zeros((3, mask.size()[1]), dtype=torch.float32) | |
buf[:,:] = img2 | |
# Collapse all masks into one and zero out pixels with no coverage. | |
mask2, _ = torch.max(mask, dim=0) | |
buf[:, 1-mask2] = 0 | |
w0cpu = array(w0) | |
w1cpu = array(w1) | |
w2cpu = array(w2) | |
maskcpu = array(mask) | |
bufcpu = array(buf) | |
data = np.zeros((height * width, 3), dtype=np.uint8) | |
data[:,:] = array(buf.t()*255) | |
plt.imshow(np.reshape(data, (height, width, 3)), interpolation='nearest') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment