Created
December 21, 2019 03:22
-
-
Save fehiepsi/f35f9c44e3c1814f7a7ed01dcab69462 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import tensorflow.compat.v1 as tf1\n", | |
"import tensorflow.compat.v2 as tf\n", | |
"import tensorflow_probability as tfp\n", | |
"from tensorflow_probability import edward2 as ed" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Python version\n", | |
"3.6.9 |Anaconda, Inc.| (default, Jul 30 2019, 19:07:31) \n", | |
"[GCC 7.3.0]\n", | |
"Version info.\n", | |
"sys.version_info(major=3, minor=6, micro=9, releaselevel='final', serial=0)\n" | |
] | |
} | |
], | |
"source": [ | |
"import sys\n", | |
"print(\"Python version\")\n", | |
"print(sys.version)\n", | |
"print(\"Version info.\")\n", | |
"print(sys.version_info)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import os\n", | |
"os.environ['CUDA_VISIBLE_DEVICES'] = ''" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[_DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 268435456, 14692155504090053292), _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 17179869184, 3186490635567837349)]\n" | |
] | |
} | |
], | |
"source": [ | |
"with tf1.Session() as session:\n", | |
" print(session.list_devices())\n", | |
"\n", | |
"USE_XLA = True" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Eager mode: True\n", | |
"XLA: True\n" | |
] | |
} | |
], | |
"source": [ | |
"print(\"Eager mode: {}\".format(tf.executing_eagerly()))\n", | |
"print(\"XLA: {}\".format(USE_XLA))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def edward_model(features):\n", | |
" \"\"\"Bayesian logistic regression, which returns labels given features.\"\"\"\n", | |
" coeffs = ed.MultivariateNormalDiag(\n", | |
" loc=tf.zeros(features.shape[1]), name=\"coeffs\")\n", | |
" labels = ed.Bernoulli(\n", | |
" logits=tf.tensordot(features, coeffs, [[1], [0]]), name=\"labels\")\n", | |
" return labels" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as onp\n", | |
"from numpyro.examples.datasets import COVTYPE, load_dataset\n", | |
"\n", | |
"def get_data():\n", | |
" _, fetch = load_dataset(COVTYPE, shuffle=False)\n", | |
" features, labels = fetch()\n", | |
"\n", | |
" # normalize features and add intercept\n", | |
" features = (features - features.mean(0)) / features.std(0)\n", | |
" features = onp.hstack([features, onp.ones((features.shape[0], 1))])\n", | |
"\n", | |
" # make binary feature\n", | |
" _, counts = onp.unique(labels, return_counts=True)\n", | |
" specific_category = onp.argmax(counts)\n", | |
" labels = (labels == specific_category)\n", | |
"\n", | |
" N, dim = features.shape\n", | |
" print(\"Data shape:\", features.shape)\n", | |
" print(\"Label distribution: {} has label 1, {} has label 0\"\n", | |
" .format(labels.sum(), N - labels.sum()))\n", | |
" return features, labels.astype(int)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Data shape: (581012, 55)\n", | |
"Label distribution: 211840 has label 1, 369172 has label 0\n" | |
] | |
} | |
], | |
"source": [ | |
"nsamples, nchains = 40, 1\n", | |
"nd = 55\n", | |
"#theta0 = onp.zeros((nd,))\n", | |
"onp.random.seed(0)\n", | |
"theta0 = onp.random.uniform(-2, 2, 55)\n", | |
"dtype=tf.float32\n", | |
"data = get_data()\n", | |
"features = tf.cast(data[0], dtype=dtype)\n", | |
"labels = tf.cast(data[1], dtype=dtype)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def run_nuts():\n", | |
" event_size=nd\n", | |
" num_steps=nsamples\n", | |
" initial_state=tf.cast(theta0, dtype=dtype)\n", | |
"\n", | |
" def trace_fn(_, pkr):\n", | |
" return (pkr.leapfrogs_taken,)\n", | |
"\n", | |
" log_joint = ed.make_log_joint_fn(edward_model)\n", | |
" def target_log_prob_fn(coeffs):\n", | |
" return log_joint(features=features, coeffs=coeffs, labels=labels)\n", | |
"\n", | |
" mc_kernel = tfp.mcmc.NoUTurnSampler(\n", | |
" target_log_prob_fn=target_log_prob_fn,\n", | |
" step_size=0.0015,\n", | |
" )\n", | |
" [chain_state], sampler_stat = tfp.mcmc.sample_chain(\n", | |
" num_results=num_steps,\n", | |
" num_burnin_steps=0,\n", | |
" current_state=[initial_state],\n", | |
" kernel=mc_kernel,\n", | |
" trace_fn=trace_fn)\n", | |
" return chain_state, sampler_stat" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/home/fehiepsi/miniconda3/envs/pydata/lib/python3.6/site-packages/ipykernel_launcher.py:9: UserWarning: tfp.edward2 module is deprecated and will be removed on 2019-12-01. Use https://github.com/google/edward2 library instead.\n", | |
" if __name__ == '__main__':\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"time per leapfrog: 0.08295800400194982\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"(<tf.Tensor: id=1628, shape=(40,), dtype=int32, numpy=\n", | |
" array([ 1, 64, 71, 72, 75, 78, 81, 84, 87,\n", | |
" 94, 203, 234, 241, 244, 251, 258, 289, 304,\n", | |
" 311, 318, 325, 332, 363, 370, 385, 400, 655,\n", | |
" 1678, 2701, 3724, 4235, 4746, 5769, 6792, 7303, 8326,\n", | |
" 9349, 10372, 11395, 12418], dtype=int32)>,)" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"import time\n", | |
"tic = time.time()\n", | |
"samples, sampler_stat = tf.xla.experimental.compile(run_nuts)\n", | |
"num_leapfrogs = sampler_stat[0].numpy()[-1]\n", | |
"print(\"time per leapfrog:\", (time.time() - tic) / num_leapfrogs)\n", | |
"sampler_stat" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"time per leapfrog: 0.08227630637937315\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"(<tf.Tensor: id=3252, shape=(40,), dtype=int32, numpy=\n", | |
" array([ 1, 4, 5, 6, 7, 14, 17, 24, 31,\n", | |
" 34, 161, 164, 171, 174, 177, 180, 243, 250,\n", | |
" 253, 260, 263, 294, 301, 316, 331, 362, 369,\n", | |
" 1392, 2415, 3438, 4461, 5484, 6507, 7530, 8041, 9064,\n", | |
" 10087, 11110, 11621, 12644], dtype=int32)>,)" | |
] | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"import time\n", | |
"tic = time.time()\n", | |
"samples, sampler_stat = tf.xla.experimental.compile(run_nuts)\n", | |
"num_leapfrogs = sampler_stat[0].numpy()[-1]\n", | |
"print(\"time per leapfrog:\", (time.time() - tic) / num_leapfrogs)\n", | |
"sampler_stat" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.9" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment