#gpu #kernel #rust
- GPU kernels in Rust
- Comptime
- Automatic vectorization
- Instruction and shape specialization
- Loop unrolling
- Autotuning
- WGSL - WebGPU Shading Language
- GLSL - OpenGL
- HLSL - High-level shading language
- MSL - Metal Shading Language
- CUDA
- ROCm
- SYCL
- Rust -> WGSL
- Rust -> CUDA
CubeCL provides runtimes (cubecl_wgpu and cubecl_cuda) that are built on top of the following backends: Wgpu and Cuda.
From my understanding, the current implementation includes the following constructs: ComputeClient, ComputeServer, and a Channel, which serves as the abstraction for sending requests from the client to the server.
Instantiating a ComputeClient involves two steps:
- Setting up the necessary data structures for each backend (e.g.,
wgpu_setupforWgpu). - Creating a client using the data structures from the setup, along with instantiating a
MemoryManagementtype to manage GPU memory allocation and deallocation strategies.
The client essentially wraps a Channel and a FeatureSet, which is a list of features supported by each runtime.
Once we have a ComputeClient, we can perform various tasks, such as creating or accessing resources (e.g., GPU buffers) and executing kernels. Note that invoking methods on the client will eventually route them to the ComputeServer, which holds the necessary Wgpu structures to actually create and access these resources.
use cubecl::prelude::*;
#[cube(launch_unchecked)]
fn gelu_array<F: Float>(input: &Array<F>, output: &mut Array<F>) {
if ABSOLUTE_POS < input.len() {
output[ABSOLUTE_POS] = gelu_scalar::<F>(input[ABSOLUTE_POS]);
}
}
#[cube]
fn gelu_scalar<F: Float>(x: F) -> F {
x * (F::erf(x / F::sqrt(2.0.into())) + 1.0) / 2.0
}CubeCL's unique selling point (USP) is its ability to write GPU kernels in Rust, as demonstrated above. However, there are a few things to keep in mind:
- All types used in a CubeCL function must implement the
CubeTypetrait. In the example above, bothFandArray<F>are CubeCL types. They both implement theCubeTypetrait, whileFalso implements theFloattrait. - CubeCL kernels are procedural macros that expand into Rust functions. These generated functions, which are semantically similar to the original ones, produce the Intermediate Representation (IR) when invoked.
Key point: Instead of directly generating the IR, the macro first creates a new Rust function.
The Flow:
- In the above example, the
CubeCLfunction annotated with the#[cube(launch_unchecked)]macro expands into a module containing aGeluArraystruct that implements theKerneltrait.
pub struct GeluArray<F: Float, __R: cubecl::prelude::Runtime> {
settings: cubecl::prelude::KernelSettings,
__ty: ::core::marker::PhantomData<(__R, F)>,
}- The
GeluArraystruct holds theKernelSettingsstruct. KernelSettingsallows us to configure various parameters, including the vectorization factor for kernel inputs and outputs.- Once we configure our
KernelSettings, we instantiate aKernelLauncherand register the associated kernel inputs and outputs for the kernel launch. - Kernel launching involves several levels of indirection:
- The
KernelLauncherinvokes theComputeClient'sexecutemethod to initiate kernel execution. - This method uses a
Channelto route the call to theComputeServer(in our case, theWgpuServer), which executes the kernel with the provided bindings.
- The
- Kernel execution involves preparing the pipeline state.
- At this stage, the kernel is compiled into source code (i.e., WGSL).
- Remember, the kernel is simply our
GeluArraystruct, which implements theKerneltrait. The Kernel trait requires two methods:
pub trait Kernel: Send + Sync + 'static + Sized {
/// Convert to a kernel definition.
fn define(&self) -> KernelDefinition;
/// Identifier for the kernel, used for caching kernel compilation.
fn id(&self) -> KernelId {
KernelId::new::<Self>()
}
}Vectorization factor: For example,
Elem::Float(FloatKind)with a vectorization factor of 4 represents a 4-element vector of floating-point numbers, which could be processed in a SIMD manner.
Binding struct: It's a memory binding, which connects the tensor handle and the actual memory (storage) on the compute server.
Kernel preparation involves two main steps:
- Kernel Expansion
- Kernel Definition
In the example above:
-
Kernel definition begins with instantiating the
KernelBuilderstruct and populating it with the kernel’s inputs, outputs, context, and the number of inputs and outputs. -
Two ordered maps are required to convert and store the inputs and outputs as
Variables. The order of insertion is crucial.Expanding the kernel input means registering an input and returning the element to be used for kernel expansion.
Here, "element" refers to either an
ExpandElementorExpandElementTyped, which are simply wrapper types forVariables. -
Now that we have a fully initialized
KernelBuilderand expanded kernel inputs/outputs, we proceed to actual kernel expansion.
In this phase, the body of the kernel function is expanded. In the gelu example:
- Several important data structures are involved in this process:
Operation: CubeCL operations that can be legally used in a GPU compute shader.Variable: Holds data or CubeCL values that can be referenced during GPU compute shader operations.Scope: A container that holds CubeCL operations and variables.CubeContext: A wrapper type forScope, containing root and non-root scopes and aVariablePool.ExpandElement: A wrapper type for CubeCLVariables.ExpandElementTyped: The typed version ofExpandElement.
CubeCL operations behave like conventional operations, taking input operands and returning a result. This behavior is modeled in CubeCL IR.
#[cube(launch_unchecked)]
fn gelu_array<F: Float>(input: &Array<F>, output: &mut Array<F>) {
if ABSOLUTE_POS < input.len() {
output[ABSOLUTE_POS] = gelu_scalar::<F>(input[ABSOLUTE_POS]);
}
}- In our
geluexample, theifcondition:
ABSOLUTE_POS < input.len()expands to:
/// Expanded Cube function
pub fn __expand<F: Float>(
context: &mut cubecl::frontend::CubeContext,
input: <Array<F> as cubecl::frontend::CubeType>::ExpandType,
output: <Array<F> as cubecl::frontend::CubeType>::ExpandType,
) -> () {
let _cond = {
let _lhs = ABSOLUTE_POS::expand(context);
let _rhs = input.clone().__expand_len_method(context);
cubecl::frontend::lt::expand(context, _lhs, _rhs)
};
...
...
...ABSOLUTE_POS(or_lhs) is aVariable.input.len()(or_rhs) is also aVariable.- The less-than operator
(<)expands into thelt::expandoperation, with_lhsand_rhsas inputs, along with thecontext. - All operations (and their operands) are added to the provided context (
Scope). - The order in which they are pushed onto a
CubeContext(i.e., scope) is crucial.
Note:
_lhsand_rhsare actuallyExpandElementTyped<UInt>s.
Once the kernel function is expanded, the next step is creating a kernel definition. The main data structures involved are:
KernelIntegrator: Enables the creation of aKernelDefinitionbased on aKernelExpansionandKernelSettings.KernelExpansion: Contains the necessary information to generate aKernelDefinition.KernelDefinition: Represents the finalized kernel after expansion and integration, functioning as CubeCL's intermediate representation.
The first step is to instantiate a KernelIntegrator by passing KernelSettings and invoking the integrator’s integrate method. This method combines the inputs and outputs (from the kernel expansion) into input/output bindings and returns a KernelDefinition.
As mentioned earlier, a KernelDefinition is the intermediate representation (IR) in CubeCL.
- The final step is to map this IR to the target compute shader source code. In our case, this is WGSL.
- Essentially, we map all variables and operations in CubeCL to the target shader source using the corresponding shader compiler—specifically, the
WgslCompilerin our case.
In other words, the KernelDefinition (IR) is mapped to the target compute shader source code, in this case, WGSL. The WgslCompiler translates (or maps) each IR variable, operation, and input/output binding into its corresponding shader source equivalent.
- Example CubeCL Macro Expansion: https://gist.github.com/nihalpasham/133a935304e22054b0fe92efde43caec
- Example CubeCL IR: https://gist.github.com/nihalpasham/6e4c0edf5b1a0b199c05c186a5a75b2d
- Example CubeCL Generated WGSL Shader: https://gist.github.com/nihalpasham/0ed25f2dbcb08278f79d6ceabf38a60b
Created a video playlist for future reference: https://youtube.com/playlist?list=PLIUa1VcxJuwlI5sg8M8MH6FgzBzUWuAFI&si=ObxnUaUgUZxSebdq