Created
September 3, 2022 01:41
-
-
Save Jackiexiao/2b053cd52d977d86e07d664688c0a7ee to your computer and use it in GitHub Desktop.
compare_torch_flops_lib
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "46a95ec2-1e54-4e56-992e-3aa6550dd0f9", | |
"metadata": {}, | |
"source": [ | |
"# torch params flops lib compare\n", | |
"\n", | |
"this jupyter compare different torch flops lib, in short, I recommand **fvcore**\n", | |
"- https://github.com/TylerYep/torchinfo\n", | |
"- https://github.com/facebookresearch/fvcore/blob/main/docs/flop_count.md\n", | |
"- https://github.com/Lyken17/pytorch-OpCounter\n", | |
"- https://github.com/sovrasov/flops-counter.pytorch\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "150b22e9-1ab0-477e-adf1-4cf5d8cc37db", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "c99bbc67-733a-4518-91f2-8b4a74a0b9ce", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"'1.12.1+cu102'" | |
] | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"torch.__version__" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "b3546007-c1da-4aad-80d3-5d95eea3f251", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Looking in indexes: https://mirrors.cloud.tencent.com/pypi/simple\n", | |
"Collecting torchinfo\n", | |
" Downloading https://mirrors.cloud.tencent.com/pypi/packages/3b/b8/f5b9770c34189cb633bf3caed3575356dcc64c315dcd00a8a6c952a70deb/torchinfo-1.7.0-py3-none-any.whl (22 kB)\n", | |
"Installing collected packages: torchinfo\n", | |
"Successfully installed torchinfo-1.7.0\n" | |
] | |
} | |
], | |
"source": [ | |
"!pip install torchinfo" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "8c8689b0-feb1-4bed-8c62-7e70010bdc6f", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from torchinfo import summary" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "b9ba2687-7334-4852-8bd2-ef9a7b182866", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"odim = 4\n", | |
"conv = torch.nn.Sequential(\n", | |
" torch.nn.Conv2d(1, odim, 3, 2),\n", | |
" torch.nn.ReLU(),\n", | |
" torch.nn.Conv2d(odim, odim, 3, 2),\n", | |
" torch.nn.ReLU(),\n", | |
" )" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "54be29c0-a4e5-424d-af8b-04351ef60c78", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Sequential(\n", | |
" (0): Conv2d(1, 4, kernel_size=(3, 3), stride=(2, 2))\n", | |
" (1): ReLU()\n", | |
" (2): Conv2d(4, 4, kernel_size=(3, 3), stride=(2, 2))\n", | |
" (3): ReLU()\n", | |
")" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"conv" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "b091fb39-37f7-4341-9154-6cc036e3cb03", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"==========================================================================================\n", | |
"Layer (type:depth-idx) Output Shape Param #\n", | |
"==========================================================================================\n", | |
"Sequential [4, 4, 4] --\n", | |
"├─Conv2d: 1-1 [4, 9, 9] 40\n", | |
"├─ReLU: 1-2 [4, 9, 9] --\n", | |
"├─Conv2d: 1-3 [4, 4, 4] 148\n", | |
"├─ReLU: 1-4 [4, 4, 4] --\n", | |
"==========================================================================================\n", | |
"Total params: 188\n", | |
"Trainable params: 188\n", | |
"Non-trainable params: 0\n", | |
"Total mult-adds (M): 0.00\n", | |
"==========================================================================================\n", | |
"Input size (MB): 0.00\n", | |
"Forward/backward pass size (MB): 0.00\n", | |
"Params size (MB): 0.00\n", | |
"Estimated Total Size (MB): 0.01\n", | |
"==========================================================================================" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"summary(conv, (1, 20, 20))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "d45a131e-23de-4fe0-bf49-924e7658cb07", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Looking in indexes: https://mirrors.cloud.tencent.com/pypi/simple\n", | |
"Collecting ptflops\n", | |
" Downloading https://mirrors.cloud.tencent.com/pypi/packages/d5/16/b6992b799d14bdc7a78fbfff8825777dd92728ed761090852c43f4792ce1/ptflops-0.6.9.tar.gz (12 kB)\n", | |
"Requirement already satisfied: torch in /home/jackie/.local/lib/python3.8/site-packages (from ptflops) (1.12.1)\n", | |
"Requirement already satisfied: typing-extensions in /home/jackie/.local/lib/python3.8/site-packages (from torch->ptflops) (4.2.0)\n", | |
"Building wheels for collected packages: ptflops\n", | |
" Building wheel for ptflops (setup.py) ... \u001b[?25ldone\n", | |
"\u001b[?25h Created wheel for ptflops: filename=ptflops-0.6.9-py3-none-any.whl size=11697 sha256=179baacb18084488097cdee20f3267a1faae76a7673b3938c4323fb91620c045\n", | |
" Stored in directory: /home/jackie/.cache/pip/wheels/69/32/5a/9813216f8e545082aa95da1ba4377475458b648288fc97fb76\n", | |
"Successfully built ptflops\n", | |
"Installing collected packages: ptflops\n", | |
"Successfully installed ptflops-0.6.9\n" | |
] | |
} | |
], | |
"source": [ | |
"!pip install ptflops" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "45d09338-0897-403b-83ee-1364e9cabc23", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Sequential(\n", | |
" 188, 100.000% Params, 6.0 KMac, 100.000% MACs, \n", | |
" (0): Conv2d(40, 21.277% Params, 3.24 KMac, 54.036% MACs, 1, 4, kernel_size=(3, 3), stride=(2, 2))\n", | |
" (1): ReLU(0, 0.000% Params, 324.0 Mac, 5.404% MACs, )\n", | |
" (2): Conv2d(148, 78.723% Params, 2.37 KMac, 39.493% MACs, 4, 4, kernel_size=(3, 3), stride=(2, 2))\n", | |
" (3): ReLU(0, 0.000% Params, 64.0 Mac, 1.067% MACs, )\n", | |
")\n" | |
] | |
} | |
], | |
"source": [ | |
"from ptflops import get_model_complexity_info\n", | |
"macs, params = get_model_complexity_info(conv, (1, 20, 20), as_strings=True,\n", | |
" print_per_layer_stat=True, verbose=True)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "64a91d22-ab71-4c2e-a8ab-7fc3aabde49e", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Looking in indexes: https://mirrors.cloud.tencent.com/pypi/simple\n", | |
"Collecting fvcore\n", | |
" Downloading https://mirrors.cloud.tencent.com/pypi/packages/d6/92/b2ee737429ccc4de3205557fc76f5da40a38708f5c2ec8e0ef2fe849fc60/fvcore-0.1.5.post20220512.tar.gz (50 kB)\n", | |
"\u001b[K |████████████████████████████████| 50 kB 428 kB/s eta 0:00:011\n", | |
"\u001b[?25hRequirement already satisfied, skipping upgrade: Pillow in /home/jackie/.local/lib/python3.8/site-packages (from fvcore) (9.0.1)\n", | |
"Collecting iopath>=0.1.7\n", | |
" Downloading https://mirrors.cloud.tencent.com/pypi/packages/72/73/b3d451dfc523756cf177d3ebb0af76dc7751b341c60e2a21871be400ae29/iopath-0.1.10.tar.gz (42 kB)\n", | |
"\u001b[K |████████████████████████████████| 42 kB 301 kB/s eta 0:00:01\n", | |
"\u001b[?25hRequirement already satisfied, skipping upgrade: numpy in /home/jackie/.local/lib/python3.8/site-packages (from fvcore) (1.22.2)\n", | |
"Requirement already satisfied, skipping upgrade: pyyaml>=5.1 in /usr/lib/python3/dist-packages (from fvcore) (5.3.1)\n", | |
"Collecting tabulate\n", | |
" Downloading https://mirrors.cloud.tencent.com/pypi/packages/92/4e/e5a13fdb3e6f81ce11893523ff289870c87c8f1f289a7369fb0e9840c3bb/tabulate-0.8.10-py3-none-any.whl (29 kB)\n", | |
"Collecting termcolor>=1.1\n", | |
" Downloading https://mirrors.cloud.tencent.com/pypi/packages/8a/48/a76be51647d0eb9f10e2a4511bf3ffb8cc1e6b14e9e4fab46173aa79f981/termcolor-1.1.0.tar.gz (3.9 kB)\n", | |
"Requirement already satisfied, skipping upgrade: tqdm in /home/jackie/.local/lib/python3.8/site-packages (from fvcore) (4.64.0)\n", | |
"Collecting yacs>=0.1.6\n", | |
" Downloading https://mirrors.cloud.tencent.com/pypi/packages/38/4f/fe9a4d472aa867878ce3bb7efb16654c5d63672b86dc0e6e953a67018433/yacs-0.1.8-py3-none-any.whl (14 kB)\n", | |
"Collecting portalocker\n", | |
" Downloading https://mirrors.cloud.tencent.com/pypi/packages/a9/0a/21422dc681e3e59ce5ec4051015de4c2074bd0e6759099c018471f3dc4e3/portalocker-2.5.1-py2.py3-none-any.whl (15 kB)\n", | |
"Requirement already satisfied, skipping upgrade: typing_extensions in /home/jackie/.local/lib/python3.8/site-packages (from iopath>=0.1.7->fvcore) (4.2.0)\n", | |
"Building wheels for collected packages: fvcore, iopath, termcolor\n", | |
" Building wheel for fvcore (setup.py) ... \u001b[?25ldone\n", | |
"\u001b[?25h Created wheel for fvcore: filename=fvcore-0.1.5.post20220512-py3-none-any.whl size=61288 sha256=ff880fdf9bea028037cf8ff311499ae012f8fb4624cc2e7c667020b0ed8be502\n", | |
" Stored in directory: /home/jackie/.cache/pip/wheels/67/a7/3b/402c88c19a2d8f925261d0d4121cf36ba756296c76ed483645\n", | |
" Building wheel for iopath (setup.py) ... \u001b[?25ldone\n", | |
"\u001b[?25h Created wheel for iopath: filename=iopath-0.1.10-py3-none-any.whl size=31541 sha256=e094d88531c43c6289565d71721f978fa9e70eaf2e0ef78b77e0a92576a907d9\n", | |
" Stored in directory: /home/jackie/.cache/pip/wheels/04/96/fd/ec479c7daea456637e67b949520b8873cc4958683d1e987938\n", | |
" Building wheel for termcolor (setup.py) ... \u001b[?25ldone\n", | |
"\u001b[?25h Created wheel for termcolor: filename=termcolor-1.1.0-py3-none-any.whl size=4830 sha256=ce8b97661a02f5874c49e4b4de5c8dd801aa6b05a7ab148905c41e66c46a130a\n", | |
" Stored in directory: /home/jackie/.cache/pip/wheels/cd/35/93/265a4fb7129ee7643c41ba30ff5075691c2ba1ad13829d37a8\n", | |
"Successfully built fvcore iopath termcolor\n", | |
"Installing collected packages: portalocker, iopath, tabulate, termcolor, yacs, fvcore\n", | |
"Successfully installed fvcore-0.1.5.post20220512 iopath-0.1.10 portalocker-2.5.1 tabulate-0.8.10 termcolor-1.1.0 yacs-0.1.8\n" | |
] | |
} | |
], | |
"source": [ | |
"!pip install -U fvcore" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"id": "7fb048fd-995f-4edc-ae10-13a00c01ea6a", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Sequential(\n", | |
" #params: 0.19K, #flops: 5.22K\n", | |
" (0): Conv2d(\n", | |
" 1, 4, kernel_size=(3, 3), stride=(2, 2)\n", | |
" #params: 40, #flops: 2.92K\n", | |
" )\n", | |
" (1): ReLU()\n", | |
" (2): Conv2d(\n", | |
" 4, 4, kernel_size=(3, 3), stride=(2, 2)\n", | |
" #params: 0.15K, #flops: 2.3K\n", | |
" )\n", | |
" (3): ReLU()\n", | |
")\n" | |
] | |
} | |
], | |
"source": [ | |
"from fvcore.nn import FlopCountAnalysis, flop_count_table, flop_count_str\n", | |
"# input = torch.rand(1, 20, 20)\n", | |
"input = torch.rand(1, 1, 20, 20)\n", | |
"flops = FlopCountAnalysis(conv, input)\n", | |
"print(flop_count_str(flops)) # 推荐" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "8e30d207-2ddb-471c-9a7e-6e6f0b999905", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Looking in indexes: https://mirrors.cloud.tencent.com/pypi/simple\n", | |
"Collecting thop\n", | |
" Downloading https://mirrors.cloud.tencent.com/pypi/packages/15/d3/b2c788a51c55a26f3785625128285baf9461078a6a5c03836d9c6c7477c5/thop-0.1.1.post2207130030-py3-none-any.whl (15 kB)\n", | |
"Requirement already satisfied: torch in /home/jackie/.local/lib/python3.8/site-packages (from thop) (1.12.1)\n", | |
"Requirement already satisfied: typing-extensions in /home/jackie/.local/lib/python3.8/site-packages (from torch->thop) (4.2.0)\n", | |
"Installing collected packages: thop\n", | |
"Successfully installed thop-0.1.1.post2207130030\n" | |
] | |
} | |
], | |
"source": [ | |
"!pip install thop" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"id": "35863c1c-2528-4e80-83f8-ee57f31f9e2c", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.\n", | |
"[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.\n", | |
"[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.\n", | |
"5220.0\n", | |
"188.0\n", | |
"5.220K\n", | |
"188.000B\n" | |
] | |
} | |
], | |
"source": [ | |
"from thop import profile, clever_format\n", | |
"\n", | |
"input = torch.rand(1, 1, 20, 20) # must with batch size, if set torch.rand(1, 20, 20), the result would be wrong\n", | |
"macs, params = profile(conv, inputs=(input, ))\n", | |
"print(macs)\n", | |
"print(params)\n", | |
"macs, params = clever_format([macs, params], \"%.3f\")\n", | |
"print(macs)\n", | |
"print(params)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "0c503add-ab85-4240-94d7-8353cdf39576", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.10" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment