Created
July 6, 2021 05:04
-
-
Save stas00/5edf8f31512cac285d9cf2d43ecbf8d7 to your computer and use it in GitHub Desktop.
this is a rough beginning of an applied torch.profiler tutorial
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "205de558", | |
"metadata": { | |
"run_control": { | |
"marked": false | |
} | |
}, | |
"source": [ | |
"# Practical Performance Analysis with `torch.profiler`" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "bab07384", | |
"metadata": { | |
"run_control": { | |
"marked": false | |
} | |
}, | |
"source": [ | |
"This tutorial demonstrates how to use `torch.profiler` to speed up code execution." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "8de3e3fa", | |
"metadata": { | |
"run_control": { | |
"marked": false | |
} | |
}, | |
"source": [ | |
"# Experiment 1\n", | |
"\n", | |
"When running [🤗 Transformers](https://github.com/huggingface/transformers/) model in fp16-mode evaluation I noticed that the performance was slower than running the same model in fp32. \n", | |
"\n", | |
"After installing the profiler context manager around the evaluation loop, I run it once with the fp16 and another with fp32 model.\n", | |
"\n", | |
"Currently, there is no easy way to compare the results and to find out the difference a lot of manual comparison has to be done. Basically, sorting the results tables by CUDA total and then comparing the same function calls to see where the numbers diverge.\n", | |
"\n", | |
"XXX: document how I figured it out. I think I had to step through the different traces in the tensorboard plugin to get there.\n", | |
"\n", | |
"Version 1. Following the slow traces I noticed that this line of code was much slower with the fp16 model.\n", | |
"\n", | |
"```\n", | |
"attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)\n", | |
"```\n", | |
"\n", | |
"Here the code upscales the input to `softmax` to fp32, performs the calculation and returns the results in fp32 and then it convers it back to fp16.\n", | |
"\n", | |
"Version 2. So the first experiment was to compare this line of code with another version which converts to fp16 internally, and that is:\n", | |
"\n", | |
"```\n", | |
"attn_weights = nn.functional.softmax(scores.float(), dim=-1, dtype=scores.dtype)\n", | |
"```\n", | |
"\n", | |
"Version 3. And then it's also a good idea to check whether the upscaling to fp32 is needed at all, that is what about this version:\n", | |
"\n", | |
"```\n", | |
"attn_weights = nn.functional.softmax(scores, dim=-1)\n", | |
"```\n", | |
"\n", | |
"`softmax` here performs everything in fp16 and returns fp16." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "5b2b079c", | |
"metadata": {}, | |
"source": [ | |
"## Setup\n", | |
"\n", | |
"First, let's import all the needed components" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "e5bc4a17", | |
"metadata": { | |
"run_control": { | |
"marked": false | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"from torch import nn\n", | |
"from torch.profiler import profile, record_function, ProfilerActivity\n", | |
"from pprint import pprint\n", | |
"import operator" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "d041ed69", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# XXX: hack - workaround to init the optimizer before we do any cuda things (otherwise it\n", | |
"# hangs if cuda has already been initited before its first run)\n", | |
"with torch.profiler.profile() as profiler:\n", | |
" pass" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "71d5bbd6", | |
"metadata": {}, | |
"source": [ | |
"When profiling it's important that the data is of a realistic size. If we were to choose a small tensor we will hardly see any activity on CUDA and won't be able to get the correct insights." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "7d2dba55", | |
"metadata": { | |
"run_control": { | |
"marked": false | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"scores = torch.rand(16,8,512,512).half().cuda()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "d9367224", | |
"metadata": {}, | |
"source": [ | |
"## Version 1\n", | |
"\n", | |
"Let's profile that line of code. Instead of doing a setup to not take into the account the warm up first iteration which at times could be slower, we will just run it 100 times in a row which would give us a good average result.\n", | |
"\n", | |
"Here we are going to call the code we are profiling as `type_as`, which will allow us to see total CUDA and CPU times for the whole of the code being profiled:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "3511df21", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ \n", | |
" Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls \n", | |
"------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ \n", | |
" type_as 1.34% 1.186ms 10.90% 9.636ms 9.636ms 0.000us 0.00% 85.026ms 85.026ms 1 \n", | |
" aten::to 0.57% 506.000us 7.14% 6.312ms 31.560us 0.000us 0.00% 53.225ms 266.125us 200 \n", | |
" aten::copy_ 1.35% 1.196ms 4.21% 3.718ms 18.590us 53.225ms 62.60% 53.225ms 266.125us 200 \n", | |
" aten::_softmax 0.62% 549.000us 2.13% 1.879ms 18.790us 31.801ms 37.40% 31.801ms 318.010us 100 \n", | |
"void (anonymous namespace)::softmax_warp_forward<flo... 0.00% 0.000us 0.00% 0.000us 0.000us 31.801ms 37.40% 31.801ms 318.010us 100 \n", | |
" aten::softmax 0.22% 196.000us 2.30% 2.034ms 20.340us 0.000us 0.00% 29.576ms 295.760us 100 \n", | |
"void at::native::unrolled_elementwise_kernel<at::nat... 0.00% 0.000us 0.00% 0.000us 0.000us 27.668ms 32.54% 27.668ms 276.680us 100 \n", | |
" aten::type_as 0.22% 197.000us 2.22% 1.963ms 19.630us 0.000us 0.00% 25.557ms 255.570us 100 \n", | |
"void at::native::unrolled_elementwise_kernel<at::nat... 0.00% 0.000us 0.00% 0.000us 0.000us 25.557ms 30.06% 25.557ms 255.570us 100 \n", | |
" aten::zeros 0.03% 25.000us 0.04% 36.000us 36.000us 0.000us 0.00% 0.000us 0.000us 1 \n", | |
"------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ \n", | |
"Self CPU time total: 88.390ms\n", | |
"Self CUDA time total: 85.026ms\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"with torch.profiler.profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=False) as prof:\n", | |
" with record_function(\"type_as\"):\n", | |
" for i in range(100):\n", | |
" attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)\n", | |
"print(prof.key_averages().table(sort_by=\"cuda_time_total\", row_limit=10))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "a573918d", | |
"metadata": {}, | |
"source": [ | |
"note: One of the not-user-friendly features of the current `table` function, is that it reports miliseconds and microsecods not on the same scale, so be careful when comparing the different entries as one can be 1000 times bigger than the other.\n", | |
"\n", | |
"The other difficult part is that the results table is too wide which makes it very difficult to compare columns.\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "0b602b8b", | |
"metadata": {}, | |
"source": [ | |
"## Version 2\n", | |
"\n", | |
"Here we are going call this line of code `auto_dtype`:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "b1e0450a", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ \n", | |
" Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls \n", | |
"------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ \n", | |
" auto_dtype 1.16% 817.000us 9.45% 6.643ms 6.643ms 0.000us 0.00% 69.257ms 69.257ms 1 \n", | |
" aten::to 0.67% 470.000us 5.84% 4.108ms 20.540us 0.000us 0.00% 53.195ms 265.975us 200 \n", | |
" aten::copy_ 1.62% 1.136ms 3.74% 2.627ms 13.135us 53.195ms 76.81% 53.195ms 265.975us 200 \n", | |
" aten::softmax 0.48% 339.000us 4.92% 3.462ms 34.620us 0.000us 0.00% 41.558ms 415.580us 100 \n", | |
"void at::native::unrolled_elementwise_kernel<at::nat... 0.00% 0.000us 0.00% 0.000us 0.000us 27.699ms 39.99% 27.699ms 276.990us 100 \n", | |
"void at::native::unrolled_elementwise_kernel<at::nat... 0.00% 0.000us 0.00% 0.000us 0.000us 25.496ms 36.81% 25.496ms 254.960us 100 \n", | |
" aten::_softmax 0.72% 504.000us 1.96% 1.377ms 13.770us 16.062ms 23.19% 16.062ms 160.620us 100 \n", | |
"void (anonymous namespace)::softmax_warp_forward<c10... 0.00% 0.000us 0.00% 0.000us 0.000us 16.062ms 23.19% 16.062ms 160.620us 100 \n", | |
" aten::zeros 0.01% 9.000us 0.05% 34.000us 34.000us 0.000us 0.00% 0.000us 0.000us 1 \n", | |
" aten::empty 0.44% 308.000us 0.49% 344.000us 3.373us 0.000us 0.00% 0.000us 0.000us 102 \n", | |
"------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ \n", | |
"Self CPU time total: 70.308ms\n", | |
"Self CUDA time total: 69.257ms\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"with torch.profiler.profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=False) as prof:\n", | |
" with record_function(\"auto_dtype\"):\n", | |
" for i in range(100): \n", | |
" attn_weights = nn.functional.softmax(scores.float(), dim=-1, dtype=scores.dtype)\n", | |
"print(prof.key_averages().table(sort_by=\"cuda_time_total\", row_limit=10))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "1dea2c53", | |
"metadata": {}, | |
"source": [ | |
"If you look very closely you will see that CUDA total of `softmax_warp_forward` is `31.824ms` in Ver1 and `16.068ms` in Ver2. So in Ver2 it's twice as fast, with `15.756ms` difference.\n", | |
"\n", | |
"The difference in the CUDA total for the whole line of code is `85.503ms` (Ver1 `type_as`) vs. `69.265ms` (Ver2 `auto_dtype`) which is about the same difference as described in the previous paragraph. The difference is about 20%. This may or may not be significant depending on how much this code runs comparatively to other chunks of code.\n", | |
"\n", | |
"Also notice that instead of looking into the CUDA total column of the virtual function `auto_dtype` we can see the same report at the end of the table, which reported `Self CUDA time total: 69.265ms`." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "f356430a", | |
"metadata": {}, | |
"source": [ | |
"## Version 3\n", | |
"\n", | |
"Here we are going to call this line of code `no_dtype_change`:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "ed1b42df", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ \n", | |
" Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls \n", | |
"------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ \n", | |
" cudaDeviceSynchronize 84.94% 14.356ms 84.94% 14.356ms 14.356ms 0.000us 0.00% 0.000us 0.000us 1 \n", | |
" no_dtype_change 2.36% 399.000us 14.79% 2.499ms 2.499ms 0.000us 0.00% 16.104ms 16.104ms 1 \n", | |
" aten::softmax 1.40% 237.000us 12.42% 2.099ms 20.990us 0.000us 0.00% 12.722ms 127.220us 100 \n", | |
" aten::_softmax 2.78% 470.000us 11.65% 1.969ms 19.690us 16.104ms 100.00% 16.104ms 161.040us 100 \n", | |
" cudaLaunchKernel 5.55% 938.000us 5.55% 938.000us 9.380us 0.000us 0.00% 0.000us 0.000us 100 \n", | |
" aten::empty_like 1.10% 186.000us 2.69% 454.000us 4.540us 0.000us 0.00% 0.000us 0.000us 100 \n", | |
" aten::empty 1.82% 307.000us 1.92% 325.000us 3.186us 0.000us 0.00% 0.000us 0.000us 102 \n", | |
" aten::zeros 0.04% 7.000us 0.27% 46.000us 46.000us 0.000us 0.00% 0.000us 0.000us 1 \n", | |
" aten::zero_ 0.01% 1.000us 0.01% 1.000us 1.000us 0.000us 0.00% 0.000us 0.000us 1 \n", | |
"void (anonymous namespace)::softmax_warp_forward<c10... 0.00% 0.000us 0.00% 0.000us 0.000us 16.104ms 100.00% 16.104ms 161.040us 100 \n", | |
"------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ \n", | |
"Self CPU time total: 16.901ms\n", | |
"Self CUDA time total: 16.104ms\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"# bypassing up and down-scaling and doing softmax in fp16\n", | |
"with torch.profiler.profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=False) as prof:\n", | |
" with record_function(\"no_dtype_change\"):\n", | |
" for i in range(100):\n", | |
" attn_weights = nn.functional.softmax(scores, dim=-1)\n", | |
"print(prof.key_averages().table(sort_by=\"cpu_time_total\", row_limit=10))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "bc3da681", | |
"metadata": {}, | |
"source": [ | |
"We can see that the total CUDA is only `16.133ms` (Ver3 `no_dtype_change`) as compared to `69.265ms` (Ver2 `auto_dtype`) and `85.503ms` (Ver1 `type_as`) vs. So Ver3 is about 5 times faster than Ver1." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "35bb1ca5", | |
"metadata": {}, | |
"source": [ | |
"## Wallclock time\n", | |
"\n", | |
"Now let's measure the speed of the 3 different versions using a wallclock timer:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "55f4dc3f", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from torch.utils.benchmark import Timer" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"id": "a4a8db89", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"t1 = Timer(stmt=\"[nn.functional.softmax(scores.float(), dim=-1).type_as(scores) for i in range(100)]\", globals=globals())\n", | |
"t2 = Timer(stmt=\"[nn.functional.softmax(scores.float(), dim=-1, dtype=scores.dtype) for i in range(100)]\", globals=globals())\n", | |
"t3 = Timer(stmt=\"[nn.functional.softmax(scores, dim=-1) for i in range(100)]\", globals=globals())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"id": "5fafb3c7", | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<torch.utils.benchmark.utils.common.Measurement object at 0x7fcc4c2d8940>\n", | |
"[nn.functional.softmax(scores.float(), dim=-1).type_as(scores) for x in range(100)]\n", | |
" 83.75 ms\n", | |
" 1 measurement, 10 runs , 1 thread" | |
] | |
}, | |
"execution_count": 24, | |
"metadata": {}, | |
"output_type": "execute_result" | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"<torch.utils.benchmark.utils.common.Measurement object at 0x7fcc4c2d8040>\n", | |
"[nn.functional.softmax(scores.float(), dim=-1, dtype=scores.dtype) for x in range(100)]\n", | |
" Median: 68.26 ms\n", | |
" 3 measurements, 1 runs per measurement, 1 thread" | |
] | |
}, | |
"execution_count": 24, | |
"metadata": {}, | |
"output_type": "execute_result" | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"<torch.utils.benchmark.utils.common.Measurement object at 0x7fcc4c2b29d0>\n", | |
"[nn.functional.softmax(scores, dim=-1) for x in range(100)]\n", | |
" Median: 16.32 ms\n", | |
" 2 measurements, 10 runs per measurement, 1 thread" | |
] | |
}, | |
"execution_count": 24, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"t1.blocked_autorange()\n", | |
"t2.blocked_autorange()\n", | |
"t3.blocked_autorange()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "051df6f9", | |
"metadata": {}, | |
"source": [ | |
"As you can see we get almost the same results as we got out of `torch.profiler` results for \"Self CPU time total\"." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "3fa53a33", | |
"metadata": {}, | |
"source": [ | |
"### Correlation of CPU and CUDA times\n", | |
"\n", | |
"Normal profilers measure just one thing, but here things are more complex since we have two devices that perform their work separately and we somehow want to take both into account.\n", | |
"\n", | |
"XXX: continue" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "410c9260", | |
"metadata": {}, | |
"source": [ | |
"## Alternative views\n", | |
"\n", | |
"### Parsing `prof.key_averages()`\n", | |
"\n", | |
"Instead of using the provided `table` formatting function, which could be not so convenient, you can access `prof.key_averages()` directly.\n", | |
"\n", | |
"It returns a list like object of objects, and a single entry looks as following:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "0a11e85e", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<FunctionEventAvg key=aten::zeros self_cpu_time=7.000us cpu_time=46.000us self_cuda_time=0.000us cuda_time=0.000us input_shapes= cpu_memory_usage=0 cuda_memory_usage=0>" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"prof.key_averages()[0]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "e64f77d1", | |
"metadata": {}, | |
"source": [ | |
"For example, let's get the name of each function and its CUDA and CPU time execution time in usecs for the first few records:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "e3d334c3", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[{'cpu_time': 46.0, 'cuda_time': 0.0, 'key': 'aten::zeros'},\n", | |
" {'cpu_time': 3.1862745098039214, 'cuda_time': 0.0, 'key': 'aten::empty'},\n", | |
" {'cpu_time': 1.0, 'cuda_time': 0.0, 'key': 'aten::zero_'},\n", | |
" {'cpu_time': 2499.0, 'cuda_time': 16104.0, 'key': 'no_dtype_change'},\n", | |
" {'cpu_time': 20.99, 'cuda_time': 127.22, 'key': 'aten::softmax'},\n", | |
" {'cpu_time': 19.69, 'cuda_time': 161.04, 'key': 'aten::_softmax'},\n", | |
" {'cpu_time': 4.54, 'cuda_time': 0.0, 'key': 'aten::empty_like'},\n", | |
" {'cpu_time': 9.38, 'cuda_time': 0.0, 'key': 'cudaLaunchKernel'},\n", | |
" {'cpu_time': 0.0,\n", | |
" 'cuda_time': 161.04,\n", | |
" 'key': 'void (anonymous namespace)::softmax_warp_forward<c10::Half, '\n", | |
" 'c10::Half, float, 9, false>(c10::Half*, c10::Half const*, int, int, '\n", | |
" 'int)'},\n", | |
" {'cpu_time': 14356.0, 'cuda_time': 0.0, 'key': 'cudaDeviceSynchronize'}]\n" | |
] | |
} | |
], | |
"source": [ | |
"attrs = ['key', 'cuda_time', 'cpu_time']\n", | |
"pprint(list(dict((key, getattr(x, key)) for key in attrs) for x in prof.key_averages())[:10])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "8c1f2b72", | |
"metadata": {}, | |
"source": [ | |
"We can even quickly create an alternative table using `tomark`." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "72f55c7c", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/markdown": [ | |
"| key | cuda_time | cpu_time |\n", | |
"|-----|-----|-----|\n", | |
"| aten::zeros | 0.0 | 46.0 |\n", | |
"| aten::empty | 0.0 | 3.1862745098039214 |\n", | |
"| aten::zero_ | 0.0 | 1.0 |\n", | |
"| no_dtype_change | 16104.0 | 2499.0 |\n", | |
"| aten::softmax | 127.22 | 20.99 |\n", | |
"| aten::_softmax | 161.04 | 19.69 |\n", | |
"| aten::empty_like | 0.0 | 4.54 |\n", | |
"| cudaLaunchKernel | 0.0 | 9.38 |\n", | |
"| void (anonymous namespace)::softmax_warp_forward<c10::Half, c10::Half, float, 9, false>(c10::Half*, c10::Half const*, int, int, int) | 161.04 | 0.0 |\n", | |
"| cudaDeviceSynchronize | 0.0 | 14356.0 |\n" | |
], | |
"text/plain": [ | |
"<IPython.core.display.Markdown object>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"#!pip install tomark\n", | |
"attrs = ['key', 'cuda_time', 'cpu_time']\n", | |
"data = list(dict((key, getattr(x, key)) for key in attrs) for x in prof.key_averages())[:10]\n", | |
"from IPython.display import display, Markdown\n", | |
"from tomark import Tomark\n", | |
"display(Markdown(Tomark.table(data)))\n", | |
"# XXX: auto-format text and floats to a more consistent format" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "6a000d29", | |
"metadata": {}, | |
"source": [ | |
"As you can see these are the same records that the earlier tables were made of. And this time we get all the data points using the same scale of usecs, which is much easier to compare with.\n", | |
"\n", | |
"For example, let's get the total CUDA times:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"id": "2b176648", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"{'aten::_softmax': 161.03,\n", | |
" 'aten::empty': 0.0,\n", | |
" 'aten::empty_like': 0.0,\n", | |
" 'aten::softmax': 146.54,\n", | |
" 'aten::zero_': 0.0,\n", | |
" 'aten::zeros': 0.0,\n", | |
" 'cudaDeviceSynchronize': 0.0,\n", | |
" 'cudaLaunchKernel': 0.0,\n", | |
" 'no_dtype_change': 16103.0,\n", | |
" 'void (anonymous namespace)::softmax_warp_forward<c10::Half, c10::Half, float, 9, false>(c10::Half*, c10::Half const*, int, int, int)': 161.03}\n" | |
] | |
} | |
], | |
"source": [ | |
"pprint({x.key:x.cuda_time for x in prof.key_averages()})" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "3f7ed0f1", | |
"metadata": {}, | |
"source": [ | |
"Now, let's get the maximum total CUDA and CPU times:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"id": "a9d900cf", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"max total cuda time 16103.0us\n", | |
"max total cpu time 14113.0us\n" | |
] | |
} | |
], | |
"source": [ | |
"max_cuda_time = max(prof.key_averages(), key=operator.attrgetter(\"cuda_time\")).cuda_time\n", | |
"max_cpu_time = max(prof.key_averages(), key=operator.attrgetter(\"cpu_time\")).cpu_time\n", | |
"print(f\"max total cuda time {max_cuda_time}us\")\n", | |
"print(f\"max total cpu time {max_cpu_time}us\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "c96d70dc", | |
"metadata": {}, | |
"source": [ | |
"### Parsing `prof.events()`\n", | |
"\n", | |
"Yet another even more detailed view is provided by `prof.events()`. \n", | |
"\n", | |
"Here each function call is recorded separately and you have to do any math yourself.\n", | |
"\n", | |
"A single event entry looks as following:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "6c869086", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<FunctionEvent id=2516 name=aten::zeros device_type=DeviceType.CPU node_id=-1 cpu_time=19.000us start_us=273.5 end_us=292.5 cpu_children=[2517, 2518] cuda_time=0.000us name=aten::zeros thread=1 input_shapes=[] cpu_memory_usage=0 cuda_memory_usage=0 is_async=False is_remote=False seq_nr=-1 is_legacy=False>" | |
] | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"prof.events()[0]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "71a59fc3", | |
"metadata": {}, | |
"source": [ | |
"For example, let's get the name of each function and its CUDA and CPU time execution time in usecs for the first few records:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "03c7a029", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[{'cpu_time': 19.0, 'cuda_time': 0.0, 'name': 'aten::zeros'},\n", | |
" {'cpu_time': 9.0, 'cuda_time': 0.0, 'name': 'aten::empty'},\n", | |
" {'cpu_time': 1.0, 'cuda_time': 0.0, 'name': 'aten::zero_'},\n", | |
" {'cpu_time': 2774.0, 'cuda_time': 16103.0, 'name': 'no_dtype_change'},\n", | |
" {'cpu_time': 20.0, 'cuda_time': 0.0, 'name': 'aten::empty'}]\n" | |
] | |
} | |
], | |
"source": [ | |
"attrs = ['name', 'cuda_time', 'cpu_time']\n", | |
"pprint(list(dict((key, getattr(x, key)) for key in attrs) for x in prof.events())[:5])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "84b12353", | |
"metadata": {}, | |
"source": [ | |
"Now let's do aggregate math on a specific function of interest:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "a0dfef6e", | |
"metadata": { | |
"run_control": { | |
"marked": false | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Function: aten::_softmax\n", | |
"total 100 calls\n", | |
"total cuda time 16103.0us\n", | |
"average cuda time 161.03us\n" | |
] | |
} | |
], | |
"source": [ | |
"f_name = 'aten::_softmax'\n", | |
"print(f\"Function: {f_name}\")\n", | |
"\n", | |
"# how many times it was called:\n", | |
"calls = sum(1 for x in prof.events() if x.name == f_name)\n", | |
"print(f\"total {calls} calls\")\n", | |
"\n", | |
"# total cuda run time in usecs\n", | |
"total_cuda_time = sum(x.cuda_time for x in prof.events() if x.name == f_name)\n", | |
"print(f\"total cuda time {total_cuda_time}us\")\n", | |
"\n", | |
"# average cuda run time\n", | |
"average_cuda_time = sum(x.cuda_time for x in prof.events() if x.name == f_name)/calls\n", | |
"print(f\"average cuda time {average_cuda_time}us\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "f192d552", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "25540c02", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "e614dbc2", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "f2d21911", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"hide_input": false, | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.10" | |
}, | |
"toc": { | |
"base_numbering": 1, | |
"nav_menu": {}, | |
"number_sections": true, | |
"sideBar": true, | |
"skip_h1_title": false, | |
"title_cell": "Table of Contents", | |
"title_sidebar": "Contents", | |
"toc_cell": false, | |
"toc_position": {}, | |
"toc_section_display": true, | |
"toc_window_display": true | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment