Skip to content

Instantly share code, notes, and snippets.

@nhira
Last active March 25, 2024 23:59
Show Gist options
  • Save nhira/dedcada5982290d5f45213faac857ebd to your computer and use it in GitHub Desktop.
Save nhira/dedcada5982290d5f45213faac857ebd to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"last_runtime": {
"build_target": "//learning/deepmind/dm_python:dm_notebook3_tpu",
"kind": "private"
},
"private_outputs": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"source": [
"# High performance LLMs - 2024 (Session 2)\n",
"\n",
"This notebook is a companion to Rafi Witten's class (content & recording on [GitHub](https://github.com/rwitten/HighPerfLLMs2024/tree/main/s02)). In this second session, we use some simple operations to understand performance with a single chip. \n",
"\n",
"We learn:\n",
"1. How to use [TensorBoard](https://www.tensorflow.org/tensorboard/get_started)\n",
"1. How device memory (high bandwidth memory or HBM) is used"
],
"metadata": {
"id": "QwCqI3dxLuYV"
}
},
{
"cell_type": "markdown",
"source": [
"## Extension\n",
"Let's start by loading the TensorBoard notebook extension."
],
"metadata": {
"id": "KqJPVAmQKWLF"
}
},
{
"cell_type": "code",
"source": [
"%load_ext google3.learning.brain.tensorboard.notebook.extension\n",
"\n",
"# for a Colab outside Google\n",
"# %load_ext tensorboard"
],
"metadata": {
"id": "8A6xUGPbKPxZ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Basics\n",
"Next, we set up our code with a simple function to profile and time operations."
],
"metadata": {
"id": "gJ9kQAHdKeJi"
}
},
{
"cell_type": "code",
"source": [
"import jax\n",
"from jax.lib import xla_bridge\n",
"import datetime\n",
"import random\n",
"import string\n",
"\n",
"\n",
"def simple_timeit(f, *args, tries = 10, task = None):\n",
" '''Simple utility to time a function for multiple runs'''\n",
" assert task is not None\n",
"\n",
" trace_name = f\"t_{task}_\" + ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(10))\n",
" trace_dir = f\"/tmp/{trace_name}\"\n",
"\n",
" outcomes_ms = []\n",
" jax.block_until_ready(f(*args)) #warm it up!\n",
" jax.profiler.start_trace(trace_dir)\n",
"\n",
" for _ in range(tries):\n",
" s = datetime.datetime.now()\n",
" jax.block_until_ready(f(*args))\n",
" e = datetime.datetime.now()\n",
" outcomes_ms.append(1000*(e-s).total_seconds())\n",
" jax.profiler.stop_trace()\n",
"\n",
" average_time_ms = sum(outcomes_ms)/len(outcomes_ms)\n",
" print(f\"{task}: average time milliseconds: {average_time_ms:.2f}, trace {trace_dir}\")\n",
" return average_time_ms"
],
"metadata": {
"id": "H09tie3bGEp8"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"To be able to find the relevant source for Jax, we need the runtime version. For example, we see ``0.4.26`` below and that corresponds to the [main branch](https://github.com/google/jax)."
],
"metadata": {
"id": "Hq72JNVoM1f3"
}
},
{
"cell_type": "code",
"source": [
"print(jax.version._version)"
],
"metadata": {
"id": "Q2wJYc2v1f0J"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"This notebook was created with a TPU v3-8 (single TPU v3), so we expect to see 8 devices available. We see from [the TPU documentation](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_v3) that each TPU v3 chip contains two TensorCores with a capacity of up to 32 GiB of HBM2 (high bandwidth memory). This means that each core has access to a maximum of 16 GiB.\n",
"\n",
"Note: this is different in TPU v4 and above where HBM may be [shared](https://github.com/google/jax/discussions/20049) between cores."
],
"metadata": {
"id": "ZUJobxr5OF3u"
}
},
{
"cell_type": "code",
"source": [
"print(\" XLA device type: \", xla_bridge.get_backend().platform)\n",
"print(\" Accelerator devices in the cluster: \", jax.device_count())\n",
"print(\"Accelerator devices attached to this process: \", jax.local_device_count())\n",
"\n",
"print(\"\")\n",
"for _d in jax.devices():\n",
" print(_d.device_kind, \"process\", _d.process_index, \"id\", _d.id, \"local_hardware_id\", _d.local_hardware_id)"
],
"metadata": {
"id": "KewyiApExU1k"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"jax.devices()"
],
"metadata": {
"id": "hYEhSmOWWPJ0"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"As expected, we see 8 accelerator devices on 4 chips on a single host or process."
],
"metadata": {
"id": "MoEixsz2V3Up"
}
},
{
"cell_type": "markdown",
"source": [
"## Matrix addition\n",
"Let's try a matrix addition example ``A + B``.\n",
"\n",
"We will repeat this multiple times so we can do some meaningful profiling."
],
"metadata": {
"id": "V4i9jRXsMqAn"
}
},
{
"cell_type": "code",
"source": [
"MATRIX_SIZE = 16384\n",
"STEPS = 10\n",
"\n",
"# for each operation, we expect 3x the memory to be used (A + B and the output matrix)\n",
"# we expect each jax.numpy.float16 to be 2 bytes long\n",
"print(\"MB: \", (MATRIX_SIZE * MATRIX_SIZE * 2 * 3)/(1024**2))"
],
"metadata": {
"id": "LR8Mxc8MMhK3"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"\n",
"A = jax.numpy.ones ( (MATRIX_SIZE, MATRIX_SIZE), dtype=jax.numpy.float16)\n",
"# print(A.on_device_size_in_bytes() / 1024**2)\n",
"B = jax.numpy.ones ( (MATRIX_SIZE, MATRIX_SIZE), dtype=jax.numpy.float16)\n",
"\n",
"jax.profiler.start_trace(\"/tmp/addition\")\n",
"\n",
"s = datetime.datetime.now()\n",
"for i in range(STEPS):\n",
" O = A + B\n",
"e = datetime.datetime.now()\n",
"\n",
"jax.profiler.stop_trace()\n",
"\n",
"print( f\"Straight addition takes {(e-s).total_seconds()/STEPS:.4f} on average\")\n"
],
"metadata": {
"id": "AcdDu5elnzfc"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Let's take a look at our profile. Be sure to select the correct run. \n",
"\n",
"1. Look at the memory_profile. Notice the familar saw-tooth pattern for heap memory.\n",
"1. Look at the memory_viewer. What's the peak memory allocation for the ``jit_fn module``? Is it what we expect?\n",
"1. Look at the trace_viewer@. Zoom in until you can see the operations. We expect to see up to 10 ``add.3`` steps.\n",
"1. Look at the op_profile. Notice the TPU FLOPS utilization, the HBM bandwidth utilization, and the wasted time."
],
"metadata": {
"id": "vqooxfGXeCRC"
}
},
{
"cell_type": "code",
"source": [
"%tensorboard --logdir=/tmp/addition --port=0\n",
"# for a Colab outside Google\n",
"# %tensorboard --logdir=/tmp/addition"
],
"metadata": {
"id": "4v8PyMEPPO0E"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Adding three matrices\n",
"Let's try adding three matrices ``A + B + C``. This requires two addition steps: first, ``(A + B)`` and then, ``(A + B) + C``.\n",
"\n",
"Again, we will repeat this multiple times so we can do some meaningful profiling."
],
"metadata": {
"id": "G_oL7uWxp4S3"
}
},
{
"cell_type": "code",
"source": [
"MATRIX_SIZE = 16384\n",
"STEPS = 10\n",
"\n",
"# for each operation, we expect 4x the memory to be used\n",
"# (A, B, C, and the resulting sum)\n",
"# we expect the operations to be carried out as:\n",
"# intermediate = (A + B); O = (intermediate + C)\n",
"# we expect each jax.numpy.float16 to be 2 bytes long\n",
"print(\"MB: \", (MATRIX_SIZE * MATRIX_SIZE * 2 * 4)/(1024**2))\n"
],
"metadata": {
"id": "d1f_d858eGBx"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"A = jax.numpy.ones ( (MATRIX_SIZE, MATRIX_SIZE), dtype=jax.numpy.float16)\n",
"B = jax.numpy.ones ( (MATRIX_SIZE, MATRIX_SIZE), dtype=jax.numpy.float16)\n",
"C = jax.numpy.ones ( (MATRIX_SIZE, MATRIX_SIZE), dtype=jax.numpy.float16)\n",
"\n",
"jax.profiler.start_trace(\"/tmp/addition3\")\n",
"\n",
"s = datetime.datetime.now()\n",
"for i in range(STEPS):\n",
" O = A + B + C\n",
"e = datetime.datetime.now()\n",
"\n",
"jax.profiler.stop_trace()\n",
"\n",
"print( f\"Two additions take {(e-s).total_seconds()/STEPS:.4f} on average\")\n"
],
"metadata": {
"id": "lK3QghYTobhr"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Let's take a look at our profile. Be sure to select the correct run. \n",
"\n",
"1. Look at the memory_viewer. What's the peak memory allocation for the ``jit_fn`` module? Is it what we expect?\n",
"1. What about ``jit_broadcast_in_dim``?\n",
"1. Look at the trace_viewer@. Zoom in until you can see the operations. We expect to see up to 16 ``add.3`` steps.\n",
"1. Look at the op_profile. Notice the TPU FLOPS utilization, the HBM bandwidth utilization, and the wasted time."
],
"metadata": {
"id": "hnynb2Mglvix"
}
},
{
"cell_type": "code",
"source": [
"%tensorboard --logdir=/tmp/addition3 --port=0\n",
"# for a Colab outside Google\n",
"# %tensorboard --logdir=/tmp/addition3"
],
"metadata": {
"id": "tixkVo3dl3gT"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Introducing just in time compilation (JIT)\n",
"\n",
"To learn more about JIT on Jax, try the [tutorial](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html).\n",
"\n",
"Let's try our ``A + B`` addition with ``jit`` to see a simple example."
],
"metadata": {
"id": "spGWMHrjh9-w"
}
},
{
"cell_type": "code",
"source": [
"A = jax.numpy.ones ( (MATRIX_SIZE, MATRIX_SIZE), dtype=jax.numpy.float16)\n",
"B = jax.numpy.ones ( (MATRIX_SIZE, MATRIX_SIZE), dtype=jax.numpy.float16)\n",
"\n",
"def f2(A, B):\n",
" return A + B\n",
"f2_jit = jax.jit(f2)\n",
"\n",
"simple_timeit(f2, A, B, task = \"f2\")\n",
"simple_timeit(f2_jit, A, B, task = \"f2_jit\")\n"
],
"metadata": {
"id": "AaIOlFV2m_-C"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"As you can see, compilation doesn't always yield dramatic performance improvements. Now, let's try our ``A + B + C`` addition with ``jit``."
],
"metadata": {
"id": "loJEtkwRrR-4"
}
},
{
"cell_type": "code",
"source": [
"A = jax.numpy.ones ( (MATRIX_SIZE, MATRIX_SIZE), dtype=jax.numpy.float16)\n",
"B = jax.numpy.ones ( (MATRIX_SIZE, MATRIX_SIZE), dtype=jax.numpy.float16)\n",
"C = jax.numpy.ones ( (MATRIX_SIZE, MATRIX_SIZE), dtype=jax.numpy.float16)\n",
"\n",
"def f3(X,Y,Z):\n",
" return X+Y+Z\n",
"\n",
"f3_jit = jax.jit(f3)\n",
"\n",
"simple_timeit(f3, A, B, C, task = \"f3\")\n",
"simple_timeit(f3_jit, A, B, C, task = \"f3_jit\")\n"
],
"metadata": {
"id": "5RquGMgjqlLf"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"To understand how jit helped, try out the TensorBoard (you'll need to use the correct log directory for the trace from the previous cell's output).\n",
"\n",
"1. Look at the memory_viewer. What's the peak memory allocation for the ``jit_f3`` module? Is it what we expect?\n",
"1. Look at the trace_viewer@. Zoom in until you can see the operations. What's different?\n",
"1. Look at the graph_viewer. Search for the Op Name \"fusion\". What do you see?"
],
"metadata": {
"id": "xGKhjrizq-aL"
}
},
{
"cell_type": "code",
"source": [
"%tensorboard --logdir=/tmp/t_f3_jit_Q4MQHIC0Q8 --port=0\n",
"# for a Colab outside Google\n",
"# %tensorboard --logdir=[DIR]"
],
"metadata": {
"id": "Wad58cTp4cRu"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Matrix multiplication\n",
"\n",
"In the right scenarios, the compiler can fuse multiple operations (even loops :) ) to offer performance gains. Let's try out matrix multiplication using ``numpy``."
],
"metadata": {
"id": "awLxSon_80ER"
}
},
{
"cell_type": "code",
"source": [
"MATRIX_SIZE = 4096\n",
"STEPS = 10\n",
"\n",
"# for each operation, we expect 4x the memory to be used\n",
"# (A, B, C, and the resulting sum)\n",
"# we expect the operations to be carried out as:\n",
"# intermediate = (A + B); O = (intermediate + C)\n",
"# we expect each jax.numpy.float16 to be 2 bytes long\n",
"print(\"MB: \", (MATRIX_SIZE * MATRIX_SIZE * 2 * 4)/(1024**2))\n"
],
"metadata": {
"id": "kr1vut-2AFqT"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"A = jax.numpy.ones ( (MATRIX_SIZE, MATRIX_SIZE), dtype=jax.numpy.float16)\n",
"B = jax.numpy.ones ( (MATRIX_SIZE, MATRIX_SIZE), dtype=jax.numpy.float16)\n",
"C = jax.numpy.ones ( (MATRIX_SIZE, MATRIX_SIZE), dtype=jax.numpy.float16)\n",
"\n",
"def matmul(A,B):\n",
" return A@B\n",
"jit_matmul = jax.jit(matmul)\n",
"\n",
"simple_timeit(matmul, A, B, task = 'matmul')\n",
"simple_timeit(jit_matmul, A, B, task = 'jit_matmul')\n"
],
"metadata": {
"id": "WXSPmRRCAHHw"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"So what does the profiler tell us about the ``jit_matmul`` approach?\n",
"\n",
"Let's look at the FLOPS utilization.\n"
],
"metadata": {
"id": "h9Vrh_CXJmBm"
}
},
{
"cell_type": "code",
"source": [
"%tensorboard --logdir=/tmp/t_jit_matmul_72ZBQZR4CX --port=0\n",
"# for a Colab outside Google\n",
"# %tensorboard --logdir=[DIR]"
],
"metadata": {
"id": "9mrIAVPMAGwQ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def matmul_with_activation(A,B,C):\n",
" return jax.nn.relu(A@B)\n",
"jit_matmul_foldin = jax.jit(matmul_with_activation)\n",
"\n",
"simple_timeit(matmul_with_activation, A, B, C, task='matmul_foldin')\n",
"simple_timeit(jit_matmul_foldin, A, B, C, task = 'jit_matmul_foldin')"
],
"metadata": {
"id": "shiBzssAAGfA"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"%tensorboard --logdir=/tmp/t_jit_matmul_foldin_0BGHOPMA6F --port=0\n",
"# for a Colab outside Google\n",
"# %tensorboard --logdir=[DIR]"
],
"metadata": {
"id": "rqTzDBFzAGEl"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Here's an example Rafi shared for roofline analysis."
],
"metadata": {
"id": "0MU4uhS1QM6U"
}
},
{
"cell_type": "code",
"source": [
"# MATMUL_SIZES = [(64000,128), (16000,256), (4000,512), (3000,640), (2000,768), (1000,1024), (250, 2048)]\n",
"MATMUL_SIZES = [(3000,256), (2000,512), (1000,640), (500,768), (300,1024), (100, 1536)]\n",
"\n",
"for num_matmuls, matrix_size in MATMUL_SIZES:\n",
" A = jax.numpy.ones ( (num_matmuls, matrix_size, matrix_size), dtype=jax.numpy.bfloat16)\n",
" B = jax.numpy.ones ( (num_matmuls, matrix_size, matrix_size), dtype=jax.numpy.bfloat16)\n",
"\n",
" @jax.jit\n",
" def f(X,Y):\n",
" return jax.lax.batch_matmul(X,Y)\n",
"\n",
" print(f(A,B).shape)\n",
"\n",
" simple_timeit(f, A, B, task = 'matmul_' + str(matrix_size))"
],
"metadata": {
"id": "vig3Z9alQNXA"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Rafi suggested this as an after session exercise:\n",
"> Consider the function f(A,B) = jax.nn.relu(A@B). (Assume A and B are square martrices.)\n",
"What percentage faster will jit(f) be than f? Does it depend on the size of A and B?"
],
"metadata": {
"id": "mMMxg8fAVET1"
}
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "VdWmfd_lVD9K"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Conclusion\n",
"In this session. we learned some simple techniques to use Jax with TPUs. We also learned how fusing execution steps at compilation can lead to performance improvements."
],
"metadata": {
"id": "XIqZeLm5sVGz"
}
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment