Skip to content

Instantly share code, notes, and snippets.

@ShawonAshraf
Created April 11, 2025 20:18
Show Gist options
  • Save ShawonAshraf/c48953b1e79f554e2b28b7a93f72f922 to your computer and use it in GitHub Desktop.
Save ShawonAshraf/c48953b1e79f554e2b28b7a93f72f922 to your computer and use it in GitHub Desktop.
Layer Norm in Flax NN
class LayerNorm(nn.Module):
epsilon: float = 1e-6 # Small value for numerical stability
use_bias: bool = True
use_scale: bool = True
@nn.compact
def __call__(self, x):
# Calculate mean and variance along the feature dimension
mean = jnp.mean(x, axis=-1, keepdims=True)
variance = jnp.var(x, axis=-1, keepdims=True)
# Normalize the input
x = (x - mean) / jnp.sqrt(variance + self.epsilon)
# Apply scaling and bias (optional)
if self.use_scale:
scale = self.param('scale', nn.initializers.ones, (x.shape[-1],))
x = x * scale
if self.use_bias:
bias = self.param('bias', nn.initializers.zeros, (x.shape[-1],))
x = x + bias
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment