Skip to content

Instantly share code, notes, and snippets.

@cloneofsimo
Created October 3, 2024 18:03
Show Gist options
  • Save cloneofsimo/e2a783a8781442936d44e7d2e677f86b to your computer and use it in GitHub Desktop.
Save cloneofsimo/e2a783a8781442936d44e7d2e677f86b to your computer and use it in GitHub Desktop.
Parameterized Fractal Triton
import torch
import triton
import triton.language as tl
from triton.language.extra import libdevice
@triton.jit
def fractal_kernel(
zr_ptr, zi_ptr, cr_ptr, ci_ptr, output_ptr,
alpha_ptr, beta_ptr, poly0_ptr, poly1_ptr, poly2_ptr, poly3_ptr, p_ptr, R, max_iter,
H, W,
BLOCK_SIZE: tl.constexpr
):
bid = tl.program_id(0)
pid = tl.program_id(1)
grid_offset = pid * BLOCK_SIZE
offsets = grid_offset + tl.arange(0, BLOCK_SIZE)
mask = offsets < H * W
batch_offset = bid * H * W
zr_ptr += batch_offset
zi_ptr += batch_offset
output_ptr += batch_offset
alpha = tl.load(alpha_ptr + bid)
beta = tl.load(beta_ptr + bid)
p = tl.load(p_ptr + bid)
poly0 = tl.load(poly0_ptr + bid)
poly1 = tl.load(poly1_ptr + bid)
poly2 = tl.load(poly2_ptr + bid)
poly3 = tl.load(poly3_ptr + bid)
zr = tl.load(zr_ptr + offsets, mask=mask)
zi = tl.load(zi_ptr + offsets, mask=mask)
cr = tl.load(cr_ptr + offsets, mask=mask)
ci = tl.load(ci_ptr + offsets, mask=mask)
iteration = tl.zeros([BLOCK_SIZE], dtype=tl.int8)
for i in range(max_iter):
zr = poly0 + poly1 * libdevice.pow(zr, 2) + poly2 * libdevice.pow(zr, 3) + poly3 * libdevice.pow(zr, 4)
zi = zi
eia_zr = zr * tl.cos(alpha) - zi * tl.sin(alpha)
eia_zi = zr * tl.sin(alpha) + zi * tl.cos(alpha)
zr = eia_zr
zi = eia_zi
modulus = tl.sqrt(zr * zr + zi * zi)
modulus = libdevice.pow(modulus, p)
angle = libdevice.atan2(zi, zr) * p
zr_new = modulus * tl.cos(angle)
zi_new = modulus * tl.sin(angle)
cos_beta = tl.cos(beta)
sin_beta = tl.sin(beta)
exp_cr = cr * cos_beta - ci * sin_beta
exp_ci = cr * sin_beta + ci * cos_beta
zr = zr_new + exp_cr
zi = zi_new + exp_ci
mag_sq = zr * zr + zi * zi
not_escaped = (mag_sq < R * R)
iteration += not_escaped.to(tl.int8)
iter_int8 = tl.cast(iteration, tl.int8)
tl.store(output_ptr + offsets, iter_int8, mask=mask)
class FractalDataset(torch.utils.data.Dataset):
def __init__(self, H, W, max_iter=100, R=4.0, device='cuda', colorize=True):
self.H = H
self.W = W
self.max_iter = max_iter
self.R = R
self.device = device
self.colorize = colorize
def batch(self, list_of_seed, polycoeffs = None):
B = len(list_of_seed)
H, W = self.H, self.W
device = self.device
x = torch.linspace(-1.0, 1.0, W, device=device)
y = torch.linspace(-1.0, 1.0, H, device=device)
xx, yy = torch.meshgrid(x, y, indexing='ij')
cr = xx.flatten()
ci = yy.flatten()
cr = cr.contiguous()
ci = ci.contiguous()
zr = torch.zeros((B, H * W), device=device)
zi = torch.zeros((B, H * W), device=device)
output = torch.empty((B, H * W), dtype=torch.int8, device=device)
torch.manual_seed(sum(list_of_seed))
alpha = torch.rand(B, device=device) * 2 * torch.pi
beta = torch.rand(B, device=device) * 0.1
p = torch.rand(B, device=device) * 1 + 1
if polycoeffs is None:
poly0 = torch.rand(B, device=device) * 1
poly1 = torch.rand(B, device=device) * 1
poly2 = torch.rand(B, device=device) * 1
poly3 = torch.rand(B, device=device) * 1
else:
poly0 = torch.ones(B, device=device) * polycoeffs[0]
poly1 = torch.ones(B, device=device) * polycoeffs[1]
poly2 = torch.ones(B, device=device) * polycoeffs[2]
poly3 = torch.ones(B, device=device) * polycoeffs[3]
BLOCK_SIZE = 1024
num_blocks = (H * W + BLOCK_SIZE - 1) // BLOCK_SIZE
grid = (B, num_blocks)
fractal_kernel[grid](
zr_ptr=zr,
zi_ptr=zi,
cr_ptr=cr,
ci_ptr=ci,
output_ptr=output,
alpha_ptr=alpha,
beta_ptr=beta,
p_ptr=p,
poly0_ptr=poly0,
poly1_ptr=poly1,
poly2_ptr=poly2,
poly3_ptr=poly3,
R=self.R,
max_iter=self.max_iter,
H=H,
W=W,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=4,
num_stages=2,
)
images = []
if self.colorize:
for b in range(B):
fractal_image = output[b].reshape(H, W)
fractal_image = self.colorize(fractal_image)
images.append(fractal_image)
batch_images = torch.stack(images, dim=0)
else:
batch_images = output.reshape(B, H, W)
return batch_images
def colorize(self, fractal_data):
fractal_data = fractal_data.to(torch.float32)
fractal_data = (fractal_data - fractal_data.min()) / (fractal_data.max() - fractal_data.min() + 1e-8)
import matplotlib.pyplot as plt
cmap = plt.get_cmap('viridis')
fractal_data_np = fractal_data.cpu().numpy()
fractal_rgb = cmap(fractal_data_np)[:, :, :3]
fractal_rgb = torch.from_numpy(fractal_rgb).permute(2, 0, 1)
fractal_rgb = (fractal_rgb * 255).to(torch.uint8)
return fractal_rgb.to(self.device)
def __len__(self):
return 2**32
def __getitem__(self, idx):
raise NotImplementedError("Use the 'batch' method instead.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment