Skip to content

Instantly share code, notes, and snippets.

@pekkavaa
Last active March 15, 2022 14:31
Show Gist options
  • Save pekkavaa/7e2cc8f37b2446035c27cbc43a561ddc to your computer and use it in GitHub Desktop.
Save pekkavaa/7e2cc8f37b2446035c27cbc43a561ddc to your computer and use it in GitHub Desktop.
A dumb triangle rasterizer with PyTorch.
import numpy as np
import matplotlib.pyplot as plt
from numpy import array
import torch
from torch import Tensor
"""
A dumb triangle rasterizer with PyTorch.
It evaluates the barycentrics for all image pixels for each triangle
and then picks the "colors" (just barycentrics again) for each pixel
based on which triangle rendered to that pixel last.
In other words it draws the triangles in order (painter's algorithm).
"""
width = 300
height = 200
# The triangle list
d = Tensor([
( (0.9, 0.5), (0.5, 0.8), (0.1, 0.15) ),
( (0.5, 0.1), (0.3, 0.85), (0.8, 0.25) ),
( (0.3, 0.15), (0.1, 0.2), (0.15, 0.05) ),
]).cuda()
N = d.size()[0]
P = height*width # The number of pixels in the output image.
# Calculates the signed distance from the edge v0-v1 to point p
def edgefunc(v0, v1, p):
"""
let S = H*W
v0 and v1 have vertex positions for all N triangles.
Their shapes are [N x 2]
p is a list of sampling points as a [N x S x 2] tensor.
Each of the N triangles has an [S x 2] matrix of sampling points.
returns a [N x S] matrix
"""
S = p.size()[1]
# Take all the x and y coordinates of all the positions as a
# [N x S] tensor
px = p[:, :, 1].cuda()
py = p[:, :, 0].cuda()
# We need to manually broadcast the vector to cover all sample points
y01 = v0[:,0] - v1[:,0] # [N]
x10 = v1[:,1] - v0[:,1] # [N]
y01 = y01.unsqueeze(0).t().repeat((1, S)).cuda() # [N x S]
x10 = x10.unsqueeze(0).t().repeat((1, S)).cuda() # [N x S]
cross = v0[:,1]*v1[:,0] - v0[:,0]*v1[:,1] # [N]
cross = cross.unsqueeze(0).t().repeat((1,S)) # [N x S]
return y01*px + x10*py + cross
# Calculate the area of the parallelogram formed by the triangle
area = edgefunc(d[:, 2, :], d[:, 1, :], d[:, None, 0, :])
# Create a grid of sampling positions
ys = np.linspace(0, 1, height, endpoint=False)
xs = np.linspace(0, 1, width, endpoint=False)
xmesh, ymesh = np.meshgrid(xs, ys)
# Reshape the sampling positions to a H x W x 2 tensor
gridcpu = np.moveaxis(array(list(zip(ymesh, xmesh))), 1, 2)
gridcpu = np.reshape(gridcpu, (height*width, 2))
grid = Tensor(gridcpu)
grid = grid.unsqueeze(0).repeat((N, 1, 1)) # [N x P x 2]
# Evaluate the edge functions at every position.
# We should get a [N x P] vector out of each.
w0 = -edgefunc(d[:, 1, :], d[:, 2, :], grid) / area
w1 = -edgefunc(d[:, 2, :], d[:, 0, :], grid) / area
w2 = -edgefunc(d[:, 0, :], d[:, 1, :], grid) / area
ids = torch.tensor(range(0, N), dtype=torch.long)
# Only pixels inside the triangles will have color
# [N x P]
mask = (w0 > 0) & (w1 > 0) & (w2 > 0)
# Pack the images to a tensor for pixel-wise indexing.
# Each triangle will have its own image.
#
# Size will be [P x N x 3] which is kind of weird but we need it (maybe?)
# for indexing later.
imgs = torch.stack([w0,w1,w2], dim=1).transpose(0,2).transpose(1,2)
# Construct a vector of length P that will tell for each pixel from which
# image we should fetch the pixel.
ids = ids.unsqueeze(0).t().repeat((1,height*width))
idmask = ids * mask.type(torch.LongTensor)
idmax = torch.max(idmask, dim=0)[0]
# Pick a rendered pixel from the correct image (specified by idmax) for each
# output pixel.
# "pixelrange" simply selects all pixels. For some reason just using
# imgs[:, idmax, :] wouldn't get rid of the middle dimension and produced
# a matrix with incorrect dimensions.
pixelrange = torch.arange(P, dtype=torch.long)
img2 = imgs[pixelrange, idmax, :].t()
# [3 x P]
buf = torch.zeros((3, mask.size()[1]), dtype=torch.float32)
buf[:,:] = img2
# Collapse all masks into one and zero out pixels with no coverage.
mask2, _ = torch.max(mask, dim=0)
buf[:, 1-mask2] = 0
w0cpu = array(w0)
w1cpu = array(w1)
w2cpu = array(w2)
maskcpu = array(mask)
bufcpu = array(buf)
data = np.zeros((height * width, 3), dtype=np.uint8)
data[:,:] = array(buf.t()*255)
plt.imshow(np.reshape(data, (height, width, 3)), interpolation='nearest')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment