Created
April 18, 2024 01:08
-
-
Save PWhiddy/3ede893728850f16b6f5212f8f454778 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<!DOCTYPE html> | |
<html lang="en"> | |
<head> | |
<meta charset="UTF-8"> | |
<meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
<title>WebGPU Image Convolution</title> | |
</head> | |
<body> | |
<h1>WebGPU Image Convolution</h1> | |
<button id="run">Run Convolution</button> | |
<script> | |
const runButton = document.getElementById('run'); | |
runButton.addEventListener('click', async () => { | |
if (!navigator.gpu) { | |
console.error('WebGPU is not supported. Make sure you are on a supported browser with experimental flags enabled.'); | |
return; | |
} | |
const adapter = await navigator.gpu.requestAdapter(); | |
const device = await adapter.requestDevice(); | |
const width = 2048; | |
const height = 2048; | |
const pixels = new Float32Array(width * height).map(() => Math.random()); | |
const radius = 7; | |
const kernelSize = 2 * radius + 1; | |
let kernel = new Float32Array(kernelSize).fill(0).map((_, i) => Math.exp(-((i - radius) ** 2) / (2 * (radius / 2) ** 2))); | |
kernel = kernel.map(k => k / kernel.reduce((a, b) => a + b)); | |
const vec4KernelSize = Math.ceil(kernelSize / 4); | |
const packedKernel = new Float32Array(vec4KernelSize * 4).fill(0); | |
packedKernel.set(kernel); | |
// CPU version of convolution (serial) | |
function cpuConvolution(image, width, height, kernel, radius) { | |
const result = new Float32Array(image.length); | |
for (let y = 0; y < height; y++) { | |
for (let x = 0; x < width; x++) { | |
let sum = 0; | |
for (let ky = -radius; ky <= radius; ky++) { | |
for (let kx = -radius; kx <= radius; kx++) { | |
const nx = Math.min(Math.max(x + kx, 0), width - 1); | |
const ny = Math.min(Math.max(y + ky, 0), height - 1); | |
sum += image[ny * width + nx] * kernel[ky + radius] * kernel[kx + radius]; | |
} | |
} | |
result[y * width + x] = sum; | |
} | |
} | |
return result; | |
} | |
console.time('CPU Convolution'); | |
const cpuResult = cpuConvolution(pixels, width, height, kernel, radius); | |
console.timeEnd('CPU Convolution'); | |
// Setup GPU buffers | |
const gpuInputBuffer = device.createBuffer({ | |
size: pixels.byteLength, | |
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST | |
}); | |
device.queue.writeBuffer(gpuInputBuffer, 0, pixels); | |
const gpuOutputBuffer = device.createBuffer({ | |
size: pixels.byteLength, | |
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | |
}); | |
const gpuKernelBuffer = device.createBuffer({ | |
size: packedKernel.byteLength, | |
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST | |
}); | |
device.queue.writeBuffer(gpuKernelBuffer, 0, packedKernel); | |
const wgslCode = ` | |
@group(0) @binding(0) var<storage, read> inputImage: array<f32>; | |
@group(0) @binding(1) var<storage, read_write> outputImage: array<f32>; | |
@group(0) @binding(2) var<uniform> kernel: array<vec4<f32>, ${vec4KernelSize}>; | |
@compute @workgroup_size(16, 16) | |
fn main(@builtin(global_invocation_id) global_id : vec3<u32>) { | |
let x = i32(global_id.x); | |
let y = i32(global_id.y); | |
var sum: f32 = 0.0; | |
for (var j: i32 = -${radius}; j <= ${radius}; j++) { | |
for (var i: i32 = -${radius}; i <= ${radius}; i++) { | |
let nx = clamp(x + i, 0, ${width - 1}); | |
let ny = clamp(y + j, 0, ${height - 1}); | |
let kernelValueX = kernel[(i + ${radius}) / 4][(i + ${radius}) % 4]; | |
let kernelValueY = kernel[(j + ${radius}) / 4][(j + ${radius}) % 4]; | |
sum += inputImage[ny * ${width} + nx] * kernelValueX * kernelValueY; | |
} | |
} | |
outputImage[y * ${width} + x] = sum; | |
} | |
`; | |
const pipelineLayout = device.createPipelineLayout({ | |
bindGroupLayouts: [device.createBindGroupLayout({ | |
entries: [ | |
{ binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'read-only-storage' } }, | |
{ binding: 1, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'storage' } }, | |
{ binding: 2, visibility: GPUShaderStage.COMPUTE, buffer: { type: 'uniform', hasDynamicOffset: false } } | |
] | |
})] | |
}); | |
const shaderModule = device.createShaderModule({ code: wgslCode }); | |
const pipeline = device.createComputePipeline({ | |
layout: pipelineLayout, | |
compute: { | |
module: shaderModule, | |
entryPoint: 'main' | |
} | |
}); | |
const bindGroup = device.createBindGroup({ | |
layout: pipeline.getBindGroupLayout(0), | |
entries: [ | |
{ binding: 0, resource: { buffer: gpuInputBuffer } }, | |
{ binding: 1, resource: { buffer: gpuOutputBuffer } }, | |
{ binding: 2, resource: { buffer: gpuKernelBuffer } } | |
] | |
}); | |
const commandEncoder = device.createCommandEncoder(); | |
const passEncoder = commandEncoder.beginComputePass(); | |
passEncoder.setPipeline(pipeline); | |
passEncoder.setBindGroup(0, bindGroup); | |
passEncoder.dispatchWorkgroups(Math.ceil(width / 16), Math.ceil(height / 16)); | |
passEncoder.end(); | |
// Copy from the compute output buffer to a readable buffer | |
const readBuffer = device.createBuffer({ | |
size: pixels.byteLength, | |
usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST | |
}); | |
commandEncoder.copyBufferToBuffer(gpuOutputBuffer, 0, readBuffer, 0, pixels.byteLength); | |
const gpuCommands = commandEncoder.finish(); | |
console.time('GPU Convolution'); | |
device.queue.submit([gpuCommands]); | |
await device.queue.onSubmittedWorkDone(); | |
console.timeEnd('GPU Convolution'); | |
await readBuffer.mapAsync(GPUMapMode.READ); | |
const arrayBuffer = readBuffer.getMappedRange(); | |
const gpuResult = new Float32Array(arrayBuffer); | |
console.log('Comparing CPU and GPU results...'); | |
let discrepancies = 0; | |
for (let i = 0; i < gpuResult.length; i++) { | |
if (Math.abs(gpuResult[i] - cpuResult[i]) > 0.00001) { | |
discrepancies++; | |
} | |
} | |
console.log(`Discrepancies: ${discrepancies}`); | |
}); | |
</script> | |
</body> | |
</html> |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment