Skip to content

Instantly share code, notes, and snippets.

@dfm
Created October 14, 2024 15:40
Show Gist options
  • Save dfm/ddd81aa117c2a9285d4cf7b124d60d44 to your computer and use it in GitHub Desktop.
Save dfm/ddd81aa117c2a9285d4cf7b124d60d44 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "0uruZzkXqWXt"
},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp"
]
},
{
"cell_type": "code",
"source": [
"def fun(A, x):\n",
" return jnp.exp(A @ x + 1.0)\n",
"\n",
"A = jnp.eye(5)\n",
"x = jnp.ones(5)"
],
"metadata": {
"id": "2A0G8iSRqXuP"
},
"execution_count": 8,
"outputs": []
},
{
"cell_type": "code",
"source": [
"jax.make_jaxpr(fun)(A, x)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "-v0wmC5wqd_a",
"outputId": "c0395733-29b8-425c-d0f9-065ae445a1d6"
},
"execution_count": 9,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{ lambda ; a:f32[5,5] b:f32[5]. let\n",
" c:f32[5] = dot_general[\n",
" dimension_numbers=(([1], [0]), ([], []))\n",
" preferred_element_type=float32\n",
" ] a b\n",
" d:f32[5] = add c 1.0\n",
" e:f32[5] = exp d\n",
" in (e,) }"
]
},
"metadata": {},
"execution_count": 9
}
]
},
{
"cell_type": "code",
"source": [
"print(jax.jit(fun).lower(A, x).as_text())"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "MQ1uFJ05qe27",
"outputId": "448d24b1-29e5-469a-f67a-7cd1570edf2d"
},
"execution_count": 10,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"module @jit_fun attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {\n",
" func.func public @main(%arg0: tensor<5x5xf32> {mhlo.layout_mode = \"default\"}, %arg1: tensor<5xf32> {mhlo.layout_mode = \"default\"}) -> (tensor<5xf32> {jax.result_info = \"\", mhlo.layout_mode = \"default\"}) {\n",
" %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<5x5xf32>, tensor<5xf32>) -> tensor<5xf32>\n",
" %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>\n",
" %1 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor<f32>) -> tensor<5xf32>\n",
" %2 = stablehlo.add %0, %1 : tensor<5xf32>\n",
" %3 = stablehlo.exponential %2 : tensor<5xf32>\n",
" return %3 : tensor<5xf32>\n",
" }\n",
"}\n",
"\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"print(jax.jit(fun).lower(A, x).compile().as_text())"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "kjY5tyb6qrva",
"outputId": "d1d6224f-05c8-484f-a2d3-0124f942ae08"
},
"execution_count": 11,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"HloModule jit_fun, is_scheduled=true, entry_computation_layout={(f32[5,5]{1,0}, f32[5]{0})->f32[5]{0}}, allow_spmd_sharding_propagation_to_parameters={true,true}, allow_spmd_sharding_propagation_to_output={true}\n",
"\n",
"%fused_computation (param_0: f32[5]) -> f32[5] {\n",
" %param_0 = f32[5]{0} parameter(0)\n",
" %constant.0 = f32[] constant(1)\n",
" %broadcast.0 = f32[5]{0} broadcast(f32[] %constant.0), dimensions={}\n",
" ROOT %add.0 = f32[5]{0} add(f32[5]{0} %param_0, f32[5]{0} %broadcast.0), metadata={op_name=\"jit(fun)/jit(main)/add\" source_file=\"<ipython-input-8-902eb6e0a694>\" source_line=2}\n",
"}\n",
"\n",
"ENTRY %main.8 (Arg_0.1: f32[5,5], Arg_1.2: f32[5]) -> f32[5] {\n",
" %Arg_0.1 = f32[5,5]{1,0} parameter(0), metadata={op_name=\"A\"}\n",
" %Arg_1.2 = f32[5]{0} parameter(1), metadata={op_name=\"x\"}\n",
" %dot.5 = f32[5]{0} dot(f32[5,5]{1,0} %Arg_0.1, f32[5]{0} %Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name=\"jit(fun)/jit(main)/dot_general\" source_file=\"<ipython-input-8-902eb6e0a694>\" source_line=2}\n",
" %fusion = f32[5]{0} fusion(f32[5]{0} %dot.5), kind=kLoop, calls=%fused_computation, metadata={op_name=\"jit(fun)/jit(main)/add\" source_file=\"<ipython-input-8-902eb6e0a694>\" source_line=2}\n",
" ROOT %exponential.7 = f32[5]{0} exponential(f32[5]{0} %fusion), metadata={op_name=\"jit(fun)/jit(main)/exp\" source_file=\"<ipython-input-8-902eb6e0a694>\" source_line=2}\n",
"}\n",
"\n",
"\n"
]
}
]
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "xwOMeHKIq1sz"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment