Skip to content

Instantly share code, notes, and snippets.

@calebrob6
Created June 4, 2023 20:32
Show Gist options
  • Save calebrob6/427512eaae6f373a429c769b8391ef40 to your computer and use it in GitHub Desktop.
Save calebrob6/427512eaae6f373a429c769b8391ef40 to your computer and use it in GitHub Desktop.
Plot how long inference and backpropagation steps take for a U-Net as a function of the number of classes.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import segmentation_models_pytorch as smp\n",
"from tqdm.notebook import tqdm\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import time\n",
"import torch.optim as optim\n",
"from numpy import polyfit"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"device = torch.device(\"cuda:0\")\n",
"criterion = nn.CrossEntropyLoss()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a5323fc0f77141d28b2718127aeb9a68",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/25 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"20 0.08915581703186035 0.006999001356624025\n",
"40 0.09475312232971192 0.002056478655073468\n",
"60 0.09783670902252198 0.0013926249231563847\n",
"80 0.10109012126922608 0.002229318361452539\n",
"100 0.10586683750152588 0.00185596356632018\n",
"120 0.10596387386322022 0.002398825884291594\n",
"140 0.12127435207366943 0.0016488205761009363\n",
"160 0.1275723934173584 0.0012418827598712984\n",
"180 0.13206076622009277 0.0017583857545726425\n",
"200 0.13514981269836426 0.0015638243120844016\n",
"220 0.13813912868499756 0.001634163007377038\n",
"240 0.1402216672897339 0.0011315485421024998\n",
"260 0.1832258701324463 0.05859004723160491\n",
"280 0.14650993347167968 0.0017165110736168901\n",
"300 0.15093004703521729 0.0014567814719269846\n",
"320 0.14555819034576417 0.0030092876925029723\n",
"340 0.15107147693634032 0.003321300416517111\n",
"360 0.17644402980804444 0.017436851315585353\n",
"380 0.15709831714630126 0.003620791003514509\n",
"400 0.1590643882751465 0.004151207186103787\n",
"420 0.1637261390686035 0.003531650862788712\n",
"440 0.17462034225463868 0.010772465242110072\n",
"460 0.168062162399292 0.004489828939106871\n",
"480 0.16781163215637207 0.004528890855395451\n",
"500 0.18220915794372558 0.016113323857669237\n"
]
}
],
"source": [
"xs = np.arange(20, 501, 20)\n",
"ys = []\n",
"for num_classes in tqdm(xs):\n",
" model = smp.Unet(\n",
" encoder_name=\"resnet50\",\n",
" encoder_weights=\"imagenet\",\n",
" in_channels=3,\n",
" classes=num_classes\n",
" ).train().to(device)\n",
" optimizer = optim.Adam(model.parameters())\n",
" y = []\n",
" for i in range(11):\n",
" tic = time.time()\n",
" optimizer.zero_grad()\n",
" out = model(torch.rand(32, 3, 256, 256).to(device))\n",
" loss = criterion(out, torch.randint(0, num_classes, (32, 256, 256)).to(device))\n",
" loss.backward()\n",
" if i > 0:\n",
" y.append(time.time() - tic)\n",
" ys.append(y)\n",
" print(num_classes, np.mean(y), np.std(y))\n",
"ys = np.array(ys)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1000 0.25271785259246826 0.007840519383439442\n"
]
}
],
"source": [
"num_classes = 1000\n",
"model = smp.Unet(\n",
" encoder_name=\"resnet50\",\n",
" encoder_weights=\"imagenet\",\n",
" in_channels=3,\n",
" classes=num_classes\n",
").train().to(device)\n",
"optimizer = optim.Adam(model.parameters())\n",
"y_test = []\n",
"for i in range(11):\n",
" tic = time.time()\n",
" optimizer.zero_grad()\n",
" out = model(torch.rand(32, 3, 256, 256).to(device))\n",
" loss = criterion(out, torch.randint(0, num_classes, (32, 256, 256)).to(device))\n",
" loss.backward()\n",
" if i > 0:\n",
" y_test.append(time.time() - tic)\n",
"print(num_classes, np.mean(y_test), np.std(y_test))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"coefficients = polyfit(xs, np.mean(ys, axis=1), 1)\n",
"polynomial = np.poly1d(coefficients)\n",
"ys_hat = polynomial(xs)\n",
"r2 = np.corrcoef(np.mean(ys, axis=1), ys_hat)[0, 1] ** 2"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure()\n",
"plt.plot(xs, ys.mean(axis=1))\n",
"plt.fill_between(xs, ys.mean(axis=1) - ys.std(axis=1), ys.mean(axis=1) + ys.std(axis=1), alpha=0.3)\n",
"\n",
"plt.plot(np.arange(20, 1001, 20), polynomial(np.arange(20, 1001, 20)), color=\"red\", label=f\"Linear fit (R2={r2:.2f})\")\n",
"plt.scatter(1000, np.mean(y_test), color=\"black\", label=\"Test time\")\n",
"\n",
"plt.xlabel(\"Number of classes\")\n",
"plt.ylabel(\"Time (s)\")\n",
"\n",
"plt.legend()\n",
"\n",
"plt.show()\n",
"plt.close()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "geospatiallib",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment