Skip to content

Instantly share code, notes, and snippets.

@llandsmeer
Created May 7, 2025 11:27
Show Gist options
  • Save llandsmeer/e4dab6b39ba43711bc7a544436e1d698 to your computer and use it in GitHub Desktop.
Save llandsmeer/e4dab6b39ba43711bc7a544436e1d698 to your computer and use it in GitHub Desktop.
Raymarch JAX
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
def normalize(x):
return x / jnp.linalg.norm(x)
def vec(x, y, z):
return jnp.array([x, y, z], dtype='float32')
screen_x, screen_y = jnp.meshgrid(jnp.linspace(-1, 1), jnp.linspace(-1, 1))
screen_z = 1 * jnp.ones_like(screen_x)
light = normalize(vec(0, -1, 0.))
light_color = vec(1, 1, 1)
base_color = vec(0.5, 0.5, 0.5)
def sdf(x):
return jnp.sqrt(((x - vec(0.2, .1, 1.0))**2).sum()) - 0.5
def sdf_normal(x):
return normalize(jax.grad(sdf)(x))
rays = jnp.vstack([
screen_x.flatten(),
screen_y.flatten(),
screen_z.flatten()
]).T
origins = jnp.zeros_like(rays)
rays = jax.vmap(normalize)(rays)
def march(origin, ray):
def loop(at, _):
d = sdf(at)
at = at + d * ray * 0.1
return at, at
at, _ = jax.lax.scan(loop, origin, length=100)
return (sdf(at) < 0.01) * (
base_color +
light_color * (sdf_normal(at) @ light)
)
img = jax.vmap(march)(origins, rays).reshape(*screen_z.shape, 3)
plt.imshow(img)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment