Skip to content

Instantly share code, notes, and snippets.

@manzt
Last active November 8, 2024 00:43
Show Gist options
  • Save manzt/f9d2aeb81c95701f392702de275346cd to your computer and use it in GitHub Desktop.
Save manzt/f9d2aeb81c95701f392702de275346cd to your computer and use it in GitHub Desktop.
Anywidget using WebGPU to simulate and render Conway's Game of Life
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "ad51a0d3",
"metadata": {
"jupyter": {
"source_hidden": true
}
},
"outputs": [],
"source": [
"# /// script\n",
"# requires-python = \">=3.13\"\n",
"# dependencies = [\n",
"# \"anywidget[dev]==0.9.13\",\n",
"# \"numpy==2.1.3\",\n",
"# ]\n",
"# ///"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a1134ae5",
"metadata": {},
"outputs": [],
"source": [
"import anywidget\n",
"import traitlets\n",
"\n",
"class GameOfLife(anywidget.AnyWidget):\n",
" _esm = \"index.js\"\n",
" _options = traitlets.Dict().tag(sync=True)\n",
" active = traitlets.Bool(True).tag(sync=True)\n",
" ms = traitlets.Int(200).tag(sync=True)\n",
" \n",
" def __init__(self,\n",
" *,\n",
" width = 500,\n",
" height = 500,\n",
" grid_size = 128,\n",
" initial_state = None,\n",
" **kwargs,\n",
" ):\n",
" options = { \"width\": width, \"height\": height }\n",
" if initial_state is not None:\n",
" arr = np.asarray(initial_state, dtype=np.uint32)\n",
" assert len(arr.shape) == 2\n",
" assert arr.shape[0] == arr.shape[1]\n",
" options[\"grid_size\"] = arr.shape[0]\n",
" options[\"initial_state_buffer\"] = arr.tobytes()\n",
" else:\n",
" options[\"grid_size\"] = grid_size\n",
" super().__init__(_options=options, **kwargs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2ffc565d",
"metadata": {},
"outputs": [],
"source": [
"import ipywidgets\n",
"\n",
"game = GameOfLife()\n",
"slider = ipywidgets.IntSlider(min=10, max=200, value=game.ms)\n",
"ipywidgets.dlink((slider, \"value\"), (game, \"ms\"))\n",
"ipywidgets.VBox([slider, game])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "283f9ca3",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"def initialize_grid(shape=(128, 128)):\n",
" grid = np.zeros(shape, dtype=np.uint32)\n",
"\n",
" glider = np.array([\n",
" [0, 1, 0],\n",
" [0, 0, 1],\n",
" [1, 1, 1]\n",
" ])\n",
"\n",
" lwss = np.array([\n",
" [0, 1, 0, 0, 1],\n",
" [1, 0, 0, 0, 0],\n",
" [1, 0, 0, 0, 1],\n",
" [1, 1, 1, 1, 0]\n",
" ])\n",
" \n",
" pulsar = np.array([\n",
" [0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1],\n",
" [1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1],\n",
" [1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1],\n",
" [0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0],\n",
" [1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1],\n",
" [1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1],\n",
" [1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1],\n",
" [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
" [0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0]\n",
" ])\n",
"\n",
" toad = np.array([\n",
" [0, 1, 1, 1],\n",
" [1, 1, 1, 0]\n",
" ])\n",
"\n",
" beehive = np.array([\n",
" [0, 1, 1, 0],\n",
" [1, 0, 0, 1],\n",
" [0, 1, 1, 0]\n",
" ])\n",
"\n",
" # Place patterns on the grid\n",
" grid[10:13, 10:13] = glider\n",
" grid[20:24, 20:25] = lwss\n",
" grid[30:43, 30:43] = pulsar\n",
" grid[50:52, 50:54] = toad\n",
" grid[60:63, 60:64] = beehive\n",
"\n",
" # Add some random gliders...\n",
" for _ in range(5):\n",
" x, y = np.random.randint(0, shape[0]-3), np.random.randint(0, shape[1]-3)\n",
" grid[x:x+3, y:y+3] = glider\n",
"\n",
" return grid"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6ca88c6b",
"metadata": {},
"outputs": [],
"source": [
"grid = initialize_grid()\n",
"grid"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6cf57b87",
"metadata": {},
"outputs": [],
"source": [
"GameOfLife(initial_state=grid, ms=20)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e985fc07",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
/// <reference types="npm:@webgpu/types" />
const WORKGROUP_SIZE = 8;
/** @param {{ el: HTMLElement, width: number, height: number }} options */
async function get_render_context({ el, width, height }) {
let adapter = await navigator.gpu.requestAdapter();
let device = await adapter?.requestDevice();
let canvas = el?.querySelector("canvas") ?? (() => {
let canvas = document.createElement("canvas");
el?.appendChild(canvas);
return canvas;
})();
canvas.width = width;
canvas.height = height;
let context = canvas.getContext("webgpu");
if (!context || !device) {
throw new Error("WebGPU not supported");
}
return { device, context, canvas };
}
/** @type {import("npm:@anywidget/types").Render} */
async function render({ model, el }) {
let { grid_size, width, height, initial_state_buffer } = model.get(
"_options",
);
let { device, context } = await get_render_context({ el, width, height });
let format = navigator.gpu.getPreferredCanvasFormat();
context.configure({ device, format });
// @deno-fmt-ignore
let vertices = new Float32Array([
-0.8, -0.8,
0.8, -0.8,
0.8, 0.8,
-0.8, -0.8,
0.8, 0.8,
-0.8, 0.8,
]);
let vertex_buffer = device.createBuffer({
label: "Cell vertices",
size: vertices.byteLength,
usage: GPUBufferUsage.VERTEX | GPUBufferUsage.COPY_DST,
});
device.queue.writeBuffer(vertex_buffer, 0, vertices);
let vertex_buffer_layout = {
arrayStride: 8,
attributes: [{
format: "float32x2",
offset: 0,
shaderLocation: 0,
}],
};
let cell_shader_module = device.createShaderModule({
label: "Cell shader",
code: `\
@group(0) @binding(0) var<uniform> grid: vec2f;
@group(0) @binding(1) var<storage> cellstate: array<u32>;
struct vertexinput {
@location(0) pos: vec2f,
@builtin(instance_index) instance: u32,
};
struct vertexoutput {
@builtin(position) pos: vec4f,
@location(0) cell: vec2f,
};
@vertex
fn vertex_main(input: vertexinput) -> vertexoutput {
let i = f32(input.instance);
let cell = vec2f(i % grid.x, floor(i / grid.x));
let state = f32(cellstate[input.instance]);
let cell_offset = cell / grid * 2;
let grid_pos = (input.pos * state + 1) / grid - 1 + cell_offset;
var output: vertexoutput;
output.pos = vec4f(grid_pos, 0, 1);
output.cell = cell;
return output;
}
struct fragmentinput {
@location(0) cell: vec2f,
};
@fragment
fn fragment_main(input: fragmentinput) -> @location(0) vec4f {
let c = input.cell / grid;
return vec4f(c, 1-c.x, 1);
}
`,
});
let uniform = new Float32Array([grid_size, grid_size]);
let uniform_buffer = device.createBuffer({
label: "Grid uniforms",
size: uniform.byteLength,
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST,
});
device.queue.writeBuffer(uniform_buffer, 0, uniform);
/** @type {Uint32Array} */
let cell_state;
if (initial_state_buffer) {
cell_state = new Uint32Array(initial_state_buffer.buffer);
} else {
cell_state = new Uint32Array(grid_size * grid_size);
for (let i = 0; i < cell_state.length; ++i) {
cell_state[i] = Math.random() > 0.6 ? 1 : 0;
}
}
let cell_state_storage = [
device.createBuffer({
label: "Cell state A",
size: cell_state.byteLength,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST,
}),
device.createBuffer({
label: "Cell state B",
size: cell_state.byteLength,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST,
}),
];
device.queue.writeBuffer(cell_state_storage[0], 0, cell_state);
let bind_group_layout = device.createBindGroupLayout({
label: "Cell Bind Group Layout",
entries: [{
binding: 0,
visibility: GPUShaderStage.VERTEX | GPUShaderStage.FRAGMENT |
GPUShaderStage.COMPUTE,
buffer: {},
}, {
binding: 1,
visibility: GPUShaderStage.VERTEX | GPUShaderStage.FRAGMENT |
GPUShaderStage.COMPUTE,
buffer: { type: "read-only-storage" },
}, {
binding: 2,
visibility: GPUShaderStage.COMPUTE,
buffer: { type: "storage" },
}],
});
let bind_groups = [
device.createBindGroup({
label: "Cell render bind group",
layout: bind_group_layout,
entries: [{
binding: 0,
resource: { buffer: uniform_buffer },
}, {
binding: 1,
resource: { buffer: cell_state_storage[0] },
}, {
binding: 2,
resource: { buffer: cell_state_storage[1] },
}],
}),
device.createBindGroup({
label: "Cell compute bind group",
layout: bind_group_layout,
entries: [{
binding: 0,
resource: { buffer: uniform_buffer },
}, {
binding: 1,
resource: { buffer: cell_state_storage[1] },
}, {
binding: 2,
resource: { buffer: cell_state_storage[0] },
}],
}),
];
let pipeline_layout = device.createPipelineLayout({
label: "Cell pipeline layout",
bindGroupLayouts: [bind_group_layout],
});
let cell_pipeline = device.createRenderPipeline({
label: "Cell pipeline",
layout: pipeline_layout,
vertex: {
module: cell_shader_module,
entryPoint: "vertex_main",
// @ts-expect-error - type not recognized?
buffers: [vertex_buffer_layout],
},
fragment: {
module: cell_shader_module,
entryPoint: "fragment_main",
targets: [{ format }],
},
});
let compute_shader_module = device.createShaderModule({
label: "Game of Life simulation shader",
code: `\
@group(0) @binding(0) var<uniform> grid: vec2f;
@group(0) @binding(1) var<storage> cell_state_in: array<u32>;
@group(0) @binding(2) var<storage, read_write> cell_state_out: array<u32>;
fn cell_index(cell: vec2u) -> u32 {
return (cell.y % u32(grid.y)) * u32(grid.x) + (cell.x % u32(grid.x));
}
fn cell_active(x: u32, y: u32) -> u32 {
return cell_state_in[cell_index(vec2(x, y))];
}
@compute @workgroup_size(${WORKGROUP_SIZE}, ${WORKGROUP_SIZE})
fn main(@builtin(global_invocation_id) cell: vec3u) {
let active_neighbors = (
cell_active(cell.x+1, cell.y+1) +
cell_active(cell.x+1, cell.y) +
cell_active(cell.x+1, cell.y-1) +
cell_active(cell.x, cell.y-1) +
cell_active(cell.x-1, cell.y-1) +
cell_active(cell.x-1, cell.y) +
cell_active(cell.x-1, cell.y+1) +
cell_active(cell.x, cell.y+1)
);
let i = cell_index(cell.xy);
switch active_neighbors {
case 2u: {
cell_state_out[i] = cell_state_in[i];
}
case 3u: {
cell_state_out[i] = 1u;
}
default: {
cell_state_out[i] = 0u;
}
}
}
`,
});
let compute_pipeline = device.createComputePipeline({
label: "Simulation pipeline",
layout: pipeline_layout,
compute: {
module: compute_shader_module,
entryPoint: "main",
},
});
let step = 0;
function update() {
let encoder = device.createCommandEncoder();
// Compute pass
{
let workgroup_count = Math.ceil(grid_size / WORKGROUP_SIZE);
let pass = encoder.beginComputePass();
pass.setPipeline(compute_pipeline);
pass.setBindGroup(0, bind_groups[step % 2]);
pass.dispatchWorkgroups(workgroup_count, workgroup_count);
pass.end();
}
step++;
// Render pass
{
let pass = encoder.beginRenderPass({
colorAttachments: [{
view: context.getCurrentTexture().createView(),
loadOp: "clear",
clearValue: [0.0666, 0.0666, 0.0666, 1],
storeOp: "store",
}],
});
pass.setPipeline(cell_pipeline);
pass.setVertexBuffer(0, vertex_buffer);
pass.setBindGroup(0, bind_groups[step % 2]);
pass.draw(vertices.length / 2, grid_size * grid_size);
pass.end();
}
device.queue.submit([encoder.finish()]);
}
let sim = (() => {
/** @type {undefined | number} */
let id = undefined;
return {
start() {
clearInterval(id);
id = setInterval(update, model.get("ms"));
},
stop() {
clearInterval(id);
id = undefined;
},
};
})();
// Start/Stop simulation & subscribe to changes
{
model.get("active") && sim.start();
model.on("change:ms", () => sim.start());
model.on("change:active", () => {
sim[model.get("active") ? "start" : "stop"]();
});
}
return () => {
device.destroy();
sim.stop();
};
}
export default { render };
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment