Skip to content

Instantly share code, notes, and snippets.

@gpantalos
Last active November 7, 2021 00:15
Show Gist options
  • Save gpantalos/cbec955549f22d8dbbcc41fa4dfb61ad to your computer and use it in GitHub Desktop.
Save gpantalos/cbec955549f22d8dbbcc41fa4dfb61ad to your computer and use it in GitHub Desktop.
Stein's Identity in Tensorflow
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "stein_identity.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyM+COvZLzmhu/hW65ZSfSP4",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/gpantalos/cbec955549f22d8dbbcc41fa4dfb61ad/stein_identity.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HjL76oTdVP_M"
},
"source": [
"# Stein's Identity in TensorFlow\n",
"This is Stein's identity\n",
"$$\n",
"\\mathbb{E}_{x \\sim p}\\left[\\phi(x) \\nabla_{x} \\log p(x)^{\\top}+\\nabla_{x} \\phi(x)\\right]=0\n",
"$$\n",
"Let's test it in TF."
]
},
{
"cell_type": "code",
"metadata": {
"id": "3BTuVKAxU62o"
},
"source": [
"import tensorflow as tf\n",
"import tensorflow_probability as tfp\n",
"tfd = tfp.distributions\n",
"tf.random.set_seed(0)"
],
"execution_count": 2,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "dny1VJB9V3nL"
},
"source": [
"Define a distribution $p = \\mathcal N(0,1)$"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "1JM92iB7VDKq",
"outputId": "65abf331-b2db-4730-8b32-280305a98781"
},
"source": [
"p = tfd.Normal(0,1)\n",
"p"
],
"execution_count": 3,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<tfp.distributions.Normal 'Normal' batch_shape=[] event_shape=[] dtype=float32>"
]
},
"metadata": {
"tags": []
},
"execution_count": 3
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gKNYVgbxV-iI"
},
"source": [
"Sample $x \\sim p$"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "pO5icgaEVG3p",
"outputId": "2351337b-2c1a-4e7f-c545-f3748ded9481"
},
"source": [
"x = p.sample(1e5)\n",
"x = tf.Variable(x)\n",
"x.shape"
],
"execution_count": 4,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"TensorShape([100000])"
]
},
"metadata": {
"tags": []
},
"execution_count": 4
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5qAhWPeiWpUZ"
},
"source": [
"Compute and evaluate score function $ \\nabla_{x} \\log p(x)$"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "uzkWCSh0Wr59",
"outputId": "c1a470f8-d761-4286-87f5-1006ffb24183"
},
"source": [
"with tf.GradientTape() as tape:\n",
" logp = p.log_prob(x)\n",
"score = tape.gradient(logp, x)\n",
"score.shape"
],
"execution_count": 5,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"TensorShape([100000])"
]
},
"metadata": {
"tags": []
},
"execution_count": 5
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZMP3J2HTVuus"
},
"source": [
"Define a mapping $\\phi(x)$ of any number of smooth functions"
]
},
{
"cell_type": "code",
"metadata": {
"id": "UsvaLrEUVKLa"
},
"source": [
"phi = [\n",
" lambda i: i, \n",
" lambda i: i ** 2,\n",
" lambda i: tf.exp(i), \n",
" lambda i: tf.sin(i), \n",
" \n",
" # add smooth functions...\n",
"]"
],
"execution_count": 6,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "I-9hirvqXXsL"
},
"source": [
"Compute gradient wrt inputs $\\nabla_{x} \\phi(x)$"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "nH6RfK_bWWYP",
"outputId": "7c1301bb-357a-4de3-827a-6ed74b4bec13"
},
"source": [
"grad_phi_x = []\n",
"phi_xs = []\n",
"\n",
"for mapping in phi:\n",
" with tf.GradientTape() as tape:\n",
" phi_x = mapping(x) \n",
" phi_xs.append(phi_x)\n",
" grad_phi_x.append(tape.gradient(phi_x, x))\n",
"\n",
"# stack gradients\n",
"phi_xs = tf.stack(phi_xs)\n",
"grad_phi_x = tf.stack(grad_phi_x)\n",
"grad_phi_x.shape"
],
"execution_count": 7,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"TensorShape([4, 100000])"
]
},
"metadata": {
"tags": []
},
"execution_count": 7
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ONnBztPCX1uc"
},
"source": [
"Combine previous results to verify Stein's identity"
]
},
{
"cell_type": "code",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "T8flr6qdXxFn",
"outputId": "261b0cc5-f333-4656-8ba2-5fa12aaa7004"
},
"source": [
"tf.reduce_mean(phi_xs * score + grad_phi_x).numpy() ** 2"
],
"execution_count": 8,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"5.552772948637223e-05"
]
},
"metadata": {
"tags": []
},
"execution_count": 8
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ijNbFJlmYukL"
},
"source": [
"Indeed approaches 0!"
]
},
{
"cell_type": "code",
"metadata": {
"id": "c7X61SRtvrCQ"
},
"source": [
""
],
"execution_count": 8,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment