Skip to content

Instantly share code, notes, and snippets.

@Jackiexiao
Created September 3, 2022 01:41
Show Gist options
  • Save Jackiexiao/2b053cd52d977d86e07d664688c0a7ee to your computer and use it in GitHub Desktop.
Save Jackiexiao/2b053cd52d977d86e07d664688c0a7ee to your computer and use it in GitHub Desktop.
compare_torch_flops_lib
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "46a95ec2-1e54-4e56-992e-3aa6550dd0f9",
"metadata": {},
"source": [
"# torch params flops lib compare\n",
"\n",
"this jupyter compare different torch flops lib, in short, I recommand **fvcore**\n",
"- https://github.com/TylerYep/torchinfo\n",
"- https://github.com/facebookresearch/fvcore/blob/main/docs/flop_count.md\n",
"- https://github.com/Lyken17/pytorch-OpCounter\n",
"- https://github.com/sovrasov/flops-counter.pytorch\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "150b22e9-1ab0-477e-adf1-4cf5d8cc37db",
"metadata": {},
"outputs": [],
"source": [
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "c99bbc67-733a-4518-91f2-8b4a74a0b9ce",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'1.12.1+cu102'"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.__version__"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "b3546007-c1da-4aad-80d3-5d95eea3f251",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Looking in indexes: https://mirrors.cloud.tencent.com/pypi/simple\n",
"Collecting torchinfo\n",
" Downloading https://mirrors.cloud.tencent.com/pypi/packages/3b/b8/f5b9770c34189cb633bf3caed3575356dcc64c315dcd00a8a6c952a70deb/torchinfo-1.7.0-py3-none-any.whl (22 kB)\n",
"Installing collected packages: torchinfo\n",
"Successfully installed torchinfo-1.7.0\n"
]
}
],
"source": [
"!pip install torchinfo"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "8c8689b0-feb1-4bed-8c62-7e70010bdc6f",
"metadata": {},
"outputs": [],
"source": [
"from torchinfo import summary"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "b9ba2687-7334-4852-8bd2-ef9a7b182866",
"metadata": {},
"outputs": [],
"source": [
"odim = 4\n",
"conv = torch.nn.Sequential(\n",
" torch.nn.Conv2d(1, odim, 3, 2),\n",
" torch.nn.ReLU(),\n",
" torch.nn.Conv2d(odim, odim, 3, 2),\n",
" torch.nn.ReLU(),\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "54be29c0-a4e5-424d-af8b-04351ef60c78",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Sequential(\n",
" (0): Conv2d(1, 4, kernel_size=(3, 3), stride=(2, 2))\n",
" (1): ReLU()\n",
" (2): Conv2d(4, 4, kernel_size=(3, 3), stride=(2, 2))\n",
" (3): ReLU()\n",
")"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"conv"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "b091fb39-37f7-4341-9154-6cc036e3cb03",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"==========================================================================================\n",
"Layer (type:depth-idx) Output Shape Param #\n",
"==========================================================================================\n",
"Sequential [4, 4, 4] --\n",
"├─Conv2d: 1-1 [4, 9, 9] 40\n",
"├─ReLU: 1-2 [4, 9, 9] --\n",
"├─Conv2d: 1-3 [4, 4, 4] 148\n",
"├─ReLU: 1-4 [4, 4, 4] --\n",
"==========================================================================================\n",
"Total params: 188\n",
"Trainable params: 188\n",
"Non-trainable params: 0\n",
"Total mult-adds (M): 0.00\n",
"==========================================================================================\n",
"Input size (MB): 0.00\n",
"Forward/backward pass size (MB): 0.00\n",
"Params size (MB): 0.00\n",
"Estimated Total Size (MB): 0.01\n",
"=========================================================================================="
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"summary(conv, (1, 20, 20))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "d45a131e-23de-4fe0-bf49-924e7658cb07",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Looking in indexes: https://mirrors.cloud.tencent.com/pypi/simple\n",
"Collecting ptflops\n",
" Downloading https://mirrors.cloud.tencent.com/pypi/packages/d5/16/b6992b799d14bdc7a78fbfff8825777dd92728ed761090852c43f4792ce1/ptflops-0.6.9.tar.gz (12 kB)\n",
"Requirement already satisfied: torch in /home/jackie/.local/lib/python3.8/site-packages (from ptflops) (1.12.1)\n",
"Requirement already satisfied: typing-extensions in /home/jackie/.local/lib/python3.8/site-packages (from torch->ptflops) (4.2.0)\n",
"Building wheels for collected packages: ptflops\n",
" Building wheel for ptflops (setup.py) ... \u001b[?25ldone\n",
"\u001b[?25h Created wheel for ptflops: filename=ptflops-0.6.9-py3-none-any.whl size=11697 sha256=179baacb18084488097cdee20f3267a1faae76a7673b3938c4323fb91620c045\n",
" Stored in directory: /home/jackie/.cache/pip/wheels/69/32/5a/9813216f8e545082aa95da1ba4377475458b648288fc97fb76\n",
"Successfully built ptflops\n",
"Installing collected packages: ptflops\n",
"Successfully installed ptflops-0.6.9\n"
]
}
],
"source": [
"!pip install ptflops"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "45d09338-0897-403b-83ee-1364e9cabc23",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sequential(\n",
" 188, 100.000% Params, 6.0 KMac, 100.000% MACs, \n",
" (0): Conv2d(40, 21.277% Params, 3.24 KMac, 54.036% MACs, 1, 4, kernel_size=(3, 3), stride=(2, 2))\n",
" (1): ReLU(0, 0.000% Params, 324.0 Mac, 5.404% MACs, )\n",
" (2): Conv2d(148, 78.723% Params, 2.37 KMac, 39.493% MACs, 4, 4, kernel_size=(3, 3), stride=(2, 2))\n",
" (3): ReLU(0, 0.000% Params, 64.0 Mac, 1.067% MACs, )\n",
")\n"
]
}
],
"source": [
"from ptflops import get_model_complexity_info\n",
"macs, params = get_model_complexity_info(conv, (1, 20, 20), as_strings=True,\n",
" print_per_layer_stat=True, verbose=True)\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "64a91d22-ab71-4c2e-a8ab-7fc3aabde49e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Looking in indexes: https://mirrors.cloud.tencent.com/pypi/simple\n",
"Collecting fvcore\n",
" Downloading https://mirrors.cloud.tencent.com/pypi/packages/d6/92/b2ee737429ccc4de3205557fc76f5da40a38708f5c2ec8e0ef2fe849fc60/fvcore-0.1.5.post20220512.tar.gz (50 kB)\n",
"\u001b[K |████████████████████████████████| 50 kB 428 kB/s eta 0:00:011\n",
"\u001b[?25hRequirement already satisfied, skipping upgrade: Pillow in /home/jackie/.local/lib/python3.8/site-packages (from fvcore) (9.0.1)\n",
"Collecting iopath>=0.1.7\n",
" Downloading https://mirrors.cloud.tencent.com/pypi/packages/72/73/b3d451dfc523756cf177d3ebb0af76dc7751b341c60e2a21871be400ae29/iopath-0.1.10.tar.gz (42 kB)\n",
"\u001b[K |████████████████████████████████| 42 kB 301 kB/s eta 0:00:01\n",
"\u001b[?25hRequirement already satisfied, skipping upgrade: numpy in /home/jackie/.local/lib/python3.8/site-packages (from fvcore) (1.22.2)\n",
"Requirement already satisfied, skipping upgrade: pyyaml>=5.1 in /usr/lib/python3/dist-packages (from fvcore) (5.3.1)\n",
"Collecting tabulate\n",
" Downloading https://mirrors.cloud.tencent.com/pypi/packages/92/4e/e5a13fdb3e6f81ce11893523ff289870c87c8f1f289a7369fb0e9840c3bb/tabulate-0.8.10-py3-none-any.whl (29 kB)\n",
"Collecting termcolor>=1.1\n",
" Downloading https://mirrors.cloud.tencent.com/pypi/packages/8a/48/a76be51647d0eb9f10e2a4511bf3ffb8cc1e6b14e9e4fab46173aa79f981/termcolor-1.1.0.tar.gz (3.9 kB)\n",
"Requirement already satisfied, skipping upgrade: tqdm in /home/jackie/.local/lib/python3.8/site-packages (from fvcore) (4.64.0)\n",
"Collecting yacs>=0.1.6\n",
" Downloading https://mirrors.cloud.tencent.com/pypi/packages/38/4f/fe9a4d472aa867878ce3bb7efb16654c5d63672b86dc0e6e953a67018433/yacs-0.1.8-py3-none-any.whl (14 kB)\n",
"Collecting portalocker\n",
" Downloading https://mirrors.cloud.tencent.com/pypi/packages/a9/0a/21422dc681e3e59ce5ec4051015de4c2074bd0e6759099c018471f3dc4e3/portalocker-2.5.1-py2.py3-none-any.whl (15 kB)\n",
"Requirement already satisfied, skipping upgrade: typing_extensions in /home/jackie/.local/lib/python3.8/site-packages (from iopath>=0.1.7->fvcore) (4.2.0)\n",
"Building wheels for collected packages: fvcore, iopath, termcolor\n",
" Building wheel for fvcore (setup.py) ... \u001b[?25ldone\n",
"\u001b[?25h Created wheel for fvcore: filename=fvcore-0.1.5.post20220512-py3-none-any.whl size=61288 sha256=ff880fdf9bea028037cf8ff311499ae012f8fb4624cc2e7c667020b0ed8be502\n",
" Stored in directory: /home/jackie/.cache/pip/wheels/67/a7/3b/402c88c19a2d8f925261d0d4121cf36ba756296c76ed483645\n",
" Building wheel for iopath (setup.py) ... \u001b[?25ldone\n",
"\u001b[?25h Created wheel for iopath: filename=iopath-0.1.10-py3-none-any.whl size=31541 sha256=e094d88531c43c6289565d71721f978fa9e70eaf2e0ef78b77e0a92576a907d9\n",
" Stored in directory: /home/jackie/.cache/pip/wheels/04/96/fd/ec479c7daea456637e67b949520b8873cc4958683d1e987938\n",
" Building wheel for termcolor (setup.py) ... \u001b[?25ldone\n",
"\u001b[?25h Created wheel for termcolor: filename=termcolor-1.1.0-py3-none-any.whl size=4830 sha256=ce8b97661a02f5874c49e4b4de5c8dd801aa6b05a7ab148905c41e66c46a130a\n",
" Stored in directory: /home/jackie/.cache/pip/wheels/cd/35/93/265a4fb7129ee7643c41ba30ff5075691c2ba1ad13829d37a8\n",
"Successfully built fvcore iopath termcolor\n",
"Installing collected packages: portalocker, iopath, tabulate, termcolor, yacs, fvcore\n",
"Successfully installed fvcore-0.1.5.post20220512 iopath-0.1.10 portalocker-2.5.1 tabulate-0.8.10 termcolor-1.1.0 yacs-0.1.8\n"
]
}
],
"source": [
"!pip install -U fvcore"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "7fb048fd-995f-4edc-ae10-13a00c01ea6a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sequential(\n",
" #params: 0.19K, #flops: 5.22K\n",
" (0): Conv2d(\n",
" 1, 4, kernel_size=(3, 3), stride=(2, 2)\n",
" #params: 40, #flops: 2.92K\n",
" )\n",
" (1): ReLU()\n",
" (2): Conv2d(\n",
" 4, 4, kernel_size=(3, 3), stride=(2, 2)\n",
" #params: 0.15K, #flops: 2.3K\n",
" )\n",
" (3): ReLU()\n",
")\n"
]
}
],
"source": [
"from fvcore.nn import FlopCountAnalysis, flop_count_table, flop_count_str\n",
"# input = torch.rand(1, 20, 20)\n",
"input = torch.rand(1, 1, 20, 20)\n",
"flops = FlopCountAnalysis(conv, input)\n",
"print(flop_count_str(flops)) # 推荐"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "8e30d207-2ddb-471c-9a7e-6e6f0b999905",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Looking in indexes: https://mirrors.cloud.tencent.com/pypi/simple\n",
"Collecting thop\n",
" Downloading https://mirrors.cloud.tencent.com/pypi/packages/15/d3/b2c788a51c55a26f3785625128285baf9461078a6a5c03836d9c6c7477c5/thop-0.1.1.post2207130030-py3-none-any.whl (15 kB)\n",
"Requirement already satisfied: torch in /home/jackie/.local/lib/python3.8/site-packages (from thop) (1.12.1)\n",
"Requirement already satisfied: typing-extensions in /home/jackie/.local/lib/python3.8/site-packages (from torch->thop) (4.2.0)\n",
"Installing collected packages: thop\n",
"Successfully installed thop-0.1.1.post2207130030\n"
]
}
],
"source": [
"!pip install thop"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "35863c1c-2528-4e80-83f8-ee57f31f9e2c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.\n",
"[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.\n",
"[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.\n",
"5220.0\n",
"188.0\n",
"5.220K\n",
"188.000B\n"
]
}
],
"source": [
"from thop import profile, clever_format\n",
"\n",
"input = torch.rand(1, 1, 20, 20) # must with batch size, if set torch.rand(1, 20, 20), the result would be wrong\n",
"macs, params = profile(conv, inputs=(input, ))\n",
"print(macs)\n",
"print(params)\n",
"macs, params = clever_format([macs, params], \"%.3f\")\n",
"print(macs)\n",
"print(params)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0c503add-ab85-4240-94d7-8353cdf39576",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment