This is an optimized implementation of RMSNorm inference kernel using Triton, a Python-based GPU programming library. This implementation is a modified version of the excellent RMSNorm kernel from the Unsloth project.
It has two improvements:
int64for pointer offset: We useint64instead of the defaultint32to compute the pointer offset value. This change prevents overflow when dealing with large sequence lengths where the offset exceeds the maximumint32value (2B).- In-place computation: Our kernel writes the result back to the input buffer, eliminating the need for additional memory allocation. This approach halves the memory usage compared to traditional implementations that use a separate output buffer.
import torch
import triton