Created
March 25, 2025 13:49
-
-
Save fulvius31/8a116a6ed8c07dddaaf5dd0dd73cef20 to your computer and use it in GitHub Desktop.
AOT vs JIT
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "0c3c86e4-b5ff-4e36-a9a3-358de57b9932", | |
"metadata": {}, | |
"source": [ | |
"# AOT vs JIT using torch aot inductor" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "164f93d6-6092-4267-bbbf-f546130fe463", | |
"metadata": {}, | |
"source": [ | |
"### Check torch and triton versions" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "17d26a5c-6205-4921-bbef-bf972cbb3bed", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\u001b[33mDEPRECATION: Loading egg at /home/alessandrosangiorgi/.conda/envs/pytorch/lib/python3.13/site-packages/sympy-1.13.1-py3.13.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330\u001b[0m\u001b[33m\n", | |
"\u001b[0mName: torch\n", | |
"Version: 2.8.0a0+git428fc14\n", | |
"Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration\n", | |
"Home-page: https://pytorch.org/\n", | |
"Author: PyTorch Team\n", | |
"Author-email: [email protected]\n", | |
"License: BSD-3-Clause\n", | |
"Location: /home/alessandrosangiorgi/Downloads/redhat/pytorch\n", | |
"Editable project location: /home/alessandrosangiorgi/Downloads/redhat/pytorch\n", | |
"Requires: filelock, fsspec, jinja2, networkx, setuptools, sympy, typing-extensions\n", | |
"Required-by: marker-pdf, surya-ocr, texify, torchaudio, torchvision\n" | |
] | |
} | |
], | |
"source": [ | |
"!pip show torch" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "63e36201-0eb1-4811-a4ec-1cccd2fd54f5", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"True" | |
] | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"import torch\n", | |
"torch.cuda.is_available()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "6b5e342f-ecb7-461a-be31-df3ae592f0d2", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\u001b[33mDEPRECATION: Loading egg at /home/alessandrosangiorgi/.conda/envs/pytorch/lib/python3.13/site-packages/sympy-1.13.1-py3.13.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330\u001b[0m\u001b[33m\n", | |
"\u001b[0mName: triton\n", | |
"Version: 3.3.0+git6276a781\n", | |
"Summary: A language and compiler for custom Deep Learning operations\n", | |
"Home-page: https://github.com/triton-lang/triton/\n", | |
"Author: Philippe Tillet\n", | |
"Author-email: [email protected]\n", | |
"License: \n", | |
"Location: /home/alessandrosangiorgi/.local/lib/python3.13/site-packages\n", | |
"Editable project location: /home/alessandrosangiorgi/triton/python\n", | |
"Requires: setuptools\n", | |
"Required-by: \n" | |
] | |
} | |
], | |
"source": [ | |
"!pip show triton" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "6e0f2f72-e8ef-457a-86bc-877188676941", | |
"metadata": {}, | |
"source": [ | |
"### Change the default triton cache dir" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "ff8e88f0-68ff-4191-8cc0-c27ac0a89f60", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"!export TRITON_CACHE_DIR=~/redhat/aot_test/triton_cache_test" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "cbc6197c-0409-46d4-9099-a75d8c53dd1a", | |
"metadata": {}, | |
"source": [ | |
"### Run resnet18 model and export it (aot)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "b2a4faa3-c836-42dd-bf30-376235a3d052", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/home/alessandrosangiorgi/Downloads/redhat/pytorch/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /home/alessandrosangiorgi/Downloads/redhat/pytorch/aten/src/ATen/Context.cpp:148.)\n", | |
" torch._C._set_onednn_allow_tf32(_allow_tf32)\n", | |
"/home/alessandrosangiorgi/Downloads/redhat/pytorch/torch/_inductor/compile_fx.py:243: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.\n", | |
" warnings.warn(\n", | |
"W0324 15:03:04.133000 79251 torch/_inductor/utils.py:1269] Not enough SMs to use max_autotune_gemm mode\n", | |
"No ROCm runtime is found, using ROCM_HOME='/usr'\n" | |
] | |
} | |
], | |
"source": [ | |
"import os\n", | |
"import torch\n", | |
"import torch._inductor\n", | |
"from torchvision.models import ResNet18_Weights, resnet18\n", | |
"\n", | |
"model = resnet18(weights=ResNet18_Weights.DEFAULT)\n", | |
"model.eval()\n", | |
"\n", | |
"with torch.inference_mode():\n", | |
" inductor_configs = {}\n", | |
"\n", | |
" if torch.cuda.is_available():\n", | |
" device = \"cuda\"\n", | |
" inductor_configs[\"max_autotune\"] = True\n", | |
" \n", | |
" model = model.to(device=device)\n", | |
" example_inputs = (torch.randn(2, 3, 224, 224, device=device),)\n", | |
"\n", | |
" exported_program = torch.export.export(\n", | |
" model,\n", | |
" example_inputs,\n", | |
" )\n", | |
" path = torch._inductor.aoti_compile_and_package(\n", | |
" exported_program,\n", | |
" package_path=os.path.join(os.getcwd(), \"resnet18.pt2\"),\n", | |
" inductor_configs=inductor_configs\n", | |
" )" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "2c5bc2a7-6e95-485f-8be9-30014851c1f9", | |
"metadata": {}, | |
"source": [ | |
"### Check how many triton kernels were stored" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "33bb626f-ce00-46f1-9d68-3cb331c1da3c", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"33\n" | |
] | |
} | |
], | |
"source": [ | |
"!ls ~/redhat/aot_test/triton_cache_test | wc -l" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "40269809-6947-457b-b093-8c295e360d85", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Archive: resnet18.pt2\n", | |
" Length Date Time Name\n", | |
"--------- ---------- ----- ----\n", | |
" 1 03-24-2025 15:03 version\n", | |
" 3 03-24-2025 15:03 archive_format\n", | |
" 12768 03-24-2025 14:47 data/aotinductor/model/cjbnyeppfzwarwnfg7s4amdubsq74j7t3iwn2yic452uhevmamtv.cubin\n", | |
" 14304 03-24-2025 14:47 data/aotinductor/model/cxsx4emhy4slp762hb4vcsrjchb2vkuk6xh6r6ilzerpu5nwwgrc.cubin\n", | |
" 16256 03-24-2025 14:47 data/aotinductor/model/cceguhxn2kahmb4uh27zrb2tzajhggxhx6jlawt6ezahyortlo5j.cubin\n", | |
" 51176 03-24-2025 15:03 data/aotinductor/model/cbxgxpbplszr7ihxp3dpazdi22bjupxjgobhspxycha355wa5lmg.cubin\n", | |
" 10592 03-24-2025 14:47 data/aotinductor/model/cegmdbffaopbg4l3qkgubqybsogrpxgubtog4silk32vak7bbhus.cubin\n", | |
" 30528 03-24-2025 14:52 data/aotinductor/model/cqv6axfamvyav2paihxgcn7izhkckazmrr7vgd2avw25ba5nxsd3.cubin\n", | |
" 32064 03-24-2025 14:52 data/aotinductor/model/crmzy2ywohrafrnk37xumfrw5zcuucxprew2zcpg7qkoqpgovyha.cubin\n", | |
" 10592 03-24-2025 14:47 data/aotinductor/model/c5pxp6mj5kmjghmtiiwp2ezpp6aefncx7orby7xdyekz5xsoufbw.cubin\n", | |
" 30528 03-24-2025 14:52 data/aotinductor/model/cx5ktxnfvoocofiiqe2vyt76h77mhvj3p2jitlepkfysdbo2oeer.cubin\n", | |
" 11360 03-24-2025 14:47 data/aotinductor/model/cck4zlb767jxkulp6zteda7y3zqfcikfpwmtfukqeiayh45ufzof.cubin\n", | |
" 27712 03-24-2025 14:47 data/aotinductor/model/ccyvys3iwmppayjp5az3tdnffsw7e35rq25errhg5if3ox55f4lr.cubin\n", | |
" 31552 03-24-2025 14:52 data/aotinductor/model/coky2nowt64ql6vcvbg5pbm66dwwu47hy2zuyx34sq33ecfpqvjf.cubin\n", | |
" 15456 03-24-2025 14:48 data/aotinductor/model/ccvm6sxpqwqktet3ksq3verhy3rhlj2t3sqqvb3snrt6eog5vr5y.cubin\n", | |
" 19520 03-24-2025 14:47 data/aotinductor/model/cljbqwld6jlikrb3dgfogknta2mk336htsw4cihxa42kvpkkwado.cubin\n", | |
" 12128 03-24-2025 14:47 data/aotinductor/model/ccosdkouajwgh52vnbl4xc2pag5hv3mhnkkzmktecq3jejieb2z4.cubin\n", | |
" 46016 03-24-2025 14:52 data/aotinductor/model/cukv5d6wyujirn44h72omw7qoygdp22m46xqft4kzo2tnf7cokvg.cubin\n", | |
" 20544 03-24-2025 14:47 data/aotinductor/model/cnfqetsw2ox47et77yggaohw72uj6zx366hge7wiylv4hb637hdj.cubin\n", | |
" 16864 03-24-2025 14:52 data/aotinductor/model/cirsy2sxlccocaoesz4jn7iocaxc6y65dy5d56dzcmh7hti2g3yw.cubin\n", | |
" 26176 03-24-2025 14:47 data/aotinductor/model/cdffub5a3wyrezvgws3w73xkueaw54yk547wozxjf6y4ojk5ktx2.cubin\n", | |
" 17632 03-24-2025 14:52 data/aotinductor/model/ccxif76gfm7chaj77cpx6m46pfemiahltpu6m37tuksh5fyrdudv.cubin\n", | |
" 25024 03-24-2025 14:48 data/aotinductor/model/cmjsx35ha6yvxb7sz25f3liuypnrzl7eqrsi6xf6n2bwgk3uaavd.cubin\n", | |
" 35008 03-24-2025 15:03 data/aotinductor/model/cvrs3tm7jqwbzdlnuvrsckyf6jdcn5yhb7egn2obxwrkb5b4htwe.cubin\n", | |
" 166603 03-24-2025 15:03 data/aotinductor/model/cwtsm4ojyonjdywfhilxluvi6icjccaafosnll2o7kgqcwuyofyz.wrapper.cpp\n", | |
" 16980 03-24-2025 15:03 data/aotinductor/model/cqlxtpew53agqx6nvi6psrgo6ajtqwj3m3mhy2bdufzeavy3ndox.kernel.cpp\n", | |
" 27 03-24-2025 15:03 data/aotinductor/model/cwtsm4ojyonjdywfhilxluvi6icjccaafosnll2o7kgqcwuyofyz.wrapper_metadata.json\n", | |
" 27 03-24-2025 15:03 data/aotinductor/model/cqlxtpew53agqx6nvi6psrgo6ajtqwj3m3mhy2bdufzeavy3ndox.kernel_metadata.json\n", | |
" 47230232 03-24-2025 15:03 data/aotinductor/model/cwtsm4ojyonjdywfhilxluvi6icjccaafosnll2o7kgqcwuyofyz.wrapper.so\n", | |
"--------- -------\n", | |
" 47927673 29 files\n" | |
] | |
} | |
], | |
"source": [ | |
"!unzip -l resnet18.pt2" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "28a632a4-2c15-45de-b733-c912eda4aca8", | |
"metadata": {}, | |
"source": [ | |
"### Use the pre-compiled model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "026ee146-397b-4c1b-8a9f-74aef085d4a2", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import os\n", | |
"import torch\n", | |
"import torch._inductor\n", | |
"\n", | |
"model_path = os.path.join(os.getcwd(), \"resnet18.pt2\")\n", | |
"\n", | |
"compiled_model = torch._inductor.aoti_load_package(model_path)\n", | |
"example_inputs = (torch.randn(2, 3, 224, 224, device=device),)\n", | |
"\n", | |
"with torch.inference_mode():\n", | |
" output = compiled_model(example_inputs)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "1f0924f3-ed5b-4d15-9a51-4bf2e58c5dda", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import time\n", | |
"def timed(fn):\n", | |
" # Returns the result of running `fn()` and the time it took for `fn()` to run,\n", | |
" # in seconds. We use CUDA events and synchronization for accurate\n", | |
" # measurement on CUDA enabled devices.\n", | |
" if torch.cuda.is_available():\n", | |
" start = torch.cuda.Event(enable_timing=True)\n", | |
" end = torch.cuda.Event(enable_timing=True)\n", | |
" start.record()\n", | |
" else:\n", | |
" start = time.time()\n", | |
"\n", | |
" result = fn()\n", | |
" if torch.cuda.is_available():\n", | |
" end.record()\n", | |
" torch.cuda.synchronize()\n", | |
" else:\n", | |
" end = time.time()\n", | |
"\n", | |
" # Measure time taken to execute the function in miliseconds\n", | |
" if torch.cuda.is_available():\n", | |
" duration = start.elapsed_time(end)\n", | |
" else:\n", | |
" duration = (end - start) * 1000\n", | |
"\n", | |
" return result, duration" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "c612c8ec-22bc-4c5a-996b-8c79a324acac", | |
"metadata": {}, | |
"source": [ | |
"### Check the time taken for first inference for AOTInductor" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "e4129034-23a8-49de-8117-31365430e4bb", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Time taken for first inference for AOTInductor is 3.15 ms\n" | |
] | |
} | |
], | |
"source": [ | |
"torch._dynamo.reset()\n", | |
"\n", | |
"model = torch._inductor.aoti_load_package(model_path)\n", | |
"example_inputs = (torch.randn(1, 3, 224, 224, device=device),)\n", | |
"\n", | |
"with torch.inference_mode():\n", | |
" _, time_taken = timed(lambda: model(example_inputs))\n", | |
" print(f\"Time taken for first inference for AOTInductor is {time_taken:.2f} ms\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "3c24c978-f40f-4c33-99aa-e526a2e80a73", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"33\n" | |
] | |
} | |
], | |
"source": [ | |
"!ls ~/redhat/aot_test/triton_cache_test | wc -l" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "0faff29c-d67e-456c-a99f-5d6c7e60a2eb", | |
"metadata": {}, | |
"source": [ | |
"### Flush the Triton cache" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "ffe543ef-61e4-4ff0-ba81-1cb27b44818c", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"!rm -rf ~/redhat/aot_test/triton_cache_test" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "9e014556-8a6f-4755-99d6-1662177d8d88", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"ls: cannot access '/home/alessandrosangiorgi/redhat/aot_test/triton_cache_test': No such file or directory\n", | |
"0\n" | |
] | |
} | |
], | |
"source": [ | |
"!ls ~/redhat/aot_test/triton_cache_test | wc -l" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "bf8a4dee-bba1-4553-b7db-1119cd226355", | |
"metadata": {}, | |
"source": [ | |
"### Check the time taken for first inference for torch.compile" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "803d87cc-bea0-4904-9c68-32abfb0596c3", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Time taken for first inference for torch.compile is 1196.46 ms\n" | |
] | |
} | |
], | |
"source": [ | |
"torch._dynamo.reset()\n", | |
"\n", | |
"model = resnet18(weights=ResNet18_Weights.DEFAULT).to(device)\n", | |
"model.eval()\n", | |
"\n", | |
"model = torch.compile(model)\n", | |
"example_inputs = torch.randn(1, 3, 224, 224, device=device)\n", | |
"\n", | |
"with torch.inference_mode():\n", | |
" _, time_taken = timed(lambda: model(example_inputs))\n", | |
" print(f\"Time taken for first inference for torch.compile is {time_taken:.2f} ms\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"id": "31514e13-d992-4f9a-85a6-efa7c989226d", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"75\n" | |
] | |
} | |
], | |
"source": [ | |
"!ls ~/redhat/aot_test/triton_cache_test | wc -l" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "ceda81e6-2719-4706-8b99-edfb21aa7746", | |
"metadata": {}, | |
"source": [ | |
"\n", | |
"### Check the time taken for second inference for torch.compile (using inductor and triton cache)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"id": "ac6d429b-e0bf-4c8a-8aa7-54af2a26a225", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Time taken for second inference for torch.compile is 321.69 ms\n" | |
] | |
} | |
], | |
"source": [ | |
"torch._dynamo.reset()\n", | |
"\n", | |
"model = resnet18(weights=ResNet18_Weights.DEFAULT).to(device)\n", | |
"modtorch._dynamo.reset()\n", | |
"\n", | |
"model = resnet18(weights=ResNet18_Weights.DEFAULT).to(device)\n", | |
"model.eval()\n", | |
"\n", | |
"model = torch.compile(model)\n", | |
"example_inputs = torch.randn(1, 3, 224, 224, device=device)\n", | |
"\n", | |
"with torch.inference_mode():\n", | |
" _, time_taken = timed(lambda: model(example_inputs))\n", | |
" print(f\"Time taken for first inference for torch.compile is {time_taken:.2f} ms\")el.eval()\n", | |
"\n", | |
"model = torch.compile(model)\n", | |
"example_inputs = torch.randn(1, 3, 224, 224, device=device)\n", | |
"\n", | |
"with torch.inference_mode():\n", | |
" _, time_taken = timed(lambda: model(example_inputs))\n", | |
" print(f\"Time taken for second inference for torch.compile is {time_taken:.2f} ms\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"id": "d28295c5-cab7-43c2-a0c4-217f2276f30d", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"75\n" | |
] | |
} | |
], | |
"source": [ | |
"!ls ~/redhat/aot_test/triton_cache_test | wc -l" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"id": "cb62d45c-6715-4cfd-8760-c73d608f6a91", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from torch._inductor import config\n", | |
"config.max_autotune = True" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"id": "4447e1d4-318e-47b7-8f71-b6f6a6b98630", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Time taken for first inference for torch.compile is 318.66 ms\n" | |
] | |
} | |
], | |
"source": [ | |
"torch._dynamo.reset()\n", | |
"\n", | |
"model = resnet18(weights=ResNet18_Weights.DEFAULT).to(device)\n", | |
"model.eval()\n", | |
"\n", | |
"model = torch.compile(model)\n", | |
"example_inputs = torch.randn(1, 3, 224, 224, device=device)\n", | |
"\n", | |
"with torch.inference_mode():\n", | |
" _, time_taken = timed(lambda: model(example_inputs))\n", | |
" print(f\"Time taken for first inference for torch.compile is {time_taken:.2f} ms\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"id": "f8567881-16dd-4cd1-92ee-bd7da85ffb57", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Time taken for second inference for torch.compile is 312.25 ms\n" | |
] | |
} | |
], | |
"source": [ | |
"torch._dynamo.reset()\n", | |
"\n", | |
"model = resnet18(weights=ResNet18_Weights.DEFAULT).to(device)\n", | |
"model.eval()\n", | |
"\n", | |
"model = torch.compile(model)\n", | |
"example_inputs = torch.randn(1, 3, 224, 224, device=device)\n", | |
"\n", | |
"with torch.inference_mode():\n", | |
" _, time_taken = timed(lambda: model(example_inputs))\n", | |
" print(f\"Time taken for second inference for torch.compile is {time_taken:.2f} ms\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"id": "037d7b1c-6758-4a63-8751-ab68bb00c818", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"100\n" | |
] | |
} | |
], | |
"source": [ | |
"!ls ~/redhat/aot_test/triton_cache_test | wc -l" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "02046888-da97-4e6e-bb48-0c5c93b366d5", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.13.1" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment