Created
October 3, 2024 18:03
-
-
Save cloneofsimo/e2a783a8781442936d44e7d2e677f86b to your computer and use it in GitHub Desktop.
Parameterized Fractal Triton
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import triton | |
import triton.language as tl | |
from triton.language.extra import libdevice | |
@triton.jit | |
def fractal_kernel( | |
zr_ptr, zi_ptr, cr_ptr, ci_ptr, output_ptr, | |
alpha_ptr, beta_ptr, poly0_ptr, poly1_ptr, poly2_ptr, poly3_ptr, p_ptr, R, max_iter, | |
H, W, | |
BLOCK_SIZE: tl.constexpr | |
): | |
bid = tl.program_id(0) | |
pid = tl.program_id(1) | |
grid_offset = pid * BLOCK_SIZE | |
offsets = grid_offset + tl.arange(0, BLOCK_SIZE) | |
mask = offsets < H * W | |
batch_offset = bid * H * W | |
zr_ptr += batch_offset | |
zi_ptr += batch_offset | |
output_ptr += batch_offset | |
alpha = tl.load(alpha_ptr + bid) | |
beta = tl.load(beta_ptr + bid) | |
p = tl.load(p_ptr + bid) | |
poly0 = tl.load(poly0_ptr + bid) | |
poly1 = tl.load(poly1_ptr + bid) | |
poly2 = tl.load(poly2_ptr + bid) | |
poly3 = tl.load(poly3_ptr + bid) | |
zr = tl.load(zr_ptr + offsets, mask=mask) | |
zi = tl.load(zi_ptr + offsets, mask=mask) | |
cr = tl.load(cr_ptr + offsets, mask=mask) | |
ci = tl.load(ci_ptr + offsets, mask=mask) | |
iteration = tl.zeros([BLOCK_SIZE], dtype=tl.int8) | |
for i in range(max_iter): | |
zr = poly0 + poly1 * libdevice.pow(zr, 2) + poly2 * libdevice.pow(zr, 3) + poly3 * libdevice.pow(zr, 4) | |
zi = zi | |
eia_zr = zr * tl.cos(alpha) - zi * tl.sin(alpha) | |
eia_zi = zr * tl.sin(alpha) + zi * tl.cos(alpha) | |
zr = eia_zr | |
zi = eia_zi | |
modulus = tl.sqrt(zr * zr + zi * zi) | |
modulus = libdevice.pow(modulus, p) | |
angle = libdevice.atan2(zi, zr) * p | |
zr_new = modulus * tl.cos(angle) | |
zi_new = modulus * tl.sin(angle) | |
cos_beta = tl.cos(beta) | |
sin_beta = tl.sin(beta) | |
exp_cr = cr * cos_beta - ci * sin_beta | |
exp_ci = cr * sin_beta + ci * cos_beta | |
zr = zr_new + exp_cr | |
zi = zi_new + exp_ci | |
mag_sq = zr * zr + zi * zi | |
not_escaped = (mag_sq < R * R) | |
iteration += not_escaped.to(tl.int8) | |
iter_int8 = tl.cast(iteration, tl.int8) | |
tl.store(output_ptr + offsets, iter_int8, mask=mask) | |
class FractalDataset(torch.utils.data.Dataset): | |
def __init__(self, H, W, max_iter=100, R=4.0, device='cuda', colorize=True): | |
self.H = H | |
self.W = W | |
self.max_iter = max_iter | |
self.R = R | |
self.device = device | |
self.colorize = colorize | |
def batch(self, list_of_seed, polycoeffs = None): | |
B = len(list_of_seed) | |
H, W = self.H, self.W | |
device = self.device | |
x = torch.linspace(-1.0, 1.0, W, device=device) | |
y = torch.linspace(-1.0, 1.0, H, device=device) | |
xx, yy = torch.meshgrid(x, y, indexing='ij') | |
cr = xx.flatten() | |
ci = yy.flatten() | |
cr = cr.contiguous() | |
ci = ci.contiguous() | |
zr = torch.zeros((B, H * W), device=device) | |
zi = torch.zeros((B, H * W), device=device) | |
output = torch.empty((B, H * W), dtype=torch.int8, device=device) | |
torch.manual_seed(sum(list_of_seed)) | |
alpha = torch.rand(B, device=device) * 2 * torch.pi | |
beta = torch.rand(B, device=device) * 0.1 | |
p = torch.rand(B, device=device) * 1 + 1 | |
if polycoeffs is None: | |
poly0 = torch.rand(B, device=device) * 1 | |
poly1 = torch.rand(B, device=device) * 1 | |
poly2 = torch.rand(B, device=device) * 1 | |
poly3 = torch.rand(B, device=device) * 1 | |
else: | |
poly0 = torch.ones(B, device=device) * polycoeffs[0] | |
poly1 = torch.ones(B, device=device) * polycoeffs[1] | |
poly2 = torch.ones(B, device=device) * polycoeffs[2] | |
poly3 = torch.ones(B, device=device) * polycoeffs[3] | |
BLOCK_SIZE = 1024 | |
num_blocks = (H * W + BLOCK_SIZE - 1) // BLOCK_SIZE | |
grid = (B, num_blocks) | |
fractal_kernel[grid]( | |
zr_ptr=zr, | |
zi_ptr=zi, | |
cr_ptr=cr, | |
ci_ptr=ci, | |
output_ptr=output, | |
alpha_ptr=alpha, | |
beta_ptr=beta, | |
p_ptr=p, | |
poly0_ptr=poly0, | |
poly1_ptr=poly1, | |
poly2_ptr=poly2, | |
poly3_ptr=poly3, | |
R=self.R, | |
max_iter=self.max_iter, | |
H=H, | |
W=W, | |
BLOCK_SIZE=BLOCK_SIZE, | |
num_warps=4, | |
num_stages=2, | |
) | |
images = [] | |
if self.colorize: | |
for b in range(B): | |
fractal_image = output[b].reshape(H, W) | |
fractal_image = self.colorize(fractal_image) | |
images.append(fractal_image) | |
batch_images = torch.stack(images, dim=0) | |
else: | |
batch_images = output.reshape(B, H, W) | |
return batch_images | |
def colorize(self, fractal_data): | |
fractal_data = fractal_data.to(torch.float32) | |
fractal_data = (fractal_data - fractal_data.min()) / (fractal_data.max() - fractal_data.min() + 1e-8) | |
import matplotlib.pyplot as plt | |
cmap = plt.get_cmap('viridis') | |
fractal_data_np = fractal_data.cpu().numpy() | |
fractal_rgb = cmap(fractal_data_np)[:, :, :3] | |
fractal_rgb = torch.from_numpy(fractal_rgb).permute(2, 0, 1) | |
fractal_rgb = (fractal_rgb * 255).to(torch.uint8) | |
return fractal_rgb.to(self.device) | |
def __len__(self): | |
return 2**32 | |
def __getitem__(self, idx): | |
raise NotImplementedError("Use the 'batch' method instead.") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment