Skip to content

Instantly share code, notes, and snippets.

@shoyer
Created February 28, 2020 07:14
Show Gist options
  • Save shoyer/0b3221ed0431befdfbfc9884e9353f8e to your computer and use it in GitHub Desktop.
Save shoyer/0b3221ed0431befdfbfc9884e9353f8e to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "JAX RNG class.ipynb",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "code",
"metadata": {
"id": "UEQGTwgdli1e",
"colab_type": "code",
"colab": {}
},
"source": [
"# Copyright 2020 Google LLC\n",
"\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License.\n",
"\n",
"import jax\n",
"from jax import random\n",
"\n",
"def _with_key(func):\n",
" def method(self, *args, **kwargs):\n",
" return func(self.key, *args, **kwargs)\n",
" return method\n",
"\n",
"class RNG:\n",
" def __init__(self, key):\n",
" self.key = key\n",
" \n",
" def __repr__(self):\n",
" return f'{type(self).__name__}({self.key!r})'\n",
" \n",
" def split(self, num=2):\n",
" return [RNG(k) for k in random.split(self.key, num)]\n",
" \n",
" uniform = _with_key(random.uniform)\n",
" normal = _with_key(random.normal)\n",
"\n",
"jax.tree_util.register_pytree_node(\n",
" RNG,\n",
" lambda rng: ([rng.key], None),\n",
" lambda aux, values: RNG(values[0]),\n",
")\n",
"\n",
"def rng(seed):\n",
" return RNG(random.PRNGKey(seed))"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "_BxUqpsimlkR",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "b1436361-ed7d-4f3e-9e58-e84168073ab1"
},
"source": [
"rng(10).uniform()"
],
"execution_count": 38,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"DeviceArray(0.08938682, dtype=float32)"
]
},
"metadata": {
"tags": []
},
"execution_count": 38
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "lCH3R8BQoQSD",
"colab_type": "code",
"colab": {}
},
"source": [
"keys = jax.vmap(random.PRNGKey)(jax.numpy.arange(3))\n",
"samples = jax.vmap(random.uniform)(keys)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "QU1T2wORolBi",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "503d7d69-eec3-4380-9558-a982f442098e"
},
"source": [
"samples"
],
"execution_count": 48,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"DeviceArray([0.41845703, 0.11815023, 0.4240216 ], dtype=float32)"
]
},
"metadata": {
"tags": []
},
"execution_count": 48
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "6WTCUSu9opYm",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "96723e21-e685-475f-9611-9e2f9141dba8"
},
"source": [
"jax.vmap(RNG.uniform)(jax.vmap(rng)(jax.numpy.arange(3)))"
],
"execution_count": 52,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"DeviceArray([0.41845703, 0.11815023, 0.4240216 ], dtype=float32)"
]
},
"metadata": {
"tags": []
},
"execution_count": 52
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "yAtMC1dGmLBX",
"colab_type": "code",
"colab": {}
},
"source": [
"@jax.jit\n",
"def split_and_sample(rng):\n",
" rng, sub_rng = rng.split()\n",
" val = sub_rng.normal(shape=(3,))\n",
" return rng, val\n"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "G-oOx3aJlwcv",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 51
},
"outputId": "4593b14e-7a1b-4e03-b036-4e68ebd5c203"
},
"source": [
"split_and_sample(rng(10))"
],
"execution_count": 41,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(RNG(DeviceArray([3912842007, 31661381], dtype=uint32)),\n",
" DeviceArray([0.47754696, 0.2578578 , 2.4254863 ], dtype=float32))"
]
},
"metadata": {
"tags": []
},
"execution_count": 41
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "lJfgo6I3lysV",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": 0,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment