Last active
April 17, 2025 23:19
-
-
Save yberreby/fdcff1c483ad363555d83c498ba9754b to your computer and use it in GitHub Desktop.
Fast Perlin noise in JAX with eccentricity-dependent feature scaling. GPU-ready.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env -S uv run --script --quiet | |
# /// script | |
# requires-python = ">=3.10" | |
# dependencies = [ | |
# "pyqt6", # For matplotlib backend | |
# "numpy", | |
# "matplotlib", | |
# "jax[cuda12]==0.5.2", # Change for CPU. | |
# ] | |
# /// | |
import numpy as np | |
import time | |
from functools import partial | |
import jax | |
import jax.numpy as jnp | |
import matplotlib.pyplot as plt | |
import matplotlib.cm as cm | |
from matplotlib.colors import Normalize | |
# --- Configuration Constants --- | |
BATCH_SIZE = 256 # Number of images to process in parallel (Ensure >= 2 for plot) | |
IMAGE_SIZE = 256 # Output image resolution (pixels) | |
GRID_EXTENT = 5.0 # Coordinate range [-extent, extent] | |
BASE_FREQ = 2.0 # Noise frequency at the center (r=0) | |
K_FACTOR = 0.8 # Eccentricity factor (0 = uniform, >0 grows size outwards) | |
OCTAVES = 4 # Noise detail levels | |
PERSISTENCE = 0.5 # Amplitude multiplier per octave | |
LACUNARITY = 2.0 # Frequency multiplier per octave | |
WARMUP_RUNS = 2 # Runs before timing starts (includes JIT compilation) | |
TIMED_RUNS = 10 # Number of batches processed for timing | |
ENABLE_PLOT = True # Generate and show plot | |
SEED = 42 # Seed for permutation table shuffle | |
# Scale for offsetting coordinates *within the noise space* per batch item | |
BATCH_NOISE_OFFSET_SCALE = 100.0 | |
# --- Shared Noise Setup --- | |
_numpy_rng = np.random.RandomState(SEED) | |
_permutation_table = _numpy_rng.permutation(256) | |
# Double the table to avoid buffer overflows with indices like ii+1 during hashing | |
_p_np = np.concatenate([_permutation_table, _permutation_table]) | |
# Standard Perlin gradient directions | |
_gradients_np = np.array( | |
[[1, 1], [-1, 1], [1, -1], [-1, -1], [1, 0], [-1, 0], [0, 1], [0, -1]], | |
dtype=np.float32, | |
) | |
# Convert shared resources to JAX arrays | |
_p_jax = jnp.array(_p_np, dtype=jnp.int32) | |
_gradients_jax = jnp.array(_gradients_np, dtype=jnp.float32) | |
# --- JAX Math & Helper Functions --- | |
@jax.jit | |
def fade(t): | |
"""Ken Perlin's fade function (smoothstep): 6t^5 - 15t^4 + 10t^3.""" | |
return t * t * t * (t * (t * 6 - 15) + 10) | |
@jax.jit | |
def lerp(t, a, b): | |
"""Linear interpolation: a + t * (b - a).""" | |
return a + t * (b - a) | |
@jax.jit | |
def calculate_gradient_indices(ix0, iy0): | |
"""Calculates indices into the doubled permutation table for the 4 cell corners.""" | |
ii0, jj0 = ix0 & 255, iy0 & 255 # Wrap integer coordinates using bitwise AND | |
ii1, jj1 = (ii0 + 1) & 255, (jj0 + 1) & 255 | |
idx00, idx10 = _p_jax[ii0] + jj0, _p_jax[ii1] + jj0 | |
idx01, idx11 = _p_jax[ii0] + jj1, _p_jax[ii1] + jj1 | |
return idx00, idx10, idx01, idx11 | |
@jax.jit | |
def get_gradients_for_corners(idx00, idx10, idx01, idx11): | |
"""Retrieves the 4 pseudo-random gradient vectors for a cell.""" | |
hash00, hash10 = ( | |
_p_jax[idx00] & 7, | |
_p_jax[idx10] & 7, | |
) # Map hash value to 8 directions | |
hash01, hash11 = _p_jax[idx01] & 7, _p_jax[idx11] & 7 | |
grad00, grad10 = _gradients_jax[hash00], _gradients_jax[hash10] | |
grad01, grad11 = _gradients_jax[hash01], _gradients_jax[hash11] | |
return grad00, grad10, grad01, grad11 | |
@jax.jit | |
def compute_dot_products(fx, fy, grad00, grad10, grad01, grad11): | |
"""Computes dot products between corner gradients and distance vectors.""" | |
dist00, dist10 = jnp.stack([fx, fy], -1), jnp.stack([fx - 1, fy], -1) | |
dist01, dist11 = jnp.stack([fx, fy - 1], -1), jnp.stack([fx - 1, fy - 1], -1) | |
# Einsum performs batched dot product efficiently | |
dot00 = jnp.einsum("...i,...i->...", grad00, dist00) | |
dot10 = jnp.einsum("...i,...i->...", grad10, dist10) | |
dot01 = jnp.einsum("...i,...i->...", grad01, dist01) | |
dot11 = jnp.einsum("...i,...i->...", grad11, dist11) | |
return dot00, dot10, dot01, dot11 | |
# --- JAX Perlin Noise Core --- | |
@partial(jax.jit, static_argnames=["octaves"]) | |
def perlin_noise_jax_octaves(x, y, octaves, persistence, lacunarity): | |
"""Calculates multi-octave Perlin noise (Fractional Brownian Motion - fBm).""" | |
x, y = jnp.asarray(x, dtype=jnp.float32), jnp.asarray(y, dtype=jnp.float32) | |
persistence, lacunarity = jnp.float32(persistence), jnp.float32(lacunarity) | |
total_noise = jnp.zeros_like(x) | |
amplitude = jnp.float32(1.0) | |
frequency = jnp.float32(1.0) | |
max_amplitude = jnp.float32(0.0) | |
for _ in range(octaves): # JIT unrolls this loop | |
xf, yf = x * frequency, y * frequency | |
ix0, iy0 = jnp.floor(xf).astype(jnp.int32), jnp.floor(yf).astype(jnp.int32) | |
fx, fy = xf - ix0, yf - iy0 # Fractional parts | |
u, v = fade(fx), fade(fy) # Smoothed interpolation weights | |
idx00, idx10, idx01, idx11 = calculate_gradient_indices(ix0, iy0) | |
grad00, grad10, grad01, grad11 = get_gradients_for_corners( | |
idx00, idx10, idx01, idx11 | |
) | |
dot00, dot10, dot01, dot11 = compute_dot_products( | |
fx, fy, grad00, grad10, grad01, grad11 | |
) | |
# Bilinear interpolation | |
interp_x1, interp_x2 = lerp(u, dot00, dot10), lerp(u, dot01, dot11) | |
octave_noise = lerp(v, interp_x1, interp_x2) | |
# Accumulate noise | |
total_noise += octave_noise * amplitude | |
# Update for next octave | |
max_amplitude += amplitude | |
amplitude *= persistence | |
frequency *= lacunarity | |
return total_noise / (max_amplitude + 1e-9) # Normalize result | |
# --- JAX Eccentric Noise --- | |
@partial(jax.jit, static_argnames=["octaves"]) | |
def eccentric_noise_jax( | |
x, y, noise_coord_offset, base_freq, k, octaves, persistence, lacunarity | |
): | |
"""Applies radial domain warping, adds offset, then calculates Perlin noise.""" | |
x, y = jnp.asarray(x, dtype=jnp.float32), jnp.asarray(y, dtype=jnp.float32) | |
noise_coord_offset = jnp.asarray(noise_coord_offset, dtype=jnp.float32) | |
base_freq, k = jnp.float32(base_freq), jnp.float32(k) | |
# 1. Eccentricity calculation: Warp based on distance r from origin | |
r = jnp.hypot(x, y) | |
safe_k = jnp.maximum(k, 1e-9) # Avoid issues if k is near 0 | |
# JIT-compatible conditional calculates warped radial coordinate R | |
R = jax.lax.cond( | |
k < 1e-9, | |
lambda r_in: base_freq * r_in, | |
lambda r_in: (base_freq / safe_k) * jnp.log1p(safe_k * r_in), | |
r, | |
) | |
r_safe = jnp.maximum(r, 1e-9) # Avoid division by zero at origin | |
scale = R / r_safe | |
u_warped = jnp.where(r > 1e-9, x * scale, 0.0) # u = R * (x/r) | |
v_warped = jnp.where(r > 1e-9, y * scale, 0.0) # v = R * (y/r) | |
# 2. Apply per-batch offset *after* warping to sample different noise regions | |
u_final = u_warped + noise_coord_offset | |
v_final = v_warped # Only offsetting u in this example for variety | |
# 3. Evaluate noise function with the final offset coordinates | |
return perlin_noise_jax_octaves(u_final, v_final, octaves, persistence, lacunarity) | |
# Helper functions | |
def sync_func(res): | |
return res.block_until_ready() | |
def copy_to_cpu(res): | |
return np.array(res) | |
# --- Main Execution --- | |
if __name__ == "__main__": | |
print("--- Running Eccentric Noise (JAX) ---") | |
print( | |
f"Config: Batch={BATCH_SIZE}, Size={IMAGE_SIZE}x{IMAGE_SIZE}, Octaves={OCTAVES}, k={K_FACTOR}" | |
) | |
if BATCH_SIZE < 2 and ENABLE_PLOT: | |
print( | |
"Warning: BATCH_SIZE < 2, cannot plot comparison. Set ENABLE_PLOT=False or increase BATCH_SIZE." | |
) | |
ENABLE_PLOT = False | |
# Generate base coordinate grid (once) | |
lin = np.linspace( | |
-GRID_EXTENT, GRID_EXTENT, IMAGE_SIZE, endpoint=False, dtype=np.float32 | |
) | |
x_base_np, y_base_np = np.meshgrid(lin, lin) | |
# Stack into a batch - all images start with same base coords before offset | |
x_batch_np = np.stack([x_base_np] * BATCH_SIZE, axis=0) | |
y_batch_np = np.stack([y_base_np] * BATCH_SIZE, axis=0) | |
# Create per-batch offsets for the noise coordinate space ('u' coordinate here) | |
batch_indices = np.arange(BATCH_SIZE, dtype=np.float32) | |
# Reshape to (BATCH_SIZE, 1, 1) for broadcasting | |
u_offsets_np = (batch_indices * BATCH_NOISE_OFFSET_SCALE)[:, None, None] | |
# Prepare args and kwargs for the JAX function call | |
# noise_coord_offset is the 3rd positional argument | |
call_args = (x_batch_np, y_batch_np, u_offsets_np, BASE_FREQ, K_FACTOR) | |
# Static arguments must be passed as keyword arguments | |
call_kwargs = { | |
"octaves": OCTAVES, | |
"persistence": PERSISTENCE, | |
"lacunarity": LACUNARITY, | |
} | |
# Warmup runs (includes JIT compilation) | |
print(f"Warming up ({WARMUP_RUNS} runs)...") | |
result_device = None | |
for _ in range(WARMUP_RUNS): | |
result_device = eccentric_noise_jax(*call_args, **call_kwargs) | |
sync_func(result_device) | |
print("Warmup finished.") | |
# Timed runs: time the whole loop for better accuracy | |
print(f"Timing ({TIMED_RUNS} runs)...") | |
start_time = time.perf_counter() | |
for i in range(TIMED_RUNS): | |
result_device = eccentric_noise_jax(*call_args, **call_kwargs) | |
sync_func(result_device) # Sync inside loop for accurate timing | |
end_time = time.perf_counter() | |
total_time = end_time - start_time | |
# Calculate throughput | |
avg_time_per_batch = total_time / TIMED_RUNS | |
total_pixels_processed = TIMED_RUNS * BATCH_SIZE * IMAGE_SIZE * IMAGE_SIZE | |
throughput_mps = (total_pixels_processed / total_time) / 1_000_000 | |
# --- Output Results --- | |
print("\n--- Results ---") | |
print("Framework: JAX") | |
print(f"Total time for {TIMED_RUNS} runs: {total_time:.4f} seconds") | |
print(f"Avg time/batch: {avg_time_per_batch:.6f} seconds") | |
print(f"Throughput: {throughput_mps:.2f} MPixels/sec") | |
# --- Plotting (Optional) --- | |
if ENABLE_PLOT: | |
print("\nPlotting comparison from last batch...") | |
# Copy first two images from the last computed batch to CPU | |
assert result_device is not None | |
img1_np = copy_to_cpu(result_device[0]) | |
img2_np = copy_to_cpu(result_device[1]) | |
# Determine shared color limits for consistent mapping | |
vmin = min(img1_np.min(), img2_np.min()) | |
vmax = max(img1_np.max(), img2_np.max()) | |
norm = Normalize(vmin=vmin, vmax=vmax) | |
cmap = "viridis" | |
# Use constrained_layout for better automatic spacing | |
fig, axes = plt.subplots( | |
1, 2, figsize=(12, 5.5), sharey=True, constrained_layout=True | |
) | |
fig.suptitle( | |
"Eccentric Perlin Noise (JAX) - Images from same batch", fontsize=14 | |
) | |
# Plot Image 1 | |
axes[0].imshow( | |
img1_np, | |
cmap=cmap, | |
norm=norm, | |
origin="lower", | |
extent=[-GRID_EXTENT, GRID_EXTENT, -GRID_EXTENT, GRID_EXTENT], | |
) | |
axes[0].set_title("Image 1 (Index 0)") | |
axes[0].set_xlabel("X coordinate") | |
axes[0].set_ylabel("Y coordinate") | |
axes[0].grid(True, alpha=0.2) | |
# Plot Image 2 | |
axes[1].imshow( | |
img2_np, | |
cmap=cmap, | |
norm=norm, | |
origin="lower", | |
extent=[-GRID_EXTENT, GRID_EXTENT, -GRID_EXTENT, GRID_EXTENT], | |
) | |
axes[1].set_title("Image 2 (Index 1)") | |
axes[1].set_xlabel("X coordinate") | |
axes[1].grid(True, alpha=0.2) | |
# Add a single shared colorbar, placed automatically by constrained_layout | |
fig.colorbar( | |
cm.ScalarMappable(norm=norm, cmap=cmap), | |
ax=axes, | |
shrink=0.8, | |
label="Noise Value", | |
location="right", | |
) | |
plt.show() | |
print("\nDone.") | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment