Skip to content

Instantly share code, notes, and snippets.

@tcapelle
Created May 6, 2022 16:38
Show Gist options
  • Save tcapelle/3390c9d7a6bd29f02440675bd2b2e228 to your computer and use it in GitHub Desktop.
Save tcapelle/3390c9d7a6bd29f02440675bd2b2e228 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 269,
"id": "87dcd94a-f70b-4a09-abad-1261be2cdc30",
"metadata": {},
"outputs": [],
"source": [
"import timm\n",
"from timm.models.helpers import group_modules\n",
"from fastai.vision.all import *"
]
},
{
"cell_type": "code",
"execution_count": 270,
"id": "92571bd0-56be-429f-81fd-08a054f3ea5b",
"metadata": {},
"outputs": [],
"source": [
"arch = \"resnet50\"\n",
"m = timm.create_model(arch)\n",
"modules_names = group_modules(m, m.group_matcher(coarse=True))"
]
},
{
"cell_type": "code",
"execution_count": 271,
"id": "22f9ba08-e17d-45ca-b90f-df16ed60664e",
"metadata": {},
"outputs": [],
"source": [
"def get_module_names(m):\n",
" modules_names = group_modules(m, m.group_matcher(coarse=True))\n",
" return list(modules_names.values())[:-1] #cut head"
]
},
{
"cell_type": "code",
"execution_count": 272,
"id": "cee8e140-da83-4084-a727-de6a14b69862",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[['conv1', 'bn1'],\n",
" ['layer1.0.conv1',\n",
" 'layer1.0.bn1',\n",
" 'layer1.0.conv2',\n",
" 'layer1.0.bn2',\n",
" 'layer1.0.conv3',\n",
" 'layer1.0.bn3',\n",
" 'layer1.0.downsample.0',\n",
" 'layer1.0.downsample.1',\n",
" 'layer1.1.conv1',\n",
" 'layer1.1.bn1',\n",
" 'layer1.1.conv2',\n",
" 'layer1.1.bn2',\n",
" 'layer1.1.conv3',\n",
" 'layer1.1.bn3',\n",
" 'layer1.2.conv1',\n",
" 'layer1.2.bn1',\n",
" 'layer1.2.conv2',\n",
" 'layer1.2.bn2',\n",
" 'layer1.2.conv3',\n",
" 'layer1.2.bn3'],\n",
" ['layer2.0.conv1',\n",
" 'layer2.0.bn1',\n",
" 'layer2.0.conv2',\n",
" 'layer2.0.bn2',\n",
" 'layer2.0.conv3',\n",
" 'layer2.0.bn3',\n",
" 'layer2.0.downsample.0',\n",
" 'layer2.0.downsample.1',\n",
" 'layer2.1.conv1',\n",
" 'layer2.1.bn1',\n",
" 'layer2.1.conv2',\n",
" 'layer2.1.bn2',\n",
" 'layer2.1.conv3',\n",
" 'layer2.1.bn3',\n",
" 'layer2.2.conv1',\n",
" 'layer2.2.bn1',\n",
" 'layer2.2.conv2',\n",
" 'layer2.2.bn2',\n",
" 'layer2.2.conv3',\n",
" 'layer2.2.bn3',\n",
" 'layer2.3.conv1',\n",
" 'layer2.3.bn1',\n",
" 'layer2.3.conv2',\n",
" 'layer2.3.bn2',\n",
" 'layer2.3.conv3',\n",
" 'layer2.3.bn3'],\n",
" ['layer3.0.conv1',\n",
" 'layer3.0.bn1',\n",
" 'layer3.0.conv2',\n",
" 'layer3.0.bn2',\n",
" 'layer3.0.conv3',\n",
" 'layer3.0.bn3',\n",
" 'layer3.0.downsample.0',\n",
" 'layer3.0.downsample.1',\n",
" 'layer3.1.conv1',\n",
" 'layer3.1.bn1',\n",
" 'layer3.1.conv2',\n",
" 'layer3.1.bn2',\n",
" 'layer3.1.conv3',\n",
" 'layer3.1.bn3',\n",
" 'layer3.2.conv1',\n",
" 'layer3.2.bn1',\n",
" 'layer3.2.conv2',\n",
" 'layer3.2.bn2',\n",
" 'layer3.2.conv3',\n",
" 'layer3.2.bn3',\n",
" 'layer3.3.conv1',\n",
" 'layer3.3.bn1',\n",
" 'layer3.3.conv2',\n",
" 'layer3.3.bn2',\n",
" 'layer3.3.conv3',\n",
" 'layer3.3.bn3',\n",
" 'layer3.4.conv1',\n",
" 'layer3.4.bn1',\n",
" 'layer3.4.conv2',\n",
" 'layer3.4.bn2',\n",
" 'layer3.4.conv3',\n",
" 'layer3.4.bn3',\n",
" 'layer3.5.conv1',\n",
" 'layer3.5.bn1',\n",
" 'layer3.5.conv2',\n",
" 'layer3.5.bn2',\n",
" 'layer3.5.conv3',\n",
" 'layer3.5.bn3'],\n",
" ['layer4.0.conv1',\n",
" 'layer4.0.bn1',\n",
" 'layer4.0.conv2',\n",
" 'layer4.0.bn2',\n",
" 'layer4.0.conv3',\n",
" 'layer4.0.bn3',\n",
" 'layer4.0.downsample.0',\n",
" 'layer4.0.downsample.1',\n",
" 'layer4.1.conv1',\n",
" 'layer4.1.bn1',\n",
" 'layer4.1.conv2',\n",
" 'layer4.1.bn2',\n",
" 'layer4.1.conv3',\n",
" 'layer4.1.bn3',\n",
" 'layer4.2.conv1',\n",
" 'layer4.2.bn1',\n",
" 'layer4.2.conv2',\n",
" 'layer4.2.bn2',\n",
" 'layer4.2.conv3',\n",
" 'layer4.2.bn3']]"
]
},
"execution_count": 272,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"module_names = get_module_names(m)\n",
"module_names"
]
},
{
"cell_type": "code",
"execution_count": 273,
"id": "4310caec-804e-41f3-a278-fec0533c4cac",
"metadata": {},
"outputs": [],
"source": [
"def get_layers_from_names(module_names, m):\n",
" layers = set()\n",
" for l_name in L(module_names).concat():\n",
" if \".\" not in l_name:\n",
" layers.add(getattr(m, l_name))\n",
" else:\n",
" first_level_name = l_name.split(\".\")[0]\n",
" layers.add(getattr(m, first_level_name))\n",
" return L(layers)"
]
},
{
"cell_type": "code",
"execution_count": 274,
"id": "1278df64-d19d-4500-93bc-f1448ddf48b9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(#6) [Sequential(\n",
" (0): Bottleneck(\n",
" (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
" (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): Bottleneck(\n",
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (2): Bottleneck(\n",
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (3): Bottleneck(\n",
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (4): Bottleneck(\n",
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (5): Bottleneck(\n",
" (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
"),Sequential(\n",
" (0): Bottleneck(\n",
" (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
" (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): Bottleneck(\n",
" (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (2): Bottleneck(\n",
" (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
"),Sequential(\n",
" (0): Bottleneck(\n",
" (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
" (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): Bottleneck(\n",
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (2): Bottleneck(\n",
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (3): Bottleneck(\n",
" (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
"),Sequential(\n",
" (0): Bottleneck(\n",
" (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" (downsample): Sequential(\n",
" (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" )\n",
" )\n",
" (1): Bottleneck(\n",
" (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
" (2): Bottleneck(\n",
" (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act1): ReLU(inplace=True)\n",
" (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
" (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (drop_block): Identity()\n",
" (act2): ReLU(inplace=True)\n",
" (aa): Identity()\n",
" (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
" (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
" (act3): ReLU(inplace=True)\n",
" )\n",
"),BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)]"
]
},
"execution_count": 274,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_layers_from_names(module_names, m)"
]
},
{
"cell_type": "markdown",
"id": "65d332fa-afd5-4a87-a82b-82d39faed827",
"metadata": {},
"source": [
"## Now integrating with fastai:"
]
},
{
"cell_type": "code",
"execution_count": 306,
"id": "b2a89eb2-75b2-4a73-8091-bbbab2f8da99",
"metadata": {},
"outputs": [],
"source": [
"r50 = create_timm_model(\"resnet50\", 20, pretrained=False)"
]
},
{
"cell_type": "code",
"execution_count": 307,
"id": "653f77f6-2e4c-4237-a65f-476a274581c2",
"metadata": {},
"outputs": [],
"source": [
"vit = create_timm_model(\"vit_base_patch16_224\", 20, pretrained=False)"
]
},
{
"cell_type": "code",
"execution_count": 317,
"id": "3ec4277f-5323-4001-8df6-95adb92204b7",
"metadata": {},
"outputs": [],
"source": [
"def timm_split(m, cut=-1):\n",
" body, head = m[0].model, m[1]\n",
" module_names = get_module_names(body)\n",
" groups = L(module_names[0:cut], module_names[cut:]).map(partial(get_layers_from_names, m=body))\n",
" return [g.map(params).concat() for g in groups] + [params(head)]"
]
},
{
"cell_type": "code",
"execution_count": 318,
"id": "da4a6602-6c1c-49a4-99dd-eb51f63c2df9",
"metadata": {},
"outputs": [],
"source": [
"def timm_resnet_split(m): return timm_split(m, cut=-1)\n",
"def timm_vit_split(m): return timm_split(m, cut=-3) # maybe more, don't know...."
]
},
{
"cell_type": "code",
"execution_count": 323,
"id": "788cd742-4ef5-4d1e-aa8d-f606204fe571",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[['conv1', 'bn1'],\n",
" ['layer1.0.conv1',\n",
" 'layer1.0.bn1',\n",
" 'layer1.0.conv2',\n",
" 'layer1.0.bn2',\n",
" 'layer1.0.conv3',\n",
" 'layer1.0.bn3',\n",
" 'layer1.0.downsample.0',\n",
" 'layer1.0.downsample.1',\n",
" 'layer1.1.conv1',\n",
" 'layer1.1.bn1',\n",
" 'layer1.1.conv2',\n",
" 'layer1.1.bn2',\n",
" 'layer1.1.conv3',\n",
" 'layer1.1.bn3',\n",
" 'layer1.2.conv1',\n",
" 'layer1.2.bn1',\n",
" 'layer1.2.conv2',\n",
" 'layer1.2.bn2',\n",
" 'layer1.2.conv3',\n",
" 'layer1.2.bn3'],\n",
" ['layer2.0.conv1',\n",
" 'layer2.0.bn1',\n",
" 'layer2.0.conv2',\n",
" 'layer2.0.bn2',\n",
" 'layer2.0.conv3',\n",
" 'layer2.0.bn3',\n",
" 'layer2.0.downsample.0',\n",
" 'layer2.0.downsample.1',\n",
" 'layer2.1.conv1',\n",
" 'layer2.1.bn1',\n",
" 'layer2.1.conv2',\n",
" 'layer2.1.bn2',\n",
" 'layer2.1.conv3',\n",
" 'layer2.1.bn3',\n",
" 'layer2.2.conv1',\n",
" 'layer2.2.bn1',\n",
" 'layer2.2.conv2',\n",
" 'layer2.2.bn2',\n",
" 'layer2.2.conv3',\n",
" 'layer2.2.bn3',\n",
" 'layer2.3.conv1',\n",
" 'layer2.3.bn1',\n",
" 'layer2.3.conv2',\n",
" 'layer2.3.bn2',\n",
" 'layer2.3.conv3',\n",
" 'layer2.3.bn3'],\n",
" ['layer3.0.conv1',\n",
" 'layer3.0.bn1',\n",
" 'layer3.0.conv2',\n",
" 'layer3.0.bn2',\n",
" 'layer3.0.conv3',\n",
" 'layer3.0.bn3',\n",
" 'layer3.0.downsample.0',\n",
" 'layer3.0.downsample.1',\n",
" 'layer3.1.conv1',\n",
" 'layer3.1.bn1',\n",
" 'layer3.1.conv2',\n",
" 'layer3.1.bn2',\n",
" 'layer3.1.conv3',\n",
" 'layer3.1.bn3',\n",
" 'layer3.2.conv1',\n",
" 'layer3.2.bn1',\n",
" 'layer3.2.conv2',\n",
" 'layer3.2.bn2',\n",
" 'layer3.2.conv3',\n",
" 'layer3.2.bn3',\n",
" 'layer3.3.conv1',\n",
" 'layer3.3.bn1',\n",
" 'layer3.3.conv2',\n",
" 'layer3.3.bn2',\n",
" 'layer3.3.conv3',\n",
" 'layer3.3.bn3',\n",
" 'layer3.4.conv1',\n",
" 'layer3.4.bn1',\n",
" 'layer3.4.conv2',\n",
" 'layer3.4.bn2',\n",
" 'layer3.4.conv3',\n",
" 'layer3.4.bn3',\n",
" 'layer3.5.conv1',\n",
" 'layer3.5.bn1',\n",
" 'layer3.5.conv2',\n",
" 'layer3.5.bn2',\n",
" 'layer3.5.conv3',\n",
" 'layer3.5.bn3']]"
]
},
"execution_count": 323,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_module_names(r50[0].model)"
]
},
{
"cell_type": "code",
"execution_count": 319,
"id": "ab00f109-9ba2-44cb-b3dc-723cdb253970",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"3"
]
},
"execution_count": 319,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"groups = timm_resnet_split(r50)\n",
"len(groups)"
]
},
{
"cell_type": "code",
"execution_count": 320,
"id": "f863fcb4-1044-4951-9f75-e73d30a38d75",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[['patch_embed.proj'],\n",
" ['blocks.0.norm1',\n",
" 'blocks.0.attn.qkv',\n",
" 'blocks.0.attn.proj',\n",
" 'blocks.0.norm2',\n",
" 'blocks.0.mlp.fc1',\n",
" 'blocks.0.mlp.fc2'],\n",
" ['blocks.1.norm1',\n",
" 'blocks.1.attn.qkv',\n",
" 'blocks.1.attn.proj',\n",
" 'blocks.1.norm2',\n",
" 'blocks.1.mlp.fc1',\n",
" 'blocks.1.mlp.fc2'],\n",
" ['blocks.2.norm1',\n",
" 'blocks.2.attn.qkv',\n",
" 'blocks.2.attn.proj',\n",
" 'blocks.2.norm2',\n",
" 'blocks.2.mlp.fc1',\n",
" 'blocks.2.mlp.fc2'],\n",
" ['blocks.3.norm1',\n",
" 'blocks.3.attn.qkv',\n",
" 'blocks.3.attn.proj',\n",
" 'blocks.3.norm2',\n",
" 'blocks.3.mlp.fc1',\n",
" 'blocks.3.mlp.fc2'],\n",
" ['blocks.4.norm1',\n",
" 'blocks.4.attn.qkv',\n",
" 'blocks.4.attn.proj',\n",
" 'blocks.4.norm2',\n",
" 'blocks.4.mlp.fc1',\n",
" 'blocks.4.mlp.fc2'],\n",
" ['blocks.5.norm1',\n",
" 'blocks.5.attn.qkv',\n",
" 'blocks.5.attn.proj',\n",
" 'blocks.5.norm2',\n",
" 'blocks.5.mlp.fc1',\n",
" 'blocks.5.mlp.fc2'],\n",
" ['blocks.6.norm1',\n",
" 'blocks.6.attn.qkv',\n",
" 'blocks.6.attn.proj',\n",
" 'blocks.6.norm2',\n",
" 'blocks.6.mlp.fc1',\n",
" 'blocks.6.mlp.fc2'],\n",
" ['blocks.7.norm1',\n",
" 'blocks.7.attn.qkv',\n",
" 'blocks.7.attn.proj',\n",
" 'blocks.7.norm2',\n",
" 'blocks.7.mlp.fc1',\n",
" 'blocks.7.mlp.fc2'],\n",
" ['blocks.8.norm1',\n",
" 'blocks.8.attn.qkv',\n",
" 'blocks.8.attn.proj',\n",
" 'blocks.8.norm2',\n",
" 'blocks.8.mlp.fc1',\n",
" 'blocks.8.mlp.fc2'],\n",
" ['blocks.9.norm1',\n",
" 'blocks.9.attn.qkv',\n",
" 'blocks.9.attn.proj',\n",
" 'blocks.9.norm2',\n",
" 'blocks.9.mlp.fc1',\n",
" 'blocks.9.mlp.fc2'],\n",
" ['blocks.10.norm1',\n",
" 'blocks.10.attn.qkv',\n",
" 'blocks.10.attn.proj',\n",
" 'blocks.10.norm2',\n",
" 'blocks.10.mlp.fc1',\n",
" 'blocks.10.mlp.fc2']]"
]
},
"execution_count": 320,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_module_names(vit[0].model)"
]
},
{
"cell_type": "code",
"execution_count": 321,
"id": "22ed3135-660d-4729-a58f-47657ebe8ddd",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"3"
]
},
"execution_count": 321,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"groups = timm_vit_split(vit)\n",
"len(groups)"
]
},
{
"cell_type": "code",
"execution_count": 322,
"id": "dce44857-098e-42c0-8bf5-036ccc237059",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['convit_base',\n",
" 'convit_small',\n",
" 'convit_tiny',\n",
" 'crossvit_9_240',\n",
" 'crossvit_9_dagger_240',\n",
" 'crossvit_15_240',\n",
" 'crossvit_15_dagger_240',\n",
" 'crossvit_15_dagger_408',\n",
" 'crossvit_18_240',\n",
" 'crossvit_18_dagger_240',\n",
" 'crossvit_18_dagger_408',\n",
" 'crossvit_base_240',\n",
" 'crossvit_small_240',\n",
" 'crossvit_tiny_240',\n",
" 'levit_128',\n",
" 'levit_128s',\n",
" 'levit_192',\n",
" 'levit_256',\n",
" 'levit_256d',\n",
" 'levit_384',\n",
" 'mobilevit_s',\n",
" 'mobilevit_xs',\n",
" 'mobilevit_xxs',\n",
" 'semobilevit_s',\n",
" 'vit_base_patch8_224',\n",
" 'vit_base_patch8_224_dino',\n",
" 'vit_base_patch8_224_in21k',\n",
" 'vit_base_patch16_18x2_224',\n",
" 'vit_base_patch16_224',\n",
" 'vit_base_patch16_224_dino',\n",
" 'vit_base_patch16_224_in21k',\n",
" 'vit_base_patch16_224_miil',\n",
" 'vit_base_patch16_224_miil_in21k',\n",
" 'vit_base_patch16_224_sam',\n",
" 'vit_base_patch16_384',\n",
" 'vit_base_patch16_plus_240',\n",
" 'vit_base_patch16_rpn_224',\n",
" 'vit_base_patch32_224',\n",
" 'vit_base_patch32_224_in21k',\n",
" 'vit_base_patch32_224_sam',\n",
" 'vit_base_patch32_384',\n",
" 'vit_base_patch32_plus_256',\n",
" 'vit_base_r26_s32_224',\n",
" 'vit_base_r50_s16_224',\n",
" 'vit_base_r50_s16_224_in21k',\n",
" 'vit_base_r50_s16_384',\n",
" 'vit_base_resnet26d_224',\n",
" 'vit_base_resnet50_224_in21k',\n",
" 'vit_base_resnet50_384',\n",
" 'vit_base_resnet50d_224',\n",
" 'vit_giant_patch14_224',\n",
" 'vit_gigantic_patch14_224',\n",
" 'vit_huge_patch14_224',\n",
" 'vit_huge_patch14_224_in21k',\n",
" 'vit_large_patch14_224',\n",
" 'vit_large_patch16_224',\n",
" 'vit_large_patch16_224_in21k',\n",
" 'vit_large_patch16_384',\n",
" 'vit_large_patch32_224',\n",
" 'vit_large_patch32_224_in21k',\n",
" 'vit_large_patch32_384',\n",
" 'vit_large_r50_s32_224',\n",
" 'vit_large_r50_s32_224_in21k',\n",
" 'vit_large_r50_s32_384',\n",
" 'vit_relpos_base_patch16_224',\n",
" 'vit_relpos_base_patch16_plus_240',\n",
" 'vit_relpos_base_patch16_rpn_224',\n",
" 'vit_relpos_base_patch32_plus_rpn_256',\n",
" 'vit_small_patch8_224_dino',\n",
" 'vit_small_patch16_18x2_224',\n",
" 'vit_small_patch16_36x1_224',\n",
" 'vit_small_patch16_224',\n",
" 'vit_small_patch16_224_dino',\n",
" 'vit_small_patch16_224_in21k',\n",
" 'vit_small_patch16_384',\n",
" 'vit_small_patch32_224',\n",
" 'vit_small_patch32_224_in21k',\n",
" 'vit_small_patch32_384',\n",
" 'vit_small_r26_s32_224',\n",
" 'vit_small_r26_s32_224_in21k',\n",
" 'vit_small_r26_s32_384',\n",
" 'vit_small_resnet26d_224',\n",
" 'vit_small_resnet50d_s16_224',\n",
" 'vit_tiny_patch16_224',\n",
" 'vit_tiny_patch16_224_in21k',\n",
" 'vit_tiny_patch16_384',\n",
" 'vit_tiny_r_s16_p8_224',\n",
" 'vit_tiny_r_s16_p8_224_in21k',\n",
" 'vit_tiny_r_s16_p8_384']"
]
},
"execution_count": 322,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"timm.list_models(\"*vit*\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d44eaadb-6016-482e-80dc-b13497069ea9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Testing bat_resnext26ts:\n",
"Testing cspresnet50:\n",
"Testing cspresnet50d:\n",
"Testing cspresnet50w:\n",
"Testing cspresnext50:\n",
"Testing cspresnext50_iabn:\n",
"Testing dla60_res2net:\n",
"Testing dla60_res2next:\n",
"Testing eca_resnet33ts:\n",
"Testing eca_resnext26ts:\n",
"Testing ecaresnet26t:\n",
"Testing ecaresnet50d:\n",
"Testing ecaresnet50d_pruned:\n",
"Testing ecaresnet50t:\n",
"Testing ecaresnet101d:\n",
"Testing ecaresnet101d_pruned:\n",
"Testing ecaresnet200d:\n",
"Testing ecaresnet269d:\n",
"Testing ecaresnetlight:\n",
"Testing ecaresnext26t_32x4d:\n",
"Testing ecaresnext50t_32x4d:\n",
"Testing ens_adv_inception_resnet_v2:\n",
"Testing gcresnet33ts:\n",
"Testing gcresnet50t:\n",
"Testing gcresnext26ts:\n",
"Testing gcresnext50ts:\n",
"Testing gluon_resnet18_v1b:\n",
"Testing gluon_resnet34_v1b:\n",
"Testing gluon_resnet50_v1b:\n",
"Testing gluon_resnet50_v1c:\n",
"Testing gluon_resnet50_v1d:\n",
"Testing gluon_resnet50_v1s:\n",
"Testing gluon_resnet101_v1b:\n",
"Testing gluon_resnet101_v1c:\n",
"Testing gluon_resnet101_v1d:\n",
"Testing gluon_resnet101_v1s:\n",
"Testing gluon_resnet152_v1b:\n",
"Testing gluon_resnet152_v1c:\n",
"Testing gluon_resnet152_v1d:\n"
]
}
],
"source": [
"for arch in timm.list_models(\"*res*\"):\n",
" print(f\"Testing {arch}:\")\n",
" timm_model = create_timm_model(arch, 20, pretrained=False)\n",
" if \"vit\" in arch:\n",
" groups = timm_vit_split(timm_model)\n",
" else:\n",
" groups = timm_resnet_split(timm_model)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment