Created
April 10, 2020 02:48
-
-
Save Chandler/80ef67b164cd1a7bb50446adc24f927a to your computer and use it in GitHub Desktop.
Normals w/ jax
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# 3D Surface normals with JAX | |
import matplotlib.pyplot as plt | |
from jax import vmap, jacfwd, np | |
# Simple Torus surface parameterization f(u,v) -> (x,y,z) | |
# The multivariable analytic function to be differentiated | |
def _f(uv): | |
u,v = uv | |
x = (8 + 2 * np.cos(v)) * np.cos(u) | |
y = (8 + 2 * np.cos(v)) * np.sin(u) | |
z = 2 * np.sin(v) | |
return [x,y,z] | |
# List of surface points to compute normals for | |
coordinates = [] | |
for u in np.linspace(0, 2*np.pi, 100): | |
for v in np.linspace(0, 2*np.pi, 100): | |
coordinates.append([u,v]) | |
coordinates = np.array(coordinates) # (N x 2) | |
#================== JAX ========================== | |
# Compile a vectorized version of _f | |
f = vmap(_f) | |
# Compile the vectorized differential of f | |
# This is a function which returns the jacobian matrix | |
# at each point, essentially the multi-variable version of a derivative. | |
df = vmap(jacfwd(_f)) | |
#================================================= | |
# 3d points on the surface (N x 3) | |
points = np.array(f(coordinates)).swapaxes(0,1) | |
# The Jacobian is the (3 x 2) matrix of all partial derivatives of f | |
# [dxdu, dxdv] | |
# [dydu, dydv] aka [dfdu, dfdv] | |
# [dzdu, dzdv] | |
jacobian_matrices = np.array(df(coordinates)) # Vectorized shape is (3 x N x 2) | |
# dfdu: list of 3d vectors tangent to the surface in u direction (N x 3) | |
# dfdv: list of 3d vectors tangent to the surface in v direction (N x 3) | |
dfdu, dfdv = np.array(jacobian_matrices).swapaxes(0,2) | |
# A list of 3d unit vectors normal to the surface (N x 3) | |
v = np.cross(dfdu, dfdv) | |
unit_normals = v / np.linalg.norm(v) | |
# See for yourself, create a quiver plot of the vectors | |
x,y,z = points.swapaxes(0,1) # Vector positions | |
u,v,w = unit_normals.swapaxes(0,1) # Vector directions | |
plt.figure().gca(projection='3d').quiver(x, y, z, u, v, w) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment