Skip to content

Instantly share code, notes, and snippets.

@yberreby
Last active April 17, 2025 23:19
Show Gist options
  • Save yberreby/fdcff1c483ad363555d83c498ba9754b to your computer and use it in GitHub Desktop.
Save yberreby/fdcff1c483ad363555d83c498ba9754b to your computer and use it in GitHub Desktop.
Fast Perlin noise in JAX with eccentricity-dependent feature scaling. GPU-ready.
#!/usr/bin/env -S uv run --script --quiet
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "pyqt6", # For matplotlib backend
# "numpy",
# "matplotlib",
# "jax[cuda12]==0.5.2", # Change for CPU.
# ]
# ///
import numpy as np
import time
from functools import partial
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib.colors import Normalize
# --- Configuration Constants ---
BATCH_SIZE = 256 # Number of images to process in parallel (Ensure >= 2 for plot)
IMAGE_SIZE = 256 # Output image resolution (pixels)
GRID_EXTENT = 5.0 # Coordinate range [-extent, extent]
BASE_FREQ = 2.0 # Noise frequency at the center (r=0)
K_FACTOR = 0.8 # Eccentricity factor (0 = uniform, >0 grows size outwards)
OCTAVES = 4 # Noise detail levels
PERSISTENCE = 0.5 # Amplitude multiplier per octave
LACUNARITY = 2.0 # Frequency multiplier per octave
WARMUP_RUNS = 2 # Runs before timing starts (includes JIT compilation)
TIMED_RUNS = 10 # Number of batches processed for timing
ENABLE_PLOT = True # Generate and show plot
SEED = 42 # Seed for permutation table shuffle
# Scale for offsetting coordinates *within the noise space* per batch item
BATCH_NOISE_OFFSET_SCALE = 100.0
# --- Shared Noise Setup ---
_numpy_rng = np.random.RandomState(SEED)
_permutation_table = _numpy_rng.permutation(256)
# Double the table to avoid buffer overflows with indices like ii+1 during hashing
_p_np = np.concatenate([_permutation_table, _permutation_table])
# Standard Perlin gradient directions
_gradients_np = np.array(
[[1, 1], [-1, 1], [1, -1], [-1, -1], [1, 0], [-1, 0], [0, 1], [0, -1]],
dtype=np.float32,
)
# Convert shared resources to JAX arrays
_p_jax = jnp.array(_p_np, dtype=jnp.int32)
_gradients_jax = jnp.array(_gradients_np, dtype=jnp.float32)
# --- JAX Math & Helper Functions ---
@jax.jit
def fade(t):
"""Ken Perlin's fade function (smoothstep): 6t^5 - 15t^4 + 10t^3."""
return t * t * t * (t * (t * 6 - 15) + 10)
@jax.jit
def lerp(t, a, b):
"""Linear interpolation: a + t * (b - a)."""
return a + t * (b - a)
@jax.jit
def calculate_gradient_indices(ix0, iy0):
"""Calculates indices into the doubled permutation table for the 4 cell corners."""
ii0, jj0 = ix0 & 255, iy0 & 255 # Wrap integer coordinates using bitwise AND
ii1, jj1 = (ii0 + 1) & 255, (jj0 + 1) & 255
idx00, idx10 = _p_jax[ii0] + jj0, _p_jax[ii1] + jj0
idx01, idx11 = _p_jax[ii0] + jj1, _p_jax[ii1] + jj1
return idx00, idx10, idx01, idx11
@jax.jit
def get_gradients_for_corners(idx00, idx10, idx01, idx11):
"""Retrieves the 4 pseudo-random gradient vectors for a cell."""
hash00, hash10 = (
_p_jax[idx00] & 7,
_p_jax[idx10] & 7,
) # Map hash value to 8 directions
hash01, hash11 = _p_jax[idx01] & 7, _p_jax[idx11] & 7
grad00, grad10 = _gradients_jax[hash00], _gradients_jax[hash10]
grad01, grad11 = _gradients_jax[hash01], _gradients_jax[hash11]
return grad00, grad10, grad01, grad11
@jax.jit
def compute_dot_products(fx, fy, grad00, grad10, grad01, grad11):
"""Computes dot products between corner gradients and distance vectors."""
dist00, dist10 = jnp.stack([fx, fy], -1), jnp.stack([fx - 1, fy], -1)
dist01, dist11 = jnp.stack([fx, fy - 1], -1), jnp.stack([fx - 1, fy - 1], -1)
# Einsum performs batched dot product efficiently
dot00 = jnp.einsum("...i,...i->...", grad00, dist00)
dot10 = jnp.einsum("...i,...i->...", grad10, dist10)
dot01 = jnp.einsum("...i,...i->...", grad01, dist01)
dot11 = jnp.einsum("...i,...i->...", grad11, dist11)
return dot00, dot10, dot01, dot11
# --- JAX Perlin Noise Core ---
@partial(jax.jit, static_argnames=["octaves"])
def perlin_noise_jax_octaves(x, y, octaves, persistence, lacunarity):
"""Calculates multi-octave Perlin noise (Fractional Brownian Motion - fBm)."""
x, y = jnp.asarray(x, dtype=jnp.float32), jnp.asarray(y, dtype=jnp.float32)
persistence, lacunarity = jnp.float32(persistence), jnp.float32(lacunarity)
total_noise = jnp.zeros_like(x)
amplitude = jnp.float32(1.0)
frequency = jnp.float32(1.0)
max_amplitude = jnp.float32(0.0)
for _ in range(octaves): # JIT unrolls this loop
xf, yf = x * frequency, y * frequency
ix0, iy0 = jnp.floor(xf).astype(jnp.int32), jnp.floor(yf).astype(jnp.int32)
fx, fy = xf - ix0, yf - iy0 # Fractional parts
u, v = fade(fx), fade(fy) # Smoothed interpolation weights
idx00, idx10, idx01, idx11 = calculate_gradient_indices(ix0, iy0)
grad00, grad10, grad01, grad11 = get_gradients_for_corners(
idx00, idx10, idx01, idx11
)
dot00, dot10, dot01, dot11 = compute_dot_products(
fx, fy, grad00, grad10, grad01, grad11
)
# Bilinear interpolation
interp_x1, interp_x2 = lerp(u, dot00, dot10), lerp(u, dot01, dot11)
octave_noise = lerp(v, interp_x1, interp_x2)
# Accumulate noise
total_noise += octave_noise * amplitude
# Update for next octave
max_amplitude += amplitude
amplitude *= persistence
frequency *= lacunarity
return total_noise / (max_amplitude + 1e-9) # Normalize result
# --- JAX Eccentric Noise ---
@partial(jax.jit, static_argnames=["octaves"])
def eccentric_noise_jax(
x, y, noise_coord_offset, base_freq, k, octaves, persistence, lacunarity
):
"""Applies radial domain warping, adds offset, then calculates Perlin noise."""
x, y = jnp.asarray(x, dtype=jnp.float32), jnp.asarray(y, dtype=jnp.float32)
noise_coord_offset = jnp.asarray(noise_coord_offset, dtype=jnp.float32)
base_freq, k = jnp.float32(base_freq), jnp.float32(k)
# 1. Eccentricity calculation: Warp based on distance r from origin
r = jnp.hypot(x, y)
safe_k = jnp.maximum(k, 1e-9) # Avoid issues if k is near 0
# JIT-compatible conditional calculates warped radial coordinate R
R = jax.lax.cond(
k < 1e-9,
lambda r_in: base_freq * r_in,
lambda r_in: (base_freq / safe_k) * jnp.log1p(safe_k * r_in),
r,
)
r_safe = jnp.maximum(r, 1e-9) # Avoid division by zero at origin
scale = R / r_safe
u_warped = jnp.where(r > 1e-9, x * scale, 0.0) # u = R * (x/r)
v_warped = jnp.where(r > 1e-9, y * scale, 0.0) # v = R * (y/r)
# 2. Apply per-batch offset *after* warping to sample different noise regions
u_final = u_warped + noise_coord_offset
v_final = v_warped # Only offsetting u in this example for variety
# 3. Evaluate noise function with the final offset coordinates
return perlin_noise_jax_octaves(u_final, v_final, octaves, persistence, lacunarity)
# Helper functions
def sync_func(res):
return res.block_until_ready()
def copy_to_cpu(res):
return np.array(res)
# --- Main Execution ---
if __name__ == "__main__":
print("--- Running Eccentric Noise (JAX) ---")
print(
f"Config: Batch={BATCH_SIZE}, Size={IMAGE_SIZE}x{IMAGE_SIZE}, Octaves={OCTAVES}, k={K_FACTOR}"
)
if BATCH_SIZE < 2 and ENABLE_PLOT:
print(
"Warning: BATCH_SIZE < 2, cannot plot comparison. Set ENABLE_PLOT=False or increase BATCH_SIZE."
)
ENABLE_PLOT = False
# Generate base coordinate grid (once)
lin = np.linspace(
-GRID_EXTENT, GRID_EXTENT, IMAGE_SIZE, endpoint=False, dtype=np.float32
)
x_base_np, y_base_np = np.meshgrid(lin, lin)
# Stack into a batch - all images start with same base coords before offset
x_batch_np = np.stack([x_base_np] * BATCH_SIZE, axis=0)
y_batch_np = np.stack([y_base_np] * BATCH_SIZE, axis=0)
# Create per-batch offsets for the noise coordinate space ('u' coordinate here)
batch_indices = np.arange(BATCH_SIZE, dtype=np.float32)
# Reshape to (BATCH_SIZE, 1, 1) for broadcasting
u_offsets_np = (batch_indices * BATCH_NOISE_OFFSET_SCALE)[:, None, None]
# Prepare args and kwargs for the JAX function call
# noise_coord_offset is the 3rd positional argument
call_args = (x_batch_np, y_batch_np, u_offsets_np, BASE_FREQ, K_FACTOR)
# Static arguments must be passed as keyword arguments
call_kwargs = {
"octaves": OCTAVES,
"persistence": PERSISTENCE,
"lacunarity": LACUNARITY,
}
# Warmup runs (includes JIT compilation)
print(f"Warming up ({WARMUP_RUNS} runs)...")
result_device = None
for _ in range(WARMUP_RUNS):
result_device = eccentric_noise_jax(*call_args, **call_kwargs)
sync_func(result_device)
print("Warmup finished.")
# Timed runs: time the whole loop for better accuracy
print(f"Timing ({TIMED_RUNS} runs)...")
start_time = time.perf_counter()
for i in range(TIMED_RUNS):
result_device = eccentric_noise_jax(*call_args, **call_kwargs)
sync_func(result_device) # Sync inside loop for accurate timing
end_time = time.perf_counter()
total_time = end_time - start_time
# Calculate throughput
avg_time_per_batch = total_time / TIMED_RUNS
total_pixels_processed = TIMED_RUNS * BATCH_SIZE * IMAGE_SIZE * IMAGE_SIZE
throughput_mps = (total_pixels_processed / total_time) / 1_000_000
# --- Output Results ---
print("\n--- Results ---")
print("Framework: JAX")
print(f"Total time for {TIMED_RUNS} runs: {total_time:.4f} seconds")
print(f"Avg time/batch: {avg_time_per_batch:.6f} seconds")
print(f"Throughput: {throughput_mps:.2f} MPixels/sec")
# --- Plotting (Optional) ---
if ENABLE_PLOT:
print("\nPlotting comparison from last batch...")
# Copy first two images from the last computed batch to CPU
assert result_device is not None
img1_np = copy_to_cpu(result_device[0])
img2_np = copy_to_cpu(result_device[1])
# Determine shared color limits for consistent mapping
vmin = min(img1_np.min(), img2_np.min())
vmax = max(img1_np.max(), img2_np.max())
norm = Normalize(vmin=vmin, vmax=vmax)
cmap = "viridis"
# Use constrained_layout for better automatic spacing
fig, axes = plt.subplots(
1, 2, figsize=(12, 5.5), sharey=True, constrained_layout=True
)
fig.suptitle(
"Eccentric Perlin Noise (JAX) - Images from same batch", fontsize=14
)
# Plot Image 1
axes[0].imshow(
img1_np,
cmap=cmap,
norm=norm,
origin="lower",
extent=[-GRID_EXTENT, GRID_EXTENT, -GRID_EXTENT, GRID_EXTENT],
)
axes[0].set_title("Image 1 (Index 0)")
axes[0].set_xlabel("X coordinate")
axes[0].set_ylabel("Y coordinate")
axes[0].grid(True, alpha=0.2)
# Plot Image 2
axes[1].imshow(
img2_np,
cmap=cmap,
norm=norm,
origin="lower",
extent=[-GRID_EXTENT, GRID_EXTENT, -GRID_EXTENT, GRID_EXTENT],
)
axes[1].set_title("Image 2 (Index 1)")
axes[1].set_xlabel("X coordinate")
axes[1].grid(True, alpha=0.2)
# Add a single shared colorbar, placed automatically by constrained_layout
fig.colorbar(
cm.ScalarMappable(norm=norm, cmap=cmap),
ax=axes,
shrink=0.8,
label="Noise Value",
location="right",
)
plt.show()
print("\nDone.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment