Created
August 28, 2024 03:11
-
-
Save nihalpasham/0ed25f2dbcb08278f79d6ceabf38a60b to your computer and use it in GitHub Desktop.
WGSL generated with the CubeCL framework for the gelu example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[START_KERNEL_COMPILATION] | |
name: gelu::gelu_array::GeluArray< | |
cubecl_core::frontend::element::float::F32, | |
cubecl_wgpu::runtime::WgpuRuntime, | |
> | |
cube_dim: (4, 1, 1) | |
shared_memory: 0 bytes | |
info: ( | |
KernelSettings { | |
mappings: [], | |
vectorization_global: None, | |
vectorization_partial: [], | |
cube_dim: CubeDim { | |
x: 4, | |
y: 1, | |
z: 1, | |
}, | |
reading_strategy: [], | |
}, | |
, | |
) | |
source: | |
```wgsl | |
@group(0) | |
@binding(0) | |
var<storage, read_write> input_0_global: array<f32>; | |
@group(0) | |
@binding(1) | |
var<storage, read_write> output_0_global: array<f32>; | |
@group(0) | |
@binding(2) | |
var<storage, read_write> info: array<u32>; | |
const WORKGROUP_SIZE_X = 4u; | |
const WORKGROUP_SIZE_Y = 1u; | |
const WORKGROUP_SIZE_Z = 1u; | |
@compute | |
@workgroup_size(4, 1, 1) | |
fn main( | |
@builtin(global_invocation_id) global_id: vec3<u32>, | |
@builtin(num_workgroups) num_workgroups: vec3<u32>, | |
) {let id = (global_id.z * num_workgroups.x * WORKGROUP_SIZE_X * num_workgroups.y * WORKGROUP_SIZE_Y) + (global_id.y * num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x; | |
let rank: u32 = info[0]; | |
var l_0_0: u32; | |
var l_0_1: bool; | |
var l_0_2: f32; | |
var l_0_3: f32; | |
l_0_0 = arrayLength(&input_0_global); | |
l_0_1 = id < l_0_0; | |
if l_0_1 { | |
l_0_2 = input_0_global[id]; | |
l_0_3 = sqrt(2f); | |
l_0_3 = l_0_2 / l_0_3; | |
l_0_3 = erf(l_0_3); | |
l_0_3 = l_0_3 + 1f; | |
l_0_2 = l_0_2 * l_0_3; | |
l_0_2 = l_0_2 / 2f; | |
output_0_global[id] = f32(l_0_2); | |
} | |
} | |
/// An approximation of the error function: https://en.wikipedia.org/wiki/Error_function#Numerical_approximations | |
/// | |
/// > (maximum error: 1.5×10−7) | |
/// > All of these approximations are valid for x ≥ 0. To use these approximations for negative x, use the fact that erf x is an odd function, so erf x = −erf(−x). | |
fn erf_positive_scalar(x: f32) -> f32 { | |
let p = 0.3275911; | |
let a1 = 0.254829592; | |
let a2 = -0.284496736; | |
let a3 = 1.421413741; | |
let a4 = -1.453152027; | |
let a5 = 1.061405429; | |
let t = 1.0 / (1.0 + p * abs(x)); | |
let tmp = ((((a5 * t + a4) * t) + a3) * t + a2) * t + a1; | |
return 1.0 - (tmp * t * exp(-x * x)); | |
} | |
fn erf_scalar(x: f32) -> f32 { | |
if (x < 0.0) { | |
return -1.0 * erf_positive_scalar(-1.0 * x); | |
} | |
return erf_positive_scalar(x); | |
} | |
fn erf(x: f32) -> f32 { | |
return erf_scalar(x); | |
} | |
``` | |
[END_KERNEL_COMPILATION] | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment