Last active
November 8, 2024 00:43
-
-
Save manzt/f9d2aeb81c95701f392702de275346cd to your computer and use it in GitHub Desktop.
Anywidget using WebGPU to simulate and render Conway's Game of Life
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/// <reference types="npm:@webgpu/types" /> | |
const WORKGROUP_SIZE = 8; | |
/** @param {{ el: HTMLElement, width: number, height: number }} options */ | |
async function get_render_context({ el, width, height }) { | |
let adapter = await navigator.gpu.requestAdapter(); | |
let device = await adapter?.requestDevice(); | |
let canvas = el?.querySelector("canvas") ?? (() => { | |
let canvas = document.createElement("canvas"); | |
el?.appendChild(canvas); | |
return canvas; | |
})(); | |
canvas.width = width; | |
canvas.height = height; | |
let context = canvas.getContext("webgpu"); | |
if (!context || !device) { | |
throw new Error("WebGPU not supported"); | |
} | |
return { device, context, canvas }; | |
} | |
/** @type {import("npm:@anywidget/types").Render} */ | |
async function render({ model, el }) { | |
let { grid_size, width, height, initial_state_buffer } = model.get( | |
"_options", | |
); | |
let { device, context } = await get_render_context({ el, width, height }); | |
let format = navigator.gpu.getPreferredCanvasFormat(); | |
context.configure({ device, format }); | |
// @deno-fmt-ignore | |
let vertices = new Float32Array([ | |
-0.8, -0.8, | |
0.8, -0.8, | |
0.8, 0.8, | |
-0.8, -0.8, | |
0.8, 0.8, | |
-0.8, 0.8, | |
]); | |
let vertex_buffer = device.createBuffer({ | |
label: "Cell vertices", | |
size: vertices.byteLength, | |
usage: GPUBufferUsage.VERTEX | GPUBufferUsage.COPY_DST, | |
}); | |
device.queue.writeBuffer(vertex_buffer, 0, vertices); | |
let vertex_buffer_layout = { | |
arrayStride: 8, | |
attributes: [{ | |
format: "float32x2", | |
offset: 0, | |
shaderLocation: 0, | |
}], | |
}; | |
let cell_shader_module = device.createShaderModule({ | |
label: "Cell shader", | |
code: `\ | |
@group(0) @binding(0) var<uniform> grid: vec2f; | |
@group(0) @binding(1) var<storage> cellstate: array<u32>; | |
struct vertexinput { | |
@location(0) pos: vec2f, | |
@builtin(instance_index) instance: u32, | |
}; | |
struct vertexoutput { | |
@builtin(position) pos: vec4f, | |
@location(0) cell: vec2f, | |
}; | |
@vertex | |
fn vertex_main(input: vertexinput) -> vertexoutput { | |
let i = f32(input.instance); | |
let cell = vec2f(i % grid.x, floor(i / grid.x)); | |
let state = f32(cellstate[input.instance]); | |
let cell_offset = cell / grid * 2; | |
let grid_pos = (input.pos * state + 1) / grid - 1 + cell_offset; | |
var output: vertexoutput; | |
output.pos = vec4f(grid_pos, 0, 1); | |
output.cell = cell; | |
return output; | |
} | |
struct fragmentinput { | |
@location(0) cell: vec2f, | |
}; | |
@fragment | |
fn fragment_main(input: fragmentinput) -> @location(0) vec4f { | |
let c = input.cell / grid; | |
return vec4f(c, 1-c.x, 1); | |
} | |
`, | |
}); | |
let uniform = new Float32Array([grid_size, grid_size]); | |
let uniform_buffer = device.createBuffer({ | |
label: "Grid uniforms", | |
size: uniform.byteLength, | |
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST, | |
}); | |
device.queue.writeBuffer(uniform_buffer, 0, uniform); | |
/** @type {Uint32Array} */ | |
let cell_state; | |
if (initial_state_buffer) { | |
cell_state = new Uint32Array(initial_state_buffer.buffer); | |
} else { | |
cell_state = new Uint32Array(grid_size * grid_size); | |
for (let i = 0; i < cell_state.length; ++i) { | |
cell_state[i] = Math.random() > 0.6 ? 1 : 0; | |
} | |
} | |
let cell_state_storage = [ | |
device.createBuffer({ | |
label: "Cell state A", | |
size: cell_state.byteLength, | |
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST, | |
}), | |
device.createBuffer({ | |
label: "Cell state B", | |
size: cell_state.byteLength, | |
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST, | |
}), | |
]; | |
device.queue.writeBuffer(cell_state_storage[0], 0, cell_state); | |
let bind_group_layout = device.createBindGroupLayout({ | |
label: "Cell Bind Group Layout", | |
entries: [{ | |
binding: 0, | |
visibility: GPUShaderStage.VERTEX | GPUShaderStage.FRAGMENT | | |
GPUShaderStage.COMPUTE, | |
buffer: {}, | |
}, { | |
binding: 1, | |
visibility: GPUShaderStage.VERTEX | GPUShaderStage.FRAGMENT | | |
GPUShaderStage.COMPUTE, | |
buffer: { type: "read-only-storage" }, | |
}, { | |
binding: 2, | |
visibility: GPUShaderStage.COMPUTE, | |
buffer: { type: "storage" }, | |
}], | |
}); | |
let bind_groups = [ | |
device.createBindGroup({ | |
label: "Cell render bind group", | |
layout: bind_group_layout, | |
entries: [{ | |
binding: 0, | |
resource: { buffer: uniform_buffer }, | |
}, { | |
binding: 1, | |
resource: { buffer: cell_state_storage[0] }, | |
}, { | |
binding: 2, | |
resource: { buffer: cell_state_storage[1] }, | |
}], | |
}), | |
device.createBindGroup({ | |
label: "Cell compute bind group", | |
layout: bind_group_layout, | |
entries: [{ | |
binding: 0, | |
resource: { buffer: uniform_buffer }, | |
}, { | |
binding: 1, | |
resource: { buffer: cell_state_storage[1] }, | |
}, { | |
binding: 2, | |
resource: { buffer: cell_state_storage[0] }, | |
}], | |
}), | |
]; | |
let pipeline_layout = device.createPipelineLayout({ | |
label: "Cell pipeline layout", | |
bindGroupLayouts: [bind_group_layout], | |
}); | |
let cell_pipeline = device.createRenderPipeline({ | |
label: "Cell pipeline", | |
layout: pipeline_layout, | |
vertex: { | |
module: cell_shader_module, | |
entryPoint: "vertex_main", | |
// @ts-expect-error - type not recognized? | |
buffers: [vertex_buffer_layout], | |
}, | |
fragment: { | |
module: cell_shader_module, | |
entryPoint: "fragment_main", | |
targets: [{ format }], | |
}, | |
}); | |
let compute_shader_module = device.createShaderModule({ | |
label: "Game of Life simulation shader", | |
code: `\ | |
@group(0) @binding(0) var<uniform> grid: vec2f; | |
@group(0) @binding(1) var<storage> cell_state_in: array<u32>; | |
@group(0) @binding(2) var<storage, read_write> cell_state_out: array<u32>; | |
fn cell_index(cell: vec2u) -> u32 { | |
return (cell.y % u32(grid.y)) * u32(grid.x) + (cell.x % u32(grid.x)); | |
} | |
fn cell_active(x: u32, y: u32) -> u32 { | |
return cell_state_in[cell_index(vec2(x, y))]; | |
} | |
@compute @workgroup_size(${WORKGROUP_SIZE}, ${WORKGROUP_SIZE}) | |
fn main(@builtin(global_invocation_id) cell: vec3u) { | |
let active_neighbors = ( | |
cell_active(cell.x+1, cell.y+1) + | |
cell_active(cell.x+1, cell.y) + | |
cell_active(cell.x+1, cell.y-1) + | |
cell_active(cell.x, cell.y-1) + | |
cell_active(cell.x-1, cell.y-1) + | |
cell_active(cell.x-1, cell.y) + | |
cell_active(cell.x-1, cell.y+1) + | |
cell_active(cell.x, cell.y+1) | |
); | |
let i = cell_index(cell.xy); | |
switch active_neighbors { | |
case 2u: { | |
cell_state_out[i] = cell_state_in[i]; | |
} | |
case 3u: { | |
cell_state_out[i] = 1u; | |
} | |
default: { | |
cell_state_out[i] = 0u; | |
} | |
} | |
} | |
`, | |
}); | |
let compute_pipeline = device.createComputePipeline({ | |
label: "Simulation pipeline", | |
layout: pipeline_layout, | |
compute: { | |
module: compute_shader_module, | |
entryPoint: "main", | |
}, | |
}); | |
let step = 0; | |
function update() { | |
let encoder = device.createCommandEncoder(); | |
// Compute pass | |
{ | |
let workgroup_count = Math.ceil(grid_size / WORKGROUP_SIZE); | |
let pass = encoder.beginComputePass(); | |
pass.setPipeline(compute_pipeline); | |
pass.setBindGroup(0, bind_groups[step % 2]); | |
pass.dispatchWorkgroups(workgroup_count, workgroup_count); | |
pass.end(); | |
} | |
step++; | |
// Render pass | |
{ | |
let pass = encoder.beginRenderPass({ | |
colorAttachments: [{ | |
view: context.getCurrentTexture().createView(), | |
loadOp: "clear", | |
clearValue: [0.0666, 0.0666, 0.0666, 1], | |
storeOp: "store", | |
}], | |
}); | |
pass.setPipeline(cell_pipeline); | |
pass.setVertexBuffer(0, vertex_buffer); | |
pass.setBindGroup(0, bind_groups[step % 2]); | |
pass.draw(vertices.length / 2, grid_size * grid_size); | |
pass.end(); | |
} | |
device.queue.submit([encoder.finish()]); | |
} | |
let sim = (() => { | |
/** @type {undefined | number} */ | |
let id = undefined; | |
return { | |
start() { | |
clearInterval(id); | |
id = setInterval(update, model.get("ms")); | |
}, | |
stop() { | |
clearInterval(id); | |
id = undefined; | |
}, | |
}; | |
})(); | |
// Start/Stop simulation & subscribe to changes | |
{ | |
model.get("active") && sim.start(); | |
model.on("change:ms", () => sim.start()); | |
model.on("change:active", () => { | |
sim[model.get("active") ? "start" : "stop"](); | |
}); | |
} | |
return () => { | |
device.destroy(); | |
sim.stop(); | |
}; | |
} | |
export default { render }; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment