Created
April 22, 2020 22:43
-
-
Save viniciusmss/52e0eaeff6722dc9254e8c8479d52e16 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Data Loading" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"import torchvision\n", | |
"import torchvision.transforms as transforms\n", | |
"\n", | |
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", | |
"batch_size = 64\n", | |
"\n", | |
"transform = transforms.Compose([\n", | |
" transforms.ToTensor(),\n", | |
" transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n", | |
" ])\n", | |
"\n", | |
"\n", | |
"trainset = torchvision.datasets.CIFAR10(root='./data', train=True,\n", | |
" download=True, transform=transform)\n", | |
"trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,\n", | |
" shuffle=True, num_workers=2)\n", | |
"\n", | |
"testset = torchvision.datasets.CIFAR10(root='./data', train=False,\n", | |
" download=True, transform=transform)\n", | |
"testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,\n", | |
" shuffle=False, num_workers=2)\n", | |
"\n", | |
"classes = ('plane', 'car', 'bird', 'cat',\n", | |
" 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Architectures\n", | |
"\n", | |
"## Traditional CNN" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch.nn as nn\n", | |
"import torch.nn.functional as F\n", | |
"\n", | |
"class CNN(nn.Module):\n", | |
" def __init__(self):\n", | |
" super(CNN, self).__init__()\n", | |
" # convolutional layer (sees 32x32x3 image tensor)\n", | |
" self.conv1 = nn.Conv2d(3, 8, 3, padding=1)\n", | |
" # convolutional layer (sees 16x16x8 tensor)\n", | |
" self.conv2 = nn.Conv2d(8, 16, 3, padding=1)\n", | |
" # convolutional layer (sees 8x8x16 tensor)\n", | |
" self.conv3 = nn.Conv2d(16, 16, 3, padding=1)\n", | |
" # max pooling layer\n", | |
" self.pool = nn.MaxPool2d(2, 2)\n", | |
" # linear layer (16 * 4 * 4 -> 100)\n", | |
" self.fc1 = nn.Linear(16 * 4 * 4, 100)\n", | |
" # linear layer (100 -> 10)\n", | |
" self.fc2 = nn.Linear(100, 10)\n", | |
" # dropout layer (p=0.25)\n", | |
" self.dropout = nn.Dropout(0.25)\n", | |
"\n", | |
" def forward(self, x):\n", | |
" # add sequence of convolutional and max pooling layers\n", | |
" x = self.pool(F.relu(self.conv1(x)))\n", | |
" x = self.pool(F.relu(self.conv2(x)))\n", | |
" x = self.pool(F.relu(self.conv3(x)))\n", | |
" # flatten image input\n", | |
" x = x.view(-1, 16 * 4 * 4)\n", | |
" # add dropout layer\n", | |
" x = self.dropout(x)\n", | |
" # add 1st hidden layer, with relu activation function\n", | |
" x = F.relu(self.fc1(x))\n", | |
" # add dropout layer\n", | |
" x = self.dropout(x)\n", | |
" # add 2nd hidden layer, with relu activation function\n", | |
" x = self.fc2(x)\n", | |
" return x\n", | |
"\n", | |
"\n", | |
"cnn = CNN().to(device)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Bayesian CNN" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from blitz.modules import BayesianLinear, BayesianConv2d\n", | |
"from blitz.utils import variational_estimator\n", | |
"\n", | |
"@variational_estimator\n", | |
"class BNN(nn.Module):\n", | |
" def __init__(self):\n", | |
" super().__init__()\n", | |
" # convolutional layer (sees 32x32x3 image tensor)\n", | |
" self.conv1 = BayesianConv2d(3, 8, (3,3), padding=1)\n", | |
" # convolutional layer (sees 16x16x8 tensor)\n", | |
" self.conv2 = BayesianConv2d(8, 16, (3,3), padding=1)\n", | |
" # convolutional layer (sees 8x8x16 tensor)\n", | |
" self.conv3 = BayesianConv2d(16, 16, (3,3), padding=1)\n", | |
" # max pooling layer\n", | |
" self.pool = nn.MaxPool2d(2, 2)\n", | |
" # linear layer (16 * 4 * 4 -> 100)\n", | |
" self.fc1 = BayesianLinear(16 * 4 * 4, 100)\n", | |
" # linear layer (100 -> 10)\n", | |
" self.fc2 = BayesianLinear(100, 10)\n", | |
"\n", | |
" def forward(self, x):\n", | |
" # add sequence of convolutional and max pooling layers\n", | |
" x = self.pool(F.relu(self.conv1(x)))\n", | |
" x = self.pool(F.relu(self.conv2(x)))\n", | |
" x = self.pool(F.relu(self.conv3(x)))\n", | |
" # flatten image input\n", | |
" x = x.view(-1, 16 * 4 * 4)\n", | |
" # add 1st hidden layer, with relu activation function\n", | |
" x = F.relu(self.fc1(x))\n", | |
" return self.fc2(x)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## BNN + Softplus" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 26, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"@variational_estimator\n", | |
"class BNN_softplus(nn.Module):\n", | |
" def __init__(self):\n", | |
" super().__init__()\n", | |
" # convolutional layer (sees 32x32x3 image tensor)\n", | |
" self.conv1 = BayesianConv2d(3, 8, (3,3), padding=1)\n", | |
" # convolutional layer (sees 16x16x8 tensor)\n", | |
" self.conv2 = BayesianConv2d(8, 16, (3,3), padding=1)\n", | |
" # convolutional layer (sees 8x8x16 tensor)\n", | |
" self.conv3 = BayesianConv2d(16, 16, (3,3), padding=1)\n", | |
" # max pooling layer\n", | |
" self.pool = nn.MaxPool2d(2, 2)\n", | |
" # linear layer (16 * 4 * 4 -> 100)\n", | |
" self.fc1 = BayesianLinear(16 * 4 * 4, 100)\n", | |
" # linear layer (100 -> 10)\n", | |
" self.fc2 = BayesianLinear(100, 10)\n", | |
"\n", | |
" def forward(self, x):\n", | |
" # add sequence of convolutional and max pooling layers\n", | |
" x = self.pool(F.softplus(self.conv1(x)))\n", | |
" x = self.pool(F.softplus(self.conv2(x)))\n", | |
" x = self.pool(F.softplus(self.conv3(x)))\n", | |
" # flatten image input\n", | |
" x = x.view(-1, 16 * 4 * 4)\n", | |
" # add 1st hidden layer, with softplus activation function\n", | |
" x = F.softplus(self.fc1(x))\n", | |
" return self.fc2(x)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Linear BNN" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 35, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"@variational_estimator\n", | |
"class BNN_Linear(nn.Module):\n", | |
" def __init__(self):\n", | |
" super().__init__()\n", | |
" self.fc1 = BayesianLinear(32 * 32 * 3, 100)\n", | |
" self.fc2 = BayesianLinear(100, 10)\n", | |
"\n", | |
" def forward(self, x):\n", | |
" x = x.view(-1, 32 * 32 * 3)\n", | |
" x = F.softplus(self.fc1(x))\n", | |
" return self.fc2(x)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Training\n", | |
"## Traditional CNN" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 27, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch.optim as optim\n", | |
"\n", | |
"criterion = nn.CrossEntropyLoss()\n", | |
"cnn_optimizer = optim.SGD(cnn.parameters(), lr=0.01)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 28, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch: 0 \tTraining Loss: 2.300790\n", | |
"Epoch: 1 \tTraining Loss: 2.267595\n", | |
"Epoch: 2 \tTraining Loss: 2.110062\n", | |
"Epoch: 3 \tTraining Loss: 1.999698\n", | |
"Epoch: 4 \tTraining Loss: 1.904039\n", | |
"Epoch: 5 \tTraining Loss: 1.788206\n", | |
"Epoch: 6 \tTraining Loss: 1.677056\n", | |
"Epoch: 7 \tTraining Loss: 1.621456\n", | |
"Epoch: 8 \tTraining Loss: 1.577198\n", | |
"Epoch: 9 \tTraining Loss: 1.547966\n", | |
"Finished Training\n" | |
] | |
} | |
], | |
"source": [ | |
"for epoch in range(10): # loop over the dataset multiple times\n", | |
"\n", | |
" running_loss = 0.0\n", | |
" for i, (inputs, labels) in enumerate(trainloader, 0):\n", | |
" # get the inputs; data is a list of [inputs, labels]\n", | |
" inputs, labels = inputs.to(device), labels.to(device)\n", | |
"\n", | |
" # zero the parameter gradients\n", | |
" cnn_optimizer.zero_grad()\n", | |
"\n", | |
" # forward + backward + optimize\n", | |
" outputs = cnn(inputs)\n", | |
" loss = criterion(outputs, labels)\n", | |
" loss.backward()\n", | |
" cnn_optimizer.step()\n", | |
"\n", | |
" # Save loss\n", | |
" running_loss += loss.item()*inputs.size(0)\n", | |
"\n", | |
" # print training/validation statistics \n", | |
" running_loss = running_loss/len(trainloader.sampler)\n", | |
" print('Epoch: {} \\tTraining Loss: {:.6f}'.format(epoch, running_loss))\n", | |
"print('Finished Training')\n", | |
"\n", | |
"CNN_PATH = './cifar_cnn.pth'\n", | |
"torch.save(cnn.state_dict(), CNN_PATH)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Bayesian CNN" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 29, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def train_bnn(net, optimizer):\n", | |
" for epoch in range(10): # loop over the dataset multiple times\n", | |
" iteration = total_loss = 0\n", | |
" for i, (inputs, labels) in enumerate(trainloader, 0):\n", | |
" # get the inputs; data is a list of [inputs, labels]\n", | |
" inputs, labels = inputs.to(device), labels.to(device)\n", | |
"\n", | |
" # zero the parameter gradients\n", | |
" optimizer.zero_grad()\n", | |
"\n", | |
" # forward + backward + optimize\n", | |
" loss = net.sample_elbo(inputs=inputs,\n", | |
" labels=labels,\n", | |
" criterion=criterion,\n", | |
" sample_nbr=5,\n", | |
" complexity_cost_weight = 1 / len(trainloader))\n", | |
" loss.backward()\n", | |
" optimizer.step()\n", | |
"\n", | |
" # Save loss\n", | |
" total_loss += loss\n", | |
" iteration += 1\n", | |
"\n", | |
" if iteration%50==0:\n", | |
" print(\"Epoch: {}.\\t Loss: {:.4f}\".format(epoch, total_loss/iteration), end=\"\\r\") \n", | |
"\n", | |
" # print training/validation statistics \n", | |
" print('Epoch: {} \\tTraining Loss: {:.6f}'.format(epoch, total_loss / len(trainloader)))\n", | |
" print('Finished Training')\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 32, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch: 0 \tTraining Loss: 4.698672\n", | |
"Epoch: 1 \tTraining Loss: 0.584409\n", | |
"Epoch: 2 \tTraining Loss: 0.518806\n", | |
"Epoch: 3 \tTraining Loss: 0.494763\n", | |
"Epoch: 4 \tTraining Loss: 0.483292\n", | |
"Epoch: 5 \tTraining Loss: 0.476598\n", | |
"Epoch: 6 \tTraining Loss: 0.477477\n", | |
"Epoch: 7 \tTraining Loss: 0.473085\n", | |
"Epoch: 8 \tTraining Loss: 0.471659\n", | |
"Epoch: 9 \tTraining Loss: 0.470142\n", | |
"Finished Training\n" | |
] | |
} | |
], | |
"source": [ | |
"bnn_conv = BNN().to(device)\n", | |
"bnn_conv_optimizer = optim.SGD(bnn_conv.parameters(), lr=0.001)\n", | |
"train_bnn(bnn_conv, bnn_conv_optimizer)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## BNN + Softplus" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 33, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch: 0 \tTraining Loss: 4.018325\n", | |
"Epoch: 1 \tTraining Loss: 0.534618\n", | |
"Epoch: 2 \tTraining Loss: 0.496891\n", | |
"Epoch: 3 \tTraining Loss: 0.495052\n", | |
"Epoch: 4 \tTraining Loss: 0.476044\n", | |
"Epoch: 5 \tTraining Loss: 0.474423\n", | |
"Epoch: 6 \tTraining Loss: 0.472371\n", | |
"Epoch: 7 \tTraining Loss: 0.473131\n", | |
"Epoch: 8 \tTraining Loss: 0.469505\n", | |
"Epoch: 9 \tTraining Loss: 0.467890\n", | |
"Finished Training\n" | |
] | |
} | |
], | |
"source": [ | |
"bnn_softplus = BNN_softplus().to(device)\n", | |
"bnn_softplus_optimizer = optim.SGD(bnn_softplus.parameters(), lr=0.001)\n", | |
"train_bnn(bnn_softplus, bnn_softplus_optimizer)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Linear BNN\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 36, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch: 0 \tTraining Loss: 0.717498\n", | |
"Epoch: 1 \tTraining Loss: 0.567976\n", | |
"Epoch: 2 \tTraining Loss: 0.529524\n", | |
"Epoch: 3 \tTraining Loss: 0.506957\n", | |
"Epoch: 4 \tTraining Loss: 0.490132\n", | |
"Epoch: 5 \tTraining Loss: 0.477800\n", | |
"Epoch: 6 \tTraining Loss: 0.467948\n", | |
"Epoch: 7 \tTraining Loss: 0.459822\n", | |
"Epoch: 8 \tTraining Loss: 0.452051\n", | |
"Epoch: 9 \tTraining Loss: 0.446135\n", | |
"Finished Training\n" | |
] | |
} | |
], | |
"source": [ | |
"bnn_linear = BNN_Linear().to(device)\n", | |
"bnn_linear_optimizer = optim.SGD(bnn_linear.parameters(), lr=0.001)\n", | |
"train_bnn(bnn_linear, bnn_linear_optimizer)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Test Performance " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 37, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def run_tests(net):\n", | |
"\n", | |
" correct = total = 0\n", | |
" class_correct = list(0. for i in range(10))\n", | |
" class_total = list(0. for i in range(10))\n", | |
" with torch.no_grad():\n", | |
" for images, labels in testloader:\n", | |
" images, labels = images.to(device), labels.to(device)\n", | |
" outputs = net(images)\n", | |
" _, predicted = torch.max(outputs, 1)\n", | |
" total += labels.size(0)\n", | |
" correct += (predicted == labels).sum().item()\n", | |
" c = (predicted == labels).squeeze()\n", | |
" for i in range(4):\n", | |
" label = labels[i]\n", | |
" class_correct[label] += c[i].item()\n", | |
" class_total[label] += 1\n", | |
"\n", | |
" print('Accuracy of the network on the 10000 test images: %d %%\\n' % (\n", | |
" 100 * correct / total))\n", | |
" for i in range(10):\n", | |
" print('Accuracy of %5s : %2d %%' % (\n", | |
" classes[i], 100 * class_correct[i] / class_total[i]))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 38, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Accuracy of the network on the 10000 test images: 43 %\n", | |
"\n", | |
"Accuracy of plane : 51 %\n", | |
"Accuracy of car : 54 %\n", | |
"Accuracy of bird : 18 %\n", | |
"Accuracy of cat : 23 %\n", | |
"Accuracy of deer : 23 %\n", | |
"Accuracy of dog : 45 %\n", | |
"Accuracy of frog : 62 %\n", | |
"Accuracy of horse : 40 %\n", | |
"Accuracy of ship : 58 %\n", | |
"Accuracy of truck : 52 %\n" | |
] | |
} | |
], | |
"source": [ | |
"run_tests(cnn)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 44, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Accuracy of the network on the 10000 test images: 10 %\n", | |
"\n", | |
"Accuracy of plane : 0 %\n", | |
"Accuracy of car : 84 %\n", | |
"Accuracy of bird : 0 %\n", | |
"Accuracy of cat : 0 %\n", | |
"Accuracy of deer : 0 %\n", | |
"Accuracy of dog : 0 %\n", | |
"Accuracy of frog : 0 %\n", | |
"Accuracy of horse : 1 %\n", | |
"Accuracy of ship : 15 %\n", | |
"Accuracy of truck : 0 %\n" | |
] | |
} | |
], | |
"source": [ | |
"run_tests(bnn_conv)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 40, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Accuracy of the network on the 10000 test images: 10 %\n", | |
"\n", | |
"Accuracy of plane : 0 %\n", | |
"Accuracy of car : 0 %\n", | |
"Accuracy of bird : 88 %\n", | |
"Accuracy of cat : 0 %\n", | |
"Accuracy of deer : 5 %\n", | |
"Accuracy of dog : 0 %\n", | |
"Accuracy of frog : 0 %\n", | |
"Accuracy of horse : 0 %\n", | |
"Accuracy of ship : 0 %\n", | |
"Accuracy of truck : 0 %\n" | |
] | |
} | |
], | |
"source": [ | |
"run_tests(bnn_softplus)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 41, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Accuracy of the network on the 10000 test images: 25 %\n", | |
"\n", | |
"Accuracy of plane : 48 %\n", | |
"Accuracy of car : 34 %\n", | |
"Accuracy of bird : 22 %\n", | |
"Accuracy of cat : 20 %\n", | |
"Accuracy of deer : 14 %\n", | |
"Accuracy of dog : 22 %\n", | |
"Accuracy of frog : 35 %\n", | |
"Accuracy of horse : 20 %\n", | |
"Accuracy of ship : 48 %\n", | |
"Accuracy of truck : 26 %\n" | |
] | |
} | |
], | |
"source": [ | |
"run_tests(bnn_linear)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.8" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment