Skip to content

Instantly share code, notes, and snippets.

@nhira
Last active July 26, 2024 05:06
Show Gist options
  • Save nhira/ea4b93738aadb1111b2ee5868d56a22b to your computer and use it in GitHub Desktop.
Save nhira/ea4b93738aadb1111b2ee5868d56a22b to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "6cb94fad-18c5-49c9-8150-477587595589",
"metadata": {},
"source": [
"# Use your notebook with the AI Hypercomputer"
]
},
{
"cell_type": "markdown",
"id": "2bbb2f5a-d40b-4c20-9d1b-94b333a45a48",
"metadata": {},
"source": [
"## Overview"
]
},
{
"cell_type": "markdown",
"id": "6ac3c6f0-8698-4949-ae36-5129a18d50f9",
"metadata": {},
"source": [
"Google shares its experience creating & operating AI experiences for billions of users at scale as the [AI Hypercomputer](https://cloud.google.com/solutions/ai-hypercomputer) architecture to help customers accelerate their AI journey. This architecture includes performance-optimized infrastructure, open-source software frameworks, and flexible consumption models.\n",
"\n",
"One of the challenges researchers face when working with contemporary models is the distributed programming involved to orchestrate work with a complex architecture. This notebook shows how the AI Hypercomputer architecture can help.\n"
]
},
{
"attachments": {
"8561558d-012e-460f-8c42-07bce794c497.png": {
"image/png": ""
}
},
"cell_type": "markdown",
"id": "dc85f2ac-b476-4bae-8df6-30c91c589ed7",
"metadata": {},
"source": [
"![Screenshot 2024-07-21 18.33.29.png](attachment:8561558d-012e-460f-8c42-07bce794c497.png)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "67ebe691-85eb-4b08-aaf1-19c89a88eb7f",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style>table {float:left}</style>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"%%html \n",
"<style>table {float:left}</style>"
]
},
{
"cell_type": "markdown",
"id": "d02ca16e-5273-4879-ad03-afc35362d84a",
"metadata": {},
"source": [
"\n",
"### Performance-optimized infrastructure\n",
"For this example, we choose [Google Cloud TPU v5e](https://cloud.google.com/tpu/docs/v5e) as the accelerator. We also chose Google Kubernetes Engine (GKE) because it enables easy [TPU cluster management](https://cloud.google.com/kubernetes-engine/docs/concepts/tpus). A number of networking and storage enhancements (including Titanium offloads and Multislice training) make this all practical, but for brevity, we will only call out [Cloud Storage (GCS)](https://cloud.google.com/storage) and NFS Filestores. \n",
"\n",
"### Open software\n",
"The code in this notebook relies on [JAX](https://jax.readthedocs.io/en/latest/index.html), a Python library for high-performance numerical computing and large-scale machine learning with accelerators. JAX is integrated with the [OpenXLA compiler](https://github.com/openxla) so the software chooses the most effective implementation and model builders can focus on the math. The crux of the example relies on [MaxText](https://github.com/google/maxtext), a reference implementation for training, tuning, and inference with LLMs.\n",
"\n",
"To interact with the cluster, we use IPython Parallels and some [cell magic](https://ipyparallel.readthedocs.io/en/latest/tutorial/magics.html). IPython Parallels (ipyparallel) is a Python package and collection of CLI scripts for controlling clusters of IPython processes, built on the Jupyter protocol. _While the default settings were adequate for this example, you should review [ipyparallel security details](https://ipyparallel.readthedocs.io/en/latest/reference/security.html) before use in a production environment._\n",
"\n",
"### Flexible consumption options\n",
"In addition to Kaggle and Colab options, customers can also choose Spot VMs (preemptible), Cloud Batch, or Dynamic Workload Scheduler (DWS). For this exercise, we used ``--spot``."
]
},
{
"cell_type": "markdown",
"id": "ae066f9f-1feb-49e3-a75e-5c05e82ad717",
"metadata": {},
"source": [
"## Create the cluster\n",
"\n",
"Again, for brevity, we will mention the steps involved without getting into the details.\n",
"1. Setup custom networking and firewall rules for TPUs (specifically, ``mtu=8896``).\n",
"2. Use [XPK](https://github.com/google/xpk) (Accelerated Processing Kit) to create the GKE cluster.\n",
"3. Create a regional NFS Filestore instance and make it available to the cluster as a persistent volume.\n",
"4. Use MaxText to create a base container image; include ipyparallel before publishing the artifact.\n",
"5. Use [LeaderWorkerSet](https://github.com/kubernetes-sigs/lws) (LWS) to deploy the ipyparallel pods as a group.\n",
"\n",
"### XPK command \n",
"```\n",
"# illustrative example\n",
"python3 xpk/xpk.py cluster create \\\n",
"--cluster ${CLUSTER_NAME} \\\n",
"--tpu-type=v5litepod-16 \\\n",
"--num-slices=2 \\\n",
"--spot --project=${PROJECT_ID} --zone=${PROJECT_ZONE} \\\n",
"--default-pool-cpu-machine-type=n1-standard-16 \\\n",
"--custom-cluster-arguments='--network=mtu9k --subnetwork=mtu9k'\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "af3d36b4-3c53-4061-aafb-69647cce3cfd",
"metadata": {},
"source": [
"### Did it work?\n",
"If it all works out, you should see something like this:\n",
"```\n",
"$ kubectl get nodes\n",
"NAME STATUS ROLES AGE VERSION\n",
"gke-______________-default-pool-03f271f3-0m6z Ready <none> 2d6h v1.30.1-gke.1329003\n",
"gke-______________-default-pool-03f271f3-cz31 Ready <none> 2d6h v1.30.1-gke.1329003\n",
"gke-______________-default-pool-03f271f3-p883 Ready <none> 2d6h v1.30.1-gke.1329003\n",
"gke-tpu-744783eb-0kh0 Ready <none> 2d5h v1.30.1-gke.1329003\n",
"gke-tpu-744783eb-35tz Ready <none> 2d5h v1.30.1-gke.1329003\n",
"gke-tpu-744783eb-647d Ready <none> 2d5h v1.30.1-gke.1329003\n",
"gke-tpu-744783eb-d0jk Ready <none> 2d5h v1.30.1-gke.1329003\n",
"gke-tpu-78fc01ba-dzvw Ready <none> 2d6h v1.30.1-gke.1329003\n",
"gke-tpu-78fc01ba-mtw4 Ready <none> 2d6h v1.30.1-gke.1329003\n",
"gke-tpu-78fc01ba-rnlh Ready <none> 2d6h v1.30.1-gke.1329003\n",
"gke-tpu-78fc01ba-v5zd Ready <none> 2d6h v1.30.1-gke.1329003\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "11b714ce-34e3-45d8-aa9e-18437c8ac07b",
"metadata": {},
"source": [
"```\n",
"$ kubectl get services\n",
"NAME TYPE CLUSTER-IP EXTERNAL-IP PORT(S) AGE\n",
"ipp ClusterIP ___.___.___.64 <none> 8888/TCP 114m\n",
"ipp-deployment ClusterIP None <none> <none> 114m\n",
"kubernetes ClusterIP ___.___.___.1 <none> 443/TCP 2d6h\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "29d4ce5c-1076-4731-acfe-09e702c5e7cc",
"metadata": {},
"source": [
"```\n",
"$ kubectl get pods\n",
"NAME READY STATUS RESTARTS AGE\n",
"ipp-deployment-0 2/2 Running 0 114m\n",
"ipp-deployment-0-1 1/1 Running 0 114m\n",
"ipp-deployment-0-2 1/1 Running 0 114m\n",
"ipp-deployment-0-3 1/1 Running 0 114m\n",
"ipp-deployment-0-4 1/1 Running 0 114m\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "1ae7c7a2-c54a-4ea3-a59b-300020286712",
"metadata": {},
"source": [
"### GKE and multi-host TPU slice node pool\n",
"A [multi-host TPU slice node pool](https://cloud.google.com/tpu/docs/tpus-in-gke#multi-host) is a node pool that contains two or more interconnected TPU VMs. Each VM has a TPU device connected to it. The TPUs in a multi-host slice are connected over a high speed interconnect (ICI). The following diagram shows an example of a v5litepod-16 (v5e) multi-host TPU slice. This slice has four TPU VMs. Each TPU VM has four TPU v5e chips and each TPU v5e chip has one TensorCore.\n",
"\n",
"![image.png](https://cloud.google.com/static/tpu/docs/images/kubernetes-diagram.png)\n"
]
},
{
"cell_type": "markdown",
"id": "1865bcbd-6010-45d5-9d69-cf3ffd0fecaa",
"metadata": {},
"source": [
"### TPU v5e system architecture\n",
"Each v5e chip contains one TensorCore. Each TensorCore has four matrix-multiply units (MXUs), a vector unit, and a scalar unit.\n",
"\n",
"| Key chip specifications | v5e values|\n",
"| :--- | :--- |\n",
"| Peak compute per chip (bf16) | 197 TFLOPs |\n",
"| Peak compute per chip (Int8) | 393 TFLOPs |\n",
"| HBM2 capacity and bandwidth | 16 GB, 819 GBps |\n",
"| Interchip Interconnect BW\t| 1600 Gbps |\n",
"\n",
"![image.png](https://cloud.google.com/static/tpu/docs/images/v5e-tensorcore.png) "
]
},
{
"cell_type": "markdown",
"id": "7a28f9c0-a2b2-4ef7-8941-b156560a7b1e",
"metadata": {},
"source": [
"## Confirm cluster connectivity\n",
"Let's try out ipyparallel to confirm we can interact with our instances (\"engines\") and the TPUs. Each of our ``ipp-deployment-0-?`` instances should have access to 4 TPU v5e chips."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "7c974580-6f10-449f-afe9-7d64a6a6c7b9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: ipyparallel in /opt/conda/lib/python3.11/site-packages (8.8.0)\n",
"Requirement already satisfied: decorator in /opt/conda/lib/python3.11/site-packages (from ipyparallel) (5.1.1)\n",
"Requirement already satisfied: entrypoints in /opt/conda/lib/python3.11/site-packages (from ipyparallel) (0.4)\n",
"Requirement already satisfied: ipykernel>=4.4 in /opt/conda/lib/python3.11/site-packages (from ipyparallel) (6.25.2)\n",
"Requirement already satisfied: ipython>=4 in /opt/conda/lib/python3.11/site-packages (from ipyparallel) (8.16.1)\n",
"Requirement already satisfied: jupyter-client>=5 in /opt/conda/lib/python3.11/site-packages (from ipyparallel) (8.4.0)\n",
"Requirement already satisfied: psutil in /opt/conda/lib/python3.11/site-packages (from ipyparallel) (5.9.5)\n",
"Requirement already satisfied: python-dateutil>=2.1 in /opt/conda/lib/python3.11/site-packages (from ipyparallel) (2.8.2)\n",
"Requirement already satisfied: pyzmq>=18 in /opt/conda/lib/python3.11/site-packages (from ipyparallel) (25.1.1)\n",
"Requirement already satisfied: tornado>=5.1 in /opt/conda/lib/python3.11/site-packages (from ipyparallel) (6.3.3)\n",
"Requirement already satisfied: tqdm in /opt/conda/lib/python3.11/site-packages (from ipyparallel) (4.66.1)\n",
"Requirement already satisfied: traitlets>=4.3 in /opt/conda/lib/python3.11/site-packages (from ipyparallel) (5.11.2)\n",
"Requirement already satisfied: comm>=0.1.1 in /opt/conda/lib/python3.11/site-packages (from ipykernel>=4.4->ipyparallel) (0.1.4)\n",
"Requirement already satisfied: debugpy>=1.6.5 in /opt/conda/lib/python3.11/site-packages (from ipykernel>=4.4->ipyparallel) (1.8.0)\n",
"Requirement already satisfied: jupyter-core!=5.0.*,>=4.12 in /opt/conda/lib/python3.11/site-packages (from ipykernel>=4.4->ipyparallel) (5.4.0)\n",
"Requirement already satisfied: matplotlib-inline>=0.1 in /opt/conda/lib/python3.11/site-packages (from ipykernel>=4.4->ipyparallel) (0.1.6)\n",
"Requirement already satisfied: nest-asyncio in /opt/conda/lib/python3.11/site-packages (from ipykernel>=4.4->ipyparallel) (1.5.8)\n",
"Requirement already satisfied: packaging in /opt/conda/lib/python3.11/site-packages (from ipykernel>=4.4->ipyparallel) (23.2)\n",
"Requirement already satisfied: backcall in /opt/conda/lib/python3.11/site-packages (from ipython>=4->ipyparallel) (0.2.0)\n",
"Requirement already satisfied: jedi>=0.16 in /opt/conda/lib/python3.11/site-packages (from ipython>=4->ipyparallel) (0.19.1)\n",
"Requirement already satisfied: pickleshare in /opt/conda/lib/python3.11/site-packages (from ipython>=4->ipyparallel) (0.7.5)\n",
"Requirement already satisfied: prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30 in /opt/conda/lib/python3.11/site-packages (from ipython>=4->ipyparallel) (3.0.39)\n",
"Requirement already satisfied: pygments>=2.4.0 in /opt/conda/lib/python3.11/site-packages (from ipython>=4->ipyparallel) (2.16.1)\n",
"Requirement already satisfied: stack-data in /opt/conda/lib/python3.11/site-packages (from ipython>=4->ipyparallel) (0.6.2)\n",
"Requirement already satisfied: pexpect>4.3 in /opt/conda/lib/python3.11/site-packages (from ipython>=4->ipyparallel) (4.8.0)\n",
"Requirement already satisfied: six>=1.5 in /opt/conda/lib/python3.11/site-packages (from python-dateutil>=2.1->ipyparallel) (1.16.0)\n",
"Requirement already satisfied: parso<0.9.0,>=0.8.3 in /opt/conda/lib/python3.11/site-packages (from jedi>=0.16->ipython>=4->ipyparallel) (0.8.3)\n",
"Requirement already satisfied: platformdirs>=2.5 in /opt/conda/lib/python3.11/site-packages (from jupyter-core!=5.0.*,>=4.12->ipykernel>=4.4->ipyparallel) (3.11.0)\n",
"Requirement already satisfied: ptyprocess>=0.5 in /opt/conda/lib/python3.11/site-packages (from pexpect>4.3->ipython>=4->ipyparallel) (0.7.0)\n",
"Requirement already satisfied: wcwidth in /opt/conda/lib/python3.11/site-packages (from prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30->ipython>=4->ipyparallel) (0.2.8)\n",
"Requirement already satisfied: executing>=1.2.0 in /opt/conda/lib/python3.11/site-packages (from stack-data->ipython>=4->ipyparallel) (1.2.0)\n",
"Requirement already satisfied: asttokens>=2.1.0 in /opt/conda/lib/python3.11/site-packages (from stack-data->ipython>=4->ipyparallel) (2.4.0)\n",
"Requirement already satisfied: pure-eval in /opt/conda/lib/python3.11/site-packages (from stack-data->ipython>=4->ipyparallel) (0.2.2)\n"
]
}
],
"source": [
"# install ipyparallel so we can use it here\n",
"!pip install ipyparallel"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "ef3f2db5-3f1b-4f87-a1b4-7400d7915645",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Successfully connected to 4 instances\n"
]
}
],
"source": [
"import ipyparallel as ipp\n",
"\n",
"# ipp access keys\n",
"IPP_FILE_PATH='/home/jovyan/nfs/security/ipcontroller-client.json'\n",
"\n",
"rc = ipp.Client(IPP_FILE_PATH)\n",
"if rc.ids:\n",
" print(f'Successfully connected to {len(rc.ids)} instances')\n",
"else:\n",
" print(f'Failed to connect to {IPP_FILE_PATH}', \"file\")"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "ccca53cf-5792-4b00-b53a-48ccc84f15d3",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[stdout:0] Hello world\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"[stdout:3] Hello world\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"[stdout:2] Hello world\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"[stdout:1] Hello world\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"%%px --block --group-outputs=engine \n",
"print('Hello world')"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "cd2b34a6-1af4-455c-b0aa-b24609543ec4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[stdout:2] Just the even engines\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"[stdout:0] Just the even engines\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"%%px --block --group-outputs=engine --targets ::2 \n",
"print('Just the even engines')"
]
},
{
"cell_type": "markdown",
"id": "ce5af357-9d73-48dc-86e8-d16e8c4f9cf8",
"metadata": {},
"source": [
"All good so far. Now, let's see if JAX can see our TPUs..."
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "23a45a92-3fa1-44cb-b69b-607a7bbd01b5",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[stdout:0] jax process: 01; host: ipp-deployment-0-1; chips on device: 4; on the cluster: 16\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"[stdout:1] jax process: 02; host: ipp-deployment-0-3; chips on device: 4; on the cluster: 16\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"[stdout:3] jax process: 03; host: ipp-deployment-0-4; chips on device: 4; on the cluster: 16\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"[stdout:2] jax process: 00; host: ipp-deployment-0-2; chips on device: 4; on the cluster: 16\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"%%px --block --group-outputs=engine\n",
"\n",
"import jax\n",
"from jax.lib import xla_bridge\n",
"import socket\n",
"\n",
"print(\n",
" f'jax process: {jax.process_index():02d}; '\n",
" f'host: {socket.gethostname()}; '\n",
" f'chips on device: {jax.local_device_count()}; '\n",
" f'on the cluster: {jax.device_count()}')"
]
},
{
"cell_type": "markdown",
"id": "6b931871-38c3-4f57-bca1-bdf2ba1a63c2",
"metadata": {},
"source": [
"### Sharding and JAX\n",
"\n",
"In JAX, Sharding objects describe distributed memory layouts.\n",
"\n",
"Let's start with some basic array manipulation on one ipyparallel instance."
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "94d813a7-3a91-4565-ab8e-19038d6e9296",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\u001b[0;31mOut[1:253]: \u001b[0m{TpuDevice(id=8, process_index=2, coords=(0,2,0), core_on_chip=0)}"
]
},
"metadata": {
"after": null,
"completed": null,
"data": {},
"engine_id": 1,
"engine_uuid": "6e398cb4-45376298b4e4d5665a884fd1",
"error": null,
"execute_input": "\nimport jax\nimport jax.numpy as jnp\narr = jnp.arange(32.0).reshape(4, 8)\narr.devices()\n",
"execute_result": {
"data": {
"text/plain": "{TpuDevice(id=8, process_index=2, coords=(0,2,0), core_on_chip=0)}"
},
"execution_count": 253,
"metadata": {}
},
"follow": null,
"msg_id": null,
"outputs": [],
"received": null,
"started": null,
"status": null,
"stderr": "",
"stdout": "",
"submitted": "2024-07-26T04:54:20.729874Z"
},
"output_type": "display_data"
}
],
"source": [
"%%px --block --group-outputs=engine --target=1\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"arr = jnp.arange(32.0).reshape(4, 8)\n",
"arr.devices()"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "ce796891-8ca3-44ad-afb6-969b909bb3c7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\u001b[0;31mOut[1:254]: \u001b[0mSingleDeviceSharding(device=TpuDevice(id=8, process_index=2, coords=(0,2,0), core_on_chip=0))"
]
},
"metadata": {
"after": null,
"completed": null,
"data": {},
"engine_id": 1,
"engine_uuid": "6e398cb4-45376298b4e4d5665a884fd1",
"error": null,
"execute_input": "\narr.sharding\n",
"execute_result": {
"data": {
"text/plain": "SingleDeviceSharding(device=TpuDevice(id=8, process_index=2, coords=(0,2,0), core_on_chip=0))"
},
"execution_count": 254,
"metadata": {}
},
"follow": null,
"msg_id": null,
"outputs": [],
"received": null,
"started": null,
"status": null,
"stderr": "",
"stdout": "",
"submitted": "2024-07-26T04:54:22.688224Z"
},
"output_type": "display_data"
}
],
"source": [
"%%px --block --group-outputs=engine --target=1\n",
"\n",
"arr.sharding"
]
},
{
"cell_type": "markdown",
"id": "82f785c2-1e8b-435b-8d55-d558df5288cf",
"metadata": {},
"source": [
"JAX offers ``debug.visualize_array_sharding`` to visualize the sharding."
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "86db90ee-32d6-4374-8320-d8ea1fd2a473",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[output:1]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┌────────────────────────────────────────────────┐\n",
"│ │\n",
"│ │\n",
"│ │\n",
"│ │\n",
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">8</span> │\n",
"│ │\n",
"│ │\n",
"│ │\n",
"│ │\n",
"└────────────────────────────────────────────────┘\n",
"</pre>\n"
],
"text/plain": [
"┌────────────────────────────────────────────────┐\n",
"│ │\n",
"│ │\n",
"│ │\n",
"│ │\n",
"│ TPU \u001b[1;36m8\u001b[0m │\n",
"│ │\n",
"│ │\n",
"│ │\n",
"│ │\n",
"└────────────────────────────────────────────────┘\n"
]
},
"metadata": {
"engine": 1
},
"output_type": "display_data"
}
],
"source": [
"%%px --block --group-outputs=engine --target=1\n",
"\n",
"jax.debug.visualize_array_sharding(arr)"
]
},
{
"cell_type": "markdown",
"id": "3370eaa5-50dd-4710-b058-0f2787609845",
"metadata": {},
"source": [
"Let's try this with a more interesting example. We can use ``device_put`` to move the data to the accelerators."
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "bdc2bb62-ff04-4612-8386-25a08867fa3b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[stdout:1] [ 0. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15. 16. 17.\n",
" 18. 19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31.]\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"[output:1]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┌───────┬───────┬───────┬───────┐\n",
"│ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">8</span> │ TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">9</span> │TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span> │TPU <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">13</span> │\n",
"└───────┴───────┴───────┴───────┘\n",
"</pre>\n"
],
"text/plain": [
"┌───────┬───────┬───────┬───────┐\n",
"│ TPU \u001b[1;36m8\u001b[0m │ TPU \u001b[1;36m9\u001b[0m │TPU \u001b[1;36m12\u001b[0m │TPU \u001b[1;36m13\u001b[0m │\n",
"└───────┴───────┴───────┴───────┘\n"
]
},
"metadata": {
"engine": 1
},
"output_type": "display_data"
}
],
"source": [
"%%px --block --group-outputs=engine --target=1\n",
"\n",
"from jax.sharding import PositionalSharding\n",
"\n",
"arr = jnp.arange(32.0)\n",
"sharding = PositionalSharding(jax.devices()[8:12])\n",
"arr_sharded = jax.device_put(arr, sharding)\n",
"\n",
"print(arr_sharded)\n",
"jax.debug.visualize_array_sharding(arr_sharded)"
]
},
{
"cell_type": "markdown",
"id": "2d824e1c-94fa-4b77-b50a-48d1004bce50",
"metadata": {},
"source": [
"Now let's execute this code across all instances using ``multihost_utils``."
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "fda4cdc1-ed0a-4175-a915-af81b5090fa5",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[stdout:3] example_batch={'inputs': Array((16, 4096), dtype=int32), 'inputs_position': Array((16, 4096), dtype=int32), 'inputs_segmentation': Array((16, 4096), dtype=int32), 'targets': Array((16, 4096), dtype=int32), 'targets_position': Array((16, 4096), dtype=int32), 'targets_segmentation': Array((16, 4096), dtype=int32)}\n",
"\n",
"host 3 chip 0 input data shape: (1, 4096)\n",
"host 3 chip 1 input data shape: (1, 4096)\n",
"host 3 chip 2 input data shape: (1, 4096)\n",
"host 3 chip 3 input data shape: (1, 4096)\n",
"global input data shape: (16, 4096)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"[stdout:0] example_batch={'inputs': Array((16, 4096), dtype=int32), 'inputs_position': Array((16, 4096), dtype=int32), 'inputs_segmentation': Array((16, 4096), dtype=int32), 'targets': Array((16, 4096), dtype=int32), 'targets_position': Array((16, 4096), dtype=int32), 'targets_segmentation': Array((16, 4096), dtype=int32)}\n",
"\n",
"host 1 chip 0 input data shape: (1, 4096)\n",
"host 1 chip 1 input data shape: (1, 4096)\n",
"host 1 chip 2 input data shape: (1, 4096)\n",
"host 1 chip 3 input data shape: (1, 4096)\n",
"global input data shape: (16, 4096)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"[stdout:1] example_batch={'inputs': Array((16, 4096), dtype=int32), 'inputs_position': Array((16, 4096), dtype=int32), 'inputs_segmentation': Array((16, 4096), dtype=int32), 'targets': Array((16, 4096), dtype=int32), 'targets_position': Array((16, 4096), dtype=int32), 'targets_segmentation': Array((16, 4096), dtype=int32)}\n",
"\n",
"host 2 chip 0 input data shape: (1, 4096)\n",
"host 2 chip 1 input data shape: (1, 4096)\n",
"host 2 chip 2 input data shape: (1, 4096)\n",
"host 2 chip 3 input data shape: (1, 4096)\n",
"global input data shape: (16, 4096)\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"[stdout:2] example_batch={'inputs': Array((16, 4096), dtype=int32), 'inputs_position': Array((16, 4096), dtype=int32), 'inputs_segmentation': Array((16, 4096), dtype=int32), 'targets': Array((16, 4096), dtype=int32), 'targets_position': Array((16, 4096), dtype=int32), 'targets_segmentation': Array((16, 4096), dtype=int32)}\n",
"\n",
"host 0 chip 0 input data shape: (1, 4096)\n",
"host 0 chip 1 input data shape: (1, 4096)\n",
"host 0 chip 2 input data shape: (1, 4096)\n",
"host 0 chip 3 input data shape: (1, 4096)\n",
"global input data shape: (16, 4096)\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"%%px --block --group-outputs=engine\n",
"\n",
"from jax.experimental import multihost_utils\n",
"\n",
"example_batch = None\n",
"example_batch = load_next_batch(data_iterator, example_batch, config)\n",
"\n",
"print(f'{example_batch=}')\n",
"data = example_batch['inputs']\n",
"print()\n",
"\n",
"for i, shard in enumerate(data.addressable_shards):\n",
" print(\n",
" f'host {jax.process_index()} chip {i} input data shape: ',\n",
" shard.data.shape)\n",
"\n",
"print('global input data shape:', multihost_utils.process_allgather(data).shape)"
]
},
{
"cell_type": "markdown",
"id": "585e256c-4340-41a8-9c93-9fcf97d22f11",
"metadata": {},
"source": [
"The [Distributed Arrays Tutorial](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#sharding-describes-how-array-values-are-laid-out-in-memory-across-devices) has more examples for you to work through."
]
},
{
"cell_type": "markdown",
"id": "9fc9fed4-3556-4f5d-969d-909474cfe6f7",
"metadata": {},
"source": [
"## Training Meta's Llama2-7b\n",
"Let's take our 16 TPU v5e chips for a spin. Imagine if you need to train Llama2-7b and you need to evaluate your approach for parallelism. Intuitively, we know that 16 data shards make more sense than 16 tensors parts, but can we try it out?\n"
]
},
{
"cell_type": "markdown",
"id": "d539df88-4228-4691-8182-c904cb7823f6",
"metadata": {},
"source": [
"### Configure parallelism and compile code to begin training\n",
"\n",
"We'll use [FSDP](https://engineering.fb.com/2021/07/15/open-source/fsdp/) and [tensor parallel](https://huggingface.co/docs/text-generation-inference/en/conceptual/tensor_parallelism) techniques to split our large model across the 16 TPUs. Let's start with [MaxText](https://github.com/google/maxtext) to save some development time.The code will automatically figure out the best way to do this.\n",
"\n",
"In this step, we configure and compile the model. With 16 chips, we plan to shard the model using 4-way FSDP and 4-way tensor parallel, by setting `ici_fsdp_parallelism=4 ici_tensor_parallelism=4`. It takes a few seconds for the code to automatically complete the configuration and compile the `p_train_step` function.\n",
"\n",
"Rafi Witten [shared](https://github.com/rwitten/HighPerfLLMs2024/blob/main/s03/Session3Slides.pdf) this illustration."
]
},
{
"attachments": {
"09ec515a-24ed-4ae5-bb0a-50c19a524af2.png": {
"image/png": ""
}
},
"cell_type": "markdown",
"id": "f46e0a4f-8dc4-4a35-ba74-5befc6257f44",
"metadata": {},
"source": [
"![download.png](attachment:09ec515a-24ed-4ae5-bb0a-50c19a524af2.png)"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "e3031556-48ef-44c6-bb0f-c8a6b24c6c4d",
"metadata": {},
"outputs": [],
"source": [
"%%px --block --group-outputs=engine\n",
"\n",
"import sys\n",
"# in our container image, MaxText is available in /deps\n",
"sys.path.append('/deps/MaxText')\n",
"from MaxText import pyconfig\n",
"from MaxText.train import train_loop, setup_train_loop, train_step, load_next_batch\n",
"from MaxText import maxtext_utils\n",
"from flax.linen import partitioning as nn_partitioning\n",
"import jax\n",
"import os\n",
"\n",
"def config_llama7b_model(\n",
" ici_fsdp_parallelism=-1,\n",
" ici_fsdp_transpose_parallelism=1,\n",
" ici_tensor_parallelism=1,\n",
" per_device_batch_size=1,\n",
"):\n",
" argv = [\n",
" 'dummy_placeholder',\n",
" 'MaxText/configs/base.yml',\n",
" 'model_name=llama2-7b',\n",
" f'per_device_batch_size={per_device_batch_size}',\n",
" 'steps=5',\n",
" 'dataset_type=tfds',\n",
" 'enable_checkpointing=false',\n",
" 'max_target_length=4096',\n",
" f'ici_fsdp_parallelism={ici_fsdp_parallelism}',\n",
" f'ici_fsdp_transpose_parallelism={ici_fsdp_transpose_parallelism}',\n",
" f'ici_tensor_parallelism={ici_tensor_parallelism}',\n",
" 'base_output_directory=gs://opmusw4/ipp/maxtext/llama2-7b/',\n",
" 'run_name=demo-test',\n",
" 'dataset_path=gs://maxtext-dataset',\n",
" 'attention=dot_product',\n",
" 'remat_policy=full',\n",
" ]\n",
" pyconfig.initialize(argv)\n",
" config = pyconfig.config\n",
" os.environ['TFDS_DATA_DIR'] = config.dataset_path\n",
" os.environ['LIBTPU_INIT_ARGS'] = (\n",
" '--xla_enable_async_all_gather=true TPU_MEGACORE=MEGACORE_DENSE'\n",
" )\n",
" (\n",
" init_rng,\n",
" writer,\n",
" checkpoint_manager,\n",
" state_mesh_annotations,\n",
" model,\n",
" mesh,\n",
" learning_rate_schedule,\n",
" data_iterator,\n",
" eval_data_iterator,\n",
" state,\n",
" ) = setup_train_loop(config)\n",
" (\n",
" functional_train,\n",
" in_shard_train,\n",
" out_shard_train,\n",
" static_argnums_train,\n",
" donate_argnums_train,\n",
" ) = maxtext_utils.get_functional_train_with_signature(\n",
" train_step, mesh, state_mesh_annotations, model, config\n",
" )\n",
" p_train_step = jax.jit(\n",
" functional_train,\n",
" in_shardings=in_shard_train,\n",
" out_shardings=out_shard_train,\n",
" static_argnums=static_argnums_train,\n",
" donate_argnums=donate_argnums_train,\n",
" )\n",
" return config, init_rng, data_iterator, state, p_train_step, mesh"
]
},
{
"cell_type": "markdown",
"id": "3d84c92d-a8e0-49e7-84d8-c83560ae73df",
"metadata": {},
"source": [
"Line 64 above is where we invoke the [JAX just-in-time compiler](https://jax.readthedocs.io/en/latest/jit-compilation.html) to compile a JAX Python function so it can be executed efficiently in XLA.\n",
"\n",
"JAX compilation was designed to work with pure functions and this can lead to some [interesting challenges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#) so do read up before diving in. The [MaxText README](https://github.com/google/maxtext/blob/b314957fcfc0410aa0cafb734706f6f27020a67c/README.md#ahead-of-time-compilation-aot-tpu-only) also shares examples of how we can use AOT (ahead of time) compilation to find issues before the target hardware becomes available (e.g., on your CPU).\n",
"\n",
"Let's try it out!"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "afa8834c-abc8-4500-9345-a08a6a4905b7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[stdout:1] Updating keys from env and command line: ['run_name', 'model_name', 'enable_checkpointing', 'remat_policy', 'attention', 'base_output_directory', 'ici_fsdp_parallelism', 'ici_fsdp_transpose_parallelism', 'ici_tensor_parallelism', 'per_device_batch_size', 'dataset_type', 'dataset_path', 'steps', 'max_target_length']\n",
"Running Model: llama2-7b\n",
"Updating following parameters in config\n",
"\n",
"base_emb_dim: 4096\n",
"base_num_query_heads: 32\n",
"base_num_kv_heads: 32\n",
"base_mlp_dim: 11008\n",
"base_num_decoder_layers: 32\n",
"head_dim: 128\n",
"mlp_activations: ['silu', 'linear']\n",
"vocab_size: 32000\n",
"enable_dropout: False\n",
"logits_via_embedding: False\n",
"normalization_layer_epsilon: 1e-05\n",
"decoder_block: llama2\n",
"logical_axis_rules: [['norm', 'fsdp']]\n",
"Updating keys from model: ['base_emb_dim', 'base_num_query_heads', 'base_num_kv_heads', 'base_mlp_dim', 'base_num_decoder_layers', 'head_dim', 'mlp_activations', 'vocab_size', 'enable_dropout', 'logits_via_embedding', 'normalization_layer_epsilon', 'decoder_block', 'logical_axis_rules']\n",
"System Information: Jax Version: 0.4.30\n",
"System Information: Jaxlib Version: 0.4.30\n",
"System Information: Jax Backend: PJRT C API\n",
"TFRT TPU v5 lite\n",
"Built on Jun 17 2024 03:03:47 (1718618627) cl/643897370\n",
"Not using emergency checkpoint, ignoring local_checkpoint_directory and local_checkpoint_period\n",
"dataset_type set to tfds, will use keys['dataset_path']='gs://maxtext-dataset' and keys['dataset_name']='c4/en:3.0.1'\n",
"Config param adam_b1: 0.9\n",
"Config param adam_b2: 0.95\n",
"Config param adam_eps: 1e-08\n",
"Config param adam_eps_root: 0.0\n",
"Config param adam_weight_decay: 0.1\n",
"Config param allow_split_physical_axes: False\n",
"Config param ar_cache_axis_order: 1,2,0,3\n",
"Config param async_checkpointing: True\n",
"Config param attention: dot_product\n",
"Config param autoregressive_decode_assert: \n",
"Config param base_emb_dim: 4096\n",
"Config param base_mlp_dim: 11008\n",
"Config param base_num_decoder_layers: 32\n",
"Config param base_num_kv_heads: 32\n",
"Config param base_num_query_heads: 32\n",
"Config param base_output_directory: gs://opmusw4/ipp/maxtext/llama2-7b/\n",
"Config param checkpoint_dir: gs://opmusw4/ipp/maxtext/llama2-7b/demo-test/checkpoints/\n",
"Config param checkpoint_is_quantized: False\n",
"Config param checkpoint_period: 10000\n",
"Config param collect_stack_trace: False\n",
"Config param compile_topology: \n",
"Config param compile_topology_num_slices: -1\n",
"Config param compiled_trainstep_file: \n",
"Config param compute_axis_order: 0,1,2,3\n",
"Config param cosine_learning_rate_final_fraction: 0.1\n",
"Config param data_sharding: (('data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive'),)\n",
"Config param data_shuffle_seed: 0\n",
"Config param dataset_name: c4/en:3.0.1\n",
"Config param dataset_path: gs://maxtext-dataset\n",
"Config param dataset_type: tfds\n",
"Config param dcn_autoregressive_parallelism: 1\n",
"Config param dcn_data_parallelism: -1\n",
"Config param dcn_fsdp_parallelism: 1\n",
"Config param dcn_fsdp_transpose_parallelism: 1\n",
"Config param dcn_pipeline_parallelism: 1\n",
"Config param dcn_sequence_parallelism: 1\n",
"Config param dcn_tensor_parallelism: 1\n",
"Config param decode_sampling_nucleus_p: -1\n",
"Config param decode_sampling_strategy: greedy\n",
"Config param decode_sampling_temperature: 1.0\n",
"Config param decode_sampling_top_k: 0\n",
"Config param decoder_block: llama2\n",
"Config param dropout_rate: 0\n",
"Config param dtype: bfloat16\n",
"Config param emb_dim: 4096\n",
"Config param enable_checkpoint_cloud_logger: False\n",
"Config param enable_checkpoint_standard_logger: False\n",
"Config param enable_checkpointing: False\n",
"Config param enable_data_shuffling: True\n",
"Config param enable_dropout: False\n",
"Config param enable_emergency_checkpoint: False\n",
"Config param enable_goodput_recording: False\n",
"Config param enable_jax_profiler: False\n",
"Config param enable_single_controller: False\n",
"Config param enable_single_replica_ckpt_restoring: False\n",
"Config param eval_batch_num: -1\n",
"Config param eval_dataset_name: c4/en:3.0.1\n",
"Config param eval_interval: -1\n",
"Config param eval_per_device_batch_size: 0\n",
"Config param eval_split: validation\n",
"Config param expansion_factor_real_data: -1\n",
"Config param force_unroll: False\n",
"Config param fused_mlp: False\n",
"Config param fused_qkv: False\n",
"Config param gcs_metrics: False\n",
"Config param global_batch_size_to_load: 16\n",
"Config param global_batch_size_to_train_on: 16\n",
"Config param global_parameter_scale: 1\n",
"Config param goodput_upload_interval_seconds: 60\n",
"Config param gradient_clipping_threshold: 1.0\n",
"Config param grain_eval_files: \n",
"Config param grain_train_files: \n",
"Config param grain_worker_count: 1\n",
"Config param hardware: tpu\n",
"Config param head_dim: 128\n",
"Config param hf_access_token: \n",
"Config param hf_data_dir: \n",
"Config param hf_eval_files: \n",
"Config param hf_eval_split: \n",
"Config param hf_path: \n",
"Config param hf_train_files: \n",
"Config param ici_autoregressive_parallelism: 1\n",
"Config param ici_data_parallelism: 1\n",
"Config param ici_fsdp_parallelism: 4\n",
"Config param ici_fsdp_transpose_parallelism: 1\n",
"Config param ici_pipeline_parallelism: 1\n",
"Config param ici_sequence_parallelism: 1\n",
"Config param ici_tensor_parallelism: 4\n",
"Config param inference_metadata_file: \n",
"Config param inference_microbenchmark_log_file_path: \n",
"Config param inference_microbenchmark_loop_iters: 10\n",
"Config param inference_microbenchmark_prefill_lengths: 64,128,256,512,1024\n",
"Config param inference_microbenchmark_stages: prefill,generate\n",
"Config param init_weights_seed: 0\n",
"Config param jax_cache_dir: ~/jax_cache\n",
"Config param jax_profiler_port: 9999\n",
"Config param kv_quant_axis: heads_and_dkv\n",
"Config param kv_quant_dtype: int8\n",
"Config param learning_rate: 3e-05\n",
"Config param learning_rate_schedule_steps: 5\n",
"Config param load_from_prefill_dir: False\n",
"Config param load_full_state_path: \n",
"Config param load_parameters_path: \n",
"Config param local_checkpoint_directory: \n",
"Config param local_checkpoint_period: 0\n",
"Config param log_period: 100\n",
"Config param logical_axis_rules: (('activation_batch', ('data', 'fsdp', 'fsdp_transpose')), ('activation_embed_and_logits_batch', ('stage', 'data', 'fsdp', 'fsdp_transpose')), ('activation_heads', ('tensor', 'sequence')), ('activation_kv_heads', ('tensor', 'sequence')), ('activation_length', 'sequence'), ('activation_embed', 'tensor'), ('activation_mlp', 'tensor'), ('activation_kv', 'tensor'), ('activation_kv_batch', ('data', 'fsdp', 'fsdp_transpose')), ('activation_kv_head_dim', 'tensor'), ('activation_vocab', ('tensor', 'sequence')), ('activation_vocab', 'tensor'), ('activation_vocab', 'sequence'), ('activation_stage', 'stage'), ('mlp', ('fsdp_transpose', 'tensor', 'autoregressive')), ('vocab', ('tensor', 'autoregressive')), ('embed', ('fsdp', 'fsdp_transpose', 'sequence')), ('embed', ('fsdp', 'sequence')), ('heads', ('tensor', 'autoregressive')), ('layers', 'stage'), ('kv', ()), ('kv_heads', ('tensor', 'autoregressive')), ('kv_head_dim', ()), ('cache_batch', ()), ('cache_heads', ('autoregressive', 'tensor')), ('cache_kv', ()), ('cache_sequence', ()), ('norm', 'fsdp'))\n",
"Config param logits_dot_in_fp32: True\n",
"Config param logits_via_embedding: False\n",
"Config param max_checkify: False\n",
"Config param max_corpus_chars: 10000000\n",
"Config param max_prefill_predict_length: 64\n",
"Config param max_target_length: 4096\n",
"Config param megablox: True\n",
"Config param mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']\n",
"Config param metrics_dir: gs://opmusw4/ipp/maxtext/llama2-7b/demo-test/metrics/\n",
"Config param metrics_file: \n",
"Config param mlp_activations: ['silu', 'linear']\n",
"Config param mlp_dim: 11008\n",
"Config param model_name: llama2-7b\n",
"Config param monitor_goodput: False\n",
"Config param normalization_layer_epsilon: 1e-05\n",
"Config param normalize_embedding_logits: True\n",
"Config param num_decoder_layers: 32\n",
"Config param num_experts: 1\n",
"Config param num_experts_per_tok: 1\n",
"Config param num_kv_heads: 32\n",
"Config param num_layers_per_pipeline_stage: 1\n",
"Config param num_pipeline_microbatches: -1\n",
"Config param num_pipeline_repeats: -1\n",
"Config param num_query_heads: 32\n",
"Config param num_slices: 1\n",
"Config param opt_type: adamw\n",
"Config param param_scan_axis: 1\n",
"Config param per_device_batch_size: 1.0\n",
"Config param prefill_cache_axis_order: 1,2,0,3\n",
"Config param prefill_cache_dir: \n",
"Config param profiler: \n",
"Config param profiler_steps: 5\n",
"Config param prometheus_port: 0\n",
"Config param prompt: I love to\n",
"Config param quant_cfg_path: \n",
"Config param quantization: \n",
"Config param quantization_local_shard_count: 1\n",
"Config param quantize_kvcache: False\n",
"Config param record_internal_nn_metrics: 0\n",
"Config param remat_policy: full\n",
"Config param reshape_q: False\n",
"Config param reuse_example_batch: 0\n",
"Config param rope_max_timescale: 10000\n",
"Config param rope_min_timescale: 1\n",
"Config param run_name: demo-test\n",
"Config param save_config_to_gcs: False\n",
"Config param save_quantized_params_path: \n",
"Config param scan_layers: True\n",
"Config param scan_pipeline_iterations: True\n",
"Config param skip_first_n_steps_for_profiler: 1\n",
"Config param stack_trace_interval_seconds: 600\n",
"Config param stack_trace_to_cloud: False\n",
"Config param steps: 5\n",
"Config param target_eval_loss: 0.0\n",
"Config param tensorboard_dir: gs://opmusw4/ipp/maxtext/llama2-7b/demo-test/tensorboard/\n",
"Config param tokenizer_path: assets/tokenizer.llama2\n",
"Config param trainable_position_size: -1\n",
"Config param upload_all_profiler_results: False\n",
"Config param use_iota_embed: False\n",
"Config param use_untrainable_positional_embedding: False\n",
"Config param use_vertex_tensorboard: False\n",
"Config param using_pipeline_parallelism: False\n",
"Config param vertex_tensorboard_project: \n",
"Config param vertex_tensorboard_region: \n",
"Config param vocab_size: 32000\n",
"Config param warmup_steps_fraction: 0.1\n",
"Config param weight_dtype: float32\n",
"Num_devices: 16, shape (1, 1, 4, 1, 1, 4, 1)\n",
"Setting up checkpoint logger...\n",
"Checkpointing disabled, not creating checkpoint manager.\n",
"Tokenizer path: assets/tokenizer.llama2\n",
"Tokenizer path: assets/tokenizer.llama2\n",
"No existing checkpoints found, not restoring checkpoint.\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"[stdout:3] Updating keys from env and command line: ['run_name', 'model_name', 'enable_checkpointing', 'remat_policy', 'attention', 'base_output_directory', 'ici_fsdp_parallelism', 'ici_fsdp_transpose_parallelism', 'ici_tensor_parallelism', 'per_device_batch_size', 'dataset_type', 'dataset_path', 'steps', 'max_target_length']\n",
"Running Model: llama2-7b\n",
"Updating following parameters in config\n",
"\n",
"base_emb_dim: 4096\n",
"base_num_query_heads: 32\n",
"base_num_kv_heads: 32\n",
"base_mlp_dim: 11008\n",
"base_num_decoder_layers: 32\n",
"head_dim: 128\n",
"mlp_activations: ['silu', 'linear']\n",
"vocab_size: 32000\n",
"enable_dropout: False\n",
"logits_via_embedding: False\n",
"normalization_layer_epsilon: 1e-05\n",
"decoder_block: llama2\n",
"logical_axis_rules: [['norm', 'fsdp']]\n",
"Updating keys from model: ['base_emb_dim', 'base_num_query_heads', 'base_num_kv_heads', 'base_mlp_dim', 'base_num_decoder_layers', 'head_dim', 'mlp_activations', 'vocab_size', 'enable_dropout', 'logits_via_embedding', 'normalization_layer_epsilon', 'decoder_block', 'logical_axis_rules']\n",
"System Information: Jax Version: 0.4.30\n",
"System Information: Jaxlib Version: 0.4.30\n",
"System Information: Jax Backend: PJRT C API\n",
"TFRT TPU v5 lite\n",
"Built on Jun 17 2024 03:03:47 (1718618627) cl/643897370\n",
"Not using emergency checkpoint, ignoring local_checkpoint_directory and local_checkpoint_period\n",
"dataset_type set to tfds, will use keys['dataset_path']='gs://maxtext-dataset' and keys['dataset_name']='c4/en:3.0.1'\n",
"Config param adam_b1: 0.9\n",
"Config param adam_b2: 0.95\n",
"Config param adam_eps: 1e-08\n",
"Config param adam_eps_root: 0.0\n",
"Config param adam_weight_decay: 0.1\n",
"Config param allow_split_physical_axes: False\n",
"Config param ar_cache_axis_order: 1,2,0,3\n",
"Config param async_checkpointing: True\n",
"Config param attention: dot_product\n",
"Config param autoregressive_decode_assert: \n",
"Config param base_emb_dim: 4096\n",
"Config param base_mlp_dim: 11008\n",
"Config param base_num_decoder_layers: 32\n",
"Config param base_num_kv_heads: 32\n",
"Config param base_num_query_heads: 32\n",
"Config param base_output_directory: gs://opmusw4/ipp/maxtext/llama2-7b/\n",
"Config param checkpoint_dir: gs://opmusw4/ipp/maxtext/llama2-7b/demo-test/checkpoints/\n",
"Config param checkpoint_is_quantized: False\n",
"Config param checkpoint_period: 10000\n",
"Config param collect_stack_trace: False\n",
"Config param compile_topology: \n",
"Config param compile_topology_num_slices: -1\n",
"Config param compiled_trainstep_file: \n",
"Config param compute_axis_order: 0,1,2,3\n",
"Config param cosine_learning_rate_final_fraction: 0.1\n",
"Config param data_sharding: (('data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive'),)\n",
"Config param data_shuffle_seed: 0\n",
"Config param dataset_name: c4/en:3.0.1\n",
"Config param dataset_path: gs://maxtext-dataset\n",
"Config param dataset_type: tfds\n",
"Config param dcn_autoregressive_parallelism: 1\n",
"Config param dcn_data_parallelism: -1\n",
"Config param dcn_fsdp_parallelism: 1\n",
"Config param dcn_fsdp_transpose_parallelism: 1\n",
"Config param dcn_pipeline_parallelism: 1\n",
"Config param dcn_sequence_parallelism: 1\n",
"Config param dcn_tensor_parallelism: 1\n",
"Config param decode_sampling_nucleus_p: -1\n",
"Config param decode_sampling_strategy: greedy\n",
"Config param decode_sampling_temperature: 1.0\n",
"Config param decode_sampling_top_k: 0\n",
"Config param decoder_block: llama2\n",
"Config param dropout_rate: 0\n",
"Config param dtype: bfloat16\n",
"Config param emb_dim: 4096\n",
"Config param enable_checkpoint_cloud_logger: False\n",
"Config param enable_checkpoint_standard_logger: False\n",
"Config param enable_checkpointing: False\n",
"Config param enable_data_shuffling: True\n",
"Config param enable_dropout: False\n",
"Config param enable_emergency_checkpoint: False\n",
"Config param enable_goodput_recording: False\n",
"Config param enable_jax_profiler: False\n",
"Config param enable_single_controller: False\n",
"Config param enable_single_replica_ckpt_restoring: False\n",
"Config param eval_batch_num: -1\n",
"Config param eval_dataset_name: c4/en:3.0.1\n",
"Config param eval_interval: -1\n",
"Config param eval_per_device_batch_size: 0\n",
"Config param eval_split: validation\n",
"Config param expansion_factor_real_data: -1\n",
"Config param force_unroll: False\n",
"Config param fused_mlp: False\n",
"Config param fused_qkv: False\n",
"Config param gcs_metrics: False\n",
"Config param global_batch_size_to_load: 16\n",
"Config param global_batch_size_to_train_on: 16\n",
"Config param global_parameter_scale: 1\n",
"Config param goodput_upload_interval_seconds: 60\n",
"Config param gradient_clipping_threshold: 1.0\n",
"Config param grain_eval_files: \n",
"Config param grain_train_files: \n",
"Config param grain_worker_count: 1\n",
"Config param hardware: tpu\n",
"Config param head_dim: 128\n",
"Config param hf_access_token: \n",
"Config param hf_data_dir: \n",
"Config param hf_eval_files: \n",
"Config param hf_eval_split: \n",
"Config param hf_path: \n",
"Config param hf_train_files: \n",
"Config param ici_autoregressive_parallelism: 1\n",
"Config param ici_data_parallelism: 1\n",
"Config param ici_fsdp_parallelism: 4\n",
"Config param ici_fsdp_transpose_parallelism: 1\n",
"Config param ici_pipeline_parallelism: 1\n",
"Config param ici_sequence_parallelism: 1\n",
"Config param ici_tensor_parallelism: 4\n",
"Config param inference_metadata_file: \n",
"Config param inference_microbenchmark_log_file_path: \n",
"Config param inference_microbenchmark_loop_iters: 10\n",
"Config param inference_microbenchmark_prefill_lengths: 64,128,256,512,1024\n",
"Config param inference_microbenchmark_stages: prefill,generate\n",
"Config param init_weights_seed: 0\n",
"Config param jax_cache_dir: ~/jax_cache\n",
"Config param jax_profiler_port: 9999\n",
"Config param kv_quant_axis: heads_and_dkv\n",
"Config param kv_quant_dtype: int8\n",
"Config param learning_rate: 3e-05\n",
"Config param learning_rate_schedule_steps: 5\n",
"Config param load_from_prefill_dir: False\n",
"Config param load_full_state_path: \n",
"Config param load_parameters_path: \n",
"Config param local_checkpoint_directory: \n",
"Config param local_checkpoint_period: 0\n",
"Config param log_period: 100\n",
"Config param logical_axis_rules: (('activation_batch', ('data', 'fsdp', 'fsdp_transpose')), ('activation_embed_and_logits_batch', ('stage', 'data', 'fsdp', 'fsdp_transpose')), ('activation_heads', ('tensor', 'sequence')), ('activation_kv_heads', ('tensor', 'sequence')), ('activation_length', 'sequence'), ('activation_embed', 'tensor'), ('activation_mlp', 'tensor'), ('activation_kv', 'tensor'), ('activation_kv_batch', ('data', 'fsdp', 'fsdp_transpose')), ('activation_kv_head_dim', 'tensor'), ('activation_vocab', ('tensor', 'sequence')), ('activation_vocab', 'tensor'), ('activation_vocab', 'sequence'), ('activation_stage', 'stage'), ('mlp', ('fsdp_transpose', 'tensor', 'autoregressive')), ('vocab', ('tensor', 'autoregressive')), ('embed', ('fsdp', 'fsdp_transpose', 'sequence')), ('embed', ('fsdp', 'sequence')), ('heads', ('tensor', 'autoregressive')), ('layers', 'stage'), ('kv', ()), ('kv_heads', ('tensor', 'autoregressive')), ('kv_head_dim', ()), ('cache_batch', ()), ('cache_heads', ('autoregressive', 'tensor')), ('cache_kv', ()), ('cache_sequence', ()), ('norm', 'fsdp'))\n",
"Config param logits_dot_in_fp32: True\n",
"Config param logits_via_embedding: False\n",
"Config param max_checkify: False\n",
"Config param max_corpus_chars: 10000000\n",
"Config param max_prefill_predict_length: 64\n",
"Config param max_target_length: 4096\n",
"Config param megablox: True\n",
"Config param mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']\n",
"Config param metrics_dir: gs://opmusw4/ipp/maxtext/llama2-7b/demo-test/metrics/\n",
"Config param metrics_file: \n",
"Config param mlp_activations: ['silu', 'linear']\n",
"Config param mlp_dim: 11008\n",
"Config param model_name: llama2-7b\n",
"Config param monitor_goodput: False\n",
"Config param normalization_layer_epsilon: 1e-05\n",
"Config param normalize_embedding_logits: True\n",
"Config param num_decoder_layers: 32\n",
"Config param num_experts: 1\n",
"Config param num_experts_per_tok: 1\n",
"Config param num_kv_heads: 32\n",
"Config param num_layers_per_pipeline_stage: 1\n",
"Config param num_pipeline_microbatches: -1\n",
"Config param num_pipeline_repeats: -1\n",
"Config param num_query_heads: 32\n",
"Config param num_slices: 1\n",
"Config param opt_type: adamw\n",
"Config param param_scan_axis: 1\n",
"Config param per_device_batch_size: 1.0\n",
"Config param prefill_cache_axis_order: 1,2,0,3\n",
"Config param prefill_cache_dir: \n",
"Config param profiler: \n",
"Config param profiler_steps: 5\n",
"Config param prometheus_port: 0\n",
"Config param prompt: I love to\n",
"Config param quant_cfg_path: \n",
"Config param quantization: \n",
"Config param quantization_local_shard_count: 1\n",
"Config param quantize_kvcache: False\n",
"Config param record_internal_nn_metrics: 0\n",
"Config param remat_policy: full\n",
"Config param reshape_q: False\n",
"Config param reuse_example_batch: 0\n",
"Config param rope_max_timescale: 10000\n",
"Config param rope_min_timescale: 1\n",
"Config param run_name: demo-test\n",
"Config param save_config_to_gcs: False\n",
"Config param save_quantized_params_path: \n",
"Config param scan_layers: True\n",
"Config param scan_pipeline_iterations: True\n",
"Config param skip_first_n_steps_for_profiler: 1\n",
"Config param stack_trace_interval_seconds: 600\n",
"Config param stack_trace_to_cloud: False\n",
"Config param steps: 5\n",
"Config param target_eval_loss: 0.0\n",
"Config param tensorboard_dir: gs://opmusw4/ipp/maxtext/llama2-7b/demo-test/tensorboard/\n",
"Config param tokenizer_path: assets/tokenizer.llama2\n",
"Config param trainable_position_size: -1\n",
"Config param upload_all_profiler_results: False\n",
"Config param use_iota_embed: False\n",
"Config param use_untrainable_positional_embedding: False\n",
"Config param use_vertex_tensorboard: False\n",
"Config param using_pipeline_parallelism: False\n",
"Config param vertex_tensorboard_project: \n",
"Config param vertex_tensorboard_region: \n",
"Config param vocab_size: 32000\n",
"Config param warmup_steps_fraction: 0.1\n",
"Config param weight_dtype: float32\n",
"Num_devices: 16, shape (1, 1, 4, 1, 1, 4, 1)\n",
"Setting up checkpoint logger...\n",
"Checkpointing disabled, not creating checkpoint manager.\n",
"Tokenizer path: assets/tokenizer.llama2\n",
"Tokenizer path: assets/tokenizer.llama2\n",
"No existing checkpoints found, not restoring checkpoint.\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"[stdout:0] Updating keys from env and command line: ['run_name', 'model_name', 'enable_checkpointing', 'remat_policy', 'attention', 'base_output_directory', 'ici_fsdp_parallelism', 'ici_fsdp_transpose_parallelism', 'ici_tensor_parallelism', 'per_device_batch_size', 'dataset_type', 'dataset_path', 'steps', 'max_target_length']\n",
"Running Model: llama2-7b\n",
"Updating following parameters in config\n",
"\n",
"base_emb_dim: 4096\n",
"base_num_query_heads: 32\n",
"base_num_kv_heads: 32\n",
"base_mlp_dim: 11008\n",
"base_num_decoder_layers: 32\n",
"head_dim: 128\n",
"mlp_activations: ['silu', 'linear']\n",
"vocab_size: 32000\n",
"enable_dropout: False\n",
"logits_via_embedding: False\n",
"normalization_layer_epsilon: 1e-05\n",
"decoder_block: llama2\n",
"logical_axis_rules: [['norm', 'fsdp']]\n",
"Updating keys from model: ['base_emb_dim', 'base_num_query_heads', 'base_num_kv_heads', 'base_mlp_dim', 'base_num_decoder_layers', 'head_dim', 'mlp_activations', 'vocab_size', 'enable_dropout', 'logits_via_embedding', 'normalization_layer_epsilon', 'decoder_block', 'logical_axis_rules']\n",
"System Information: Jax Version: 0.4.30\n",
"System Information: Jaxlib Version: 0.4.30\n",
"System Information: Jax Backend: PJRT C API\n",
"TFRT TPU v5 lite\n",
"Built on Jun 17 2024 03:03:47 (1718618627) cl/643897370\n",
"Not using emergency checkpoint, ignoring local_checkpoint_directory and local_checkpoint_period\n",
"dataset_type set to tfds, will use keys['dataset_path']='gs://maxtext-dataset' and keys['dataset_name']='c4/en:3.0.1'\n",
"Config param adam_b1: 0.9\n",
"Config param adam_b2: 0.95\n",
"Config param adam_eps: 1e-08\n",
"Config param adam_eps_root: 0.0\n",
"Config param adam_weight_decay: 0.1\n",
"Config param allow_split_physical_axes: False\n",
"Config param ar_cache_axis_order: 1,2,0,3\n",
"Config param async_checkpointing: True\n",
"Config param attention: dot_product\n",
"Config param autoregressive_decode_assert: \n",
"Config param base_emb_dim: 4096\n",
"Config param base_mlp_dim: 11008\n",
"Config param base_num_decoder_layers: 32\n",
"Config param base_num_kv_heads: 32\n",
"Config param base_num_query_heads: 32\n",
"Config param base_output_directory: gs://opmusw4/ipp/maxtext/llama2-7b/\n",
"Config param checkpoint_dir: gs://opmusw4/ipp/maxtext/llama2-7b/demo-test/checkpoints/\n",
"Config param checkpoint_is_quantized: False\n",
"Config param checkpoint_period: 10000\n",
"Config param collect_stack_trace: False\n",
"Config param compile_topology: \n",
"Config param compile_topology_num_slices: -1\n",
"Config param compiled_trainstep_file: \n",
"Config param compute_axis_order: 0,1,2,3\n",
"Config param cosine_learning_rate_final_fraction: 0.1\n",
"Config param data_sharding: (('data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive'),)\n",
"Config param data_shuffle_seed: 0\n",
"Config param dataset_name: c4/en:3.0.1\n",
"Config param dataset_path: gs://maxtext-dataset\n",
"Config param dataset_type: tfds\n",
"Config param dcn_autoregressive_parallelism: 1\n",
"Config param dcn_data_parallelism: -1\n",
"Config param dcn_fsdp_parallelism: 1\n",
"Config param dcn_fsdp_transpose_parallelism: 1\n",
"Config param dcn_pipeline_parallelism: 1\n",
"Config param dcn_sequence_parallelism: 1\n",
"Config param dcn_tensor_parallelism: 1\n",
"Config param decode_sampling_nucleus_p: -1\n",
"Config param decode_sampling_strategy: greedy\n",
"Config param decode_sampling_temperature: 1.0\n",
"Config param decode_sampling_top_k: 0\n",
"Config param decoder_block: llama2\n",
"Config param dropout_rate: 0\n",
"Config param dtype: bfloat16\n",
"Config param emb_dim: 4096\n",
"Config param enable_checkpoint_cloud_logger: False\n",
"Config param enable_checkpoint_standard_logger: False\n",
"Config param enable_checkpointing: False\n",
"Config param enable_data_shuffling: True\n",
"Config param enable_dropout: False\n",
"Config param enable_emergency_checkpoint: False\n",
"Config param enable_goodput_recording: False\n",
"Config param enable_jax_profiler: False\n",
"Config param enable_single_controller: False\n",
"Config param enable_single_replica_ckpt_restoring: False\n",
"Config param eval_batch_num: -1\n",
"Config param eval_dataset_name: c4/en:3.0.1\n",
"Config param eval_interval: -1\n",
"Config param eval_per_device_batch_size: 0\n",
"Config param eval_split: validation\n",
"Config param expansion_factor_real_data: -1\n",
"Config param force_unroll: False\n",
"Config param fused_mlp: False\n",
"Config param fused_qkv: False\n",
"Config param gcs_metrics: False\n",
"Config param global_batch_size_to_load: 16\n",
"Config param global_batch_size_to_train_on: 16\n",
"Config param global_parameter_scale: 1\n",
"Config param goodput_upload_interval_seconds: 60\n",
"Config param gradient_clipping_threshold: 1.0\n",
"Config param grain_eval_files: \n",
"Config param grain_train_files: \n",
"Config param grain_worker_count: 1\n",
"Config param hardware: tpu\n",
"Config param head_dim: 128\n",
"Config param hf_access_token: \n",
"Config param hf_data_dir: \n",
"Config param hf_eval_files: \n",
"Config param hf_eval_split: \n",
"Config param hf_path: \n",
"Config param hf_train_files: \n",
"Config param ici_autoregressive_parallelism: 1\n",
"Config param ici_data_parallelism: 1\n",
"Config param ici_fsdp_parallelism: 4\n",
"Config param ici_fsdp_transpose_parallelism: 1\n",
"Config param ici_pipeline_parallelism: 1\n",
"Config param ici_sequence_parallelism: 1\n",
"Config param ici_tensor_parallelism: 4\n",
"Config param inference_metadata_file: \n",
"Config param inference_microbenchmark_log_file_path: \n",
"Config param inference_microbenchmark_loop_iters: 10\n",
"Config param inference_microbenchmark_prefill_lengths: 64,128,256,512,1024\n",
"Config param inference_microbenchmark_stages: prefill,generate\n",
"Config param init_weights_seed: 0\n",
"Config param jax_cache_dir: ~/jax_cache\n",
"Config param jax_profiler_port: 9999\n",
"Config param kv_quant_axis: heads_and_dkv\n",
"Config param kv_quant_dtype: int8\n",
"Config param learning_rate: 3e-05\n",
"Config param learning_rate_schedule_steps: 5\n",
"Config param load_from_prefill_dir: False\n",
"Config param load_full_state_path: \n",
"Config param load_parameters_path: \n",
"Config param local_checkpoint_directory: \n",
"Config param local_checkpoint_period: 0\n",
"Config param log_period: 100\n",
"Config param logical_axis_rules: (('activation_batch', ('data', 'fsdp', 'fsdp_transpose')), ('activation_embed_and_logits_batch', ('stage', 'data', 'fsdp', 'fsdp_transpose')), ('activation_heads', ('tensor', 'sequence')), ('activation_kv_heads', ('tensor', 'sequence')), ('activation_length', 'sequence'), ('activation_embed', 'tensor'), ('activation_mlp', 'tensor'), ('activation_kv', 'tensor'), ('activation_kv_batch', ('data', 'fsdp', 'fsdp_transpose')), ('activation_kv_head_dim', 'tensor'), ('activation_vocab', ('tensor', 'sequence')), ('activation_vocab', 'tensor'), ('activation_vocab', 'sequence'), ('activation_stage', 'stage'), ('mlp', ('fsdp_transpose', 'tensor', 'autoregressive')), ('vocab', ('tensor', 'autoregressive')), ('embed', ('fsdp', 'fsdp_transpose', 'sequence')), ('embed', ('fsdp', 'sequence')), ('heads', ('tensor', 'autoregressive')), ('layers', 'stage'), ('kv', ()), ('kv_heads', ('tensor', 'autoregressive')), ('kv_head_dim', ()), ('cache_batch', ()), ('cache_heads', ('autoregressive', 'tensor')), ('cache_kv', ()), ('cache_sequence', ()), ('norm', 'fsdp'))\n",
"Config param logits_dot_in_fp32: True\n",
"Config param logits_via_embedding: False\n",
"Config param max_checkify: False\n",
"Config param max_corpus_chars: 10000000\n",
"Config param max_prefill_predict_length: 64\n",
"Config param max_target_length: 4096\n",
"Config param megablox: True\n",
"Config param mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']\n",
"Config param metrics_dir: gs://opmusw4/ipp/maxtext/llama2-7b/demo-test/metrics/\n",
"Config param metrics_file: \n",
"Config param mlp_activations: ['silu', 'linear']\n",
"Config param mlp_dim: 11008\n",
"Config param model_name: llama2-7b\n",
"Config param monitor_goodput: False\n",
"Config param normalization_layer_epsilon: 1e-05\n",
"Config param normalize_embedding_logits: True\n",
"Config param num_decoder_layers: 32\n",
"Config param num_experts: 1\n",
"Config param num_experts_per_tok: 1\n",
"Config param num_kv_heads: 32\n",
"Config param num_layers_per_pipeline_stage: 1\n",
"Config param num_pipeline_microbatches: -1\n",
"Config param num_pipeline_repeats: -1\n",
"Config param num_query_heads: 32\n",
"Config param num_slices: 1\n",
"Config param opt_type: adamw\n",
"Config param param_scan_axis: 1\n",
"Config param per_device_batch_size: 1.0\n",
"Config param prefill_cache_axis_order: 1,2,0,3\n",
"Config param prefill_cache_dir: \n",
"Config param profiler: \n",
"Config param profiler_steps: 5\n",
"Config param prometheus_port: 0\n",
"Config param prompt: I love to\n",
"Config param quant_cfg_path: \n",
"Config param quantization: \n",
"Config param quantization_local_shard_count: 1\n",
"Config param quantize_kvcache: False\n",
"Config param record_internal_nn_metrics: 0\n",
"Config param remat_policy: full\n",
"Config param reshape_q: False\n",
"Config param reuse_example_batch: 0\n",
"Config param rope_max_timescale: 10000\n",
"Config param rope_min_timescale: 1\n",
"Config param run_name: demo-test\n",
"Config param save_config_to_gcs: False\n",
"Config param save_quantized_params_path: \n",
"Config param scan_layers: True\n",
"Config param scan_pipeline_iterations: True\n",
"Config param skip_first_n_steps_for_profiler: 1\n",
"Config param stack_trace_interval_seconds: 600\n",
"Config param stack_trace_to_cloud: False\n",
"Config param steps: 5\n",
"Config param target_eval_loss: 0.0\n",
"Config param tensorboard_dir: gs://opmusw4/ipp/maxtext/llama2-7b/demo-test/tensorboard/\n",
"Config param tokenizer_path: assets/tokenizer.llama2\n",
"Config param trainable_position_size: -1\n",
"Config param upload_all_profiler_results: False\n",
"Config param use_iota_embed: False\n",
"Config param use_untrainable_positional_embedding: False\n",
"Config param use_vertex_tensorboard: False\n",
"Config param using_pipeline_parallelism: False\n",
"Config param vertex_tensorboard_project: \n",
"Config param vertex_tensorboard_region: \n",
"Config param vocab_size: 32000\n",
"Config param warmup_steps_fraction: 0.1\n",
"Config param weight_dtype: float32\n",
"Num_devices: 16, shape (1, 1, 4, 1, 1, 4, 1)\n",
"Setting up checkpoint logger...\n",
"Checkpointing disabled, not creating checkpoint manager.\n",
"Tokenizer path: assets/tokenizer.llama2\n",
"Tokenizer path: assets/tokenizer.llama2\n",
"No existing checkpoints found, not restoring checkpoint.\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"[stdout:2] Updating keys from env and command line: ['run_name', 'model_name', 'enable_checkpointing', 'remat_policy', 'attention', 'base_output_directory', 'ici_fsdp_parallelism', 'ici_fsdp_transpose_parallelism', 'ici_tensor_parallelism', 'per_device_batch_size', 'dataset_type', 'dataset_path', 'steps', 'max_target_length']\n",
"Running Model: llama2-7b\n",
"Updating following parameters in config\n",
"\n",
"base_emb_dim: 4096\n",
"base_num_query_heads: 32\n",
"base_num_kv_heads: 32\n",
"base_mlp_dim: 11008\n",
"base_num_decoder_layers: 32\n",
"head_dim: 128\n",
"mlp_activations: ['silu', 'linear']\n",
"vocab_size: 32000\n",
"enable_dropout: False\n",
"logits_via_embedding: False\n",
"normalization_layer_epsilon: 1e-05\n",
"decoder_block: llama2\n",
"logical_axis_rules: [['norm', 'fsdp']]\n",
"Updating keys from model: ['base_emb_dim', 'base_num_query_heads', 'base_num_kv_heads', 'base_mlp_dim', 'base_num_decoder_layers', 'head_dim', 'mlp_activations', 'vocab_size', 'enable_dropout', 'logits_via_embedding', 'normalization_layer_epsilon', 'decoder_block', 'logical_axis_rules']\n",
"System Information: Jax Version: 0.4.30\n",
"System Information: Jaxlib Version: 0.4.30\n",
"System Information: Jax Backend: PJRT C API\n",
"TFRT TPU v5 lite\n",
"Built on Jun 17 2024 03:03:47 (1718618627) cl/643897370\n",
"Not using emergency checkpoint, ignoring local_checkpoint_directory and local_checkpoint_period\n",
"dataset_type set to tfds, will use keys['dataset_path']='gs://maxtext-dataset' and keys['dataset_name']='c4/en:3.0.1'\n",
"Config param adam_b1: 0.9\n",
"Config param adam_b2: 0.95\n",
"Config param adam_eps: 1e-08\n",
"Config param adam_eps_root: 0.0\n",
"Config param adam_weight_decay: 0.1\n",
"Config param allow_split_physical_axes: False\n",
"Config param ar_cache_axis_order: 1,2,0,3\n",
"Config param async_checkpointing: True\n",
"Config param attention: dot_product\n",
"Config param autoregressive_decode_assert: \n",
"Config param base_emb_dim: 4096\n",
"Config param base_mlp_dim: 11008\n",
"Config param base_num_decoder_layers: 32\n",
"Config param base_num_kv_heads: 32\n",
"Config param base_num_query_heads: 32\n",
"Config param base_output_directory: gs://opmusw4/ipp/maxtext/llama2-7b/\n",
"Config param checkpoint_dir: gs://opmusw4/ipp/maxtext/llama2-7b/demo-test/checkpoints/\n",
"Config param checkpoint_is_quantized: False\n",
"Config param checkpoint_period: 10000\n",
"Config param collect_stack_trace: False\n",
"Config param compile_topology: \n",
"Config param compile_topology_num_slices: -1\n",
"Config param compiled_trainstep_file: \n",
"Config param compute_axis_order: 0,1,2,3\n",
"Config param cosine_learning_rate_final_fraction: 0.1\n",
"Config param data_sharding: (('data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive'),)\n",
"Config param data_shuffle_seed: 0\n",
"Config param dataset_name: c4/en:3.0.1\n",
"Config param dataset_path: gs://maxtext-dataset\n",
"Config param dataset_type: tfds\n",
"Config param dcn_autoregressive_parallelism: 1\n",
"Config param dcn_data_parallelism: -1\n",
"Config param dcn_fsdp_parallelism: 1\n",
"Config param dcn_fsdp_transpose_parallelism: 1\n",
"Config param dcn_pipeline_parallelism: 1\n",
"Config param dcn_sequence_parallelism: 1\n",
"Config param dcn_tensor_parallelism: 1\n",
"Config param decode_sampling_nucleus_p: -1\n",
"Config param decode_sampling_strategy: greedy\n",
"Config param decode_sampling_temperature: 1.0\n",
"Config param decode_sampling_top_k: 0\n",
"Config param decoder_block: llama2\n",
"Config param dropout_rate: 0\n",
"Config param dtype: bfloat16\n",
"Config param emb_dim: 4096\n",
"Config param enable_checkpoint_cloud_logger: False\n",
"Config param enable_checkpoint_standard_logger: False\n",
"Config param enable_checkpointing: False\n",
"Config param enable_data_shuffling: True\n",
"Config param enable_dropout: False\n",
"Config param enable_emergency_checkpoint: False\n",
"Config param enable_goodput_recording: False\n",
"Config param enable_jax_profiler: False\n",
"Config param enable_single_controller: False\n",
"Config param enable_single_replica_ckpt_restoring: False\n",
"Config param eval_batch_num: -1\n",
"Config param eval_dataset_name: c4/en:3.0.1\n",
"Config param eval_interval: -1\n",
"Config param eval_per_device_batch_size: 0\n",
"Config param eval_split: validation\n",
"Config param expansion_factor_real_data: -1\n",
"Config param force_unroll: False\n",
"Config param fused_mlp: False\n",
"Config param fused_qkv: False\n",
"Config param gcs_metrics: False\n",
"Config param global_batch_size_to_load: 16\n",
"Config param global_batch_size_to_train_on: 16\n",
"Config param global_parameter_scale: 1\n",
"Config param goodput_upload_interval_seconds: 60\n",
"Config param gradient_clipping_threshold: 1.0\n",
"Config param grain_eval_files: \n",
"Config param grain_train_files: \n",
"Config param grain_worker_count: 1\n",
"Config param hardware: tpu\n",
"Config param head_dim: 128\n",
"Config param hf_access_token: \n",
"Config param hf_data_dir: \n",
"Config param hf_eval_files: \n",
"Config param hf_eval_split: \n",
"Config param hf_path: \n",
"Config param hf_train_files: \n",
"Config param ici_autoregressive_parallelism: 1\n",
"Config param ici_data_parallelism: 1\n",
"Config param ici_fsdp_parallelism: 4\n",
"Config param ici_fsdp_transpose_parallelism: 1\n",
"Config param ici_pipeline_parallelism: 1\n",
"Config param ici_sequence_parallelism: 1\n",
"Config param ici_tensor_parallelism: 4\n",
"Config param inference_metadata_file: \n",
"Config param inference_microbenchmark_log_file_path: \n",
"Config param inference_microbenchmark_loop_iters: 10\n",
"Config param inference_microbenchmark_prefill_lengths: 64,128,256,512,1024\n",
"Config param inference_microbenchmark_stages: prefill,generate\n",
"Config param init_weights_seed: 0\n",
"Config param jax_cache_dir: ~/jax_cache\n",
"Config param jax_profiler_port: 9999\n",
"Config param kv_quant_axis: heads_and_dkv\n",
"Config param kv_quant_dtype: int8\n",
"Config param learning_rate: 3e-05\n",
"Config param learning_rate_schedule_steps: 5\n",
"Config param load_from_prefill_dir: False\n",
"Config param load_full_state_path: \n",
"Config param load_parameters_path: \n",
"Config param local_checkpoint_directory: \n",
"Config param local_checkpoint_period: 0\n",
"Config param log_period: 100\n",
"Config param logical_axis_rules: (('activation_batch', ('data', 'fsdp', 'fsdp_transpose')), ('activation_embed_and_logits_batch', ('stage', 'data', 'fsdp', 'fsdp_transpose')), ('activation_heads', ('tensor', 'sequence')), ('activation_kv_heads', ('tensor', 'sequence')), ('activation_length', 'sequence'), ('activation_embed', 'tensor'), ('activation_mlp', 'tensor'), ('activation_kv', 'tensor'), ('activation_kv_batch', ('data', 'fsdp', 'fsdp_transpose')), ('activation_kv_head_dim', 'tensor'), ('activation_vocab', ('tensor', 'sequence')), ('activation_vocab', 'tensor'), ('activation_vocab', 'sequence'), ('activation_stage', 'stage'), ('mlp', ('fsdp_transpose', 'tensor', 'autoregressive')), ('vocab', ('tensor', 'autoregressive')), ('embed', ('fsdp', 'fsdp_transpose', 'sequence')), ('embed', ('fsdp', 'sequence')), ('heads', ('tensor', 'autoregressive')), ('layers', 'stage'), ('kv', ()), ('kv_heads', ('tensor', 'autoregressive')), ('kv_head_dim', ()), ('cache_batch', ()), ('cache_heads', ('autoregressive', 'tensor')), ('cache_kv', ()), ('cache_sequence', ()), ('norm', 'fsdp'))\n",
"Config param logits_dot_in_fp32: True\n",
"Config param logits_via_embedding: False\n",
"Config param max_checkify: False\n",
"Config param max_corpus_chars: 10000000\n",
"Config param max_prefill_predict_length: 64\n",
"Config param max_target_length: 4096\n",
"Config param megablox: True\n",
"Config param mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']\n",
"Config param metrics_dir: gs://opmusw4/ipp/maxtext/llama2-7b/demo-test/metrics/\n",
"Config param metrics_file: \n",
"Config param mlp_activations: ['silu', 'linear']\n",
"Config param mlp_dim: 11008\n",
"Config param model_name: llama2-7b\n",
"Config param monitor_goodput: False\n",
"Config param normalization_layer_epsilon: 1e-05\n",
"Config param normalize_embedding_logits: True\n",
"Config param num_decoder_layers: 32\n",
"Config param num_experts: 1\n",
"Config param num_experts_per_tok: 1\n",
"Config param num_kv_heads: 32\n",
"Config param num_layers_per_pipeline_stage: 1\n",
"Config param num_pipeline_microbatches: -1\n",
"Config param num_pipeline_repeats: -1\n",
"Config param num_query_heads: 32\n",
"Config param num_slices: 1\n",
"Config param opt_type: adamw\n",
"Config param param_scan_axis: 1\n",
"Config param per_device_batch_size: 1.0\n",
"Config param prefill_cache_axis_order: 1,2,0,3\n",
"Config param prefill_cache_dir: \n",
"Config param profiler: \n",
"Config param profiler_steps: 5\n",
"Config param prometheus_port: 0\n",
"Config param prompt: I love to\n",
"Config param quant_cfg_path: \n",
"Config param quantization: \n",
"Config param quantization_local_shard_count: 1\n",
"Config param quantize_kvcache: False\n",
"Config param record_internal_nn_metrics: 0\n",
"Config param remat_policy: full\n",
"Config param reshape_q: False\n",
"Config param reuse_example_batch: 0\n",
"Config param rope_max_timescale: 10000\n",
"Config param rope_min_timescale: 1\n",
"Config param run_name: demo-test\n",
"Config param save_config_to_gcs: False\n",
"Config param save_quantized_params_path: \n",
"Config param scan_layers: True\n",
"Config param scan_pipeline_iterations: True\n",
"Config param skip_first_n_steps_for_profiler: 1\n",
"Config param stack_trace_interval_seconds: 600\n",
"Config param stack_trace_to_cloud: False\n",
"Config param steps: 5\n",
"Config param target_eval_loss: 0.0\n",
"Config param tensorboard_dir: gs://opmusw4/ipp/maxtext/llama2-7b/demo-test/tensorboard/\n",
"Config param tokenizer_path: assets/tokenizer.llama2\n",
"Config param trainable_position_size: -1\n",
"Config param upload_all_profiler_results: False\n",
"Config param use_iota_embed: False\n",
"Config param use_untrainable_positional_embedding: False\n",
"Config param use_vertex_tensorboard: False\n",
"Config param using_pipeline_parallelism: False\n",
"Config param vertex_tensorboard_project: \n",
"Config param vertex_tensorboard_region: \n",
"Config param vocab_size: 32000\n",
"Config param warmup_steps_fraction: 0.1\n",
"Config param weight_dtype: float32\n",
"Num_devices: 16, shape (1, 1, 4, 1, 1, 4, 1)\n",
"Setting up checkpoint logger...\n",
"Checkpointing disabled, not creating checkpoint manager.\n",
"Tokenizer path: assets/tokenizer.llama2\n",
"Tokenizer path: assets/tokenizer.llama2\n",
"No existing checkpoints found, not restoring checkpoint.\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"[stderr:3] /usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"[stderr:1] /usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"[stderr:0] /usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"[stderr:2] /usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"%px: 100%|██████████| 4/4 [00:01<00:00, 2.68tasks/s]\n"
]
}
],
"source": [
"%%px --block --group-outputs=engine\n",
"\n",
"config, init_rng, data_iterator, state, p_train_step, mesh = (\n",
" config_llama7b_model(\n",
" ici_fsdp_parallelism=4,\n",
" ici_fsdp_transpose_parallelism=1,\n",
" ici_tensor_parallelism=4,\n",
" per_device_batch_size=1\n",
" )\n",
")"
]
},
{
"cell_type": "markdown",
"id": "7cf5e375-1c5e-474d-8cd4-1df58550cbf5",
"metadata": {},
"source": [
"## Run some training steps\n",
"Now that we have our data in place, let's try to run some training steps."
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "cfd7c20b-1509-4b02-91e5-91984154ba1a",
"metadata": {},
"outputs": [],
"source": [
"%%px --block --group-outputs=engine\n",
"\n",
"import time\n",
"import max_utils\n",
"from train import get_first_step\n",
"\n",
"# https://cloud.google.com/tpu/docs/v5e - 197 for v5e\n",
"TPU_V5e_PEAK_TFLOPS = 197.0\n",
"\n",
"def train(config, init_rng, data_iterator, state, p_train_step, mesh):\n",
" num_model_parameters = max_utils.calculate_num_params_from_pytree(\n",
" state.params\n",
" )\n",
" # print(f'number parameters: {num_model_parameters/1e9:.3f} billion')\n",
" per_device_tflops, _, _ = maxtext_utils.calculate_tflops_training_per_device(config)\n",
" example_batch = None\n",
" last_step_finish = time.time()\n",
" tflops_per_devices = []\n",
" start_step = get_first_step(state)\n",
" num_steps = 20\n",
" for step in range(start_step, start_step + num_steps):\n",
" with jax.profiler.StepTraceAnnotation(\"train\", step_num=step):\n",
" nextrng = jax.jit(jax.random.fold_in)(init_rng, step)\n",
" example_batch = load_next_batch(data_iterator, example_batch, config)\n",
" with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):\n",
" state, metrics = p_train_step(state, example_batch, nextrng)\n",
" step_finish = time.time()\n",
" step_time = step_finish - last_step_finish\n",
" tflops_per_devices.append(per_device_tflops / step_time)\n",
" if step==start_step:\n",
" print(f\" step={step: 02d} loss={metrics['scalar']['learning/loss']:.5f} step_time={step_time:.2f}s\") \n",
" else:\n",
" print(\n",
" f\" step={step: 02d} loss={metrics['scalar']['learning/loss']:.5f} step_time={step_time:.2f}s\"\n",
" f\" TFLOP/s/device:{per_device_tflops/step_time:.3f}, eMFU:\"\n",
" f\" {per_device_tflops/step_time/TPU_V5e_PEAK_TFLOPS*100:.2f}%\")\n",
" last_step_finish = step_finish\n",
" print(f'Parallelism (fsdp={config.ici_fsdp_parallelism}, '\n",
" f'transpose={config.ici_fsdp_transpose_parallelism}, '\n",
" f'tensor={config.ici_tensor_parallelism}); {num_steps} steps, '\n",
" f\"loss={metrics['scalar']['learning/loss']:.5f}, \"\n",
" f'step_time={step_time:.2f}s,',\n",
" f'eMFU={per_device_tflops/step_time/TPU_V5e_PEAK_TFLOPS*100:.2f}%')\n",
"\n",
" return sum(tflops_per_devices[-3:]) / len(tflops_per_devices[-3:]), state"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "c3ae6b27-c69d-4d52-b6fe-ffa449d54c65",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[stdout:3] Per train step:\n",
" Total TFLOPs: 188.76 \n",
" split as 86.02% learnable weight flops and 13.98% attention flops\n",
" step= 0 loss=10.82477 step_time=9.10s\n",
" step= 1 loss=10.74619 step_time=4.26s TFLOP/s/device:44.318, eMFU: 22.50%\n",
" step= 2 loss=11.10764 step_time=3.18s TFLOP/s/device:59.306, eMFU: 30.10%\n",
" step= 3 loss=10.73281 step_time=3.17s TFLOP/s/device:59.572, eMFU: 30.24%\n",
" step= 4 loss=10.04965 step_time=3.17s TFLOP/s/device:59.611, eMFU: 30.26%\n",
" step= 5 loss=9.77972 step_time=3.17s TFLOP/s/device:59.584, eMFU: 30.25%\n",
" step= 6 loss=9.58509 step_time=3.17s TFLOP/s/device:59.624, eMFU: 30.27%\n",
" step= 7 loss=9.43253 step_time=3.17s TFLOP/s/device:59.603, eMFU: 30.26%\n",
" step= 8 loss=8.99779 step_time=3.17s TFLOP/s/device:59.603, eMFU: 30.26%\n",
" step= 9 loss=9.25011 step_time=3.17s TFLOP/s/device:59.561, eMFU: 30.23%\n",
" step= 10 loss=12.71937 step_time=3.17s TFLOP/s/device:59.612, eMFU: 30.26%\n",
" step= 11 loss=10.92250 step_time=3.17s TFLOP/s/device:59.609, eMFU: 30.26%\n",
" step= 12 loss=9.34341 step_time=3.17s TFLOP/s/device:59.618, eMFU: 30.26%\n",
" step= 13 loss=8.45162 step_time=3.17s TFLOP/s/device:59.602, eMFU: 30.26%\n",
" step= 14 loss=8.24009 step_time=3.17s TFLOP/s/device:59.625, eMFU: 30.27%\n",
" step= 15 loss=8.13112 step_time=3.17s TFLOP/s/device:59.589, eMFU: 30.25%\n",
" step= 16 loss=8.20177 step_time=3.18s TFLOP/s/device:59.346, eMFU: 30.12%\n",
" step= 17 loss=8.12323 step_time=3.17s TFLOP/s/device:59.611, eMFU: 30.26%\n",
" step= 18 loss=8.03848 step_time=3.17s TFLOP/s/device:59.619, eMFU: 30.26%\n",
" step= 19 loss=8.01624 step_time=3.17s TFLOP/s/device:59.563, eMFU: 30.23%\n",
"Parallelism (fsdp=4, transpose=1, tensor=4); 20 steps, loss=8.01624, step_time=3.17s, eMFU=30.23%\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"[stdout:1] Per train step:\n",
" Total TFLOPs: 188.76 \n",
" split as 86.02% learnable weight flops and 13.98% attention flops\n",
" step= 0 loss=10.82477 step_time=8.79s\n",
" step= 1 loss=10.74619 step_time=4.57s TFLOP/s/device:41.275, eMFU: 20.95%\n",
" step= 2 loss=11.10764 step_time=3.18s TFLOP/s/device:59.334, eMFU: 30.12%\n",
" step= 3 loss=10.73281 step_time=3.17s TFLOP/s/device:59.615, eMFU: 30.26%\n",
" step= 4 loss=10.04965 step_time=3.17s TFLOP/s/device:59.603, eMFU: 30.26%\n",
" step= 5 loss=9.77972 step_time=3.17s TFLOP/s/device:59.592, eMFU: 30.25%\n",
" step= 6 loss=9.58509 step_time=3.17s TFLOP/s/device:59.599, eMFU: 30.25%\n",
" step= 7 loss=9.43253 step_time=3.17s TFLOP/s/device:59.607, eMFU: 30.26%\n",
" step= 8 loss=8.99779 step_time=3.17s TFLOP/s/device:59.592, eMFU: 30.25%\n",
" step= 9 loss=9.25011 step_time=3.17s TFLOP/s/device:59.605, eMFU: 30.26%\n",
" step= 10 loss=12.71937 step_time=3.17s TFLOP/s/device:59.586, eMFU: 30.25%\n",
" step= 11 loss=10.92250 step_time=3.17s TFLOP/s/device:59.618, eMFU: 30.26%\n",
" step= 12 loss=9.34341 step_time=3.17s TFLOP/s/device:59.600, eMFU: 30.25%\n",
" step= 13 loss=8.45162 step_time=3.17s TFLOP/s/device:59.609, eMFU: 30.26%\n",
" step= 14 loss=8.24009 step_time=3.17s TFLOP/s/device:59.586, eMFU: 30.25%\n",
" step= 15 loss=8.13112 step_time=3.17s TFLOP/s/device:59.620, eMFU: 30.26%\n",
" step= 16 loss=8.20177 step_time=3.18s TFLOP/s/device:59.354, eMFU: 30.13%\n",
" step= 17 loss=8.12323 step_time=3.17s TFLOP/s/device:59.611, eMFU: 30.26%\n",
" step= 18 loss=8.03848 step_time=3.17s TFLOP/s/device:59.573, eMFU: 30.24%\n",
" step= 19 loss=8.01624 step_time=3.17s TFLOP/s/device:59.575, eMFU: 30.24%\n",
"Parallelism (fsdp=4, transpose=1, tensor=4); 20 steps, loss=8.01624, step_time=3.17s, eMFU=30.24%\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"[stdout:0] Per train step:\n",
" Total TFLOPs: 188.76 \n",
" split as 86.02% learnable weight flops and 13.98% attention flops\n",
" step= 0 loss=10.82477 step_time=10.19s\n",
" step= 1 loss=10.74619 step_time=3.17s TFLOP/s/device:59.478, eMFU: 30.19%\n",
" step= 2 loss=11.10764 step_time=3.18s TFLOP/s/device:59.329, eMFU: 30.12%\n",
" step= 3 loss=10.73281 step_time=3.17s TFLOP/s/device:59.582, eMFU: 30.24%\n",
" step= 4 loss=10.04965 step_time=3.17s TFLOP/s/device:59.591, eMFU: 30.25%\n",
" step= 5 loss=9.77972 step_time=3.17s TFLOP/s/device:59.621, eMFU: 30.26%\n",
" step= 6 loss=9.58509 step_time=3.17s TFLOP/s/device:59.582, eMFU: 30.24%\n",
" step= 7 loss=9.43253 step_time=3.17s TFLOP/s/device:59.616, eMFU: 30.26%\n",
" step= 8 loss=8.99779 step_time=3.17s TFLOP/s/device:59.620, eMFU: 30.26%\n",
" step= 9 loss=9.25011 step_time=3.17s TFLOP/s/device:59.558, eMFU: 30.23%\n",
" step= 10 loss=12.71937 step_time=3.17s TFLOP/s/device:59.613, eMFU: 30.26%\n",
" step= 11 loss=10.92250 step_time=3.17s TFLOP/s/device:59.615, eMFU: 30.26%\n",
" step= 12 loss=9.34341 step_time=3.17s TFLOP/s/device:59.581, eMFU: 30.24%\n",
" step= 13 loss=8.45162 step_time=3.17s TFLOP/s/device:59.625, eMFU: 30.27%\n",
" step= 14 loss=8.24009 step_time=3.17s TFLOP/s/device:59.621, eMFU: 30.26%\n",
" step= 15 loss=8.13112 step_time=3.17s TFLOP/s/device:59.590, eMFU: 30.25%\n",
" step= 16 loss=8.20177 step_time=3.18s TFLOP/s/device:59.353, eMFU: 30.13%\n",
" step= 17 loss=8.12323 step_time=3.17s TFLOP/s/device:59.595, eMFU: 30.25%\n",
" step= 18 loss=8.03848 step_time=3.17s TFLOP/s/device:59.604, eMFU: 30.26%\n",
" step= 19 loss=8.01624 step_time=3.17s TFLOP/s/device:59.606, eMFU: 30.26%\n",
"Parallelism (fsdp=4, transpose=1, tensor=4); 20 steps, loss=8.01624, step_time=3.17s, eMFU=30.26%\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"[stdout:2] Per train step:\n",
" Total TFLOPs: 188.76 \n",
" split as 86.02% learnable weight flops and 13.98% attention flops\n",
" step= 0 loss=10.82477 step_time=9.15s\n",
" step= 1 loss=10.74619 step_time=4.21s TFLOP/s/device:44.872, eMFU: 22.78%\n",
" step= 2 loss=11.10764 step_time=3.18s TFLOP/s/device:59.369, eMFU: 30.14%\n",
" step= 3 loss=10.73281 step_time=3.17s TFLOP/s/device:59.592, eMFU: 30.25%\n",
" step= 4 loss=10.04965 step_time=3.17s TFLOP/s/device:59.599, eMFU: 30.25%\n",
" step= 5 loss=9.77972 step_time=3.17s TFLOP/s/device:59.601, eMFU: 30.25%\n",
" step= 6 loss=9.58509 step_time=3.17s TFLOP/s/device:59.589, eMFU: 30.25%\n",
" step= 7 loss=9.43253 step_time=3.17s TFLOP/s/device:59.608, eMFU: 30.26%\n",
" step= 8 loss=8.99779 step_time=3.17s TFLOP/s/device:59.584, eMFU: 30.25%\n",
" step= 9 loss=9.25011 step_time=3.17s TFLOP/s/device:59.610, eMFU: 30.26%\n",
" step= 10 loss=12.71937 step_time=3.17s TFLOP/s/device:59.602, eMFU: 30.25%\n",
" step= 11 loss=10.92250 step_time=3.17s TFLOP/s/device:59.597, eMFU: 30.25%\n",
" step= 12 loss=9.34341 step_time=3.17s TFLOP/s/device:59.622, eMFU: 30.26%\n",
" step= 13 loss=8.45162 step_time=3.17s TFLOP/s/device:59.600, eMFU: 30.25%\n",
" step= 14 loss=8.24009 step_time=3.17s TFLOP/s/device:59.620, eMFU: 30.26%\n",
" step= 15 loss=8.13112 step_time=3.18s TFLOP/s/device:59.295, eMFU: 30.10%\n",
" step= 16 loss=8.20177 step_time=3.16s TFLOP/s/device:59.647, eMFU: 30.28%\n",
" step= 17 loss=8.12323 step_time=3.17s TFLOP/s/device:59.617, eMFU: 30.26%\n",
" step= 18 loss=8.03848 step_time=3.17s TFLOP/s/device:59.590, eMFU: 30.25%\n",
" step= 19 loss=8.01624 step_time=3.17s TFLOP/s/device:59.605, eMFU: 30.26%\n",
"Parallelism (fsdp=4, transpose=1, tensor=4); 20 steps, loss=8.01624, step_time=3.17s, eMFU=30.26%\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"%px: 0%| | 0/4 [00:05<?, ?tasks/s]"
]
},
{
"data": {
"text/plain": [
"[stderr:1] /usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"%px: 0%| | 0/4 [00:05<?, ?tasks/s]"
]
},
{
"data": {
"text/plain": [
"[stderr:3] /usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"%px: 0%| | 0/4 [00:05<?, ?tasks/s]"
]
},
{
"data": {
"text/plain": [
"[stderr:2] /usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"%px: 0%| | 0/4 [00:06<?, ?tasks/s]"
]
},
{
"data": {
"text/plain": [
"[stderr:0] /usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"%px: 100%|██████████| 4/4 [01:11<00:00, 17.89s/tasks]\n"
]
}
],
"source": [
"%%px --block --group-outputs=engine\n",
"\n",
"avg_tflops_per_devices, state = train(\n",
" config, init_rng, data_iterator, state, p_train_step, mesh\n",
")\n",
"# jax.profiler.stop_trace()"
]
},
{
"cell_type": "markdown",
"id": "87df7bd2-0135-442f-abe8-f328e7758636",
"metadata": {},
"source": [
"As shown in this output, MaxText estimates the total TFLOPs required per device, per training step ([code](https://github.com/google/maxtext/blob/b314957fcfc0410aa0cafb734706f6f27020a67c/MaxText/maxtext_utils.py#L99)).\n",
"```\n",
"Per train step:\n",
" Total TFLOPs: 188.76 \n",
" split as 86.02% learnable weight flops and 13.98% attention flops\n",
"```\n",
"We also see that our approach leads to a typical step time of ``3.17s`` and an effective MFU of (approximately) ``30%``.\n",
"```\n",
"Parallelism (fsdp=4, transpose=1, tensor=4); 20 steps, loss=8.01624, step_time=3.17s, eMFU=30.25%\n",
"```\n"
]
},
{
"cell_type": "markdown",
"id": "bdabf641-a2e5-4438-af34-2214d1e6e14d",
"metadata": {},
"source": [
"## Change parallelism strategy\n",
"So what if we changed from 4-way data/4-way tensor to just 16-way data parallelism? "
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "d756894c-11de-4db8-9be0-b9eee2553b81",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[stdout:0] Updating keys from env and command line: ['run_name', 'model_name', 'enable_checkpointing', 'remat_policy', 'attention', 'base_output_directory', 'ici_fsdp_parallelism', 'ici_fsdp_transpose_parallelism', 'ici_tensor_parallelism', 'per_device_batch_size', 'dataset_type', 'dataset_path', 'steps', 'max_target_length']\n",
"Running Model: llama2-7b\n",
"Updating following parameters in config\n",
"\n",
"base_emb_dim: 4096\n",
"base_num_query_heads: 32\n",
"base_num_kv_heads: 32\n",
"base_mlp_dim: 11008\n",
"base_num_decoder_layers: 32\n",
"head_dim: 128\n",
"mlp_activations: ['silu', 'linear']\n",
"vocab_size: 32000\n",
"enable_dropout: False\n",
"logits_via_embedding: False\n",
"normalization_layer_epsilon: 1e-05\n",
"decoder_block: llama2\n",
"logical_axis_rules: [['norm', 'fsdp']]\n",
"Updating keys from model: ['base_emb_dim', 'base_num_query_heads', 'base_num_kv_heads', 'base_mlp_dim', 'base_num_decoder_layers', 'head_dim', 'mlp_activations', 'vocab_size', 'enable_dropout', 'logits_via_embedding', 'normalization_layer_epsilon', 'decoder_block', 'logical_axis_rules']\n",
"System Information: Jax Version: 0.4.30\n",
"System Information: Jaxlib Version: 0.4.30\n",
"System Information: Jax Backend: PJRT C API\n",
"TFRT TPU v5 lite\n",
"Built on Jun 17 2024 03:03:47 (1718618627) cl/643897370\n",
"Not using emergency checkpoint, ignoring local_checkpoint_directory and local_checkpoint_period\n",
"dataset_type set to tfds, will use keys['dataset_path']='gs://maxtext-dataset' and keys['dataset_name']='c4/en:3.0.1'\n",
"Config param adam_b1: 0.9\n",
"Config param adam_b2: 0.95\n",
"Config param adam_eps: 1e-08\n",
"Config param adam_eps_root: 0.0\n",
"Config param adam_weight_decay: 0.1\n",
"Config param allow_split_physical_axes: False\n",
"Config param ar_cache_axis_order: 1,2,0,3\n",
"Config param async_checkpointing: True\n",
"Config param attention: dot_product\n",
"Config param autoregressive_decode_assert: \n",
"Config param base_emb_dim: 4096\n",
"Config param base_mlp_dim: 11008\n",
"Config param base_num_decoder_layers: 32\n",
"Config param base_num_kv_heads: 32\n",
"Config param base_num_query_heads: 32\n",
"Config param base_output_directory: gs://opmusw4/ipp/maxtext/llama2-7b/\n",
"Config param checkpoint_dir: gs://opmusw4/ipp/maxtext/llama2-7b/demo-test/checkpoints/\n",
"Config param checkpoint_is_quantized: False\n",
"Config param checkpoint_period: 10000\n",
"Config param collect_stack_trace: False\n",
"Config param compile_topology: \n",
"Config param compile_topology_num_slices: -1\n",
"Config param compiled_trainstep_file: \n",
"Config param compute_axis_order: 0,1,2,3\n",
"Config param cosine_learning_rate_final_fraction: 0.1\n",
"Config param data_sharding: (('data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive'),)\n",
"Config param data_shuffle_seed: 0\n",
"Config param dataset_name: c4/en:3.0.1\n",
"Config param dataset_path: gs://maxtext-dataset\n",
"Config param dataset_type: tfds\n",
"Config param dcn_autoregressive_parallelism: 1\n",
"Config param dcn_data_parallelism: -1\n",
"Config param dcn_fsdp_parallelism: 1\n",
"Config param dcn_fsdp_transpose_parallelism: 1\n",
"Config param dcn_pipeline_parallelism: 1\n",
"Config param dcn_sequence_parallelism: 1\n",
"Config param dcn_tensor_parallelism: 1\n",
"Config param decode_sampling_nucleus_p: -1\n",
"Config param decode_sampling_strategy: greedy\n",
"Config param decode_sampling_temperature: 1.0\n",
"Config param decode_sampling_top_k: 0\n",
"Config param decoder_block: llama2\n",
"Config param dropout_rate: 0\n",
"Config param dtype: bfloat16\n",
"Config param emb_dim: 4096\n",
"Config param enable_checkpoint_cloud_logger: False\n",
"Config param enable_checkpoint_standard_logger: False\n",
"Config param enable_checkpointing: False\n",
"Config param enable_data_shuffling: True\n",
"Config param enable_dropout: False\n",
"Config param enable_emergency_checkpoint: False\n",
"Config param enable_goodput_recording: False\n",
"Config param enable_jax_profiler: False\n",
"Config param enable_single_controller: False\n",
"Config param enable_single_replica_ckpt_restoring: False\n",
"Config param eval_batch_num: -1\n",
"Config param eval_dataset_name: c4/en:3.0.1\n",
"Config param eval_interval: -1\n",
"Config param eval_per_device_batch_size: 0\n",
"Config param eval_split: validation\n",
"Config param expansion_factor_real_data: -1\n",
"Config param force_unroll: False\n",
"Config param fused_mlp: False\n",
"Config param fused_qkv: False\n",
"Config param gcs_metrics: False\n",
"Config param global_batch_size_to_load: 16\n",
"Config param global_batch_size_to_train_on: 16\n",
"Config param global_parameter_scale: 1\n",
"Config param goodput_upload_interval_seconds: 60\n",
"Config param gradient_clipping_threshold: 1.0\n",
"Config param grain_eval_files: \n",
"Config param grain_train_files: \n",
"Config param grain_worker_count: 1\n",
"Config param hardware: tpu\n",
"Config param head_dim: 128\n",
"Config param hf_access_token: \n",
"Config param hf_data_dir: \n",
"Config param hf_eval_files: \n",
"Config param hf_eval_split: \n",
"Config param hf_path: \n",
"Config param hf_train_files: \n",
"Config param ici_autoregressive_parallelism: 1\n",
"Config param ici_data_parallelism: 1\n",
"Config param ici_fsdp_parallelism: 16\n",
"Config param ici_fsdp_transpose_parallelism: 1\n",
"Config param ici_pipeline_parallelism: 1\n",
"Config param ici_sequence_parallelism: 1\n",
"Config param ici_tensor_parallelism: 1\n",
"Config param inference_metadata_file: \n",
"Config param inference_microbenchmark_log_file_path: \n",
"Config param inference_microbenchmark_loop_iters: 10\n",
"Config param inference_microbenchmark_prefill_lengths: 64,128,256,512,1024\n",
"Config param inference_microbenchmark_stages: prefill,generate\n",
"Config param init_weights_seed: 0\n",
"Config param jax_cache_dir: ~/jax_cache\n",
"Config param jax_profiler_port: 9999\n",
"Config param kv_quant_axis: heads_and_dkv\n",
"Config param kv_quant_dtype: int8\n",
"Config param learning_rate: 3e-05\n",
"Config param learning_rate_schedule_steps: 5\n",
"Config param load_from_prefill_dir: False\n",
"Config param load_full_state_path: \n",
"Config param load_parameters_path: \n",
"Config param local_checkpoint_directory: \n",
"Config param local_checkpoint_period: 0\n",
"Config param log_period: 100\n",
"Config param logical_axis_rules: (('activation_batch', ('data', 'fsdp', 'fsdp_transpose')), ('activation_embed_and_logits_batch', ('stage', 'data', 'fsdp', 'fsdp_transpose')), ('activation_heads', ('tensor', 'sequence')), ('activation_kv_heads', ('tensor', 'sequence')), ('activation_length', 'sequence'), ('activation_embed', 'tensor'), ('activation_mlp', 'tensor'), ('activation_kv', 'tensor'), ('activation_kv_batch', ('data', 'fsdp', 'fsdp_transpose')), ('activation_kv_head_dim', 'tensor'), ('activation_vocab', ('tensor', 'sequence')), ('activation_vocab', 'tensor'), ('activation_vocab', 'sequence'), ('activation_stage', 'stage'), ('mlp', ('fsdp_transpose', 'tensor', 'autoregressive')), ('vocab', ('tensor', 'autoregressive')), ('embed', ('fsdp', 'fsdp_transpose', 'sequence')), ('embed', ('fsdp', 'sequence')), ('heads', ('tensor', 'autoregressive')), ('layers', 'stage'), ('kv', ()), ('kv_heads', ('tensor', 'autoregressive')), ('kv_head_dim', ()), ('cache_batch', ()), ('cache_heads', ('autoregressive', 'tensor')), ('cache_kv', ()), ('cache_sequence', ()), ('norm', 'fsdp'))\n",
"Config param logits_dot_in_fp32: True\n",
"Config param logits_via_embedding: False\n",
"Config param max_checkify: False\n",
"Config param max_corpus_chars: 10000000\n",
"Config param max_prefill_predict_length: 64\n",
"Config param max_target_length: 4096\n",
"Config param megablox: True\n",
"Config param mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']\n",
"Config param metrics_dir: gs://opmusw4/ipp/maxtext/llama2-7b/demo-test/metrics/\n",
"Config param metrics_file: \n",
"Config param mlp_activations: ['silu', 'linear']\n",
"Config param mlp_dim: 11008\n",
"Config param model_name: llama2-7b\n",
"Config param monitor_goodput: False\n",
"Config param normalization_layer_epsilon: 1e-05\n",
"Config param normalize_embedding_logits: True\n",
"Config param num_decoder_layers: 32\n",
"Config param num_experts: 1\n",
"Config param num_experts_per_tok: 1\n",
"Config param num_kv_heads: 32\n",
"Config param num_layers_per_pipeline_stage: 1\n",
"Config param num_pipeline_microbatches: -1\n",
"Config param num_pipeline_repeats: -1\n",
"Config param num_query_heads: 32\n",
"Config param num_slices: 1\n",
"Config param opt_type: adamw\n",
"Config param param_scan_axis: 1\n",
"Config param per_device_batch_size: 1.0\n",
"Config param prefill_cache_axis_order: 1,2,0,3\n",
"Config param prefill_cache_dir: \n",
"Config param profiler: \n",
"Config param profiler_steps: 5\n",
"Config param prometheus_port: 0\n",
"Config param prompt: I love to\n",
"Config param quant_cfg_path: \n",
"Config param quantization: \n",
"Config param quantization_local_shard_count: 1\n",
"Config param quantize_kvcache: False\n",
"Config param record_internal_nn_metrics: 0\n",
"Config param remat_policy: full\n",
"Config param reshape_q: False\n",
"Config param reuse_example_batch: 0\n",
"Config param rope_max_timescale: 10000\n",
"Config param rope_min_timescale: 1\n",
"Config param run_name: demo-test\n",
"Config param save_config_to_gcs: False\n",
"Config param save_quantized_params_path: \n",
"Config param scan_layers: True\n",
"Config param scan_pipeline_iterations: True\n",
"Config param skip_first_n_steps_for_profiler: 1\n",
"Config param stack_trace_interval_seconds: 600\n",
"Config param stack_trace_to_cloud: False\n",
"Config param steps: 5\n",
"Config param target_eval_loss: 0.0\n",
"Config param tensorboard_dir: gs://opmusw4/ipp/maxtext/llama2-7b/demo-test/tensorboard/\n",
"Config param tokenizer_path: assets/tokenizer.llama2\n",
"Config param trainable_position_size: -1\n",
"Config param upload_all_profiler_results: False\n",
"Config param use_iota_embed: False\n",
"Config param use_untrainable_positional_embedding: False\n",
"Config param use_vertex_tensorboard: False\n",
"Config param using_pipeline_parallelism: False\n",
"Config param vertex_tensorboard_project: \n",
"Config param vertex_tensorboard_region: \n",
"Config param vocab_size: 32000\n",
"Config param warmup_steps_fraction: 0.1\n",
"Config param weight_dtype: float32\n",
"Num_devices: 16, shape (1, 1, 16, 1, 1, 1, 1)\n",
"Setting up checkpoint logger...\n",
"Checkpointing disabled, not creating checkpoint manager.\n",
"Tokenizer path: assets/tokenizer.llama2\n",
"Tokenizer path: assets/tokenizer.llama2\n",
"No existing checkpoints found, not restoring checkpoint.\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"[stdout:2] Updating keys from env and command line: ['run_name', 'model_name', 'enable_checkpointing', 'remat_policy', 'attention', 'base_output_directory', 'ici_fsdp_parallelism', 'ici_fsdp_transpose_parallelism', 'ici_tensor_parallelism', 'per_device_batch_size', 'dataset_type', 'dataset_path', 'steps', 'max_target_length']\n",
"Running Model: llama2-7b\n",
"Updating following parameters in config\n",
"\n",
"base_emb_dim: 4096\n",
"base_num_query_heads: 32\n",
"base_num_kv_heads: 32\n",
"base_mlp_dim: 11008\n",
"base_num_decoder_layers: 32\n",
"head_dim: 128\n",
"mlp_activations: ['silu', 'linear']\n",
"vocab_size: 32000\n",
"enable_dropout: False\n",
"logits_via_embedding: False\n",
"normalization_layer_epsilon: 1e-05\n",
"decoder_block: llama2\n",
"logical_axis_rules: [['norm', 'fsdp']]\n",
"Updating keys from model: ['base_emb_dim', 'base_num_query_heads', 'base_num_kv_heads', 'base_mlp_dim', 'base_num_decoder_layers', 'head_dim', 'mlp_activations', 'vocab_size', 'enable_dropout', 'logits_via_embedding', 'normalization_layer_epsilon', 'decoder_block', 'logical_axis_rules']\n",
"System Information: Jax Version: 0.4.30\n",
"System Information: Jaxlib Version: 0.4.30\n",
"System Information: Jax Backend: PJRT C API\n",
"TFRT TPU v5 lite\n",
"Built on Jun 17 2024 03:03:47 (1718618627) cl/643897370\n",
"Not using emergency checkpoint, ignoring local_checkpoint_directory and local_checkpoint_period\n",
"dataset_type set to tfds, will use keys['dataset_path']='gs://maxtext-dataset' and keys['dataset_name']='c4/en:3.0.1'\n",
"Config param adam_b1: 0.9\n",
"Config param adam_b2: 0.95\n",
"Config param adam_eps: 1e-08\n",
"Config param adam_eps_root: 0.0\n",
"Config param adam_weight_decay: 0.1\n",
"Config param allow_split_physical_axes: False\n",
"Config param ar_cache_axis_order: 1,2,0,3\n",
"Config param async_checkpointing: True\n",
"Config param attention: dot_product\n",
"Config param autoregressive_decode_assert: \n",
"Config param base_emb_dim: 4096\n",
"Config param base_mlp_dim: 11008\n",
"Config param base_num_decoder_layers: 32\n",
"Config param base_num_kv_heads: 32\n",
"Config param base_num_query_heads: 32\n",
"Config param base_output_directory: gs://opmusw4/ipp/maxtext/llama2-7b/\n",
"Config param checkpoint_dir: gs://opmusw4/ipp/maxtext/llama2-7b/demo-test/checkpoints/\n",
"Config param checkpoint_is_quantized: False\n",
"Config param checkpoint_period: 10000\n",
"Config param collect_stack_trace: False\n",
"Config param compile_topology: \n",
"Config param compile_topology_num_slices: -1\n",
"Config param compiled_trainstep_file: \n",
"Config param compute_axis_order: 0,1,2,3\n",
"Config param cosine_learning_rate_final_fraction: 0.1\n",
"Config param data_sharding: (('data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive'),)\n",
"Config param data_shuffle_seed: 0\n",
"Config param dataset_name: c4/en:3.0.1\n",
"Config param dataset_path: gs://maxtext-dataset\n",
"Config param dataset_type: tfds\n",
"Config param dcn_autoregressive_parallelism: 1\n",
"Config param dcn_data_parallelism: -1\n",
"Config param dcn_fsdp_parallelism: 1\n",
"Config param dcn_fsdp_transpose_parallelism: 1\n",
"Config param dcn_pipeline_parallelism: 1\n",
"Config param dcn_sequence_parallelism: 1\n",
"Config param dcn_tensor_parallelism: 1\n",
"Config param decode_sampling_nucleus_p: -1\n",
"Config param decode_sampling_strategy: greedy\n",
"Config param decode_sampling_temperature: 1.0\n",
"Config param decode_sampling_top_k: 0\n",
"Config param decoder_block: llama2\n",
"Config param dropout_rate: 0\n",
"Config param dtype: bfloat16\n",
"Config param emb_dim: 4096\n",
"Config param enable_checkpoint_cloud_logger: False\n",
"Config param enable_checkpoint_standard_logger: False\n",
"Config param enable_checkpointing: False\n",
"Config param enable_data_shuffling: True\n",
"Config param enable_dropout: False\n",
"Config param enable_emergency_checkpoint: False\n",
"Config param enable_goodput_recording: False\n",
"Config param enable_jax_profiler: False\n",
"Config param enable_single_controller: False\n",
"Config param enable_single_replica_ckpt_restoring: False\n",
"Config param eval_batch_num: -1\n",
"Config param eval_dataset_name: c4/en:3.0.1\n",
"Config param eval_interval: -1\n",
"Config param eval_per_device_batch_size: 0\n",
"Config param eval_split: validation\n",
"Config param expansion_factor_real_data: -1\n",
"Config param force_unroll: False\n",
"Config param fused_mlp: False\n",
"Config param fused_qkv: False\n",
"Config param gcs_metrics: False\n",
"Config param global_batch_size_to_load: 16\n",
"Config param global_batch_size_to_train_on: 16\n",
"Config param global_parameter_scale: 1\n",
"Config param goodput_upload_interval_seconds: 60\n",
"Config param gradient_clipping_threshold: 1.0\n",
"Config param grain_eval_files: \n",
"Config param grain_train_files: \n",
"Config param grain_worker_count: 1\n",
"Config param hardware: tpu\n",
"Config param head_dim: 128\n",
"Config param hf_access_token: \n",
"Config param hf_data_dir: \n",
"Config param hf_eval_files: \n",
"Config param hf_eval_split: \n",
"Config param hf_path: \n",
"Config param hf_train_files: \n",
"Config param ici_autoregressive_parallelism: 1\n",
"Config param ici_data_parallelism: 1\n",
"Config param ici_fsdp_parallelism: 16\n",
"Config param ici_fsdp_transpose_parallelism: 1\n",
"Config param ici_pipeline_parallelism: 1\n",
"Config param ici_sequence_parallelism: 1\n",
"Config param ici_tensor_parallelism: 1\n",
"Config param inference_metadata_file: \n",
"Config param inference_microbenchmark_log_file_path: \n",
"Config param inference_microbenchmark_loop_iters: 10\n",
"Config param inference_microbenchmark_prefill_lengths: 64,128,256,512,1024\n",
"Config param inference_microbenchmark_stages: prefill,generate\n",
"Config param init_weights_seed: 0\n",
"Config param jax_cache_dir: ~/jax_cache\n",
"Config param jax_profiler_port: 9999\n",
"Config param kv_quant_axis: heads_and_dkv\n",
"Config param kv_quant_dtype: int8\n",
"Config param learning_rate: 3e-05\n",
"Config param learning_rate_schedule_steps: 5\n",
"Config param load_from_prefill_dir: False\n",
"Config param load_full_state_path: \n",
"Config param load_parameters_path: \n",
"Config param local_checkpoint_directory: \n",
"Config param local_checkpoint_period: 0\n",
"Config param log_period: 100\n",
"Config param logical_axis_rules: (('activation_batch', ('data', 'fsdp', 'fsdp_transpose')), ('activation_embed_and_logits_batch', ('stage', 'data', 'fsdp', 'fsdp_transpose')), ('activation_heads', ('tensor', 'sequence')), ('activation_kv_heads', ('tensor', 'sequence')), ('activation_length', 'sequence'), ('activation_embed', 'tensor'), ('activation_mlp', 'tensor'), ('activation_kv', 'tensor'), ('activation_kv_batch', ('data', 'fsdp', 'fsdp_transpose')), ('activation_kv_head_dim', 'tensor'), ('activation_vocab', ('tensor', 'sequence')), ('activation_vocab', 'tensor'), ('activation_vocab', 'sequence'), ('activation_stage', 'stage'), ('mlp', ('fsdp_transpose', 'tensor', 'autoregressive')), ('vocab', ('tensor', 'autoregressive')), ('embed', ('fsdp', 'fsdp_transpose', 'sequence')), ('embed', ('fsdp', 'sequence')), ('heads', ('tensor', 'autoregressive')), ('layers', 'stage'), ('kv', ()), ('kv_heads', ('tensor', 'autoregressive')), ('kv_head_dim', ()), ('cache_batch', ()), ('cache_heads', ('autoregressive', 'tensor')), ('cache_kv', ()), ('cache_sequence', ()), ('norm', 'fsdp'))\n",
"Config param logits_dot_in_fp32: True\n",
"Config param logits_via_embedding: False\n",
"Config param max_checkify: False\n",
"Config param max_corpus_chars: 10000000\n",
"Config param max_prefill_predict_length: 64\n",
"Config param max_target_length: 4096\n",
"Config param megablox: True\n",
"Config param mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']\n",
"Config param metrics_dir: gs://opmusw4/ipp/maxtext/llama2-7b/demo-test/metrics/\n",
"Config param metrics_file: \n",
"Config param mlp_activations: ['silu', 'linear']\n",
"Config param mlp_dim: 11008\n",
"Config param model_name: llama2-7b\n",
"Config param monitor_goodput: False\n",
"Config param normalization_layer_epsilon: 1e-05\n",
"Config param normalize_embedding_logits: True\n",
"Config param num_decoder_layers: 32\n",
"Config param num_experts: 1\n",
"Config param num_experts_per_tok: 1\n",
"Config param num_kv_heads: 32\n",
"Config param num_layers_per_pipeline_stage: 1\n",
"Config param num_pipeline_microbatches: -1\n",
"Config param num_pipeline_repeats: -1\n",
"Config param num_query_heads: 32\n",
"Config param num_slices: 1\n",
"Config param opt_type: adamw\n",
"Config param param_scan_axis: 1\n",
"Config param per_device_batch_size: 1.0\n",
"Config param prefill_cache_axis_order: 1,2,0,3\n",
"Config param prefill_cache_dir: \n",
"Config param profiler: \n",
"Config param profiler_steps: 5\n",
"Config param prometheus_port: 0\n",
"Config param prompt: I love to\n",
"Config param quant_cfg_path: \n",
"Config param quantization: \n",
"Config param quantization_local_shard_count: 1\n",
"Config param quantize_kvcache: False\n",
"Config param record_internal_nn_metrics: 0\n",
"Config param remat_policy: full\n",
"Config param reshape_q: False\n",
"Config param reuse_example_batch: 0\n",
"Config param rope_max_timescale: 10000\n",
"Config param rope_min_timescale: 1\n",
"Config param run_name: demo-test\n",
"Config param save_config_to_gcs: False\n",
"Config param save_quantized_params_path: \n",
"Config param scan_layers: True\n",
"Config param scan_pipeline_iterations: True\n",
"Config param skip_first_n_steps_for_profiler: 1\n",
"Config param stack_trace_interval_seconds: 600\n",
"Config param stack_trace_to_cloud: False\n",
"Config param steps: 5\n",
"Config param target_eval_loss: 0.0\n",
"Config param tensorboard_dir: gs://opmusw4/ipp/maxtext/llama2-7b/demo-test/tensorboard/\n",
"Config param tokenizer_path: assets/tokenizer.llama2\n",
"Config param trainable_position_size: -1\n",
"Config param upload_all_profiler_results: False\n",
"Config param use_iota_embed: False\n",
"Config param use_untrainable_positional_embedding: False\n",
"Config param use_vertex_tensorboard: False\n",
"Config param using_pipeline_parallelism: False\n",
"Config param vertex_tensorboard_project: \n",
"Config param vertex_tensorboard_region: \n",
"Config param vocab_size: 32000\n",
"Config param warmup_steps_fraction: 0.1\n",
"Config param weight_dtype: float32\n",
"Num_devices: 16, shape (1, 1, 16, 1, 1, 1, 1)\n",
"Setting up checkpoint logger...\n",
"Checkpointing disabled, not creating checkpoint manager.\n",
"Tokenizer path: assets/tokenizer.llama2\n",
"Tokenizer path: assets/tokenizer.llama2\n",
"No existing checkpoints found, not restoring checkpoint.\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"[stdout:1] Updating keys from env and command line: ['run_name', 'model_name', 'enable_checkpointing', 'remat_policy', 'attention', 'base_output_directory', 'ici_fsdp_parallelism', 'ici_fsdp_transpose_parallelism', 'ici_tensor_parallelism', 'per_device_batch_size', 'dataset_type', 'dataset_path', 'steps', 'max_target_length']\n",
"Running Model: llama2-7b\n",
"Updating following parameters in config\n",
"\n",
"base_emb_dim: 4096\n",
"base_num_query_heads: 32\n",
"base_num_kv_heads: 32\n",
"base_mlp_dim: 11008\n",
"base_num_decoder_layers: 32\n",
"head_dim: 128\n",
"mlp_activations: ['silu', 'linear']\n",
"vocab_size: 32000\n",
"enable_dropout: False\n",
"logits_via_embedding: False\n",
"normalization_layer_epsilon: 1e-05\n",
"decoder_block: llama2\n",
"logical_axis_rules: [['norm', 'fsdp']]\n",
"Updating keys from model: ['base_emb_dim', 'base_num_query_heads', 'base_num_kv_heads', 'base_mlp_dim', 'base_num_decoder_layers', 'head_dim', 'mlp_activations', 'vocab_size', 'enable_dropout', 'logits_via_embedding', 'normalization_layer_epsilon', 'decoder_block', 'logical_axis_rules']\n",
"System Information: Jax Version: 0.4.30\n",
"System Information: Jaxlib Version: 0.4.30\n",
"System Information: Jax Backend: PJRT C API\n",
"TFRT TPU v5 lite\n",
"Built on Jun 17 2024 03:03:47 (1718618627) cl/643897370\n",
"Not using emergency checkpoint, ignoring local_checkpoint_directory and local_checkpoint_period\n",
"dataset_type set to tfds, will use keys['dataset_path']='gs://maxtext-dataset' and keys['dataset_name']='c4/en:3.0.1'\n",
"Config param adam_b1: 0.9\n",
"Config param adam_b2: 0.95\n",
"Config param adam_eps: 1e-08\n",
"Config param adam_eps_root: 0.0\n",
"Config param adam_weight_decay: 0.1\n",
"Config param allow_split_physical_axes: False\n",
"Config param ar_cache_axis_order: 1,2,0,3\n",
"Config param async_checkpointing: True\n",
"Config param attention: dot_product\n",
"Config param autoregressive_decode_assert: \n",
"Config param base_emb_dim: 4096\n",
"Config param base_mlp_dim: 11008\n",
"Config param base_num_decoder_layers: 32\n",
"Config param base_num_kv_heads: 32\n",
"Config param base_num_query_heads: 32\n",
"Config param base_output_directory: gs://opmusw4/ipp/maxtext/llama2-7b/\n",
"Config param checkpoint_dir: gs://opmusw4/ipp/maxtext/llama2-7b/demo-test/checkpoints/\n",
"Config param checkpoint_is_quantized: False\n",
"Config param checkpoint_period: 10000\n",
"Config param collect_stack_trace: False\n",
"Config param compile_topology: \n",
"Config param compile_topology_num_slices: -1\n",
"Config param compiled_trainstep_file: \n",
"Config param compute_axis_order: 0,1,2,3\n",
"Config param cosine_learning_rate_final_fraction: 0.1\n",
"Config param data_sharding: (('data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive'),)\n",
"Config param data_shuffle_seed: 0\n",
"Config param dataset_name: c4/en:3.0.1\n",
"Config param dataset_path: gs://maxtext-dataset\n",
"Config param dataset_type: tfds\n",
"Config param dcn_autoregressive_parallelism: 1\n",
"Config param dcn_data_parallelism: -1\n",
"Config param dcn_fsdp_parallelism: 1\n",
"Config param dcn_fsdp_transpose_parallelism: 1\n",
"Config param dcn_pipeline_parallelism: 1\n",
"Config param dcn_sequence_parallelism: 1\n",
"Config param dcn_tensor_parallelism: 1\n",
"Config param decode_sampling_nucleus_p: -1\n",
"Config param decode_sampling_strategy: greedy\n",
"Config param decode_sampling_temperature: 1.0\n",
"Config param decode_sampling_top_k: 0\n",
"Config param decoder_block: llama2\n",
"Config param dropout_rate: 0\n",
"Config param dtype: bfloat16\n",
"Config param emb_dim: 4096\n",
"Config param enable_checkpoint_cloud_logger: False\n",
"Config param enable_checkpoint_standard_logger: False\n",
"Config param enable_checkpointing: False\n",
"Config param enable_data_shuffling: True\n",
"Config param enable_dropout: False\n",
"Config param enable_emergency_checkpoint: False\n",
"Config param enable_goodput_recording: False\n",
"Config param enable_jax_profiler: False\n",
"Config param enable_single_controller: False\n",
"Config param enable_single_replica_ckpt_restoring: False\n",
"Config param eval_batch_num: -1\n",
"Config param eval_dataset_name: c4/en:3.0.1\n",
"Config param eval_interval: -1\n",
"Config param eval_per_device_batch_size: 0\n",
"Config param eval_split: validation\n",
"Config param expansion_factor_real_data: -1\n",
"Config param force_unroll: False\n",
"Config param fused_mlp: False\n",
"Config param fused_qkv: False\n",
"Config param gcs_metrics: False\n",
"Config param global_batch_size_to_load: 16\n",
"Config param global_batch_size_to_train_on: 16\n",
"Config param global_parameter_scale: 1\n",
"Config param goodput_upload_interval_seconds: 60\n",
"Config param gradient_clipping_threshold: 1.0\n",
"Config param grain_eval_files: \n",
"Config param grain_train_files: \n",
"Config param grain_worker_count: 1\n",
"Config param hardware: tpu\n",
"Config param head_dim: 128\n",
"Config param hf_access_token: \n",
"Config param hf_data_dir: \n",
"Config param hf_eval_files: \n",
"Config param hf_eval_split: \n",
"Config param hf_path: \n",
"Config param hf_train_files: \n",
"Config param ici_autoregressive_parallelism: 1\n",
"Config param ici_data_parallelism: 1\n",
"Config param ici_fsdp_parallelism: 16\n",
"Config param ici_fsdp_transpose_parallelism: 1\n",
"Config param ici_pipeline_parallelism: 1\n",
"Config param ici_sequence_parallelism: 1\n",
"Config param ici_tensor_parallelism: 1\n",
"Config param inference_metadata_file: \n",
"Config param inference_microbenchmark_log_file_path: \n",
"Config param inference_microbenchmark_loop_iters: 10\n",
"Config param inference_microbenchmark_prefill_lengths: 64,128,256,512,1024\n",
"Config param inference_microbenchmark_stages: prefill,generate\n",
"Config param init_weights_seed: 0\n",
"Config param jax_cache_dir: ~/jax_cache\n",
"Config param jax_profiler_port: 9999\n",
"Config param kv_quant_axis: heads_and_dkv\n",
"Config param kv_quant_dtype: int8\n",
"Config param learning_rate: 3e-05\n",
"Config param learning_rate_schedule_steps: 5\n",
"Config param load_from_prefill_dir: False\n",
"Config param load_full_state_path: \n",
"Config param load_parameters_path: \n",
"Config param local_checkpoint_directory: \n",
"Config param local_checkpoint_period: 0\n",
"Config param log_period: 100\n",
"Config param logical_axis_rules: (('activation_batch', ('data', 'fsdp', 'fsdp_transpose')), ('activation_embed_and_logits_batch', ('stage', 'data', 'fsdp', 'fsdp_transpose')), ('activation_heads', ('tensor', 'sequence')), ('activation_kv_heads', ('tensor', 'sequence')), ('activation_length', 'sequence'), ('activation_embed', 'tensor'), ('activation_mlp', 'tensor'), ('activation_kv', 'tensor'), ('activation_kv_batch', ('data', 'fsdp', 'fsdp_transpose')), ('activation_kv_head_dim', 'tensor'), ('activation_vocab', ('tensor', 'sequence')), ('activation_vocab', 'tensor'), ('activation_vocab', 'sequence'), ('activation_stage', 'stage'), ('mlp', ('fsdp_transpose', 'tensor', 'autoregressive')), ('vocab', ('tensor', 'autoregressive')), ('embed', ('fsdp', 'fsdp_transpose', 'sequence')), ('embed', ('fsdp', 'sequence')), ('heads', ('tensor', 'autoregressive')), ('layers', 'stage'), ('kv', ()), ('kv_heads', ('tensor', 'autoregressive')), ('kv_head_dim', ()), ('cache_batch', ()), ('cache_heads', ('autoregressive', 'tensor')), ('cache_kv', ()), ('cache_sequence', ()), ('norm', 'fsdp'))\n",
"Config param logits_dot_in_fp32: True\n",
"Config param logits_via_embedding: False\n",
"Config param max_checkify: False\n",
"Config param max_corpus_chars: 10000000\n",
"Config param max_prefill_predict_length: 64\n",
"Config param max_target_length: 4096\n",
"Config param megablox: True\n",
"Config param mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']\n",
"Config param metrics_dir: gs://opmusw4/ipp/maxtext/llama2-7b/demo-test/metrics/\n",
"Config param metrics_file: \n",
"Config param mlp_activations: ['silu', 'linear']\n",
"Config param mlp_dim: 11008\n",
"Config param model_name: llama2-7b\n",
"Config param monitor_goodput: False\n",
"Config param normalization_layer_epsilon: 1e-05\n",
"Config param normalize_embedding_logits: True\n",
"Config param num_decoder_layers: 32\n",
"Config param num_experts: 1\n",
"Config param num_experts_per_tok: 1\n",
"Config param num_kv_heads: 32\n",
"Config param num_layers_per_pipeline_stage: 1\n",
"Config param num_pipeline_microbatches: -1\n",
"Config param num_pipeline_repeats: -1\n",
"Config param num_query_heads: 32\n",
"Config param num_slices: 1\n",
"Config param opt_type: adamw\n",
"Config param param_scan_axis: 1\n",
"Config param per_device_batch_size: 1.0\n",
"Config param prefill_cache_axis_order: 1,2,0,3\n",
"Config param prefill_cache_dir: \n",
"Config param profiler: \n",
"Config param profiler_steps: 5\n",
"Config param prometheus_port: 0\n",
"Config param prompt: I love to\n",
"Config param quant_cfg_path: \n",
"Config param quantization: \n",
"Config param quantization_local_shard_count: 1\n",
"Config param quantize_kvcache: False\n",
"Config param record_internal_nn_metrics: 0\n",
"Config param remat_policy: full\n",
"Config param reshape_q: False\n",
"Config param reuse_example_batch: 0\n",
"Config param rope_max_timescale: 10000\n",
"Config param rope_min_timescale: 1\n",
"Config param run_name: demo-test\n",
"Config param save_config_to_gcs: False\n",
"Config param save_quantized_params_path: \n",
"Config param scan_layers: True\n",
"Config param scan_pipeline_iterations: True\n",
"Config param skip_first_n_steps_for_profiler: 1\n",
"Config param stack_trace_interval_seconds: 600\n",
"Config param stack_trace_to_cloud: False\n",
"Config param steps: 5\n",
"Config param target_eval_loss: 0.0\n",
"Config param tensorboard_dir: gs://opmusw4/ipp/maxtext/llama2-7b/demo-test/tensorboard/\n",
"Config param tokenizer_path: assets/tokenizer.llama2\n",
"Config param trainable_position_size: -1\n",
"Config param upload_all_profiler_results: False\n",
"Config param use_iota_embed: False\n",
"Config param use_untrainable_positional_embedding: False\n",
"Config param use_vertex_tensorboard: False\n",
"Config param using_pipeline_parallelism: False\n",
"Config param vertex_tensorboard_project: \n",
"Config param vertex_tensorboard_region: \n",
"Config param vocab_size: 32000\n",
"Config param warmup_steps_fraction: 0.1\n",
"Config param weight_dtype: float32\n",
"Num_devices: 16, shape (1, 1, 16, 1, 1, 1, 1)\n",
"Setting up checkpoint logger...\n",
"Checkpointing disabled, not creating checkpoint manager.\n",
"Tokenizer path: assets/tokenizer.llama2\n",
"Tokenizer path: assets/tokenizer.llama2\n",
"No existing checkpoints found, not restoring checkpoint.\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"[stdout:3] Updating keys from env and command line: ['run_name', 'model_name', 'enable_checkpointing', 'remat_policy', 'attention', 'base_output_directory', 'ici_fsdp_parallelism', 'ici_fsdp_transpose_parallelism', 'ici_tensor_parallelism', 'per_device_batch_size', 'dataset_type', 'dataset_path', 'steps', 'max_target_length']\n",
"Running Model: llama2-7b\n",
"Updating following parameters in config\n",
"\n",
"base_emb_dim: 4096\n",
"base_num_query_heads: 32\n",
"base_num_kv_heads: 32\n",
"base_mlp_dim: 11008\n",
"base_num_decoder_layers: 32\n",
"head_dim: 128\n",
"mlp_activations: ['silu', 'linear']\n",
"vocab_size: 32000\n",
"enable_dropout: False\n",
"logits_via_embedding: False\n",
"normalization_layer_epsilon: 1e-05\n",
"decoder_block: llama2\n",
"logical_axis_rules: [['norm', 'fsdp']]\n",
"Updating keys from model: ['base_emb_dim', 'base_num_query_heads', 'base_num_kv_heads', 'base_mlp_dim', 'base_num_decoder_layers', 'head_dim', 'mlp_activations', 'vocab_size', 'enable_dropout', 'logits_via_embedding', 'normalization_layer_epsilon', 'decoder_block', 'logical_axis_rules']\n",
"System Information: Jax Version: 0.4.30\n",
"System Information: Jaxlib Version: 0.4.30\n",
"System Information: Jax Backend: PJRT C API\n",
"TFRT TPU v5 lite\n",
"Built on Jun 17 2024 03:03:47 (1718618627) cl/643897370\n",
"Not using emergency checkpoint, ignoring local_checkpoint_directory and local_checkpoint_period\n",
"dataset_type set to tfds, will use keys['dataset_path']='gs://maxtext-dataset' and keys['dataset_name']='c4/en:3.0.1'\n",
"Config param adam_b1: 0.9\n",
"Config param adam_b2: 0.95\n",
"Config param adam_eps: 1e-08\n",
"Config param adam_eps_root: 0.0\n",
"Config param adam_weight_decay: 0.1\n",
"Config param allow_split_physical_axes: False\n",
"Config param ar_cache_axis_order: 1,2,0,3\n",
"Config param async_checkpointing: True\n",
"Config param attention: dot_product\n",
"Config param autoregressive_decode_assert: \n",
"Config param base_emb_dim: 4096\n",
"Config param base_mlp_dim: 11008\n",
"Config param base_num_decoder_layers: 32\n",
"Config param base_num_kv_heads: 32\n",
"Config param base_num_query_heads: 32\n",
"Config param base_output_directory: gs://opmusw4/ipp/maxtext/llama2-7b/\n",
"Config param checkpoint_dir: gs://opmusw4/ipp/maxtext/llama2-7b/demo-test/checkpoints/\n",
"Config param checkpoint_is_quantized: False\n",
"Config param checkpoint_period: 10000\n",
"Config param collect_stack_trace: False\n",
"Config param compile_topology: \n",
"Config param compile_topology_num_slices: -1\n",
"Config param compiled_trainstep_file: \n",
"Config param compute_axis_order: 0,1,2,3\n",
"Config param cosine_learning_rate_final_fraction: 0.1\n",
"Config param data_sharding: (('data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive'),)\n",
"Config param data_shuffle_seed: 0\n",
"Config param dataset_name: c4/en:3.0.1\n",
"Config param dataset_path: gs://maxtext-dataset\n",
"Config param dataset_type: tfds\n",
"Config param dcn_autoregressive_parallelism: 1\n",
"Config param dcn_data_parallelism: -1\n",
"Config param dcn_fsdp_parallelism: 1\n",
"Config param dcn_fsdp_transpose_parallelism: 1\n",
"Config param dcn_pipeline_parallelism: 1\n",
"Config param dcn_sequence_parallelism: 1\n",
"Config param dcn_tensor_parallelism: 1\n",
"Config param decode_sampling_nucleus_p: -1\n",
"Config param decode_sampling_strategy: greedy\n",
"Config param decode_sampling_temperature: 1.0\n",
"Config param decode_sampling_top_k: 0\n",
"Config param decoder_block: llama2\n",
"Config param dropout_rate: 0\n",
"Config param dtype: bfloat16\n",
"Config param emb_dim: 4096\n",
"Config param enable_checkpoint_cloud_logger: False\n",
"Config param enable_checkpoint_standard_logger: False\n",
"Config param enable_checkpointing: False\n",
"Config param enable_data_shuffling: True\n",
"Config param enable_dropout: False\n",
"Config param enable_emergency_checkpoint: False\n",
"Config param enable_goodput_recording: False\n",
"Config param enable_jax_profiler: False\n",
"Config param enable_single_controller: False\n",
"Config param enable_single_replica_ckpt_restoring: False\n",
"Config param eval_batch_num: -1\n",
"Config param eval_dataset_name: c4/en:3.0.1\n",
"Config param eval_interval: -1\n",
"Config param eval_per_device_batch_size: 0\n",
"Config param eval_split: validation\n",
"Config param expansion_factor_real_data: -1\n",
"Config param force_unroll: False\n",
"Config param fused_mlp: False\n",
"Config param fused_qkv: False\n",
"Config param gcs_metrics: False\n",
"Config param global_batch_size_to_load: 16\n",
"Config param global_batch_size_to_train_on: 16\n",
"Config param global_parameter_scale: 1\n",
"Config param goodput_upload_interval_seconds: 60\n",
"Config param gradient_clipping_threshold: 1.0\n",
"Config param grain_eval_files: \n",
"Config param grain_train_files: \n",
"Config param grain_worker_count: 1\n",
"Config param hardware: tpu\n",
"Config param head_dim: 128\n",
"Config param hf_access_token: \n",
"Config param hf_data_dir: \n",
"Config param hf_eval_files: \n",
"Config param hf_eval_split: \n",
"Config param hf_path: \n",
"Config param hf_train_files: \n",
"Config param ici_autoregressive_parallelism: 1\n",
"Config param ici_data_parallelism: 1\n",
"Config param ici_fsdp_parallelism: 16\n",
"Config param ici_fsdp_transpose_parallelism: 1\n",
"Config param ici_pipeline_parallelism: 1\n",
"Config param ici_sequence_parallelism: 1\n",
"Config param ici_tensor_parallelism: 1\n",
"Config param inference_metadata_file: \n",
"Config param inference_microbenchmark_log_file_path: \n",
"Config param inference_microbenchmark_loop_iters: 10\n",
"Config param inference_microbenchmark_prefill_lengths: 64,128,256,512,1024\n",
"Config param inference_microbenchmark_stages: prefill,generate\n",
"Config param init_weights_seed: 0\n",
"Config param jax_cache_dir: ~/jax_cache\n",
"Config param jax_profiler_port: 9999\n",
"Config param kv_quant_axis: heads_and_dkv\n",
"Config param kv_quant_dtype: int8\n",
"Config param learning_rate: 3e-05\n",
"Config param learning_rate_schedule_steps: 5\n",
"Config param load_from_prefill_dir: False\n",
"Config param load_full_state_path: \n",
"Config param load_parameters_path: \n",
"Config param local_checkpoint_directory: \n",
"Config param local_checkpoint_period: 0\n",
"Config param log_period: 100\n",
"Config param logical_axis_rules: (('activation_batch', ('data', 'fsdp', 'fsdp_transpose')), ('activation_embed_and_logits_batch', ('stage', 'data', 'fsdp', 'fsdp_transpose')), ('activation_heads', ('tensor', 'sequence')), ('activation_kv_heads', ('tensor', 'sequence')), ('activation_length', 'sequence'), ('activation_embed', 'tensor'), ('activation_mlp', 'tensor'), ('activation_kv', 'tensor'), ('activation_kv_batch', ('data', 'fsdp', 'fsdp_transpose')), ('activation_kv_head_dim', 'tensor'), ('activation_vocab', ('tensor', 'sequence')), ('activation_vocab', 'tensor'), ('activation_vocab', 'sequence'), ('activation_stage', 'stage'), ('mlp', ('fsdp_transpose', 'tensor', 'autoregressive')), ('vocab', ('tensor', 'autoregressive')), ('embed', ('fsdp', 'fsdp_transpose', 'sequence')), ('embed', ('fsdp', 'sequence')), ('heads', ('tensor', 'autoregressive')), ('layers', 'stage'), ('kv', ()), ('kv_heads', ('tensor', 'autoregressive')), ('kv_head_dim', ()), ('cache_batch', ()), ('cache_heads', ('autoregressive', 'tensor')), ('cache_kv', ()), ('cache_sequence', ()), ('norm', 'fsdp'))\n",
"Config param logits_dot_in_fp32: True\n",
"Config param logits_via_embedding: False\n",
"Config param max_checkify: False\n",
"Config param max_corpus_chars: 10000000\n",
"Config param max_prefill_predict_length: 64\n",
"Config param max_target_length: 4096\n",
"Config param megablox: True\n",
"Config param mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']\n",
"Config param metrics_dir: gs://opmusw4/ipp/maxtext/llama2-7b/demo-test/metrics/\n",
"Config param metrics_file: \n",
"Config param mlp_activations: ['silu', 'linear']\n",
"Config param mlp_dim: 11008\n",
"Config param model_name: llama2-7b\n",
"Config param monitor_goodput: False\n",
"Config param normalization_layer_epsilon: 1e-05\n",
"Config param normalize_embedding_logits: True\n",
"Config param num_decoder_layers: 32\n",
"Config param num_experts: 1\n",
"Config param num_experts_per_tok: 1\n",
"Config param num_kv_heads: 32\n",
"Config param num_layers_per_pipeline_stage: 1\n",
"Config param num_pipeline_microbatches: -1\n",
"Config param num_pipeline_repeats: -1\n",
"Config param num_query_heads: 32\n",
"Config param num_slices: 1\n",
"Config param opt_type: adamw\n",
"Config param param_scan_axis: 1\n",
"Config param per_device_batch_size: 1.0\n",
"Config param prefill_cache_axis_order: 1,2,0,3\n",
"Config param prefill_cache_dir: \n",
"Config param profiler: \n",
"Config param profiler_steps: 5\n",
"Config param prometheus_port: 0\n",
"Config param prompt: I love to\n",
"Config param quant_cfg_path: \n",
"Config param quantization: \n",
"Config param quantization_local_shard_count: 1\n",
"Config param quantize_kvcache: False\n",
"Config param record_internal_nn_metrics: 0\n",
"Config param remat_policy: full\n",
"Config param reshape_q: False\n",
"Config param reuse_example_batch: 0\n",
"Config param rope_max_timescale: 10000\n",
"Config param rope_min_timescale: 1\n",
"Config param run_name: demo-test\n",
"Config param save_config_to_gcs: False\n",
"Config param save_quantized_params_path: \n",
"Config param scan_layers: True\n",
"Config param scan_pipeline_iterations: True\n",
"Config param skip_first_n_steps_for_profiler: 1\n",
"Config param stack_trace_interval_seconds: 600\n",
"Config param stack_trace_to_cloud: False\n",
"Config param steps: 5\n",
"Config param target_eval_loss: 0.0\n",
"Config param tensorboard_dir: gs://opmusw4/ipp/maxtext/llama2-7b/demo-test/tensorboard/\n",
"Config param tokenizer_path: assets/tokenizer.llama2\n",
"Config param trainable_position_size: -1\n",
"Config param upload_all_profiler_results: False\n",
"Config param use_iota_embed: False\n",
"Config param use_untrainable_positional_embedding: False\n",
"Config param use_vertex_tensorboard: False\n",
"Config param using_pipeline_parallelism: False\n",
"Config param vertex_tensorboard_project: \n",
"Config param vertex_tensorboard_region: \n",
"Config param vocab_size: 32000\n",
"Config param warmup_steps_fraction: 0.1\n",
"Config param weight_dtype: float32\n",
"Num_devices: 16, shape (1, 1, 16, 1, 1, 1, 1)\n",
"Setting up checkpoint logger...\n",
"Checkpointing disabled, not creating checkpoint manager.\n",
"Tokenizer path: assets/tokenizer.llama2\n",
"Tokenizer path: assets/tokenizer.llama2\n",
"No existing checkpoints found, not restoring checkpoint.\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"[stderr:0] /usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"[stderr:1] /usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"[stderr:3] /usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"[stderr:2] /usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"%px: 100%|██████████| 4/4 [00:01<00:00, 2.82tasks/s]\n"
]
}
],
"source": [
"%%px --block --group-outputs=engine\n",
"\n",
"config, init_rng, data_iterator, state, p_train_step, mesh = (\n",
" config_llama7b_model(\n",
" ici_fsdp_parallelism=16,\n",
" ici_fsdp_transpose_parallelism=1,\n",
" ici_tensor_parallelism=1,\n",
" per_device_batch_size=1\n",
" )\n",
")"
]
},
{
"cell_type": "markdown",
"id": "ecf5ea9f-9606-4592-9cf3-65e592998d36",
"metadata": {},
"source": [
"Ideally, we'd see our step time go down and our loss stay substantially similar."
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "f04ea0ff-62ea-48a8-9e87-64b06ad874cf",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[stdout:1] Per train step:\n",
" Total TFLOPs: 188.76 \n",
" split as 86.02% learnable weight flops and 13.98% attention flops\n",
" step= 0 loss=10.82464 step_time=4.40s\n",
" step= 1 loss=10.74638 step_time=2.94s TFLOP/s/device:64.161, eMFU: 32.57%\n",
" step= 2 loss=11.10725 step_time=2.11s TFLOP/s/device:89.570, eMFU: 45.47%\n",
" step= 3 loss=10.73088 step_time=2.10s TFLOP/s/device:89.728, eMFU: 45.55%\n",
" step= 4 loss=10.04863 step_time=2.11s TFLOP/s/device:89.486, eMFU: 45.42%\n",
" step= 5 loss=9.77938 step_time=2.10s TFLOP/s/device:89.793, eMFU: 45.58%\n",
" step= 6 loss=9.58473 step_time=2.10s TFLOP/s/device:89.756, eMFU: 45.56%\n",
" step= 7 loss=9.43211 step_time=2.10s TFLOP/s/device:89.774, eMFU: 45.57%\n",
" step= 8 loss=8.99658 step_time=2.10s TFLOP/s/device:89.742, eMFU: 45.55%\n",
" step= 9 loss=9.29634 step_time=2.10s TFLOP/s/device:89.747, eMFU: 45.56%\n",
" step= 10 loss=12.73707 step_time=2.10s TFLOP/s/device:89.712, eMFU: 45.54%\n",
" step= 11 loss=10.97237 step_time=2.10s TFLOP/s/device:89.750, eMFU: 45.56%\n",
" step= 12 loss=9.27653 step_time=2.10s TFLOP/s/device:89.800, eMFU: 45.58%\n",
" step= 13 loss=8.43188 step_time=2.10s TFLOP/s/device:89.797, eMFU: 45.58%\n",
" step= 14 loss=8.23481 step_time=2.10s TFLOP/s/device:89.808, eMFU: 45.59%\n",
" step= 15 loss=8.13170 step_time=2.10s TFLOP/s/device:89.746, eMFU: 45.56%\n",
" step= 16 loss=8.20371 step_time=2.10s TFLOP/s/device:89.736, eMFU: 45.55%\n",
" step= 17 loss=8.12382 step_time=2.11s TFLOP/s/device:89.624, eMFU: 45.49%\n",
" step= 18 loss=8.03367 step_time=2.10s TFLOP/s/device:89.809, eMFU: 45.59%\n",
" step= 19 loss=8.01734 step_time=2.10s TFLOP/s/device:89.730, eMFU: 45.55%\n",
"Parallelism (fsdp=16, transpose=1, tensor=1); 20 steps, loss=8.01734, step_time=2.10s, eMFU=45.55%\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"[stdout:2] Per train step:\n",
" Total TFLOPs: 188.76 \n",
" split as 86.02% learnable weight flops and 13.98% attention flops\n",
" step= 0 loss=10.82464 step_time=5.23s\n",
" step= 1 loss=10.74638 step_time=2.11s TFLOP/s/device:89.355, eMFU: 45.36%\n",
" step= 2 loss=11.10725 step_time=2.11s TFLOP/s/device:89.633, eMFU: 45.50%\n",
" step= 3 loss=10.73088 step_time=2.10s TFLOP/s/device:89.774, eMFU: 45.57%\n",
" step= 4 loss=10.04863 step_time=2.10s TFLOP/s/device:89.723, eMFU: 45.54%\n",
" step= 5 loss=9.77938 step_time=2.11s TFLOP/s/device:89.447, eMFU: 45.40%\n",
" step= 6 loss=9.58473 step_time=2.10s TFLOP/s/device:89.821, eMFU: 45.59%\n",
" step= 7 loss=9.43211 step_time=2.10s TFLOP/s/device:89.701, eMFU: 45.53%\n",
" step= 8 loss=8.99658 step_time=2.10s TFLOP/s/device:89.750, eMFU: 45.56%\n",
" step= 9 loss=9.29634 step_time=2.10s TFLOP/s/device:89.744, eMFU: 45.56%\n",
" step= 10 loss=12.73707 step_time=2.10s TFLOP/s/device:89.786, eMFU: 45.58%\n",
" step= 11 loss=10.97237 step_time=2.10s TFLOP/s/device:89.776, eMFU: 45.57%\n",
" step= 12 loss=9.27653 step_time=2.10s TFLOP/s/device:89.769, eMFU: 45.57%\n",
" step= 13 loss=8.43188 step_time=2.10s TFLOP/s/device:89.786, eMFU: 45.58%\n",
" step= 14 loss=8.23481 step_time=2.10s TFLOP/s/device:89.805, eMFU: 45.59%\n",
" step= 15 loss=8.13170 step_time=2.10s TFLOP/s/device:89.793, eMFU: 45.58%\n",
" step= 16 loss=8.20371 step_time=2.10s TFLOP/s/device:89.768, eMFU: 45.57%\n",
" step= 17 loss=8.12382 step_time=2.11s TFLOP/s/device:89.578, eMFU: 45.47%\n",
" step= 18 loss=8.03367 step_time=2.10s TFLOP/s/device:89.723, eMFU: 45.54%\n",
" step= 19 loss=8.01734 step_time=2.10s TFLOP/s/device:89.769, eMFU: 45.57%\n",
"Parallelism (fsdp=16, transpose=1, tensor=1); 20 steps, loss=8.01734, step_time=2.10s, eMFU=45.57%\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"[stdout:3] Per train step:\n",
" Total TFLOPs: 188.76 \n",
" split as 86.02% learnable weight flops and 13.98% attention flops\n",
" step= 0 loss=10.82464 step_time=4.88s\n",
" step= 1 loss=10.74638 step_time=2.46s TFLOP/s/device:76.796, eMFU: 38.98%\n",
" step= 2 loss=11.10725 step_time=2.11s TFLOP/s/device:89.405, eMFU: 45.38%\n",
" step= 3 loss=10.73088 step_time=2.10s TFLOP/s/device:89.779, eMFU: 45.57%\n",
" step= 4 loss=10.04863 step_time=2.10s TFLOP/s/device:89.769, eMFU: 45.57%\n",
" step= 5 loss=9.77938 step_time=2.11s TFLOP/s/device:89.423, eMFU: 45.39%\n",
" step= 6 loss=9.58473 step_time=2.10s TFLOP/s/device:89.775, eMFU: 45.57%\n",
" step= 7 loss=9.43211 step_time=2.10s TFLOP/s/device:89.827, eMFU: 45.60%\n",
" step= 8 loss=8.99658 step_time=2.10s TFLOP/s/device:89.734, eMFU: 45.55%\n",
" step= 9 loss=9.29634 step_time=2.10s TFLOP/s/device:89.738, eMFU: 45.55%\n",
" step= 10 loss=12.73707 step_time=2.10s TFLOP/s/device:89.720, eMFU: 45.54%\n",
" step= 11 loss=10.97237 step_time=2.10s TFLOP/s/device:89.778, eMFU: 45.57%\n",
" step= 12 loss=9.27653 step_time=2.10s TFLOP/s/device:89.788, eMFU: 45.58%\n",
" step= 13 loss=8.43188 step_time=2.10s TFLOP/s/device:89.776, eMFU: 45.57%\n",
" step= 14 loss=8.23481 step_time=2.10s TFLOP/s/device:89.808, eMFU: 45.59%\n",
" step= 15 loss=8.13170 step_time=2.10s TFLOP/s/device:89.740, eMFU: 45.55%\n",
" step= 16 loss=8.20371 step_time=2.11s TFLOP/s/device:89.565, eMFU: 45.46%\n",
" step= 17 loss=8.12382 step_time=2.10s TFLOP/s/device:89.822, eMFU: 45.60%\n",
" step= 18 loss=8.03367 step_time=2.10s TFLOP/s/device:89.797, eMFU: 45.58%\n",
" step= 19 loss=8.01734 step_time=2.10s TFLOP/s/device:89.759, eMFU: 45.56%\n",
"Parallelism (fsdp=16, transpose=1, tensor=1); 20 steps, loss=8.01734, step_time=2.10s, eMFU=45.56%\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"[stdout:0] Per train step:\n",
" Total TFLOPs: 188.76 \n",
" split as 86.02% learnable weight flops and 13.98% attention flops\n",
" step= 0 loss=10.82464 step_time=4.53s\n",
" step= 1 loss=10.74638 step_time=2.81s TFLOP/s/device:67.263, eMFU: 34.14%\n",
" step= 2 loss=11.10725 step_time=2.11s TFLOP/s/device:89.558, eMFU: 45.46%\n",
" step= 3 loss=10.73088 step_time=2.10s TFLOP/s/device:89.721, eMFU: 45.54%\n",
" step= 4 loss=10.04863 step_time=2.10s TFLOP/s/device:89.794, eMFU: 45.58%\n",
" step= 5 loss=9.77938 step_time=2.11s TFLOP/s/device:89.477, eMFU: 45.42%\n",
" step= 6 loss=9.58473 step_time=2.10s TFLOP/s/device:89.753, eMFU: 45.56%\n",
" step= 7 loss=9.43211 step_time=2.10s TFLOP/s/device:89.766, eMFU: 45.57%\n",
" step= 8 loss=8.99658 step_time=2.10s TFLOP/s/device:89.755, eMFU: 45.56%\n",
" step= 9 loss=9.29634 step_time=2.10s TFLOP/s/device:89.749, eMFU: 45.56%\n",
" step= 10 loss=12.73707 step_time=2.10s TFLOP/s/device:89.715, eMFU: 45.54%\n",
" step= 11 loss=10.97237 step_time=2.10s TFLOP/s/device:89.787, eMFU: 45.58%\n",
" step= 12 loss=9.27653 step_time=2.10s TFLOP/s/device:89.732, eMFU: 45.55%\n",
" step= 13 loss=8.43188 step_time=2.10s TFLOP/s/device:89.830, eMFU: 45.60%\n",
" step= 14 loss=8.23481 step_time=2.10s TFLOP/s/device:89.760, eMFU: 45.56%\n",
" step= 15 loss=8.13170 step_time=2.10s TFLOP/s/device:89.787, eMFU: 45.58%\n",
" step= 16 loss=8.20371 step_time=2.10s TFLOP/s/device:89.810, eMFU: 45.59%\n",
" step= 17 loss=8.12382 step_time=2.11s TFLOP/s/device:89.561, eMFU: 45.46%\n",
" step= 18 loss=8.03367 step_time=2.10s TFLOP/s/device:89.811, eMFU: 45.59%\n",
" step= 19 loss=8.01734 step_time=2.10s TFLOP/s/device:89.753, eMFU: 45.56%\n",
"Parallelism (fsdp=16, transpose=1, tensor=1); 20 steps, loss=8.01734, step_time=2.10s, eMFU=45.56%\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"%px: 0%| | 0/4 [00:00<?, ?tasks/s]"
]
},
{
"data": {
"text/plain": [
"[stderr:1] /usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"[stderr:0] /usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"%px: 0%| | 0/4 [00:01<?, ?tasks/s]"
]
},
{
"data": {
"text/plain": [
"[stderr:3] /usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"%px: 0%| | 0/4 [00:01<?, ?tasks/s]"
]
},
{
"data": {
"text/plain": [
"[stderr:2] /usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/site-packages/jax/_src/core.py:678: FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.\n",
" warnings.warn(\n"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"%px: 100%|██████████| 4/4 [00:45<00:00, 11.33s/tasks]\n"
]
}
],
"source": [
"%%px --block --group-outputs=engine\n",
"\n",
"avg_tflops_per_devices, state = train(\n",
" config, init_rng, data_iterator, state, p_train_step, mesh\n",
")"
]
},
{
"cell_type": "markdown",
"id": "363c61c8-97ef-49ab-9611-b7ecd055b016",
"metadata": {},
"source": [
"MaxText (naively) estimates a similar effort as the last time:\n",
"```\n",
"Per train step:\n",
" Total TFLOPs: 188.76 \n",
" split as 86.02% learnable weight flops and 13.98% attention flops\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "7a6e4f71-4a0b-478c-8b7d-749583ae65d0",
"metadata": {},
"source": [
"## Comparing the two approaches \n",
"So let's look at these two together. We see that our machine utilization has risen and our step time has reduced by 33.76%.\n",
"```\n",
"Parallelism (fsdp=4, transpose=1, tensor=4); 20 steps, loss=8.01624, step_time=3.17s, eMFU=30.25%\n",
"Parallelism (fsdp=16, transpose=1, tensor=1); 20 steps, loss=8.01734, step_time=2.10s, eMFU=45.57%\n",
"```\n",
"\n",
"To put this in context, if we needed 20,000 steps for training, the first strategy would take approximately (20,000 * 3.17s) or 63,400s (17.61h).\n",
"\n",
"The intuitively attractive approach would take approximately (20,000 * 2.10s) or 42,000s (11.67h)."
]
},
{
"cell_type": "markdown",
"id": "5384f1d1-5754-4e30-8425-13289002850c",
"metadata": {},
"source": [
"## Conclusion\n",
"\n",
"We tried a couple of different approaches for parallelism while training LLama2-7b on TPU v5e-16. MaxText makes it practical to loop over a selected set of parallelism strategies and compare results. \n",
"\n",
"Most importantly, we saw how we can quickly iterate with a notebook and use some of the same code to scale up."
]
},
{
"cell_type": "markdown",
"id": "29345521-0d31-49cc-8886-6d725b51d482",
"metadata": {},
"source": [
"### Additional references\n",
"Selected JAX resources:\n",
"1. The [JAX tutorials](https://jax.readthedocs.io/en/latest/tutorials.html) that cover important concepts like just-in-time compilation and sharded computation\n",
"2. The [GSPMD paper](https://arxiv.org/pdf/2105.04663) that inspired some of the sharding efforts in MaxText\n",
"3. Rafi Witten's [High Performance LLMs](https://github.com/rwitten/HighPerfLLMs2024/tree/main) course\n",
"\n",
"While this example focused on JAX, it would be interesting to try something similar with [PyTorch/XLA](https://github.com/pytorch/xla)."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment