Skip to content

Instantly share code, notes, and snippets.

@fulvius31
Created March 25, 2025 13:49
Show Gist options
  • Save fulvius31/8a116a6ed8c07dddaaf5dd0dd73cef20 to your computer and use it in GitHub Desktop.
Save fulvius31/8a116a6ed8c07dddaaf5dd0dd73cef20 to your computer and use it in GitHub Desktop.
AOT vs JIT
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "0c3c86e4-b5ff-4e36-a9a3-358de57b9932",
"metadata": {},
"source": [
"# AOT vs JIT using torch aot inductor"
]
},
{
"cell_type": "markdown",
"id": "164f93d6-6092-4267-bbbf-f546130fe463",
"metadata": {},
"source": [
"### Check torch and triton versions"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "17d26a5c-6205-4921-bbef-bf972cbb3bed",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33mDEPRECATION: Loading egg at /home/alessandrosangiorgi/.conda/envs/pytorch/lib/python3.13/site-packages/sympy-1.13.1-py3.13.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330\u001b[0m\u001b[33m\n",
"\u001b[0mName: torch\n",
"Version: 2.8.0a0+git428fc14\n",
"Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration\n",
"Home-page: https://pytorch.org/\n",
"Author: PyTorch Team\n",
"Author-email: [email protected]\n",
"License: BSD-3-Clause\n",
"Location: /home/alessandrosangiorgi/Downloads/redhat/pytorch\n",
"Editable project location: /home/alessandrosangiorgi/Downloads/redhat/pytorch\n",
"Requires: filelock, fsspec, jinja2, networkx, setuptools, sympy, typing-extensions\n",
"Required-by: marker-pdf, surya-ocr, texify, torchaudio, torchvision\n"
]
}
],
"source": [
"!pip show torch"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "63e36201-0eb1-4811-a4ec-1cccd2fd54f5",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import torch\n",
"torch.cuda.is_available()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "6b5e342f-ecb7-461a-be31-df3ae592f0d2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33mDEPRECATION: Loading egg at /home/alessandrosangiorgi/.conda/envs/pytorch/lib/python3.13/site-packages/sympy-1.13.1-py3.13.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330\u001b[0m\u001b[33m\n",
"\u001b[0mName: triton\n",
"Version: 3.3.0+git6276a781\n",
"Summary: A language and compiler for custom Deep Learning operations\n",
"Home-page: https://github.com/triton-lang/triton/\n",
"Author: Philippe Tillet\n",
"Author-email: [email protected]\n",
"License: \n",
"Location: /home/alessandrosangiorgi/.local/lib/python3.13/site-packages\n",
"Editable project location: /home/alessandrosangiorgi/triton/python\n",
"Requires: setuptools\n",
"Required-by: \n"
]
}
],
"source": [
"!pip show triton"
]
},
{
"cell_type": "markdown",
"id": "6e0f2f72-e8ef-457a-86bc-877188676941",
"metadata": {},
"source": [
"### Change the default triton cache dir"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "ff8e88f0-68ff-4191-8cc0-c27ac0a89f60",
"metadata": {},
"outputs": [],
"source": [
"!export TRITON_CACHE_DIR=~/redhat/aot_test/triton_cache_test"
]
},
{
"cell_type": "markdown",
"id": "cbc6197c-0409-46d4-9099-a75d8c53dd1a",
"metadata": {},
"source": [
"### Run resnet18 model and export it (aot)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "b2a4faa3-c836-42dd-bf30-376235a3d052",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/alessandrosangiorgi/Downloads/redhat/pytorch/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /home/alessandrosangiorgi/Downloads/redhat/pytorch/aten/src/ATen/Context.cpp:148.)\n",
" torch._C._set_onednn_allow_tf32(_allow_tf32)\n",
"/home/alessandrosangiorgi/Downloads/redhat/pytorch/torch/_inductor/compile_fx.py:243: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.\n",
" warnings.warn(\n",
"W0324 15:03:04.133000 79251 torch/_inductor/utils.py:1269] Not enough SMs to use max_autotune_gemm mode\n",
"No ROCm runtime is found, using ROCM_HOME='/usr'\n"
]
}
],
"source": [
"import os\n",
"import torch\n",
"import torch._inductor\n",
"from torchvision.models import ResNet18_Weights, resnet18\n",
"\n",
"model = resnet18(weights=ResNet18_Weights.DEFAULT)\n",
"model.eval()\n",
"\n",
"with torch.inference_mode():\n",
" inductor_configs = {}\n",
"\n",
" if torch.cuda.is_available():\n",
" device = \"cuda\"\n",
" inductor_configs[\"max_autotune\"] = True\n",
" \n",
" model = model.to(device=device)\n",
" example_inputs = (torch.randn(2, 3, 224, 224, device=device),)\n",
"\n",
" exported_program = torch.export.export(\n",
" model,\n",
" example_inputs,\n",
" )\n",
" path = torch._inductor.aoti_compile_and_package(\n",
" exported_program,\n",
" package_path=os.path.join(os.getcwd(), \"resnet18.pt2\"),\n",
" inductor_configs=inductor_configs\n",
" )"
]
},
{
"cell_type": "markdown",
"id": "2c5bc2a7-6e95-485f-8be9-30014851c1f9",
"metadata": {},
"source": [
"### Check how many triton kernels were stored"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "33bb626f-ce00-46f1-9d68-3cb331c1da3c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"33\n"
]
}
],
"source": [
"!ls ~/redhat/aot_test/triton_cache_test | wc -l"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "40269809-6947-457b-b093-8c295e360d85",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Archive: resnet18.pt2\n",
" Length Date Time Name\n",
"--------- ---------- ----- ----\n",
" 1 03-24-2025 15:03 version\n",
" 3 03-24-2025 15:03 archive_format\n",
" 12768 03-24-2025 14:47 data/aotinductor/model/cjbnyeppfzwarwnfg7s4amdubsq74j7t3iwn2yic452uhevmamtv.cubin\n",
" 14304 03-24-2025 14:47 data/aotinductor/model/cxsx4emhy4slp762hb4vcsrjchb2vkuk6xh6r6ilzerpu5nwwgrc.cubin\n",
" 16256 03-24-2025 14:47 data/aotinductor/model/cceguhxn2kahmb4uh27zrb2tzajhggxhx6jlawt6ezahyortlo5j.cubin\n",
" 51176 03-24-2025 15:03 data/aotinductor/model/cbxgxpbplszr7ihxp3dpazdi22bjupxjgobhspxycha355wa5lmg.cubin\n",
" 10592 03-24-2025 14:47 data/aotinductor/model/cegmdbffaopbg4l3qkgubqybsogrpxgubtog4silk32vak7bbhus.cubin\n",
" 30528 03-24-2025 14:52 data/aotinductor/model/cqv6axfamvyav2paihxgcn7izhkckazmrr7vgd2avw25ba5nxsd3.cubin\n",
" 32064 03-24-2025 14:52 data/aotinductor/model/crmzy2ywohrafrnk37xumfrw5zcuucxprew2zcpg7qkoqpgovyha.cubin\n",
" 10592 03-24-2025 14:47 data/aotinductor/model/c5pxp6mj5kmjghmtiiwp2ezpp6aefncx7orby7xdyekz5xsoufbw.cubin\n",
" 30528 03-24-2025 14:52 data/aotinductor/model/cx5ktxnfvoocofiiqe2vyt76h77mhvj3p2jitlepkfysdbo2oeer.cubin\n",
" 11360 03-24-2025 14:47 data/aotinductor/model/cck4zlb767jxkulp6zteda7y3zqfcikfpwmtfukqeiayh45ufzof.cubin\n",
" 27712 03-24-2025 14:47 data/aotinductor/model/ccyvys3iwmppayjp5az3tdnffsw7e35rq25errhg5if3ox55f4lr.cubin\n",
" 31552 03-24-2025 14:52 data/aotinductor/model/coky2nowt64ql6vcvbg5pbm66dwwu47hy2zuyx34sq33ecfpqvjf.cubin\n",
" 15456 03-24-2025 14:48 data/aotinductor/model/ccvm6sxpqwqktet3ksq3verhy3rhlj2t3sqqvb3snrt6eog5vr5y.cubin\n",
" 19520 03-24-2025 14:47 data/aotinductor/model/cljbqwld6jlikrb3dgfogknta2mk336htsw4cihxa42kvpkkwado.cubin\n",
" 12128 03-24-2025 14:47 data/aotinductor/model/ccosdkouajwgh52vnbl4xc2pag5hv3mhnkkzmktecq3jejieb2z4.cubin\n",
" 46016 03-24-2025 14:52 data/aotinductor/model/cukv5d6wyujirn44h72omw7qoygdp22m46xqft4kzo2tnf7cokvg.cubin\n",
" 20544 03-24-2025 14:47 data/aotinductor/model/cnfqetsw2ox47et77yggaohw72uj6zx366hge7wiylv4hb637hdj.cubin\n",
" 16864 03-24-2025 14:52 data/aotinductor/model/cirsy2sxlccocaoesz4jn7iocaxc6y65dy5d56dzcmh7hti2g3yw.cubin\n",
" 26176 03-24-2025 14:47 data/aotinductor/model/cdffub5a3wyrezvgws3w73xkueaw54yk547wozxjf6y4ojk5ktx2.cubin\n",
" 17632 03-24-2025 14:52 data/aotinductor/model/ccxif76gfm7chaj77cpx6m46pfemiahltpu6m37tuksh5fyrdudv.cubin\n",
" 25024 03-24-2025 14:48 data/aotinductor/model/cmjsx35ha6yvxb7sz25f3liuypnrzl7eqrsi6xf6n2bwgk3uaavd.cubin\n",
" 35008 03-24-2025 15:03 data/aotinductor/model/cvrs3tm7jqwbzdlnuvrsckyf6jdcn5yhb7egn2obxwrkb5b4htwe.cubin\n",
" 166603 03-24-2025 15:03 data/aotinductor/model/cwtsm4ojyonjdywfhilxluvi6icjccaafosnll2o7kgqcwuyofyz.wrapper.cpp\n",
" 16980 03-24-2025 15:03 data/aotinductor/model/cqlxtpew53agqx6nvi6psrgo6ajtqwj3m3mhy2bdufzeavy3ndox.kernel.cpp\n",
" 27 03-24-2025 15:03 data/aotinductor/model/cwtsm4ojyonjdywfhilxluvi6icjccaafosnll2o7kgqcwuyofyz.wrapper_metadata.json\n",
" 27 03-24-2025 15:03 data/aotinductor/model/cqlxtpew53agqx6nvi6psrgo6ajtqwj3m3mhy2bdufzeavy3ndox.kernel_metadata.json\n",
" 47230232 03-24-2025 15:03 data/aotinductor/model/cwtsm4ojyonjdywfhilxluvi6icjccaafosnll2o7kgqcwuyofyz.wrapper.so\n",
"--------- -------\n",
" 47927673 29 files\n"
]
}
],
"source": [
"!unzip -l resnet18.pt2"
]
},
{
"cell_type": "markdown",
"id": "28a632a4-2c15-45de-b733-c912eda4aca8",
"metadata": {},
"source": [
"### Use the pre-compiled model"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "026ee146-397b-4c1b-8a9f-74aef085d4a2",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import torch\n",
"import torch._inductor\n",
"\n",
"model_path = os.path.join(os.getcwd(), \"resnet18.pt2\")\n",
"\n",
"compiled_model = torch._inductor.aoti_load_package(model_path)\n",
"example_inputs = (torch.randn(2, 3, 224, 224, device=device),)\n",
"\n",
"with torch.inference_mode():\n",
" output = compiled_model(example_inputs)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "1f0924f3-ed5b-4d15-9a51-4bf2e58c5dda",
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"def timed(fn):\n",
" # Returns the result of running `fn()` and the time it took for `fn()` to run,\n",
" # in seconds. We use CUDA events and synchronization for accurate\n",
" # measurement on CUDA enabled devices.\n",
" if torch.cuda.is_available():\n",
" start = torch.cuda.Event(enable_timing=True)\n",
" end = torch.cuda.Event(enable_timing=True)\n",
" start.record()\n",
" else:\n",
" start = time.time()\n",
"\n",
" result = fn()\n",
" if torch.cuda.is_available():\n",
" end.record()\n",
" torch.cuda.synchronize()\n",
" else:\n",
" end = time.time()\n",
"\n",
" # Measure time taken to execute the function in miliseconds\n",
" if torch.cuda.is_available():\n",
" duration = start.elapsed_time(end)\n",
" else:\n",
" duration = (end - start) * 1000\n",
"\n",
" return result, duration"
]
},
{
"cell_type": "markdown",
"id": "c612c8ec-22bc-4c5a-996b-8c79a324acac",
"metadata": {},
"source": [
"### Check the time taken for first inference for AOTInductor"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "e4129034-23a8-49de-8117-31365430e4bb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Time taken for first inference for AOTInductor is 3.15 ms\n"
]
}
],
"source": [
"torch._dynamo.reset()\n",
"\n",
"model = torch._inductor.aoti_load_package(model_path)\n",
"example_inputs = (torch.randn(1, 3, 224, 224, device=device),)\n",
"\n",
"with torch.inference_mode():\n",
" _, time_taken = timed(lambda: model(example_inputs))\n",
" print(f\"Time taken for first inference for AOTInductor is {time_taken:.2f} ms\")"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "3c24c978-f40f-4c33-99aa-e526a2e80a73",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"33\n"
]
}
],
"source": [
"!ls ~/redhat/aot_test/triton_cache_test | wc -l"
]
},
{
"cell_type": "markdown",
"id": "0faff29c-d67e-456c-a99f-5d6c7e60a2eb",
"metadata": {},
"source": [
"### Flush the Triton cache"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "ffe543ef-61e4-4ff0-ba81-1cb27b44818c",
"metadata": {},
"outputs": [],
"source": [
"!rm -rf ~/redhat/aot_test/triton_cache_test"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "9e014556-8a6f-4755-99d6-1662177d8d88",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ls: cannot access '/home/alessandrosangiorgi/redhat/aot_test/triton_cache_test': No such file or directory\n",
"0\n"
]
}
],
"source": [
"!ls ~/redhat/aot_test/triton_cache_test | wc -l"
]
},
{
"cell_type": "markdown",
"id": "bf8a4dee-bba1-4553-b7db-1119cd226355",
"metadata": {},
"source": [
"### Check the time taken for first inference for torch.compile"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "803d87cc-bea0-4904-9c68-32abfb0596c3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Time taken for first inference for torch.compile is 1196.46 ms\n"
]
}
],
"source": [
"torch._dynamo.reset()\n",
"\n",
"model = resnet18(weights=ResNet18_Weights.DEFAULT).to(device)\n",
"model.eval()\n",
"\n",
"model = torch.compile(model)\n",
"example_inputs = torch.randn(1, 3, 224, 224, device=device)\n",
"\n",
"with torch.inference_mode():\n",
" _, time_taken = timed(lambda: model(example_inputs))\n",
" print(f\"Time taken for first inference for torch.compile is {time_taken:.2f} ms\")"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "31514e13-d992-4f9a-85a6-efa7c989226d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"75\n"
]
}
],
"source": [
"!ls ~/redhat/aot_test/triton_cache_test | wc -l"
]
},
{
"cell_type": "markdown",
"id": "ceda81e6-2719-4706-8b99-edfb21aa7746",
"metadata": {},
"source": [
"\n",
"### Check the time taken for second inference for torch.compile (using inductor and triton cache)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "ac6d429b-e0bf-4c8a-8aa7-54af2a26a225",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Time taken for second inference for torch.compile is 321.69 ms\n"
]
}
],
"source": [
"torch._dynamo.reset()\n",
"\n",
"model = resnet18(weights=ResNet18_Weights.DEFAULT).to(device)\n",
"modtorch._dynamo.reset()\n",
"\n",
"model = resnet18(weights=ResNet18_Weights.DEFAULT).to(device)\n",
"model.eval()\n",
"\n",
"model = torch.compile(model)\n",
"example_inputs = torch.randn(1, 3, 224, 224, device=device)\n",
"\n",
"with torch.inference_mode():\n",
" _, time_taken = timed(lambda: model(example_inputs))\n",
" print(f\"Time taken for first inference for torch.compile is {time_taken:.2f} ms\")el.eval()\n",
"\n",
"model = torch.compile(model)\n",
"example_inputs = torch.randn(1, 3, 224, 224, device=device)\n",
"\n",
"with torch.inference_mode():\n",
" _, time_taken = timed(lambda: model(example_inputs))\n",
" print(f\"Time taken for second inference for torch.compile is {time_taken:.2f} ms\")"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "d28295c5-cab7-43c2-a0c4-217f2276f30d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"75\n"
]
}
],
"source": [
"!ls ~/redhat/aot_test/triton_cache_test | wc -l"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "cb62d45c-6715-4cfd-8760-c73d608f6a91",
"metadata": {},
"outputs": [],
"source": [
"from torch._inductor import config\n",
"config.max_autotune = True"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "4447e1d4-318e-47b7-8f71-b6f6a6b98630",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Time taken for first inference for torch.compile is 318.66 ms\n"
]
}
],
"source": [
"torch._dynamo.reset()\n",
"\n",
"model = resnet18(weights=ResNet18_Weights.DEFAULT).to(device)\n",
"model.eval()\n",
"\n",
"model = torch.compile(model)\n",
"example_inputs = torch.randn(1, 3, 224, 224, device=device)\n",
"\n",
"with torch.inference_mode():\n",
" _, time_taken = timed(lambda: model(example_inputs))\n",
" print(f\"Time taken for first inference for torch.compile is {time_taken:.2f} ms\")"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "f8567881-16dd-4cd1-92ee-bd7da85ffb57",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Time taken for second inference for torch.compile is 312.25 ms\n"
]
}
],
"source": [
"torch._dynamo.reset()\n",
"\n",
"model = resnet18(weights=ResNet18_Weights.DEFAULT).to(device)\n",
"model.eval()\n",
"\n",
"model = torch.compile(model)\n",
"example_inputs = torch.randn(1, 3, 224, 224, device=device)\n",
"\n",
"with torch.inference_mode():\n",
" _, time_taken = timed(lambda: model(example_inputs))\n",
" print(f\"Time taken for second inference for torch.compile is {time_taken:.2f} ms\")"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "037d7b1c-6758-4a63-8751-ab68bb00c818",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"100\n"
]
}
],
"source": [
"!ls ~/redhat/aot_test/triton_cache_test | wc -l"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "02046888-da97-4e6e-bb48-0c5c93b366d5",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment