Skip to content

Instantly share code, notes, and snippets.

@drbh
Created May 8, 2025 14:57
Show Gist options
  • Save drbh/141373363e83ea0345807d6525e1fb64 to your computer and use it in GitHub Desktop.
Save drbh/141373363e83ea0345807d6525e1fb64 to your computer and use it in GitHub Desktop.
# /// script
# dependencies = [
# "numpy",
# "torch",
# "kernels",
# ]
# ///
import torch
import torch.nn as nn
DEVICE = "cuda"
DTYPE = torch.float16 # Use float16 for better kernel performance potential
# Simple PyTorch implementation of RMSNorm for baseline comparison
class RMSNorm(nn.Module):
def __init__(self, hidden_size, variance_epsilon=1e-5):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.eps = variance_epsilon
self.hidden_size = hidden_size
def forward(self, x):
# Assumes x is (batch_size, ..., hidden_size)
input_dtype = x.dtype
# Calculate variance in float32 for stability
variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.eps)
# Apply weight and convert back to original dtype
return (self.weight * x).to(input_dtype)
class BaselineModel(nn.Module):
def __init__(self, input_size, hidden_size, output_size, eps=1e-5):
super().__init__()
self.linear1 = nn.Linear(input_size, hidden_size)
self.norm = RMSNorm(hidden_size, variance_epsilon=eps)
self.activation = nn.GELU()
self.linear2 = nn.Linear(hidden_size, output_size)
# ensure all linear layers weights are 1 for testing
with torch.no_grad():
self.linear1.weight.fill_(1)
self.linear1.bias.fill_(0)
self.linear2.weight.fill_(1)
self.linear2.bias.fill_(0)
self.norm.weight.fill_(1)
def forward(self, x):
x = self.linear1(x)
x = self.norm(x) # Apply RMSNorm
x = self.activation(x)
x = self.linear2(x)
return x
# Example usage
input_size = 128
hidden_size = 256
output_size = 10
eps_val = 1e-5
baseline_model = (
BaselineModel(input_size, hidden_size, output_size, eps=eps_val)
.to(DEVICE)
.to(DTYPE)
)
dummy_input = torch.randn(32, input_size, device=DEVICE, dtype=DTYPE) # Batch of 32
output = baseline_model(dummy_input)
print("Baseline RMSNorm model output shape:", output.shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment