Last active
February 1, 2018 21:29
-
-
Save djsutherland/6f317916c46795d4e98ae364666ff0c7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"\n", | |
"$$\n", | |
"\\DeclareMathOperator{\\E}{\\mathbb E}\n", | |
"\\DeclareMathOperator{\\Var}{Var}\n", | |
"\\DeclareMathOperator{\\Cov}{Cov}\n", | |
"\\DeclareMathOperator{\\N}{\\mathcal N}\n", | |
"\\newcommand{\\XY}{\\begin{bmatrix}X \\\\ Y\\end{bmatrix}}\n", | |
"\\newcommand{\\tp}{^\\mathsf{T}}\n", | |
"\\newcommand{\\T}{\\tilde}\n", | |
"\\newcommand{\\v}{\\mathcal V}\n", | |
"$$\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%matplotlib inline\n", | |
"import numpy as np\n", | |
"import matplotlib.pyplot as plt\n", | |
"import seaborn as sns\n", | |
"sns.set(rc={'figure.figsize': [12, 8]})\n", | |
"\n", | |
"from scipy import stats" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"We'll use results from [Rosenbaum (1961)](https://www.jstor.org/stable/2984029) about the bivariate truncated normal." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"If\n", | |
"\n", | |
"$$\n", | |
"\\begin{bmatrix}\\T X \\\\ \\T Y\\end{bmatrix} \\sim \\N\\left( \\begin{bmatrix}0 \\\\ 0\\end{bmatrix}, \\begin{bmatrix}1 & \\rho \\\\ \\rho & 1\\end{bmatrix} \\right)\n", | |
"$$" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"then Rosenbaum's (1) tells us that\n", | |
"\n", | |
"$$\n", | |
"\\Pr(\\T X \\ge h, \\T Y \\ge k) \\E[\\T X \\mid \\T X \\ge h, \\T Y \\ge k]\n", | |
"= \\phi(h) \\Phi\\left( \\frac{\\rho h - k}{\\sqrt{1 - \\rho^2}} \\right)\n", | |
"+ \\rho \\phi(k) \\Phi\\left( \\frac{\\rho k - h}{\\sqrt{1 - \\rho^2}} \\right)\n", | |
".$$" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Rosenbaum's (3) is\n", | |
"\n", | |
"\\begin{align}\n", | |
" \\Pr\\left(\\T X \\ge h, \\T Y \\ge k \\right) \\E\\left[\\T X^2 \\mid \\T X \\ge h, \\T Y \\ge k\\right]\n", | |
" &= \\Pr\\left(\\T X \\ge h, \\T Y \\ge k \\right)\n", | |
" + h \\phi(h) \\Phi\\left( \\frac{\\rho h - k}{\\sqrt{1 - \\rho^2}} \\right)\n", | |
" + \\rho^2 k \\phi(k) \\Phi\\left( \\frac{\\rho k - h}{\\sqrt{1 - \\rho^2}} \\right)\n", | |
" + \\frac{\\rho \\sqrt{1-\\rho^2}}{\\sqrt{2 \\pi}} \\phi\\left( \\sqrt{\\frac{h^2 - 2 \\rho h k + k^2}{1 - \\rho^2}} \\right)\n", | |
"\\end{align}" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"and (5) is\n", | |
"\n", | |
"\\begin{align}\n", | |
" \\Pr\\left(\\T X \\ge h, \\T Y \\ge k \\right) \\E\\left[\\T X \\T Y \\mid \\T X \\ge h, \\T Y \\ge k\\right]\n", | |
" &= \\rho \\Pr\\left(\\T X \\ge h, \\T Y \\ge k \\right)\n", | |
" + \\rho h \\phi(h) \\Phi\\left( \\frac{\\rho h - k}{\\sqrt{1 - \\rho^2}} \\right)\n", | |
" + \\rho k \\phi(k) \\Phi\\left( \\frac{\\rho k - h}{\\sqrt{1 - \\rho^2}} \\right)\n", | |
" + \\frac{\\sqrt{1-\\rho^2}}{\\sqrt{2 \\pi}} \\phi\\left( \\sqrt{\\frac{h^2 - 2 \\rho h k + k^2}{1 - \\rho^2}} \\right)\n", | |
"\\end{align}" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Special case of (1) when $k = -\\infty$:\n", | |
"\n", | |
"\\begin{align}\n", | |
"\\Pr(\\T X \\ge h) \\E[\\T X \\mid \\T X \\ge h]\n", | |
" &= \\Pr(\\T X \\ge h, \\T Y \\ge -\\infty) \\E[\\T X \\mid \\T X \\ge h, \\T Y \\ge -\\infty]\n", | |
"\\\\&= \\phi(h) \\left(1 - \\underbrace{\\Phi\\left( \\frac{-\\infty - \\rho h}{\\sqrt{1 - \\rho^2}} \\right)}_0 \\right)\n", | |
" + \\rho \\underbrace{\\phi(k)}_0 \\left(1 - \\Phi\\left( \\frac{h + \\rho \\infty}{\\sqrt{1 - \\rho^2}} \\right) \\right)\n", | |
"\\\\&= \\phi(h)\n", | |
".\\end{align}" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"and of (3) with $k = -\\infty$:\n", | |
"\n", | |
"\\begin{align}\n", | |
" \\Pr\\left(\\T X \\ge h \\right) \\E\\left[\\T X^2 \\mid \\T X \\ge h \\right]\n", | |
" &= \\Pr\\left(\\T X \\ge h \\right)\n", | |
" + h \\phi(h) \\underbrace{\\Phi\\left( \\frac{\\rho h - k}{\\sqrt{1 - \\rho^2}} \\right)}_0\n", | |
" + \\rho^2 \\underbrace{k \\phi(k)}_{0} \\Phi\\left( \\frac{\\rho k - h}{\\sqrt{1 - \\rho^2}} \\right)\n", | |
" + \\frac{\\rho \\sqrt{1-\\rho^2}}{\\sqrt{2 \\pi}} \\underbrace{\\phi\\left( \\sqrt{\\frac{h^2 - 2 \\rho h k + k^2}{1 - \\rho^2}} \\right)}_0\n", | |
"\\\\&= \\Phi(-h)\n", | |
"\\end{align}" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Univariate ReLUs" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Let\n", | |
"$$\n", | |
"\\XY = \\begin{bmatrix}\\sigma_x & 0 \\\\ 0 & \\sigma_y\\end{bmatrix} \\begin{bmatrix}\\T X\\\\\\T Y\\end{bmatrix} + \\begin{bmatrix}\\mu_x \\\\ \\mu_y\\end{bmatrix}\n", | |
"\\sim \\N\\left( \\begin{bmatrix}\\mu_x \\\\ \\mu_y\\end{bmatrix}, \\begin{bmatrix}\\sigma_x^2 & \\rho \\sigma_x \\sigma_y \\\\ \\rho \\sigma_x \\sigma_y & \\sigma_y^2\\end{bmatrix} \\right)\n", | |
"= \\N(\\mu, \\Sigma)\n", | |
"$$\n", | |
"and define $X_+ = \\max(X, 0)$, $Y_+ = \\max(Y, 0)$." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Then, letting $Q_x := \\Phi\\left(\\frac{\\mu_x}{\\sigma_x}\\right)$ and $q_x := \\phi\\left(\\frac{\\mu_x}{\\sigma_x}\\right)$\n", | |
"\n", | |
"\\begin{align}\n", | |
" \\E[ X_+ ]\n", | |
" &= \\Pr(X_+=0) 0 + \\Pr(X_+ > 0) \\E[X \\mid X > 0]\n", | |
"\\\\&= \\Pr(X > 0)\\left( \\mu_x + \\sigma_x \\E[\\T X \\mid \\T X \\ge -\\mu_x / \\sigma_x] \\right)\n", | |
"\\\\&= \\Phi\\left(\\frac{\\mu_x}{\\sigma_x}\\right) \\mu_x + \\phi\\left(\\frac{\\mu_x}{\\sigma_x}\\right) \\sigma_x\n", | |
"\\\\&= Q_x \\mu_x + q_x \\sigma_x\n", | |
"\\end{align}" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"and\n", | |
"\n", | |
"\\begin{align}\n", | |
" \\E[ X_+^2 ]\n", | |
" &= \\Pr(X_+=0) 0 + \\Pr(X_+ > 0) \\E[X^2 \\mid X > 0]\n", | |
"\\\\&= \\Pr\\left(\\T X \\ge \\frac{-\\mu_x}{\\sigma_x}\\right) \\E\\left[(\\mu_x + \\sigma_x \\T X)^2 \\mid \\T X \\ge -\\mu_x / \\sigma_x\\right]\n", | |
"\\\\&= \\Pr\\left(\\T X \\ge \\frac{-\\mu_x}{\\sigma_x}\\right) \\E\\left[\\mu_x^2 + \\mu_x \\sigma_x \\T X + \\sigma_x^2 \\T X^2 \\mid \\T X \\ge -\\mu_x / \\sigma_x\\right]\n", | |
"\\\\&= \\mu_x^2 \\Phi\\left(\\frac{\\mu_x}{\\sigma_x}\\right)\n", | |
" + \\mu_x \\sigma_x \\phi\\left(\\frac{\\mu_x}{\\sigma_x}\\right)\n", | |
" + \\sigma_x^2 \\Phi\\left(\\frac{\\mu_x}{\\sigma_x}\\right)\n", | |
"\\\\&= Q_x \\mu_x^2\n", | |
" + q_x \\mu_x \\sigma_x\n", | |
" + Q_x \\sigma_x^2\n", | |
"\\end{align}" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"which yields\n", | |
"\n", | |
"\\begin{align}\n", | |
" \\Var[X_+]\n", | |
" &= \\E[X_+^2] - \\E[X_+]^2\n", | |
"\\\\&=\n", | |
" Q_x \\mu_x^2\n", | |
" + q_x \\mu_x \\sigma_x\n", | |
" + Q_x \\sigma_x^2\n", | |
" - Q_x^2 \\mu_x^2\n", | |
" - q_x^2 \\sigma_x^2\n", | |
" - 2 q_x Q_x \\mu_x \\sigma_x\n", | |
"\\\\&= Q_x (1 - Q_x) \\mu_x^2\n", | |
" + (1 - 2 Q_x) q_x \\mu_x \\sigma_x\n", | |
" + (Q_x - q_x^2) \\sigma_x^2\n", | |
"\\end{align}" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def relu_1d_normal_mean_var(mu, sigma2):\n", | |
" mu = np.asarray(mu, dtype=float)\n", | |
" sigma2 = np.asarray(sigma2, dtype=float)\n", | |
" sigma = np.sqrt(sigma2)\n", | |
"\n", | |
" Q = stats.norm.cdf(mu / sigma)\n", | |
" q = stats.norm.pdf(mu / sigma)\n", | |
" mn = Q * mu + q * sigma\n", | |
" var = Q * (1 - Q) * mu**2 + (1 - 2 * Q) * q * mu * sigma + (Q - q**2) * sigma2\n", | |
" return mn, var" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"mu = .3\n", | |
"sigma = 2.1\n", | |
"mn, vr = relu_1d_normal_mean_var(mu, sigma**2)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"samps = np.maximum(0, np.random.normal(mu, sigma, 10**7))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(0.99671002149472687, 0.99631304289848877)" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"samps.mean(), mn" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(1.7627938841827537, 1.7617356043805259)" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"np.var(samps), vr" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Multivariate ReLUs" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Let $\\T X \\sim \\N(0, I)$ in $\\mathbb R^d$, and let $\\Sigma = C C^T$ so that\n", | |
"$$\n", | |
"X = \\mu + C \\T X \\sim \\N(\\mu, \\Sigma)\n", | |
".$$\n", | |
"Define $X_+$ elementwise,\n", | |
"and define $q_i = \\phi(\\mu_i / \\sigma_i)$, $Q_i = \\Phi(\\mu_i / \\sigma_i)$ as before." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"We have that\n", | |
"$$\n", | |
"\\E[X]\n", | |
"= \\begin{bmatrix} \\E[(X_i)_+] \\end{bmatrix}_i\n", | |
"= \\begin{bmatrix} Q_i \\mu_i + q_i \\mu_i \\end{bmatrix}_i\n", | |
".$$" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"We also have\n", | |
"$$\n", | |
"\\Cov(X)\n", | |
"= \\begin{bmatrix} \\Cov\\left( (X_i)_+, (X_j)_+ \\right) \\end{bmatrix}_{ij}\n", | |
",$$\n", | |
"and of course we know the diagonal terms from before.\n", | |
"\n", | |
"The off-diagonal terms are given by\n", | |
"$$\n", | |
"\\E[ (X_i)_+ (X_j)_+ ] - \\E[(X_i)_+] \\E[(X_j)_+]\n", | |
",$$\n", | |
"and we know the latter terms already.\n", | |
"It remains to compute $\\E[X_+ Y_+]$ in the bivariate case." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Recall that $\\rho = \\Sigma_{xy} / (\\sigma_x \\sigma_y)$,\n", | |
"so\n", | |
"$$\n", | |
"\\sqrt{1 - \\rho^2}\n", | |
"= \\sqrt{\\frac{\\sigma_x^2 \\sigma_y^2 - \\Sigma_{xy}^2}{\\sigma_x^2 \\sigma_y^2}}\n", | |
"= \\frac{\\sqrt{\\sigma_x^2 \\sigma_y^2 - \\Sigma_{xy}^2}}{\\sigma_x \\sigma_y}\n", | |
"$$" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Let the event $\\v$ be $\\{X > 0, Y > 0\\} = \\{ \\T X > \\frac{-\\mu_x}{\\sigma_x}, \\T Y > \\frac{-\\mu_y}{\\sigma_y} \\}$.\n", | |
"Then\n", | |
"\n", | |
"\\begin{align}\n", | |
" \\E[X_+ Y_+]\n", | |
" &= \\Pr(\\v) \\E[ X Y \\mid \\v] + Pr(\\lnot\\v) \\, 0\n", | |
"\\\\&= \\Pr(\\v)\n", | |
" \\E\\left[ (\\mu_x + \\sigma_x \\T X) (\\mu_y + \\sigma_y \\T Y) \\mid \\v \\right]\n", | |
"\\\\&= \\mu_x \\mu_y \\Pr(\\v)\n", | |
" + \\mu_y \\sigma_x \\Pr(\\v) \\E[ \\T X \\mid \\v]\n", | |
" + \\mu_x \\sigma_y \\Pr(\\v) \\E[ \\T Y \\mid \\v]\n", | |
" + \\sigma_x \\sigma_y \\Pr(\\v) \\E\\left[ \\T X \\T Y \\mid \\v \\right]\n", | |
"\\end{align}" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Let\n", | |
"$h = -\\mu_x / \\sigma_x$, $k = -\\mu_y / \\sigma_y$,\n", | |
"$$\n", | |
"R_{xy}\n", | |
"= \\Phi\\left( \\frac{\\rho h - k}{\\sqrt{1 - \\rho^2}} \\right)\n", | |
",$$\n", | |
"which is consistent with\n", | |
"$$\n", | |
"R_{yx}\n", | |
"= \\Phi\\left( \\frac{\\rho k - h}{\\sqrt{1 - \\rho^2}} \\right)\n", | |
"$$" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Then (1) gives\n", | |
"\n", | |
"\\begin{align}\n", | |
" \\Pr(\\v) \\E[ \\T X \\mid \\v]\n", | |
" &= q_x R_{xy}\n", | |
" + \\rho q_y R_{yx}\n", | |
"\\\\ \\Pr(\\v) \\E[ \\T Y \\mid \\v]\n", | |
" &= \\rho q_x R_{xy}\n", | |
" + q_y R_{yx}\n", | |
"\\end{align}" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"and (5) is\n", | |
"\n", | |
"\\begin{align}\n", | |
" \\Pr\\left(\\v\\right) \\E\\left[\\T X \\T Y \\mid \\v\\right]\n", | |
" &= \\rho \\Pr\\left( \\v \\right)\n", | |
" + \\rho h q_x R_{xy}\n", | |
" + \\rho k q_y R_{yx}\n", | |
" + \\underbrace{\\frac{\\sqrt{1-\\rho^2}}{\\sqrt{2 \\pi}} \\phi\\left( \\sqrt{\\frac{h^2 - 2 \\rho h k + k^2}{1 - \\rho^2}} \\right)}_{r_{xy}}\n", | |
"\\end{align}" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Thus we have\n", | |
"\n", | |
"\\begin{align}\n", | |
" \\E[X_+ Y_+]\n", | |
" &= \\mu_x \\mu_y \\Pr(\\v)\n", | |
" + \\mu_y \\sigma_x \\Pr(\\v) \\E[ \\T X \\mid \\v]\n", | |
" + \\mu_x \\sigma_y \\Pr(\\v) \\E[ \\T Y \\mid \\v]\n", | |
" + \\sigma_x \\sigma_y \\Pr(\\v) \\E\\left[ \\T X \\T Y \\mid \\v \\right]\n", | |
"\\\\&= \\mu_x \\mu_y \\Pr(\\v)\n", | |
" + \\mu_y \\sigma_x (q_x R_{xy} + \\rho q_y R_{yx})\n", | |
" + \\mu_x \\sigma_y (\\rho q_x R_{xy} + q_y R_{yx})\n", | |
" + \\sigma_x \\sigma_y \\left(\n", | |
" \\rho \\Pr\\left( \\v \\right)\n", | |
" - \\rho \\mu_x q_x R_{xy} / \\sigma_x\n", | |
" - \\rho \\mu_y q_y R_{yx} / \\sigma_y\n", | |
" + r_{xy}\n", | |
" \\right)\n", | |
"\\\\&= (\\mu_x \\mu_y + \\sigma_x \\sigma_y \\rho) \\Pr(\\v)\n", | |
" + (\\mu_y \\sigma_x + \\mu_x \\sigma_y \\rho - \\rho \\mu_x \\sigma_y) q_x R_{xy}\n", | |
" + (\\mu_y \\sigma_x \\rho + \\mu_x \\sigma_y - \\rho \\mu_y \\sigma_x) q_y R_{yx}\n", | |
" + \\sigma_x \\sigma_y r_{xy}\n", | |
"\\\\&= (\\mu_x \\mu_y + \\Sigma_{xy}) \\Pr(\\v)\n", | |
" + \\mu_y \\sigma_x q_x R_{xy}\n", | |
" + \\mu_x \\sigma_y q_y R_{yx}\n", | |
" + \\sigma_x \\sigma_y r_{xy}\n", | |
"\\end{align}" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"Recalling $\\E[X_+] = Q_x \\mu_x + q_x \\sigma_x$,\n", | |
"we get\n", | |
"\n", | |
"\\begin{align}\n", | |
" \\Cov(X_+, Y_+)\n", | |
" &= (\\mu_x \\mu_y + \\Sigma_{xy}) \\Pr(\\v)\n", | |
" + \\mu_y \\sigma_x q_x R_{xy}\n", | |
" + \\mu_x \\sigma_y q_y R_{yx}\n", | |
" + \\sigma_x \\sigma_y r_{xy}\n", | |
" - (Q_x \\mu_x + q_x \\sigma_x) (Q_y \\mu_y + q_y \\sigma_y)\n", | |
".\\end{align}\n", | |
"\n", | |
"Don't know if that simplifies at all." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def relu_mvn_mean_cov(mu, Sigma):\n", | |
" mu = np.asarray(mu, dtype=float)\n", | |
" Sigma = np.asarray(Sigma, dtype=float)\n", | |
" d, = mu.shape\n", | |
" assert Sigma.shape == (d, d)\n", | |
"\n", | |
" x = (slice(None), np.newaxis)\n", | |
" y = (np.newaxis, slice(None))\n", | |
" \n", | |
" sigma2s = np.diagonal(Sigma)\n", | |
" sigmas = np.sqrt(sigma2s)\n", | |
" rhos = Sigma / sigmas[x] / sigmas[y]\n", | |
"\n", | |
" prob = np.empty((d, d)) # prob[i, j] = Pr(X_i > 0, X_j > 0)\n", | |
" zero = np.zeros(d)\n", | |
" for i in range(d):\n", | |
" prob[i, i] = np.nan\n", | |
" for j in range(i + 1, d):\n", | |
" # Pr(X > 0) = Pr(-X < 0); X ~ N(mu, S) => -X ~ N(-mu, S)\n", | |
" s = [i, j]\n", | |
" prob[i, j] = prob[j, i] = stats.multivariate_normal.cdf(\n", | |
" zero[s], mean=-mu[s], cov=Sigma[np.ix_(s, s)])\n", | |
" \n", | |
" mu_sigs = mu / sigmas\n", | |
" \n", | |
" Q = stats.norm.cdf(mu_sigs)\n", | |
" q = stats.norm.pdf(mu_sigs)\n", | |
" mean = Q * mu + q * sigmas\n", | |
" \n", | |
" # rho_cs is sqrt(1 - rhos**2); but don't calculate diagonal, because\n", | |
" # it'll just be zero and we're dividing by it (but not using result)\n", | |
" # use inf instead of nan; stats.norm.cdf doesn't like nan inputs\n", | |
" rho_cs = 1 - rhos**2\n", | |
" np.fill_diagonal(rho_cs, np.inf)\n", | |
" np.sqrt(rho_cs, out=rho_cs)\n", | |
" \n", | |
" R = stats.norm.cdf((mu_sigs[y] - rhos * mu_sigs[x]) / rho_cs)\n", | |
" \n", | |
" mu_sigs_sq = mu_sigs ** 2\n", | |
" r_num = mu_sigs_sq[x] + mu_sigs_sq[y] - 2 * rhos * mu_sigs[x] * mu_sigs[y]\n", | |
" np.fill_diagonal(r_num, 1) # don't want slightly negative numerator here\n", | |
" r = rho_cs / np.sqrt(2 * np.pi) * stats.norm.pdf(np.sqrt(r_num) / rho_cs)\n", | |
" \n", | |
" bit = mu[y] * sigmas[x] * q[x] * R\n", | |
" cov = (\n", | |
" (mu[x] * mu[y] + Sigma) * prob\n", | |
" + bit + bit.T\n", | |
" + sigmas[x] * sigmas[y] * r\n", | |
" - mean[x] * mean[y])\n", | |
" \n", | |
" cov[range(d), range(d)] = Q * (1 - Q) * mu**2 + (1 - 2 * Q) * q * mu * sigmas + (Q - q**2) * sigma2s\n", | |
"\n", | |
" return mean, cov" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"np.random.seed(12)\n", | |
"d = 4\n", | |
"mu = np.random.randn(d)\n", | |
"L = np.random.randn(d, d)\n", | |
"Sigma = L.T @ L\n", | |
"dist = stats.multivariate_normal(mu, Sigma)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"mn, cov = relu_mvn_mean_cov(mu, Sigma)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"samps = np.maximum(0, dist.rvs(10**7))\n", | |
"mn_est = samps.mean(axis=0)\n", | |
"cov_est = np.cov(samps, rowvar=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"0.000572145310512 0.00298692620286\n" | |
] | |
} | |
], | |
"source": [ | |
"np.max(np.abs(mn - mn_est)), np.max(np.abs(cov - cov_est))" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python [conda env:py3]", | |
"language": "python", | |
"name": "conda-env-py3-py" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment