Skip to content

Instantly share code, notes, and snippets.

@ekzhang
Created April 12, 2025 20:11
Show Gist options
  • Save ekzhang/6597e0027d087ae078dcc6fa1dde8bdd to your computer and use it in GitHub Desktop.
Save ekzhang/6597e0027d087ae078dcc6fa1dde8bdd to your computer and use it in GitHub Desktop.
Synchronously read data from a WebGPU buffer
/**
* Graphics state used to synchronously read data from WebGPU buffers.
*
* This trick is borrowed from TensorFlow.js. Basically, the idea is to create
* an offscreen canvas with one pixel for every 4 bytes ("device storage"), then
* configure it with a WebGPU context. Copy the buffer to a texture, then draw
* the canvas onto another offscreen canvas with '2d' context ("host storage").
*
* Once it's on host storage, we can use `getImageData()` to read the pixels
* from the image directly.
*
* We use 256x256 canvases here (256 KiB). The performance of this is bad
* because it involves multiple data copies, but it still works. We also
* actually need to copy the image twice: once in "opaque" mode for the RGB
* values, and once in "premultiplied" mode for the alpha channel.
*
* https://github.com/tensorflow/tfjs/blob/tfjs-v4.22.0/tfjs-backend-webgpu/src/backend_webgpu.ts#L379
*/
class SyncReader {
static readonly alphaModes: GPUCanvasAlphaMode[] = [
"opaque",
"premultiplied",
];
static readonly width = 256;
static readonly height = 256;
initialized = false;
deviceStorage?: OffscreenCanvas[];
deviceContexts?: GPUCanvasContext[];
hostStorage?: OffscreenCanvas;
hostContext?: OffscreenCanvasRenderingContext2D;
constructor(readonly device: GPUDevice) {}
#init() {
const makeCanvas = () =>
new OffscreenCanvas(SyncReader.width, SyncReader.height);
this.deviceStorage = SyncReader.alphaModes.map(makeCanvas);
this.deviceContexts = this.deviceStorage.map((canvas, i) => {
const context = canvas.getContext("webgpu")!;
context.configure({
device: this.device,
// rgba8unorm is not supported on Chrome for macOS.
// https://bugs.chromium.org/p/chromium/issues/detail?id=1298618
format: "bgra8unorm",
usage: GPUTextureUsage.COPY_DST,
alphaMode: SyncReader.alphaModes[i],
});
return context;
});
this.hostStorage = makeCanvas();
this.hostContext = this.hostStorage.getContext("2d", {
willReadFrequently: true,
})!;
this.initialized = true;
}
read(buffer: GPUBuffer, start: number, count: number): ArrayBuffer {
if (!this.initialized) this.#init();
if (count % 4 !== 0) {
throw new Error("Read size must be a multiple of 4 bytes");
}
const deviceStorage = this.deviceStorage!;
const deviceContexts = this.deviceContexts!;
const hostContext = this.hostContext!;
const pixelsSize = count / 4;
const bytesPerRow = SyncReader.width * 4;
const valsGPU = new ArrayBuffer(count);
for (let i = 0; i < deviceContexts.length; i++) {
const texture = deviceContexts[i].getCurrentTexture();
// Read data using a (width, height) image at `offset` in valsGPU.
const readData = (width: number, height: number, offset: number) => {
const encoder = this.device.createCommandEncoder();
encoder.copyBufferToTexture(
{ buffer, bytesPerRow, offset: offset + start },
{ texture },
{ width, height, depthOrArrayLayers: 1 },
);
const commandBuffer = encoder.finish();
this.device.queue.submit([commandBuffer]);
hostContext.clearRect(0, 0, width, height);
hostContext.drawImage(deviceStorage[i], 0, 0);
const values = hostContext.getImageData(0, 0, width, height).data;
const span = new Uint8ClampedArray(valsGPU, offset, 4 * width * height);
const alphaMode = SyncReader.alphaModes[i];
for (let k = 0; k < span.length; k += 4) {
if (alphaMode === "premultiplied") {
span[k + 3] = values[k + 3];
} else {
span[k] = values[k + 2]; // opaque (BGRA)
span[k + 1] = values[k + 1];
span[k + 2] = values[k];
}
}
};
const pixelsPerCanvas = SyncReader.width * SyncReader.height;
const wholeChunks = Math.floor(pixelsSize / pixelsPerCanvas);
let remainder = pixelsSize % pixelsPerCanvas;
const remainderRows = Math.floor(remainder / SyncReader.width);
remainder = remainder % SyncReader.width;
let offset = 0;
// Read entire canvases.
for (let j = 0; j < wholeChunks; j++) {
readData(SyncReader.width, SyncReader.height, offset);
offset += pixelsPerCanvas * 4;
}
// Read a partial canvas with whole rows.
if (remainderRows > 0) {
readData(SyncReader.width, remainderRows, offset);
offset += remainderRows * SyncReader.width * 4;
}
// Read a partial canvas with some columns in the first row.
if (remainder > 0) readData(remainder, 1, offset);
}
return valsGPU;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment