Last active
March 25, 2024 23:59
-
-
Save nhira/dedcada5982290d5f45213faac857ebd to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"last_runtime": { | |
"build_target": "//learning/deepmind/dm_python:dm_notebook3_tpu", | |
"kind": "private" | |
}, | |
"private_outputs": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# High performance LLMs - 2024 (Session 2)\n", | |
"\n", | |
"This notebook is a companion to Rafi Witten's class (content & recording on [GitHub](https://github.com/rwitten/HighPerfLLMs2024/tree/main/s02)). In this second session, we use some simple operations to understand performance with a single chip. \n", | |
"\n", | |
"We learn:\n", | |
"1. How to use [TensorBoard](https://www.tensorflow.org/tensorboard/get_started)\n", | |
"1. How device memory (high bandwidth memory or HBM) is used" | |
], | |
"metadata": { | |
"id": "QwCqI3dxLuYV" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Extension\n", | |
"Let's start by loading the TensorBoard notebook extension." | |
], | |
"metadata": { | |
"id": "KqJPVAmQKWLF" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"%load_ext google3.learning.brain.tensorboard.notebook.extension\n", | |
"\n", | |
"# for a Colab outside Google\n", | |
"# %load_ext tensorboard" | |
], | |
"metadata": { | |
"id": "8A6xUGPbKPxZ" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Basics\n", | |
"Next, we set up our code with a simple function to profile and time operations." | |
], | |
"metadata": { | |
"id": "gJ9kQAHdKeJi" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import jax\n", | |
"from jax.lib import xla_bridge\n", | |
"import datetime\n", | |
"import random\n", | |
"import string\n", | |
"\n", | |
"\n", | |
"def simple_timeit(f, *args, tries = 10, task = None):\n", | |
" '''Simple utility to time a function for multiple runs'''\n", | |
" assert task is not None\n", | |
"\n", | |
" trace_name = f\"t_{task}_\" + ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(10))\n", | |
" trace_dir = f\"/tmp/{trace_name}\"\n", | |
"\n", | |
" outcomes_ms = []\n", | |
" jax.block_until_ready(f(*args)) #warm it up!\n", | |
" jax.profiler.start_trace(trace_dir)\n", | |
"\n", | |
" for _ in range(tries):\n", | |
" s = datetime.datetime.now()\n", | |
" jax.block_until_ready(f(*args))\n", | |
" e = datetime.datetime.now()\n", | |
" outcomes_ms.append(1000*(e-s).total_seconds())\n", | |
" jax.profiler.stop_trace()\n", | |
"\n", | |
" average_time_ms = sum(outcomes_ms)/len(outcomes_ms)\n", | |
" print(f\"{task}: average time milliseconds: {average_time_ms:.2f}, trace {trace_dir}\")\n", | |
" return average_time_ms" | |
], | |
"metadata": { | |
"id": "H09tie3bGEp8" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"To be able to find the relevant source for Jax, we need the runtime version. For example, we see ``0.4.26`` below and that corresponds to the [main branch](https://github.com/google/jax)." | |
], | |
"metadata": { | |
"id": "Hq72JNVoM1f3" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"print(jax.version._version)" | |
], | |
"metadata": { | |
"id": "Q2wJYc2v1f0J" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"This notebook was created with a TPU v3-8 (single TPU v3), so we expect to see 8 devices available. We see from [the TPU documentation](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_v3) that each TPU v3 chip contains two TensorCores with a capacity of up to 32 GiB of HBM2 (high bandwidth memory). This means that each core has access to a maximum of 16 GiB.\n", | |
"\n", | |
"Note: this is different in TPU v4 and above where HBM may be [shared](https://github.com/google/jax/discussions/20049) between cores." | |
], | |
"metadata": { | |
"id": "ZUJobxr5OF3u" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"print(\" XLA device type: \", xla_bridge.get_backend().platform)\n", | |
"print(\" Accelerator devices in the cluster: \", jax.device_count())\n", | |
"print(\"Accelerator devices attached to this process: \", jax.local_device_count())\n", | |
"\n", | |
"print(\"\")\n", | |
"for _d in jax.devices():\n", | |
" print(_d.device_kind, \"process\", _d.process_index, \"id\", _d.id, \"local_hardware_id\", _d.local_hardware_id)" | |
], | |
"metadata": { | |
"id": "KewyiApExU1k" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"jax.devices()" | |
], | |
"metadata": { | |
"id": "hYEhSmOWWPJ0" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"As expected, we see 8 accelerator devices on 4 chips on a single host or process." | |
], | |
"metadata": { | |
"id": "MoEixsz2V3Up" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Matrix addition\n", | |
"Let's try a matrix addition example ``A + B``.\n", | |
"\n", | |
"We will repeat this multiple times so we can do some meaningful profiling." | |
], | |
"metadata": { | |
"id": "V4i9jRXsMqAn" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"MATRIX_SIZE = 16384\n", | |
"STEPS = 10\n", | |
"\n", | |
"# for each operation, we expect 3x the memory to be used (A + B and the output matrix)\n", | |
"# we expect each jax.numpy.float16 to be 2 bytes long\n", | |
"print(\"MB: \", (MATRIX_SIZE * MATRIX_SIZE * 2 * 3)/(1024**2))" | |
], | |
"metadata": { | |
"id": "LR8Mxc8MMhK3" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"\n", | |
"A = jax.numpy.ones ( (MATRIX_SIZE, MATRIX_SIZE), dtype=jax.numpy.float16)\n", | |
"# print(A.on_device_size_in_bytes() / 1024**2)\n", | |
"B = jax.numpy.ones ( (MATRIX_SIZE, MATRIX_SIZE), dtype=jax.numpy.float16)\n", | |
"\n", | |
"jax.profiler.start_trace(\"/tmp/addition\")\n", | |
"\n", | |
"s = datetime.datetime.now()\n", | |
"for i in range(STEPS):\n", | |
" O = A + B\n", | |
"e = datetime.datetime.now()\n", | |
"\n", | |
"jax.profiler.stop_trace()\n", | |
"\n", | |
"print( f\"Straight addition takes {(e-s).total_seconds()/STEPS:.4f} on average\")\n" | |
], | |
"metadata": { | |
"id": "AcdDu5elnzfc" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"Let's take a look at our profile. Be sure to select the correct run. \n", | |
"\n", | |
"1. Look at the memory_profile. Notice the familar saw-tooth pattern for heap memory.\n", | |
"1. Look at the memory_viewer. What's the peak memory allocation for the ``jit_fn module``? Is it what we expect?\n", | |
"1. Look at the trace_viewer@. Zoom in until you can see the operations. We expect to see up to 10 ``add.3`` steps.\n", | |
"1. Look at the op_profile. Notice the TPU FLOPS utilization, the HBM bandwidth utilization, and the wasted time." | |
], | |
"metadata": { | |
"id": "vqooxfGXeCRC" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"%tensorboard --logdir=/tmp/addition --port=0\n", | |
"# for a Colab outside Google\n", | |
"# %tensorboard --logdir=/tmp/addition" | |
], | |
"metadata": { | |
"id": "4v8PyMEPPO0E" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Adding three matrices\n", | |
"Let's try adding three matrices ``A + B + C``. This requires two addition steps: first, ``(A + B)`` and then, ``(A + B) + C``.\n", | |
"\n", | |
"Again, we will repeat this multiple times so we can do some meaningful profiling." | |
], | |
"metadata": { | |
"id": "G_oL7uWxp4S3" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"MATRIX_SIZE = 16384\n", | |
"STEPS = 10\n", | |
"\n", | |
"# for each operation, we expect 4x the memory to be used\n", | |
"# (A, B, C, and the resulting sum)\n", | |
"# we expect the operations to be carried out as:\n", | |
"# intermediate = (A + B); O = (intermediate + C)\n", | |
"# we expect each jax.numpy.float16 to be 2 bytes long\n", | |
"print(\"MB: \", (MATRIX_SIZE * MATRIX_SIZE * 2 * 4)/(1024**2))\n" | |
], | |
"metadata": { | |
"id": "d1f_d858eGBx" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"A = jax.numpy.ones ( (MATRIX_SIZE, MATRIX_SIZE), dtype=jax.numpy.float16)\n", | |
"B = jax.numpy.ones ( (MATRIX_SIZE, MATRIX_SIZE), dtype=jax.numpy.float16)\n", | |
"C = jax.numpy.ones ( (MATRIX_SIZE, MATRIX_SIZE), dtype=jax.numpy.float16)\n", | |
"\n", | |
"jax.profiler.start_trace(\"/tmp/addition3\")\n", | |
"\n", | |
"s = datetime.datetime.now()\n", | |
"for i in range(STEPS):\n", | |
" O = A + B + C\n", | |
"e = datetime.datetime.now()\n", | |
"\n", | |
"jax.profiler.stop_trace()\n", | |
"\n", | |
"print( f\"Two additions take {(e-s).total_seconds()/STEPS:.4f} on average\")\n" | |
], | |
"metadata": { | |
"id": "lK3QghYTobhr" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"Let's take a look at our profile. Be sure to select the correct run. \n", | |
"\n", | |
"1. Look at the memory_viewer. What's the peak memory allocation for the ``jit_fn`` module? Is it what we expect?\n", | |
"1. What about ``jit_broadcast_in_dim``?\n", | |
"1. Look at the trace_viewer@. Zoom in until you can see the operations. We expect to see up to 16 ``add.3`` steps.\n", | |
"1. Look at the op_profile. Notice the TPU FLOPS utilization, the HBM bandwidth utilization, and the wasted time." | |
], | |
"metadata": { | |
"id": "hnynb2Mglvix" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"%tensorboard --logdir=/tmp/addition3 --port=0\n", | |
"# for a Colab outside Google\n", | |
"# %tensorboard --logdir=/tmp/addition3" | |
], | |
"metadata": { | |
"id": "tixkVo3dl3gT" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Introducing just in time compilation (JIT)\n", | |
"\n", | |
"To learn more about JIT on Jax, try the [tutorial](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html).\n", | |
"\n", | |
"Let's try our ``A + B`` addition with ``jit`` to see a simple example." | |
], | |
"metadata": { | |
"id": "spGWMHrjh9-w" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"A = jax.numpy.ones ( (MATRIX_SIZE, MATRIX_SIZE), dtype=jax.numpy.float16)\n", | |
"B = jax.numpy.ones ( (MATRIX_SIZE, MATRIX_SIZE), dtype=jax.numpy.float16)\n", | |
"\n", | |
"def f2(A, B):\n", | |
" return A + B\n", | |
"f2_jit = jax.jit(f2)\n", | |
"\n", | |
"simple_timeit(f2, A, B, task = \"f2\")\n", | |
"simple_timeit(f2_jit, A, B, task = \"f2_jit\")\n" | |
], | |
"metadata": { | |
"id": "AaIOlFV2m_-C" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"As you can see, compilation doesn't always yield dramatic performance improvements. Now, let's try our ``A + B + C`` addition with ``jit``." | |
], | |
"metadata": { | |
"id": "loJEtkwRrR-4" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"A = jax.numpy.ones ( (MATRIX_SIZE, MATRIX_SIZE), dtype=jax.numpy.float16)\n", | |
"B = jax.numpy.ones ( (MATRIX_SIZE, MATRIX_SIZE), dtype=jax.numpy.float16)\n", | |
"C = jax.numpy.ones ( (MATRIX_SIZE, MATRIX_SIZE), dtype=jax.numpy.float16)\n", | |
"\n", | |
"def f3(X,Y,Z):\n", | |
" return X+Y+Z\n", | |
"\n", | |
"f3_jit = jax.jit(f3)\n", | |
"\n", | |
"simple_timeit(f3, A, B, C, task = \"f3\")\n", | |
"simple_timeit(f3_jit, A, B, C, task = \"f3_jit\")\n" | |
], | |
"metadata": { | |
"id": "5RquGMgjqlLf" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"To understand how jit helped, try out the TensorBoard (you'll need to use the correct log directory for the trace from the previous cell's output).\n", | |
"\n", | |
"1. Look at the memory_viewer. What's the peak memory allocation for the ``jit_f3`` module? Is it what we expect?\n", | |
"1. Look at the trace_viewer@. Zoom in until you can see the operations. What's different?\n", | |
"1. Look at the graph_viewer. Search for the Op Name \"fusion\". What do you see?" | |
], | |
"metadata": { | |
"id": "xGKhjrizq-aL" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"%tensorboard --logdir=/tmp/t_f3_jit_Q4MQHIC0Q8 --port=0\n", | |
"# for a Colab outside Google\n", | |
"# %tensorboard --logdir=[DIR]" | |
], | |
"metadata": { | |
"id": "Wad58cTp4cRu" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Matrix multiplication\n", | |
"\n", | |
"In the right scenarios, the compiler can fuse multiple operations (even loops :) ) to offer performance gains. Let's try out matrix multiplication using ``numpy``." | |
], | |
"metadata": { | |
"id": "awLxSon_80ER" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"MATRIX_SIZE = 4096\n", | |
"STEPS = 10\n", | |
"\n", | |
"# for each operation, we expect 4x the memory to be used\n", | |
"# (A, B, C, and the resulting sum)\n", | |
"# we expect the operations to be carried out as:\n", | |
"# intermediate = (A + B); O = (intermediate + C)\n", | |
"# we expect each jax.numpy.float16 to be 2 bytes long\n", | |
"print(\"MB: \", (MATRIX_SIZE * MATRIX_SIZE * 2 * 4)/(1024**2))\n" | |
], | |
"metadata": { | |
"id": "kr1vut-2AFqT" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"A = jax.numpy.ones ( (MATRIX_SIZE, MATRIX_SIZE), dtype=jax.numpy.float16)\n", | |
"B = jax.numpy.ones ( (MATRIX_SIZE, MATRIX_SIZE), dtype=jax.numpy.float16)\n", | |
"C = jax.numpy.ones ( (MATRIX_SIZE, MATRIX_SIZE), dtype=jax.numpy.float16)\n", | |
"\n", | |
"def matmul(A,B):\n", | |
" return A@B\n", | |
"jit_matmul = jax.jit(matmul)\n", | |
"\n", | |
"simple_timeit(matmul, A, B, task = 'matmul')\n", | |
"simple_timeit(jit_matmul, A, B, task = 'jit_matmul')\n" | |
], | |
"metadata": { | |
"id": "WXSPmRRCAHHw" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"So what does the profiler tell us about the ``jit_matmul`` approach?\n", | |
"\n", | |
"Let's look at the FLOPS utilization.\n" | |
], | |
"metadata": { | |
"id": "h9Vrh_CXJmBm" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"%tensorboard --logdir=/tmp/t_jit_matmul_72ZBQZR4CX --port=0\n", | |
"# for a Colab outside Google\n", | |
"# %tensorboard --logdir=[DIR]" | |
], | |
"metadata": { | |
"id": "9mrIAVPMAGwQ" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def matmul_with_activation(A,B,C):\n", | |
" return jax.nn.relu(A@B)\n", | |
"jit_matmul_foldin = jax.jit(matmul_with_activation)\n", | |
"\n", | |
"simple_timeit(matmul_with_activation, A, B, C, task='matmul_foldin')\n", | |
"simple_timeit(jit_matmul_foldin, A, B, C, task = 'jit_matmul_foldin')" | |
], | |
"metadata": { | |
"id": "shiBzssAAGfA" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"%tensorboard --logdir=/tmp/t_jit_matmul_foldin_0BGHOPMA6F --port=0\n", | |
"# for a Colab outside Google\n", | |
"# %tensorboard --logdir=[DIR]" | |
], | |
"metadata": { | |
"id": "rqTzDBFzAGEl" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"Here's an example Rafi shared for roofline analysis." | |
], | |
"metadata": { | |
"id": "0MU4uhS1QM6U" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# MATMUL_SIZES = [(64000,128), (16000,256), (4000,512), (3000,640), (2000,768), (1000,1024), (250, 2048)]\n", | |
"MATMUL_SIZES = [(3000,256), (2000,512), (1000,640), (500,768), (300,1024), (100, 1536)]\n", | |
"\n", | |
"for num_matmuls, matrix_size in MATMUL_SIZES:\n", | |
" A = jax.numpy.ones ( (num_matmuls, matrix_size, matrix_size), dtype=jax.numpy.bfloat16)\n", | |
" B = jax.numpy.ones ( (num_matmuls, matrix_size, matrix_size), dtype=jax.numpy.bfloat16)\n", | |
"\n", | |
" @jax.jit\n", | |
" def f(X,Y):\n", | |
" return jax.lax.batch_matmul(X,Y)\n", | |
"\n", | |
" print(f(A,B).shape)\n", | |
"\n", | |
" simple_timeit(f, A, B, task = 'matmul_' + str(matrix_size))" | |
], | |
"metadata": { | |
"id": "vig3Z9alQNXA" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"Rafi suggested this as an after session exercise:\n", | |
"> Consider the function f(A,B) = jax.nn.relu(A@B). (Assume A and B are square martrices.)\n", | |
"What percentage faster will jit(f) be than f? Does it depend on the size of A and B?" | |
], | |
"metadata": { | |
"id": "mMMxg8fAVET1" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [], | |
"metadata": { | |
"id": "VdWmfd_lVD9K" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Conclusion\n", | |
"In this session. we learned some simple techniques to use Jax with TPUs. We also learned how fusing execution steps at compilation can lead to performance improvements." | |
], | |
"metadata": { | |
"id": "XIqZeLm5sVGz" | |
} | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment