Last active
March 31, 2019 04:19
-
-
Save fehiepsi/2fdbdd7db13183a0b3fd9fad53287cc8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import jax.numpy as np\n", | |
"from jax import jit, lax, random\n", | |
"from jax.util import partial" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/home/fehiepsi/jax/jax/lib/xla_bridge.py:122: UserWarning: No GPU found, falling back to CPU.\n", | |
" warnings.warn('No GPU found, falling back to CPU.')\n" | |
] | |
} | |
], | |
"source": [ | |
"x = random.normal(random.PRNGKey(0), ())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def f(x, a):\n", | |
" return x + a" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 12.6 ms, sys: 49 µs, total: 12.6 ms\n", | |
"Wall time: 12 ms\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"array(1.7941577, dtype=float32)" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%time f(x, 2)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def g(fn, a):\n", | |
" return lax.fori_loop(0, 1000000, lambda i, v: fn(v), a)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"@jit\n", | |
"def h(x, a):\n", | |
" pf = lambda a: f(x, a)\n", | |
" return g(pf, a)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 20.2 ms, sys: 13 µs, total: 20.2 ms\n", | |
"Wall time: 19.3 ms\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"array(-204124.73, dtype=float32)" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%time h(x, 3.)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 4.05 ms, sys: 2 µs, total: 4.05 ms\n", | |
"Wall time: 3.42 ms\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"array(-204126.7, dtype=float32)" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%time h(x, 1.)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"y = x + 0.1" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 4.43 ms, sys: 40 µs, total: 4.47 ms\n", | |
"Wall time: 3.76 ms\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"array(-106981.37, dtype=float32)" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%time h(y, 3.)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 4.2 ms, sys: 15 µs, total: 4.21 ms\n", | |
"Wall time: 3.74 ms\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"array(-106980.33, dtype=float32)" | |
] | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"%time h(y, 4.)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.8" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I think this seems reasonable. Let me check a few other things I had in mind and get back to you on this.