Created
May 6, 2022 16:38
-
-
Save tcapelle/3390c9d7a6bd29f02440675bd2b2e228 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 269, | |
"id": "87dcd94a-f70b-4a09-abad-1261be2cdc30", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import timm\n", | |
"from timm.models.helpers import group_modules\n", | |
"from fastai.vision.all import *" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 270, | |
"id": "92571bd0-56be-429f-81fd-08a054f3ea5b", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"arch = \"resnet50\"\n", | |
"m = timm.create_model(arch)\n", | |
"modules_names = group_modules(m, m.group_matcher(coarse=True))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 271, | |
"id": "22f9ba08-e17d-45ca-b90f-df16ed60664e", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def get_module_names(m):\n", | |
" modules_names = group_modules(m, m.group_matcher(coarse=True))\n", | |
" return list(modules_names.values())[:-1] #cut head" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 272, | |
"id": "cee8e140-da83-4084-a727-de6a14b69862", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[['conv1', 'bn1'],\n", | |
" ['layer1.0.conv1',\n", | |
" 'layer1.0.bn1',\n", | |
" 'layer1.0.conv2',\n", | |
" 'layer1.0.bn2',\n", | |
" 'layer1.0.conv3',\n", | |
" 'layer1.0.bn3',\n", | |
" 'layer1.0.downsample.0',\n", | |
" 'layer1.0.downsample.1',\n", | |
" 'layer1.1.conv1',\n", | |
" 'layer1.1.bn1',\n", | |
" 'layer1.1.conv2',\n", | |
" 'layer1.1.bn2',\n", | |
" 'layer1.1.conv3',\n", | |
" 'layer1.1.bn3',\n", | |
" 'layer1.2.conv1',\n", | |
" 'layer1.2.bn1',\n", | |
" 'layer1.2.conv2',\n", | |
" 'layer1.2.bn2',\n", | |
" 'layer1.2.conv3',\n", | |
" 'layer1.2.bn3'],\n", | |
" ['layer2.0.conv1',\n", | |
" 'layer2.0.bn1',\n", | |
" 'layer2.0.conv2',\n", | |
" 'layer2.0.bn2',\n", | |
" 'layer2.0.conv3',\n", | |
" 'layer2.0.bn3',\n", | |
" 'layer2.0.downsample.0',\n", | |
" 'layer2.0.downsample.1',\n", | |
" 'layer2.1.conv1',\n", | |
" 'layer2.1.bn1',\n", | |
" 'layer2.1.conv2',\n", | |
" 'layer2.1.bn2',\n", | |
" 'layer2.1.conv3',\n", | |
" 'layer2.1.bn3',\n", | |
" 'layer2.2.conv1',\n", | |
" 'layer2.2.bn1',\n", | |
" 'layer2.2.conv2',\n", | |
" 'layer2.2.bn2',\n", | |
" 'layer2.2.conv3',\n", | |
" 'layer2.2.bn3',\n", | |
" 'layer2.3.conv1',\n", | |
" 'layer2.3.bn1',\n", | |
" 'layer2.3.conv2',\n", | |
" 'layer2.3.bn2',\n", | |
" 'layer2.3.conv3',\n", | |
" 'layer2.3.bn3'],\n", | |
" ['layer3.0.conv1',\n", | |
" 'layer3.0.bn1',\n", | |
" 'layer3.0.conv2',\n", | |
" 'layer3.0.bn2',\n", | |
" 'layer3.0.conv3',\n", | |
" 'layer3.0.bn3',\n", | |
" 'layer3.0.downsample.0',\n", | |
" 'layer3.0.downsample.1',\n", | |
" 'layer3.1.conv1',\n", | |
" 'layer3.1.bn1',\n", | |
" 'layer3.1.conv2',\n", | |
" 'layer3.1.bn2',\n", | |
" 'layer3.1.conv3',\n", | |
" 'layer3.1.bn3',\n", | |
" 'layer3.2.conv1',\n", | |
" 'layer3.2.bn1',\n", | |
" 'layer3.2.conv2',\n", | |
" 'layer3.2.bn2',\n", | |
" 'layer3.2.conv3',\n", | |
" 'layer3.2.bn3',\n", | |
" 'layer3.3.conv1',\n", | |
" 'layer3.3.bn1',\n", | |
" 'layer3.3.conv2',\n", | |
" 'layer3.3.bn2',\n", | |
" 'layer3.3.conv3',\n", | |
" 'layer3.3.bn3',\n", | |
" 'layer3.4.conv1',\n", | |
" 'layer3.4.bn1',\n", | |
" 'layer3.4.conv2',\n", | |
" 'layer3.4.bn2',\n", | |
" 'layer3.4.conv3',\n", | |
" 'layer3.4.bn3',\n", | |
" 'layer3.5.conv1',\n", | |
" 'layer3.5.bn1',\n", | |
" 'layer3.5.conv2',\n", | |
" 'layer3.5.bn2',\n", | |
" 'layer3.5.conv3',\n", | |
" 'layer3.5.bn3'],\n", | |
" ['layer4.0.conv1',\n", | |
" 'layer4.0.bn1',\n", | |
" 'layer4.0.conv2',\n", | |
" 'layer4.0.bn2',\n", | |
" 'layer4.0.conv3',\n", | |
" 'layer4.0.bn3',\n", | |
" 'layer4.0.downsample.0',\n", | |
" 'layer4.0.downsample.1',\n", | |
" 'layer4.1.conv1',\n", | |
" 'layer4.1.bn1',\n", | |
" 'layer4.1.conv2',\n", | |
" 'layer4.1.bn2',\n", | |
" 'layer4.1.conv3',\n", | |
" 'layer4.1.bn3',\n", | |
" 'layer4.2.conv1',\n", | |
" 'layer4.2.bn1',\n", | |
" 'layer4.2.conv2',\n", | |
" 'layer4.2.bn2',\n", | |
" 'layer4.2.conv3',\n", | |
" 'layer4.2.bn3']]" | |
] | |
}, | |
"execution_count": 272, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"module_names = get_module_names(m)\n", | |
"module_names" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 273, | |
"id": "4310caec-804e-41f3-a278-fec0533c4cac", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def get_layers_from_names(module_names, m):\n", | |
" layers = set()\n", | |
" for l_name in L(module_names).concat():\n", | |
" if \".\" not in l_name:\n", | |
" layers.add(getattr(m, l_name))\n", | |
" else:\n", | |
" first_level_name = l_name.split(\".\")[0]\n", | |
" layers.add(getattr(m, first_level_name))\n", | |
" return L(layers)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 274, | |
"id": "1278df64-d19d-4500-93bc-f1448ddf48b9", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(#6) [Sequential(\n", | |
" (0): Bottleneck(\n", | |
" (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (act1): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (drop_block): Identity()\n", | |
" (act2): ReLU(inplace=True)\n", | |
" (aa): Identity()\n", | |
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", | |
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (act3): ReLU(inplace=True)\n", | |
" (downsample): Sequential(\n", | |
" (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", | |
" (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (1): Bottleneck(\n", | |
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (act1): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (drop_block): Identity()\n", | |
" (act2): ReLU(inplace=True)\n", | |
" (aa): Identity()\n", | |
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", | |
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (act3): ReLU(inplace=True)\n", | |
" )\n", | |
" (2): Bottleneck(\n", | |
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (act1): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (drop_block): Identity()\n", | |
" (act2): ReLU(inplace=True)\n", | |
" (aa): Identity()\n", | |
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", | |
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (act3): ReLU(inplace=True)\n", | |
" )\n", | |
" (3): Bottleneck(\n", | |
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (act1): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (drop_block): Identity()\n", | |
" (act2): ReLU(inplace=True)\n", | |
" (aa): Identity()\n", | |
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", | |
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (act3): ReLU(inplace=True)\n", | |
" )\n", | |
" (4): Bottleneck(\n", | |
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (act1): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (drop_block): Identity()\n", | |
" (act2): ReLU(inplace=True)\n", | |
" (aa): Identity()\n", | |
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", | |
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (act3): ReLU(inplace=True)\n", | |
" )\n", | |
" (5): Bottleneck(\n", | |
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (act1): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (drop_block): Identity()\n", | |
" (act2): ReLU(inplace=True)\n", | |
" (aa): Identity()\n", | |
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", | |
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (act3): ReLU(inplace=True)\n", | |
" )\n", | |
"),Sequential(\n", | |
" (0): Bottleneck(\n", | |
" (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (act1): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (drop_block): Identity()\n", | |
" (act2): ReLU(inplace=True)\n", | |
" (aa): Identity()\n", | |
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", | |
" (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (act3): ReLU(inplace=True)\n", | |
" (downsample): Sequential(\n", | |
" (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", | |
" (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (1): Bottleneck(\n", | |
" (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (act1): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (drop_block): Identity()\n", | |
" (act2): ReLU(inplace=True)\n", | |
" (aa): Identity()\n", | |
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", | |
" (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (act3): ReLU(inplace=True)\n", | |
" )\n", | |
" (2): Bottleneck(\n", | |
" (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (act1): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (drop_block): Identity()\n", | |
" (act2): ReLU(inplace=True)\n", | |
" (aa): Identity()\n", | |
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", | |
" (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (act3): ReLU(inplace=True)\n", | |
" )\n", | |
"),Sequential(\n", | |
" (0): Bottleneck(\n", | |
" (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (act1): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (drop_block): Identity()\n", | |
" (act2): ReLU(inplace=True)\n", | |
" (aa): Identity()\n", | |
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", | |
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (act3): ReLU(inplace=True)\n", | |
" (downsample): Sequential(\n", | |
" (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n", | |
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (1): Bottleneck(\n", | |
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (act1): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (drop_block): Identity()\n", | |
" (act2): ReLU(inplace=True)\n", | |
" (aa): Identity()\n", | |
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", | |
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (act3): ReLU(inplace=True)\n", | |
" )\n", | |
" (2): Bottleneck(\n", | |
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (act1): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (drop_block): Identity()\n", | |
" (act2): ReLU(inplace=True)\n", | |
" (aa): Identity()\n", | |
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", | |
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (act3): ReLU(inplace=True)\n", | |
" )\n", | |
" (3): Bottleneck(\n", | |
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (act1): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (drop_block): Identity()\n", | |
" (act2): ReLU(inplace=True)\n", | |
" (aa): Identity()\n", | |
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", | |
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (act3): ReLU(inplace=True)\n", | |
" )\n", | |
"),Sequential(\n", | |
" (0): Bottleneck(\n", | |
" (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (act1): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (drop_block): Identity()\n", | |
" (act2): ReLU(inplace=True)\n", | |
" (aa): Identity()\n", | |
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", | |
" (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (act3): ReLU(inplace=True)\n", | |
" (downsample): Sequential(\n", | |
" (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", | |
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" )\n", | |
" )\n", | |
" (1): Bottleneck(\n", | |
" (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (act1): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (drop_block): Identity()\n", | |
" (act2): ReLU(inplace=True)\n", | |
" (aa): Identity()\n", | |
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", | |
" (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (act3): ReLU(inplace=True)\n", | |
" )\n", | |
" (2): Bottleneck(\n", | |
" (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", | |
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (act1): ReLU(inplace=True)\n", | |
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", | |
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (drop_block): Identity()\n", | |
" (act2): ReLU(inplace=True)\n", | |
" (aa): Identity()\n", | |
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", | |
" (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", | |
" (act3): ReLU(inplace=True)\n", | |
" )\n", | |
"),BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)]" | |
] | |
}, | |
"execution_count": 274, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"get_layers_from_names(module_names, m)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "65d332fa-afd5-4a87-a82b-82d39faed827", | |
"metadata": {}, | |
"source": [ | |
"## Now integrating with fastai:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 306, | |
"id": "b2a89eb2-75b2-4a73-8091-bbbab2f8da99", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"r50 = create_timm_model(\"resnet50\", 20, pretrained=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 307, | |
"id": "653f77f6-2e4c-4237-a65f-476a274581c2", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"vit = create_timm_model(\"vit_base_patch16_224\", 20, pretrained=False)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 317, | |
"id": "3ec4277f-5323-4001-8df6-95adb92204b7", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def timm_split(m, cut=-1):\n", | |
" body, head = m[0].model, m[1]\n", | |
" module_names = get_module_names(body)\n", | |
" groups = L(module_names[0:cut], module_names[cut:]).map(partial(get_layers_from_names, m=body))\n", | |
" return [g.map(params).concat() for g in groups] + [params(head)]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 318, | |
"id": "da4a6602-6c1c-49a4-99dd-eb51f63c2df9", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def timm_resnet_split(m): return timm_split(m, cut=-1)\n", | |
"def timm_vit_split(m): return timm_split(m, cut=-3) # maybe more, don't know...." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 323, | |
"id": "788cd742-4ef5-4d1e-aa8d-f606204fe571", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[['conv1', 'bn1'],\n", | |
" ['layer1.0.conv1',\n", | |
" 'layer1.0.bn1',\n", | |
" 'layer1.0.conv2',\n", | |
" 'layer1.0.bn2',\n", | |
" 'layer1.0.conv3',\n", | |
" 'layer1.0.bn3',\n", | |
" 'layer1.0.downsample.0',\n", | |
" 'layer1.0.downsample.1',\n", | |
" 'layer1.1.conv1',\n", | |
" 'layer1.1.bn1',\n", | |
" 'layer1.1.conv2',\n", | |
" 'layer1.1.bn2',\n", | |
" 'layer1.1.conv3',\n", | |
" 'layer1.1.bn3',\n", | |
" 'layer1.2.conv1',\n", | |
" 'layer1.2.bn1',\n", | |
" 'layer1.2.conv2',\n", | |
" 'layer1.2.bn2',\n", | |
" 'layer1.2.conv3',\n", | |
" 'layer1.2.bn3'],\n", | |
" ['layer2.0.conv1',\n", | |
" 'layer2.0.bn1',\n", | |
" 'layer2.0.conv2',\n", | |
" 'layer2.0.bn2',\n", | |
" 'layer2.0.conv3',\n", | |
" 'layer2.0.bn3',\n", | |
" 'layer2.0.downsample.0',\n", | |
" 'layer2.0.downsample.1',\n", | |
" 'layer2.1.conv1',\n", | |
" 'layer2.1.bn1',\n", | |
" 'layer2.1.conv2',\n", | |
" 'layer2.1.bn2',\n", | |
" 'layer2.1.conv3',\n", | |
" 'layer2.1.bn3',\n", | |
" 'layer2.2.conv1',\n", | |
" 'layer2.2.bn1',\n", | |
" 'layer2.2.conv2',\n", | |
" 'layer2.2.bn2',\n", | |
" 'layer2.2.conv3',\n", | |
" 'layer2.2.bn3',\n", | |
" 'layer2.3.conv1',\n", | |
" 'layer2.3.bn1',\n", | |
" 'layer2.3.conv2',\n", | |
" 'layer2.3.bn2',\n", | |
" 'layer2.3.conv3',\n", | |
" 'layer2.3.bn3'],\n", | |
" ['layer3.0.conv1',\n", | |
" 'layer3.0.bn1',\n", | |
" 'layer3.0.conv2',\n", | |
" 'layer3.0.bn2',\n", | |
" 'layer3.0.conv3',\n", | |
" 'layer3.0.bn3',\n", | |
" 'layer3.0.downsample.0',\n", | |
" 'layer3.0.downsample.1',\n", | |
" 'layer3.1.conv1',\n", | |
" 'layer3.1.bn1',\n", | |
" 'layer3.1.conv2',\n", | |
" 'layer3.1.bn2',\n", | |
" 'layer3.1.conv3',\n", | |
" 'layer3.1.bn3',\n", | |
" 'layer3.2.conv1',\n", | |
" 'layer3.2.bn1',\n", | |
" 'layer3.2.conv2',\n", | |
" 'layer3.2.bn2',\n", | |
" 'layer3.2.conv3',\n", | |
" 'layer3.2.bn3',\n", | |
" 'layer3.3.conv1',\n", | |
" 'layer3.3.bn1',\n", | |
" 'layer3.3.conv2',\n", | |
" 'layer3.3.bn2',\n", | |
" 'layer3.3.conv3',\n", | |
" 'layer3.3.bn3',\n", | |
" 'layer3.4.conv1',\n", | |
" 'layer3.4.bn1',\n", | |
" 'layer3.4.conv2',\n", | |
" 'layer3.4.bn2',\n", | |
" 'layer3.4.conv3',\n", | |
" 'layer3.4.bn3',\n", | |
" 'layer3.5.conv1',\n", | |
" 'layer3.5.bn1',\n", | |
" 'layer3.5.conv2',\n", | |
" 'layer3.5.bn2',\n", | |
" 'layer3.5.conv3',\n", | |
" 'layer3.5.bn3']]" | |
] | |
}, | |
"execution_count": 323, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"get_module_names(r50[0].model)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 319, | |
"id": "ab00f109-9ba2-44cb-b3dc-723cdb253970", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"3" | |
] | |
}, | |
"execution_count": 319, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"groups = timm_resnet_split(r50)\n", | |
"len(groups)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 320, | |
"id": "f863fcb4-1044-4951-9f75-e73d30a38d75", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[['patch_embed.proj'],\n", | |
" ['blocks.0.norm1',\n", | |
" 'blocks.0.attn.qkv',\n", | |
" 'blocks.0.attn.proj',\n", | |
" 'blocks.0.norm2',\n", | |
" 'blocks.0.mlp.fc1',\n", | |
" 'blocks.0.mlp.fc2'],\n", | |
" ['blocks.1.norm1',\n", | |
" 'blocks.1.attn.qkv',\n", | |
" 'blocks.1.attn.proj',\n", | |
" 'blocks.1.norm2',\n", | |
" 'blocks.1.mlp.fc1',\n", | |
" 'blocks.1.mlp.fc2'],\n", | |
" ['blocks.2.norm1',\n", | |
" 'blocks.2.attn.qkv',\n", | |
" 'blocks.2.attn.proj',\n", | |
" 'blocks.2.norm2',\n", | |
" 'blocks.2.mlp.fc1',\n", | |
" 'blocks.2.mlp.fc2'],\n", | |
" ['blocks.3.norm1',\n", | |
" 'blocks.3.attn.qkv',\n", | |
" 'blocks.3.attn.proj',\n", | |
" 'blocks.3.norm2',\n", | |
" 'blocks.3.mlp.fc1',\n", | |
" 'blocks.3.mlp.fc2'],\n", | |
" ['blocks.4.norm1',\n", | |
" 'blocks.4.attn.qkv',\n", | |
" 'blocks.4.attn.proj',\n", | |
" 'blocks.4.norm2',\n", | |
" 'blocks.4.mlp.fc1',\n", | |
" 'blocks.4.mlp.fc2'],\n", | |
" ['blocks.5.norm1',\n", | |
" 'blocks.5.attn.qkv',\n", | |
" 'blocks.5.attn.proj',\n", | |
" 'blocks.5.norm2',\n", | |
" 'blocks.5.mlp.fc1',\n", | |
" 'blocks.5.mlp.fc2'],\n", | |
" ['blocks.6.norm1',\n", | |
" 'blocks.6.attn.qkv',\n", | |
" 'blocks.6.attn.proj',\n", | |
" 'blocks.6.norm2',\n", | |
" 'blocks.6.mlp.fc1',\n", | |
" 'blocks.6.mlp.fc2'],\n", | |
" ['blocks.7.norm1',\n", | |
" 'blocks.7.attn.qkv',\n", | |
" 'blocks.7.attn.proj',\n", | |
" 'blocks.7.norm2',\n", | |
" 'blocks.7.mlp.fc1',\n", | |
" 'blocks.7.mlp.fc2'],\n", | |
" ['blocks.8.norm1',\n", | |
" 'blocks.8.attn.qkv',\n", | |
" 'blocks.8.attn.proj',\n", | |
" 'blocks.8.norm2',\n", | |
" 'blocks.8.mlp.fc1',\n", | |
" 'blocks.8.mlp.fc2'],\n", | |
" ['blocks.9.norm1',\n", | |
" 'blocks.9.attn.qkv',\n", | |
" 'blocks.9.attn.proj',\n", | |
" 'blocks.9.norm2',\n", | |
" 'blocks.9.mlp.fc1',\n", | |
" 'blocks.9.mlp.fc2'],\n", | |
" ['blocks.10.norm1',\n", | |
" 'blocks.10.attn.qkv',\n", | |
" 'blocks.10.attn.proj',\n", | |
" 'blocks.10.norm2',\n", | |
" 'blocks.10.mlp.fc1',\n", | |
" 'blocks.10.mlp.fc2']]" | |
] | |
}, | |
"execution_count": 320, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"get_module_names(vit[0].model)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 321, | |
"id": "22ed3135-660d-4729-a58f-47657ebe8ddd", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"3" | |
] | |
}, | |
"execution_count": 321, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"groups = timm_vit_split(vit)\n", | |
"len(groups)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 322, | |
"id": "dce44857-098e-42c0-8bf5-036ccc237059", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"['convit_base',\n", | |
" 'convit_small',\n", | |
" 'convit_tiny',\n", | |
" 'crossvit_9_240',\n", | |
" 'crossvit_9_dagger_240',\n", | |
" 'crossvit_15_240',\n", | |
" 'crossvit_15_dagger_240',\n", | |
" 'crossvit_15_dagger_408',\n", | |
" 'crossvit_18_240',\n", | |
" 'crossvit_18_dagger_240',\n", | |
" 'crossvit_18_dagger_408',\n", | |
" 'crossvit_base_240',\n", | |
" 'crossvit_small_240',\n", | |
" 'crossvit_tiny_240',\n", | |
" 'levit_128',\n", | |
" 'levit_128s',\n", | |
" 'levit_192',\n", | |
" 'levit_256',\n", | |
" 'levit_256d',\n", | |
" 'levit_384',\n", | |
" 'mobilevit_s',\n", | |
" 'mobilevit_xs',\n", | |
" 'mobilevit_xxs',\n", | |
" 'semobilevit_s',\n", | |
" 'vit_base_patch8_224',\n", | |
" 'vit_base_patch8_224_dino',\n", | |
" 'vit_base_patch8_224_in21k',\n", | |
" 'vit_base_patch16_18x2_224',\n", | |
" 'vit_base_patch16_224',\n", | |
" 'vit_base_patch16_224_dino',\n", | |
" 'vit_base_patch16_224_in21k',\n", | |
" 'vit_base_patch16_224_miil',\n", | |
" 'vit_base_patch16_224_miil_in21k',\n", | |
" 'vit_base_patch16_224_sam',\n", | |
" 'vit_base_patch16_384',\n", | |
" 'vit_base_patch16_plus_240',\n", | |
" 'vit_base_patch16_rpn_224',\n", | |
" 'vit_base_patch32_224',\n", | |
" 'vit_base_patch32_224_in21k',\n", | |
" 'vit_base_patch32_224_sam',\n", | |
" 'vit_base_patch32_384',\n", | |
" 'vit_base_patch32_plus_256',\n", | |
" 'vit_base_r26_s32_224',\n", | |
" 'vit_base_r50_s16_224',\n", | |
" 'vit_base_r50_s16_224_in21k',\n", | |
" 'vit_base_r50_s16_384',\n", | |
" 'vit_base_resnet26d_224',\n", | |
" 'vit_base_resnet50_224_in21k',\n", | |
" 'vit_base_resnet50_384',\n", | |
" 'vit_base_resnet50d_224',\n", | |
" 'vit_giant_patch14_224',\n", | |
" 'vit_gigantic_patch14_224',\n", | |
" 'vit_huge_patch14_224',\n", | |
" 'vit_huge_patch14_224_in21k',\n", | |
" 'vit_large_patch14_224',\n", | |
" 'vit_large_patch16_224',\n", | |
" 'vit_large_patch16_224_in21k',\n", | |
" 'vit_large_patch16_384',\n", | |
" 'vit_large_patch32_224',\n", | |
" 'vit_large_patch32_224_in21k',\n", | |
" 'vit_large_patch32_384',\n", | |
" 'vit_large_r50_s32_224',\n", | |
" 'vit_large_r50_s32_224_in21k',\n", | |
" 'vit_large_r50_s32_384',\n", | |
" 'vit_relpos_base_patch16_224',\n", | |
" 'vit_relpos_base_patch16_plus_240',\n", | |
" 'vit_relpos_base_patch16_rpn_224',\n", | |
" 'vit_relpos_base_patch32_plus_rpn_256',\n", | |
" 'vit_small_patch8_224_dino',\n", | |
" 'vit_small_patch16_18x2_224',\n", | |
" 'vit_small_patch16_36x1_224',\n", | |
" 'vit_small_patch16_224',\n", | |
" 'vit_small_patch16_224_dino',\n", | |
" 'vit_small_patch16_224_in21k',\n", | |
" 'vit_small_patch16_384',\n", | |
" 'vit_small_patch32_224',\n", | |
" 'vit_small_patch32_224_in21k',\n", | |
" 'vit_small_patch32_384',\n", | |
" 'vit_small_r26_s32_224',\n", | |
" 'vit_small_r26_s32_224_in21k',\n", | |
" 'vit_small_r26_s32_384',\n", | |
" 'vit_small_resnet26d_224',\n", | |
" 'vit_small_resnet50d_s16_224',\n", | |
" 'vit_tiny_patch16_224',\n", | |
" 'vit_tiny_patch16_224_in21k',\n", | |
" 'vit_tiny_patch16_384',\n", | |
" 'vit_tiny_r_s16_p8_224',\n", | |
" 'vit_tiny_r_s16_p8_224_in21k',\n", | |
" 'vit_tiny_r_s16_p8_384']" | |
] | |
}, | |
"execution_count": 322, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"timm.list_models(\"*vit*\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "d44eaadb-6016-482e-80dc-b13497069ea9", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Testing bat_resnext26ts:\n", | |
"Testing cspresnet50:\n", | |
"Testing cspresnet50d:\n", | |
"Testing cspresnet50w:\n", | |
"Testing cspresnext50:\n", | |
"Testing cspresnext50_iabn:\n", | |
"Testing dla60_res2net:\n", | |
"Testing dla60_res2next:\n", | |
"Testing eca_resnet33ts:\n", | |
"Testing eca_resnext26ts:\n", | |
"Testing ecaresnet26t:\n", | |
"Testing ecaresnet50d:\n", | |
"Testing ecaresnet50d_pruned:\n", | |
"Testing ecaresnet50t:\n", | |
"Testing ecaresnet101d:\n", | |
"Testing ecaresnet101d_pruned:\n", | |
"Testing ecaresnet200d:\n", | |
"Testing ecaresnet269d:\n", | |
"Testing ecaresnetlight:\n", | |
"Testing ecaresnext26t_32x4d:\n", | |
"Testing ecaresnext50t_32x4d:\n", | |
"Testing ens_adv_inception_resnet_v2:\n", | |
"Testing gcresnet33ts:\n", | |
"Testing gcresnet50t:\n", | |
"Testing gcresnext26ts:\n", | |
"Testing gcresnext50ts:\n", | |
"Testing gluon_resnet18_v1b:\n", | |
"Testing gluon_resnet34_v1b:\n", | |
"Testing gluon_resnet50_v1b:\n", | |
"Testing gluon_resnet50_v1c:\n", | |
"Testing gluon_resnet50_v1d:\n", | |
"Testing gluon_resnet50_v1s:\n", | |
"Testing gluon_resnet101_v1b:\n", | |
"Testing gluon_resnet101_v1c:\n", | |
"Testing gluon_resnet101_v1d:\n", | |
"Testing gluon_resnet101_v1s:\n", | |
"Testing gluon_resnet152_v1b:\n", | |
"Testing gluon_resnet152_v1c:\n", | |
"Testing gluon_resnet152_v1d:\n" | |
] | |
} | |
], | |
"source": [ | |
"for arch in timm.list_models(\"*res*\"):\n", | |
" print(f\"Testing {arch}:\")\n", | |
" timm_model = create_timm_model(arch, 20, pretrained=False)\n", | |
" if \"vit\" in arch:\n", | |
" groups = timm_vit_split(timm_model)\n", | |
" else:\n", | |
" groups = timm_resnet_split(timm_model)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.9" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment