Skip to content

Instantly share code, notes, and snippets.

@nihalpasham
Created August 28, 2024 03:11
Show Gist options
  • Save nihalpasham/0ed25f2dbcb08278f79d6ceabf38a60b to your computer and use it in GitHub Desktop.
Save nihalpasham/0ed25f2dbcb08278f79d6ceabf38a60b to your computer and use it in GitHub Desktop.
WGSL generated with the CubeCL framework for the gelu example
[START_KERNEL_COMPILATION]
name: gelu::gelu_array::GeluArray<
cubecl_core::frontend::element::float::F32,
cubecl_wgpu::runtime::WgpuRuntime,
>
cube_dim: (4, 1, 1)
shared_memory: 0 bytes
info: (
KernelSettings {
mappings: [],
vectorization_global: None,
vectorization_partial: [],
cube_dim: CubeDim {
x: 4,
y: 1,
z: 1,
},
reading_strategy: [],
},
,
)
source:
```wgsl
@group(0)
@binding(0)
var<storage, read_write> input_0_global: array<f32>;
@group(0)
@binding(1)
var<storage, read_write> output_0_global: array<f32>;
@group(0)
@binding(2)
var<storage, read_write> info: array<u32>;
const WORKGROUP_SIZE_X = 4u;
const WORKGROUP_SIZE_Y = 1u;
const WORKGROUP_SIZE_Z = 1u;
@compute
@workgroup_size(4, 1, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
) {let id = (global_id.z * num_workgroups.x * WORKGROUP_SIZE_X * num_workgroups.y * WORKGROUP_SIZE_Y) + (global_id.y * num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
let rank: u32 = info[0];
var l_0_0: u32;
var l_0_1: bool;
var l_0_2: f32;
var l_0_3: f32;
l_0_0 = arrayLength(&input_0_global);
l_0_1 = id < l_0_0;
if l_0_1 {
l_0_2 = input_0_global[id];
l_0_3 = sqrt(2f);
l_0_3 = l_0_2 / l_0_3;
l_0_3 = erf(l_0_3);
l_0_3 = l_0_3 + 1f;
l_0_2 = l_0_2 * l_0_3;
l_0_2 = l_0_2 / 2f;
output_0_global[id] = f32(l_0_2);
}
}
/// An approximation of the error function: https://en.wikipedia.org/wiki/Error_function#Numerical_approximations
///
/// > (maximum error: 1.5×10−7)
/// > All of these approximations are valid for x ≥ 0. To use these approximations for negative x, use the fact that erf x is an odd function, so erf x = −erf(−x).
fn erf_positive_scalar(x: f32) -> f32 {
let p = 0.3275911;
let a1 = 0.254829592;
let a2 = -0.284496736;
let a3 = 1.421413741;
let a4 = -1.453152027;
let a5 = 1.061405429;
let t = 1.0 / (1.0 + p * abs(x));
let tmp = ((((a5 * t + a4) * t) + a3) * t + a2) * t + a1;
return 1.0 - (tmp * t * exp(-x * x));
}
fn erf_scalar(x: f32) -> f32 {
if (x < 0.0) {
return -1.0 * erf_positive_scalar(-1.0 * x);
}
return erf_positive_scalar(x);
}
fn erf(x: f32) -> f32 {
return erf_scalar(x);
}
```
[END_KERNEL_COMPILATION]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment