Last active
June 30, 2021 11:24
-
-
Save chausies/011df759f167b17b5278264454fff379 to your computer and use it in GitHub Desktop.
Numerically stable and accurate PyTorch implementation of the log of the CDF of the standard normal distribution
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Numerically stable and accurate implementation of the natural logarithm | |
# of the cumulative distribution function (CDF) for the standard | |
# Normal/Gaussian distribution in PyTorch. | |
import matplotlib.pylab as P # replace this with numpy if you want | |
import torch as T | |
def norm_cdf(x): | |
return (1 + T.erf(x/P.sqrt(2)))/2 | |
def log_norm_cdf_helper(x): | |
a = 0.344 | |
b = 5.334 | |
return ((1 - a)*x + a*x**2+b).sqrt() | |
def log_norm_cdf(x): | |
thresh = 3 | |
out = x*0 | |
l = x<-thresh | |
g = x>thresh | |
m = T.logical_and(x>=-thresh, x<=thresh) | |
out[m] = norm_cdf(x[m]).log() | |
out[l] = -( | |
(x[l]**2 + P.log(2*P.pi))/2 + | |
log_norm_cdf_helper(-x[l]).log() | |
) | |
out[g] = T.log1p(- | |
(-x[g]**2/2).exp()/P.sqrt(2*P.pi)/log_norm_cdf_helper(x[g]) | |
) | |
return out | |
# Example plot | |
if __name__ == "__main__": | |
x = T.linspace(-10, 10, 25) | |
y = log_norm_cdf(x) | |
y2 = norm_cdf(x).log() | |
P.plot(x, y, label="improved") | |
P.plot(x, y2, label="original") | |
P.legend() | |
P.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Link to the final example image: https://i.imgur.com/oH0xyR6.png
This uses the approximation from https://en.wikipedia.org/wiki/Q-function#Bounds_and_approximations (original source: https://doi.org/10.1109/TCOM.1979.1094433 )