Created
August 6, 2017 12:16
-
-
Save yvt/fef3622bc97d57af63d3d98ff3d21410 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#version 310 es | |
precision mediump float; | |
const uint local_size = 64u; | |
const uint kernel_size = 4u; | |
layout(local_size_x = 64 /* local_size */) in; | |
shared uint in_values[local_size + kernel_size - 1u]; | |
shared uint kernel_values[kernel_size]; | |
layout(set = 0, binding = 0) readonly buffer ConvolutionParameter { | |
uint kernel_values[kernel_size]; | |
} conv_param; | |
layout(set = 0, binding = 1) readonly buffer ConvolutionInput { | |
uint data[]; | |
} conv_in; | |
layout(set = 0, binding = 2) writeonly buffer ConvolutionOutput { | |
uint data[]; | |
} conv_out; | |
void main() | |
{ | |
uint local_id = gl_LocalInvocationID.x; | |
// load input data into shared memory | |
uint global_id = gl_GlobalInvocationID.x; | |
in_values[local_id] = conv_in.data[global_id]; | |
if (local_id > local_size - kernel_size) { | |
in_values[local_id + kernel_size - 1u] = | |
conv_in.data[global_id + kernel_size - 1u]; | |
} | |
// load kernel into shared memory | |
if (local_id < kernel_size) { | |
kernel_values[local_id] = conv_param.kernel_values[local_id]; | |
} | |
// wait for all inputs to be ready... | |
groupMemoryBarrier(); | |
barrier(); | |
// perform convolution | |
uint sum = 0u; | |
for (uint i = 0u; i < kernel_size; ++i) { | |
sum += in_values[local_id + i] * kernel_values[i]; | |
} | |
// store the result | |
conv_out.data[global_id] = sum; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment