-
-
Save sailfish009/14fb135179d16cdb6f9dae12aaffbab8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import gc\n", | |
"\n", | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"import numpy as np\n", | |
"\n", | |
"from sklearn.datasets import make_classification" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(1024,)" | |
] | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"X, y = make_classification(\n", | |
" n_samples=1024, \n", | |
" n_features=256, \n", | |
" n_informative=128, \n", | |
" n_redundant=0, \n", | |
" n_repeated=0, \n", | |
" n_classes=2, \n", | |
" n_clusters_per_class=2, \n", | |
" flip_y=0.01, \n", | |
" class_sep=1.0, \n", | |
" hypercube=True, \n", | |
" shuffle=True, \n", | |
" random_state=42\n", | |
")\n", | |
"y.shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def get_model(swish_module):\n", | |
" # Deliberately make the model very large\n", | |
" width = 2 ** 19\n", | |
" return nn.Sequential(\n", | |
" nn.Linear(256, width),\n", | |
" swish_module(),\n", | |
" nn.BatchNorm1d(width),\n", | |
" nn.Linear(width, 1)\n", | |
" )" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"criterion = nn.BCEWithLogitsLoss()\n", | |
"batch_size = 128" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def print_parameter_count(model):\n", | |
" print(\"# of parameters: {:,d}\".format(\n", | |
" np.sum(list(p.numel() for p in model.parameters()))))\n", | |
" print(\"# of trainable parameters: {:,d}\".format(\n", | |
" np.sum(list(p.numel() for p in model.parameters() if p.requires_grad)))) " | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Plain Swish Version" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class PlainSwish(nn.Module):\n", | |
" def forward(self, input_tensor):\n", | |
" return input_tensor * torch.sigmoid(input_tensor)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"# of parameters: 136,314,881\n", | |
"# of trainable parameters: 136,314,881\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"524.0009765625" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model = get_model(PlainSwish).cuda()\n", | |
"print_parameter_count(model)\n", | |
"optimizer = torch.optim.SGD(model.parameters(), lr=0.1)\n", | |
"optimizer.zero_grad()\n", | |
"torch.cuda.memory_allocated() / 1024 ** 2" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"data: 524.12646484375\n", | |
"forw: 1552.126953125\n", | |
"loss: 1552.12744140625\n", | |
"back: 1044.1279296875\n", | |
"step: 1044.1279296875\n", | |
"====================\n", | |
"data: 1044.1279296875\n", | |
"forw: 2072.1279296875\n", | |
"loss: 2072.1279296875\n", | |
"back: 1044.1279296875\n", | |
"step: 1044.1279296875\n", | |
"====================\n", | |
"data: 1044.1279296875\n", | |
"forw: 2072.1279296875\n", | |
"loss: 2072.1279296875\n", | |
"back: 1044.1279296875\n", | |
"step: 1044.1279296875\n", | |
"====================\n", | |
"data: 1044.1279296875\n", | |
"forw: 2072.1279296875\n", | |
"loss: 2072.1279296875\n", | |
"back: 1044.1279296875\n", | |
"step: 1044.1279296875\n", | |
"====================\n", | |
"data: 1044.1279296875\n", | |
"forw: 2072.1279296875\n", | |
"loss: 2072.1279296875\n", | |
"back: 1044.1279296875\n", | |
"step: 1044.1279296875\n", | |
"====================\n", | |
"data: 1044.1279296875\n", | |
"forw: 2072.1279296875\n", | |
"loss: 2072.1279296875\n", | |
"back: 1044.1279296875\n", | |
"step: 1044.1279296875\n", | |
"====================\n", | |
"data: 1044.1279296875\n", | |
"forw: 2072.1279296875\n", | |
"loss: 2072.1279296875\n", | |
"back: 1044.1279296875\n", | |
"step: 1044.1279296875\n", | |
"====================\n", | |
"data: 1044.1279296875\n", | |
"forw: 2072.1279296875\n", | |
"loss: 2072.1279296875\n", | |
"back: 1044.1279296875\n", | |
"step: 1044.1279296875\n", | |
"====================\n" | |
] | |
} | |
], | |
"source": [ | |
"for i in range(0, 1024, batch_size):\n", | |
" Xt, yt = torch.tensor(X[i:i+batch_size], dtype=torch.float).cuda(), torch.tensor(y[i:i+batch_size], dtype=torch.float).cuda()\n", | |
" print(\"data:\", torch.cuda.memory_allocated() / 1024 ** 2)\n", | |
" pred = model(Xt)[:, 0]\n", | |
" print(\"forw:\", torch.cuda.memory_allocated() / 1024 ** 2)\n", | |
" loss = criterion(pred, yt)\n", | |
" # print(loss)\n", | |
" print(\"loss:\", torch.cuda.memory_allocated() / 1024 ** 2)\n", | |
" loss.backward()\n", | |
" print(\"back:\", torch.cuda.memory_allocated() / 1024 ** 2)\n", | |
" optimizer.step()\n", | |
" optimizer.zero_grad()\n", | |
" print(\"step:\", torch.cuda.memory_allocated() / 1024 ** 2)\n", | |
" print(\"=\" * 20)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.0" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"del optimizer, model, Xt, yt, loss, pred\n", | |
"gc.collect()\n", | |
"torch.cuda.memory_allocated() / 1024" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Custom Swith Version\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Swish(torch.autograd.Function):\n", | |
" @staticmethod\n", | |
" def forward(ctx, i):\n", | |
" result = i * torch.sigmoid(i)\n", | |
" ctx.save_for_backward(i)\n", | |
" return result\n", | |
"\n", | |
" @staticmethod\n", | |
" def backward(ctx, grad_output):\n", | |
" i = ctx.saved_variables[0]\n", | |
" sigmoid_i = torch.sigmoid(i)\n", | |
" return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))\n", | |
" \n", | |
"class CustomSwish(nn.Module):\n", | |
" def forward(self, input_tensor):\n", | |
" return Swish.apply(input_tensor)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"536577.0" | |
] | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model = get_model(CustomSwish).cuda()\n", | |
"optimizer = torch.optim.SGD(model.parameters(), lr=0.1)\n", | |
"optimizer.zero_grad()\n", | |
"torch.cuda.memory_allocated() / 1024" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"data: 524.12646484375\n", | |
"forw: 1296.126953125\n", | |
"loss: 1296.12744140625\n", | |
"back: 1044.1279296875\n", | |
"step: 1044.1279296875\n", | |
"====================\n", | |
"data: 1044.1279296875\n", | |
"forw: 1816.1279296875\n", | |
"loss: 1816.1279296875\n", | |
"back: 1044.1279296875\n", | |
"step: 1044.1279296875\n", | |
"====================\n", | |
"data: 1044.1279296875\n", | |
"forw: 1816.1279296875\n", | |
"loss: 1816.1279296875\n", | |
"back: 1044.1279296875\n", | |
"step: 1044.1279296875\n", | |
"====================\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/home/ceshine/miniconda3/envs/deep/lib/python3.7/site-packages/ipykernel_launcher.py:10: DeprecationWarning: 'saved_variables' is deprecated; use 'saved_tensors'\n", | |
" # Remove the CWD from sys.path while we load stuff.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"data: 1044.1279296875\n", | |
"forw: 1816.1279296875\n", | |
"loss: 1816.1279296875\n", | |
"back: 1044.1279296875\n", | |
"step: 1044.1279296875\n", | |
"====================\n", | |
"data: 1044.1279296875\n", | |
"forw: 1816.1279296875\n", | |
"loss: 1816.1279296875\n", | |
"back: 1044.1279296875\n", | |
"step: 1044.1279296875\n", | |
"====================\n", | |
"data: 1044.1279296875\n", | |
"forw: 1816.1279296875\n", | |
"loss: 1816.1279296875\n", | |
"back: 1044.1279296875\n", | |
"step: 1044.1279296875\n", | |
"====================\n", | |
"data: 1044.1279296875\n", | |
"forw: 1816.1279296875\n", | |
"loss: 1816.1279296875\n", | |
"back: 1044.1279296875\n", | |
"step: 1044.1279296875\n", | |
"====================\n", | |
"data: 1044.1279296875\n", | |
"forw: 1816.1279296875\n", | |
"loss: 1816.1279296875\n", | |
"back: 1044.1279296875\n", | |
"step: 1044.1279296875\n", | |
"====================\n" | |
] | |
} | |
], | |
"source": [ | |
"for i in range(0, 1024, batch_size):\n", | |
" Xt, yt = torch.tensor(X[i:i+batch_size], dtype=torch.float).cuda(), torch.tensor(y[i:i+batch_size], dtype=torch.float).cuda()\n", | |
" print(\"data:\", torch.cuda.memory_allocated() / 1024 ** 2)\n", | |
" pred = model(Xt)[:, 0]\n", | |
" print(\"forw:\", torch.cuda.memory_allocated() / 1024 ** 2)\n", | |
" loss = criterion(pred, yt)\n", | |
" # print(loss)\n", | |
" print(\"loss:\", torch.cuda.memory_allocated() / 1024 ** 2)\n", | |
" loss.backward()\n", | |
" print(\"back:\", torch.cuda.memory_allocated() / 1024 ** 2)\n", | |
" optimizer.step()\n", | |
" optimizer.zero_grad()\n", | |
" print(\"step:\", torch.cuda.memory_allocated() / 1024 ** 2)\n", | |
" print(\"=\" * 20)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.0" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"del optimizer, model, Xt, yt, loss, pred\n", | |
"gc.collect()\n", | |
"torch.cuda.memory_allocated() / 1024" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.3" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment