Skip to content

Instantly share code, notes, and snippets.

@norabelrose
Created October 18, 2024 00:46
Show Gist options
  • Save norabelrose/175eec3ddb57c2f0fd033bc3d9648306 to your computer and use it in GitHub Desktop.
Save norabelrose/175eec3ddb57c2f0fd033bc3d9648306 to your computer and use it in GitHub Desktop.
expectation of x * gelu(x) where x ~ N(mu, sigma)
def x_gelu_expectation(mu, sigma):
"""Compute E[x * gelu(x)] for x ~ N(mu, sigma^2) analytically."""
evCDF = norm.cdf(mu / np.sqrt(1 + sigma**2))
evPDF = norm.pdf(mu / np.sqrt(1 + sigma**2)) / np.sqrt(1 + sigma**2)
evZPDF = -mu*sigma/np.sqrt(1 + sigma**2)**3 * norm.pdf(mu / np.sqrt(1 + sigma**2))
# linearity
evXPDF = mu * evPDF + sigma * evZPDF
# identity (first time)
evXCDF = mu * evCDF + sigma**2 * evPDF
# by defn
evGELU = evXCDF
evDGELU = evCDF + evXPDF
# identity (second time)
return mu * evGELU + sigma**2 * evDGELU
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment