Skip to content

Instantly share code, notes, and snippets.

@iacolippo
Created February 10, 2020 13:06
Show Gist options
  • Save iacolippo/d7e7c26e980cbd4a9b86c24db6ca922e to your computer and use it in GitHub Desktop.
Save iacolippo/d7e7c26e980cbd4a9b86c24db6ca922e to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from time import time\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"nb = 3\n",
"bs = 3000\n",
"dim = 200\n",
"X = torch.randn(bs*nb, dim).to(\"cuda:0\")\n",
"target = (torch.randn(bs*nb) > 0).float().to(\"cuda:1\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"n_components = 50000"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"R = torch.randn(dim, n_components) / (n_components**0.5)\n",
"R = R.to(\"cuda:0\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# artificially amount of operations performed for more evident timing differences\n",
"\n",
"class DumbModule(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" \n",
" def forward(self, input):\n",
" for i in range(3):\n",
" torch.mm(input, input.t())\n",
" return input\n",
" \n",
"model = nn.Sequential(DumbModule(), nn.Linear(n_components, 1)).to(\"cuda:1\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Simulating external accelerator with device 0 and moving to device 1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In the following an external accelerator is simulated with device 0 and device 1 is the GPU supposed to run asynchronously. The accelerator returns output is a synchronization point, so I call synchronize before moving the tensor to the other device. What happens after moving the tensor `y` to device 1 should happen asynchronously with the first line of the next loop, if I'm not mistaken."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.0008285045623779297\n"
]
}
],
"source": [
"t = time()\n",
"for i in range(nb):\n",
" y = torch.mm(X[i*bs:(i+1)*bs], R)\n",
" y = y.to(\"cuda:1\")\n",
"print(time()-t)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.03906726837158203\n"
]
}
],
"source": [
"t = time()\n",
"for i in range(nb):\n",
" y = torch.mm(X[i*bs:(i+1)*bs], R)\n",
" torch.cuda.synchronize()\n",
" y = y.to(\"cuda:1\")\n",
" torch.cuda.synchronize()\n",
"print(time()-t)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Adding processing through model on device 1"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.26718616485595703\n"
]
}
],
"source": [
"correct = 0.\n",
"total = 0.\n",
"t = time()\n",
"for i in range(nb):\n",
" y = torch.mm(X[i*bs:(i+1)*bs], R)\n",
" torch.cuda.synchronize()\n",
" y = y.to(\"cuda:1\")\n",
" y_hat = model(y)\n",
" correct += (torch.sign(y_hat)).eq(target[i*bs:(i+1)*bs].view_as(y_hat)).sum()\n",
" total += target.numel()\n",
"print(time()-t)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.9821622371673584\n"
]
}
],
"source": [
"correct = 0.\n",
"total = 0.\n",
"t = time()\n",
"for i in range(nb):\n",
" y = torch.mm(X[i*bs:(i+1)*bs], R)\n",
" torch.cuda.synchronize()\n",
" y = y.to(\"cuda:1\")\n",
" y_hat = model(y)\n",
" correct += (torch.sign(y_hat)).eq(target[i*bs:(i+1)*bs].view_as(y_hat)).sum()\n",
" total += target.numel()\n",
" torch.cuda.synchronize()\n",
"print(time()-t)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Adding explicit loss computation on device 1"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.2649667263031006\n"
]
}
],
"source": [
"correct = 0.\n",
"total = 0.\n",
"t = time()\n",
"for i in range(nb):\n",
" y = torch.mm(X[i*bs:(i+1)*bs], R)\n",
" torch.cuda.synchronize()\n",
" y = y.to(\"cuda:1\")\n",
" y_hat = model(y)\n",
" correct += (torch.sign(y_hat)).eq(target[i*bs:(i+1)*bs].view_as(y_hat)).sum()\n",
" total += target.numel()\n",
" loss = ((y_hat - target[i*bs:(i+1)*bs].view_as(y_hat)) ** 2).mean()\n",
"print(time()-t)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.9794397354125977\n"
]
}
],
"source": [
"correct = 0.\n",
"total = 0.\n",
"t = time()\n",
"for i in range(nb):\n",
" y = torch.mm(X[i*bs:(i+1)*bs], R)\n",
" torch.cuda.synchronize()\n",
" y = y.to(\"cuda:1\")\n",
" y_hat = model(y)\n",
" correct += (torch.sign(y_hat)).eq(target[i*bs:(i+1)*bs].view_as(y_hat)).sum()\n",
" total += target.numel()\n",
" loss = ((y_hat - target[i*bs:(i+1)*bs].view_as(y_hat)) ** 2).mean()\n",
" torch.cuda.synchronize()\n",
"print(time()-t)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Comparing with MSELoss"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.7607848644256592\n"
]
}
],
"source": [
"correct = 0.\n",
"total = 0.\n",
"t = time()\n",
"for i in range(nb):\n",
" y = torch.mm(X[i*bs:(i+1)*bs], R)\n",
" torch.cuda.synchronize()\n",
" y = y.to(\"cuda:1\")\n",
" y_hat = model(y)\n",
" correct += (torch.sign(y_hat)).eq(target[i*bs:(i+1)*bs].view_as(y_hat)).sum()\n",
" total += target.numel()\n",
" loss = F.mse_loss(y_hat, target[i*bs:(i+1)*bs].view_as(y_hat))\n",
"print(time()-t)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.760429859161377\n"
]
}
],
"source": [
"correct = 0.\n",
"total = 0.\n",
"t = time()\n",
"for i in range(nb):\n",
" y = torch.mm(X[i*bs:(i+1)*bs], R)\n",
" torch.cuda.synchronize()\n",
" y = y.to(\"cuda:1\")\n",
" y_hat = model(y)\n",
" correct += (torch.sign(y_hat)).eq(target[i*bs:(i+1)*bs].view_as(y_hat)).sum()\n",
" total += target.numel()\n",
" loss = F.mse_loss(y_hat, target[i*bs:(i+1)*bs].view_as(y_hat))\n",
" torch.cuda.synchronize()\n",
"print(time()-t)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Using nn.MSELoss()"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"criterion = nn.MSELoss().to(\"cuda:1\")"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.7595822811126709\n"
]
}
],
"source": [
"correct = 0.\n",
"total = 0.\n",
"t = time()\n",
"for i in range(nb):\n",
" y = torch.mm(X[i*bs:(i+1)*bs], R)\n",
" torch.cuda.synchronize()\n",
" y = y.to(\"cuda:1\")\n",
" y_hat = model(y)\n",
" correct += (torch.sign(y_hat)).eq(target[i*bs:(i+1)*bs].view_as(y_hat)).sum()\n",
" total += target.numel()\n",
" loss = criterion(y_hat, target[i*bs:(i+1)*bs].view_as(y_hat))\n",
"print(time()-t)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.7597024440765381\n"
]
}
],
"source": [
"correct = 0.\n",
"total = 0.\n",
"t = time()\n",
"for i in range(nb):\n",
" y = torch.mm(X[i*bs:(i+1)*bs], R)\n",
" torch.cuda.synchronize()\n",
" y = y.to(\"cuda:1\")\n",
" y_hat = model(y)\n",
" correct += (torch.sign(y_hat)).eq(target[i*bs:(i+1)*bs].view_as(y_hat)).sum()\n",
" total += target.numel()\n",
" loss = criterion(y_hat, target[i*bs:(i+1)*bs].view_as(y_hat))\n",
" torch.cuda.synchronize()\n",
"print(time()-t)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In the case of the explicit loss computation there is a timing difference, while using the `MSELoss` from `torch.nn` or `torch.nn.functional` calling a synchronization at the end of each iteration or not does not make any difference."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:.conda-lightondev] *",
"language": "python",
"name": "conda-env-.conda-lightondev-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment