Created
May 8, 2025 14:57
-
-
Save drbh/141373363e83ea0345807d6525e1fb64 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# /// script | |
# dependencies = [ | |
# "numpy", | |
# "torch", | |
# "kernels", | |
# ] | |
# /// | |
import torch | |
import torch.nn as nn | |
DEVICE = "cuda" | |
DTYPE = torch.float16 # Use float16 for better kernel performance potential | |
# Simple PyTorch implementation of RMSNorm for baseline comparison | |
class RMSNorm(nn.Module): | |
def __init__(self, hidden_size, variance_epsilon=1e-5): | |
super().__init__() | |
self.weight = nn.Parameter(torch.ones(hidden_size)) | |
self.eps = variance_epsilon | |
self.hidden_size = hidden_size | |
def forward(self, x): | |
# Assumes x is (batch_size, ..., hidden_size) | |
input_dtype = x.dtype | |
# Calculate variance in float32 for stability | |
variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True) | |
x = x * torch.rsqrt(variance + self.eps) | |
# Apply weight and convert back to original dtype | |
return (self.weight * x).to(input_dtype) | |
class BaselineModel(nn.Module): | |
def __init__(self, input_size, hidden_size, output_size, eps=1e-5): | |
super().__init__() | |
self.linear1 = nn.Linear(input_size, hidden_size) | |
self.norm = RMSNorm(hidden_size, variance_epsilon=eps) | |
self.activation = nn.GELU() | |
self.linear2 = nn.Linear(hidden_size, output_size) | |
# ensure all linear layers weights are 1 for testing | |
with torch.no_grad(): | |
self.linear1.weight.fill_(1) | |
self.linear1.bias.fill_(0) | |
self.linear2.weight.fill_(1) | |
self.linear2.bias.fill_(0) | |
self.norm.weight.fill_(1) | |
def forward(self, x): | |
x = self.linear1(x) | |
x = self.norm(x) # Apply RMSNorm | |
x = self.activation(x) | |
x = self.linear2(x) | |
return x | |
# Example usage | |
input_size = 128 | |
hidden_size = 256 | |
output_size = 10 | |
eps_val = 1e-5 | |
baseline_model = ( | |
BaselineModel(input_size, hidden_size, output_size, eps=eps_val) | |
.to(DEVICE) | |
.to(DTYPE) | |
) | |
dummy_input = torch.randn(32, input_size, device=DEVICE, dtype=DTYPE) # Batch of 32 | |
output = baseline_model(dummy_input) | |
print("Baseline RMSNorm model output shape:", output.shape) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment