Created
February 7, 2022 03:39
-
-
Save CoryKornowicz/0c15c348dd25dd51a3eb50965fd3ea6f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"import torchvision\n", | |
"import numpy as np\n", | |
"import math\n", | |
"from functorch import vmap" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<torch._C.Generator at 0x7f2728e50600>" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"n_epochs = 3\n", | |
"batch_size_train = 64\n", | |
"batch_size_test = 1000\n", | |
"learning_rate = 0.01\n", | |
"momentum = 0.5\n", | |
"log_interval = 10\n", | |
"\n", | |
"random_seed = 1\n", | |
"torch.backends.cudnn.enabled = True\n", | |
"torch.manual_seed(random_seed)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/home/cory/miniconda3/envs/deepgen/lib/python3.6/site-packages/torchvision/datasets/mnist.py:62: UserWarning: train_data has been renamed data\n", | |
" warnings.warn(\"train_data has been renamed data\")\n", | |
"/home/cory/miniconda3/envs/deepgen/lib/python3.6/site-packages/torchvision/datasets/mnist.py:52: UserWarning: train_labels has been renamed targets\n", | |
" warnings.warn(\"train_labels has been renamed targets\")\n", | |
"/home/cory/miniconda3/envs/deepgen/lib/python3.6/site-packages/torchvision/datasets/mnist.py:67: UserWarning: test_data has been renamed data\n", | |
" warnings.warn(\"test_data has been renamed data\")\n", | |
"/home/cory/miniconda3/envs/deepgen/lib/python3.6/site-packages/torchvision/datasets/mnist.py:57: UserWarning: test_labels has been renamed targets\n", | |
" warnings.warn(\"test_labels has been renamed targets\")\n" | |
] | |
} | |
], | |
"source": [ | |
"train_dataset = torchvision.datasets.MNIST('.', train=True, download=True,\n", | |
" transform=torchvision.transforms.Compose([\n", | |
" torchvision.transforms.ToTensor(),\n", | |
" torchvision.transforms.Normalize(\n", | |
" (0.1307,), (0.3081,))\n", | |
" ]))\n", | |
"\n", | |
"\n", | |
"\n", | |
"test_dataset = torchvision.datasets.MNIST('.', train=False, download=True,\n", | |
" transform=torchvision.transforms.Compose([\n", | |
" torchvision.transforms.ToTensor(),\n", | |
" torchvision.transforms.Normalize(\n", | |
" (0.1307,), (0.3081,))\n", | |
" ]))\n", | |
"\n", | |
"\n", | |
"train_dataset.train_data.to(\"cuda\")\n", | |
"train_dataset.train_labels.to(\"cuda\")\n", | |
"\n", | |
"test_dataset.test_data.to(\"cuda\")\n", | |
"test_dataset.test_labels.to(\"cuda\")\n", | |
"\n", | |
"\n", | |
"train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True)\n", | |
"test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=batch_size_test, shuffle=True)\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"examples = enumerate(test_loader)\n", | |
"batch_idx, (example_data, example_targets) = next(examples)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "", | |
"text/plain": [ | |
"<Figure size 432x288 with 6 Axes>" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
}, | |
{ | |
"data": { | |
"image/png": "", | |
"text/plain": [ | |
"<Figure size 432x288 with 6 Axes>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"import matplotlib.pyplot as plt\n", | |
"%matplotlib inline\n", | |
"\n", | |
"fig = plt.figure()\n", | |
"for i in range(6):\n", | |
" plt.subplot(2,3,i+1)\n", | |
" plt.tight_layout()\n", | |
" plt.imshow(example_data[i][0], cmap='gray', interpolation='none')\n", | |
" plt.title(\"Ground Truth: {}\".format(example_targets[i]))\n", | |
" plt.xticks([])\n", | |
" plt.yticks([])\n", | |
"fig" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch.nn as nn\n", | |
"import torch.nn.functional as F\n", | |
"import torch.optim as optim\n", | |
"import functorch" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Net(nn.Module):\n", | |
" def __init__(self):\n", | |
" super(Net, self).__init__()\n", | |
" self.conv1 = nn.Conv2d(1, 10, kernel_size=5)\n", | |
" self.conv2 = nn.Conv2d(10, 20, kernel_size=5)\n", | |
" self.conv2_drop = nn.Dropout2d()\n", | |
" self.fc1 = nn.Linear(320, 50)\n", | |
" self.fc2 = nn.Linear(50, 10)\n", | |
"\n", | |
" def forward(self, x):\n", | |
" x = F.relu(F.max_pool2d(self.conv1(x), 2))\n", | |
" x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))\n", | |
" x = x.view(-1, 320)\n", | |
" x = F.relu(self.fc1(x))\n", | |
" x = F.dropout(x, training=self.training)\n", | |
" x = self.fc2(x)\n", | |
" return F.log_softmax(x)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class MaxAbsPool2D(nn.Module):\n", | |
" def __init__(self, pool_size, pad_to_fit=False):\n", | |
" super(MaxAbsPool2D, self).__init__()\n", | |
" self.pad = pad_to_fit\n", | |
" self.pool_size = pool_size\n", | |
" \n", | |
"\n", | |
" def gather_nd(self, params, indices):\n", | |
" '''\n", | |
" 4D example\n", | |
" params: tensor shaped [n_1, n_2, n_3, n_4] --> 4 dimensional\n", | |
" indices: tensor shaped [m_1, m_2, m_3, m_4, 4] --> multidimensional list of 4D indices\n", | |
" \n", | |
" returns: tensor shaped [m_1, m_2, m_3, m_4]\n", | |
" \n", | |
" ND_example\n", | |
" params: tensor shaped [n_1, ..., n_p] --> d-dimensional tensor\n", | |
" indices: tensor shaped [m_1, ..., m_i, d] --> multidimensional list of d-dimensional indices\n", | |
" \n", | |
" returns: tensor shaped [m_1, ..., m_1]\n", | |
" '''\n", | |
"\n", | |
" out_shape = indices.shape[:-1]\n", | |
" indices = indices.unsqueeze(0).transpose(0, -1) # roll last axis to fring\n", | |
" ndim = indices.shape[0]\n", | |
" indices = indices.long()\n", | |
" idx = torch.zeros_like(indices[0], device=indices.device).long()\n", | |
" m = 1\n", | |
" \n", | |
" for i in range(ndim)[::-1]:\n", | |
" idx += indices[i] * m \n", | |
" m *= params.size(i)\n", | |
" out = torch.take(params, idx.cuda())\n", | |
" return out.view(out_shape)\n", | |
"\n", | |
" def forward(self, inputs):\n", | |
" if self.pad:\n", | |
" outshape = (inputs.shape[0], \n", | |
" inputs.shape[1],\n", | |
" math.ceil(inputs.shape[2] / self.pool_size), \n", | |
" math.ceil(inputs.shape[3] / self.pool_size))\n", | |
" else: \n", | |
" outshape = (inputs.shape[0], \n", | |
" inputs.shape[1],\n", | |
" (inputs.shape[2] // self.pool_size), \n", | |
" (inputs.shape[3] // self.pool_size)) \n", | |
"\n", | |
" mod_y = inputs.shape[2] % self.pool_size\n", | |
" y1 = mod_y // 2\n", | |
" y2 = mod_y - y1\n", | |
" mod_x = inputs.shape[3] % self.pool_size\n", | |
" x1 = mod_x // 2\n", | |
" x2 = mod_x - x1\n", | |
" padding = (y1, y2, x1, x2)\n", | |
"\n", | |
" if self.pad:\n", | |
" inputs = F.pad(inputs, padding)\n", | |
"\n", | |
" batch_size = inputs.shape[0]\n", | |
" max_height = (inputs.shape[2] // self.pool_size) * self.pool_size\n", | |
" max_width = (inputs.shape[3] // self.pool_size) * self.pool_size\n", | |
" \n", | |
" stacked = torch.stack(\n", | |
" [inputs[:, :, i:max_height:self.pool_size, j:max_width:self.pool_size] \n", | |
" for i in range(self.pool_size) for j in range(self.pool_size)], dim=-1)\n", | |
"\n", | |
" inds = torch.argmax(torch.abs(stacked), dim=-1).to(\"cpu\")\n", | |
" ks = stacked.shape\n", | |
" idx = torch.stack([\n", | |
" *torch.meshgrid(\n", | |
" torch.arange(0, ks[0]), \n", | |
" torch.arange(0, ks[1]),\n", | |
" torch.arange(0, ks[2]), \n", | |
" torch.arange(0, ks[3]),\n", | |
" indexing='ij'\n", | |
" ), inds], \n", | |
" dim=-1)\n", | |
"\n", | |
" x = self.gather_nd(stacked, idx)\n", | |
" x = x.reshape(outshape)\n", | |
" return x" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class CosSim2D_REMAKE(nn.Module):\n", | |
" def __init__(self, input_channels, units=32, kernel_size=[3,3], stride=1, padding=1, depthwise_separable=False):\n", | |
" super(CosSim2D_REMAKE, self).__init__()\n", | |
" self.depthwise_separable = depthwise_separable\n", | |
" self.units = units\n", | |
" assert len(kernel_size) == 2, \"kernel of this size not supported\"\n", | |
" self.kernel_size = kernel_size\n", | |
" self.stride = stride\n", | |
" self.padding = padding\n", | |
" self.input_channels = input_channels\n", | |
"\n", | |
" if self.depthwise_separable:\n", | |
" w = torch.empty(1, np.square(kernel_size[0]), self.units)\n", | |
" nn.init.xavier_uniform_(w)\n", | |
" self.w = nn.Parameter(w, requires_grad=True)\n", | |
" else:\n", | |
" w = torch.empty(1, self.input_channels * np.square(self.kernel_size[0]), self.units)\n", | |
" nn.init.xavier_uniform_(w)\n", | |
" self.w = nn.Parameter(w, requires_grad=True)\n", | |
"\n", | |
" b = torch.empty((self.units,))\n", | |
" nn.init.zeros_(b)\n", | |
" self.b = nn.Parameter(b, requires_grad=True)\n", | |
"\n", | |
" p = torch.empty((self.units,))\n", | |
" nn.init.constant_(p, 2.0)\n", | |
" self.p = nn.Parameter(p, requires_grad=True)\n", | |
"\n", | |
" q = torch.empty((1,))\n", | |
" nn.init.uniform_(q)\n", | |
" self.q = nn.Parameter(q, requires_grad=True)\n", | |
"\n", | |
" \n", | |
" def l2_normal(self, x, axis=None, epsilon=torch.Tensor([1e-12])):\n", | |
" square_sum = torch.sum(torch.square(x), axis, keepdims=True)\n", | |
" x_inv_norm = torch.sqrt(torch.max(square_sum, epsilon.to(x.device)))\n", | |
" return x_inv_norm\n", | |
"\n", | |
" def sigplus(self, x):\n", | |
" return torch.sigmoid(x) * F.softplus(x)\n", | |
"\n", | |
" def stack(self, x):\n", | |
" x = F.pad(x, [self.padding]*4)\n", | |
" strided_x = x.unfold(2, self.kernel_size[0], self.stride).unfold(3, self.kernel_size[1], self.stride)\n", | |
" return strided_x\n", | |
" \n", | |
"\n", | |
" def forward(self, x):\n", | |
"\n", | |
" if self.depthwise_separable:\n", | |
" # print(x.shape, \"PreVex\")\n", | |
" x = vmap(self.call_body)(torch.unsqueeze(x.permute(1,0,2,3), axis=2))\n", | |
" # print(x.shape, \"PostVex\")\n", | |
" # [3, 1, 10, 28, 28]) PostVex\n", | |
" # s = x.shape\n", | |
" # print(s)\n", | |
" x = x.permute(1,0,2,3,4)\n", | |
" # print(x.shape, \"PostPermute\")\n", | |
" # [1, 3, 10, 28, 28]) PostPermute\n", | |
" b, c, f, h, w = x.shape\n", | |
" x = x.reshape(b, c*self.units, h, w)\n", | |
" # print(x.shape, \"PostView\", f)\n", | |
" return x\n", | |
" else:\n", | |
" x = self.call_body(x)\n", | |
" return x\n", | |
"\n", | |
"\n", | |
" def call_body(self, x):\n", | |
" # print(x.shape, \"PreCallBody\")\n", | |
" x = self.stack(x)\n", | |
" n, c, h, w, ks, kd = x.shape\n", | |
" x = x.reshape(n,h*w,c*ks*kd)\n", | |
" \n", | |
" q = torch.square(self.q)\n", | |
"\n", | |
" x_norm = (self.l2_normal(x, axis=2)) + q\n", | |
" w_norm = (self.l2_normal(self.w, axis=1)) + q\n", | |
"\n", | |
" x = (x / x_norm) @ (self.w / w_norm)\n", | |
" sign = torch.sign(x)\n", | |
" x = torch.abs(x) + 1e-12\n", | |
" x = x.pow(self.sigplus(self.p))\n", | |
" x = sign * x + + self.sigplus(self.b)\n", | |
"\n", | |
" x = x.reshape(-1, self.units, h, w)\n", | |
" return x\n", | |
" \n", | |
" \n", | |
" " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 137, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"7571" | |
] | |
}, | |
"execution_count": 137, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"def count_parameters(model):\n", | |
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", | |
"\n", | |
"count_parameters(model)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 223, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"error = nn.CrossEntropyLoss()\n", | |
"\n", | |
"model = nn.Sequential(*[\n", | |
" CosSim2D_REMAKE(input_channels=1, units=10, kernel_size=[5,5], stride=2, padding=1, depthwise_separable=False),\n", | |
" CosSim2D_REMAKE(input_channels=10, units=20, kernel_size=[5,5], stride=1, padding=1, depthwise_separable=False),\n", | |
" CosSim2D_REMAKE(input_channels=20, units=8, kernel_size=[1,1], stride=1, padding=0, depthwise_separable=False),\n", | |
" # nn.Dropout(0.07),\n", | |
" # nn.MaxPool2d((2,2)),\n", | |
" CosSim2D_REMAKE(input_channels=8, units=4, kernel_size=[5,5], stride=2, padding=1, depthwise_separable=True),\n", | |
" CosSim2D_REMAKE(input_channels=32, units=10, kernel_size=[1,1], stride=1, padding=0, depthwise_separable=False),\n", | |
" # nn.Dropout(0.07),\n", | |
" nn.MaxPool2d((2,2)),\n", | |
" CosSim2D_REMAKE(input_channels=10, units=4, kernel_size=[3,3], stride=2, padding=1, depthwise_separable=True),\n", | |
" CosSim2D_REMAKE(input_channels=40, units=10, kernel_size=[3,3], stride=2, padding=1, depthwise_separable=False),\n", | |
" # nn.Dropout(0.07),\n", | |
" MaxAbsPool2D(2, True),\n", | |
" # nn.MaxPool2d((1,1)),\n", | |
" # nn.MaxPool2d(2),\n", | |
" # CosSim2D_REMAKE(input_channels=20, units=1, kernel_size=[3,3], stride=2, padding=1, depthwise_separable=True),\n", | |
" # CosSim2D_REMAKE(input_channels=20, units=10, kernel_size=[3,3], stride=2, padding=1, depthwise_separable=False),\n", | |
" # # CosSim2D_REMAKE(input_channels=20, units=10, kernel_size=[3,3], stride=2, padding=1, depthwise_separable=False),\n", | |
" # # CosSim2D_REMAKE(input_channels=20, units=10, kernel_size=[3,3], stride=2, padding=1, depthwise_separable=False),\n", | |
" # MaxAbsPool2D(2, True),\n", | |
" nn.Flatten(),\n", | |
" nn.Linear(10, 10),\n", | |
" # nn.LeakyReLU(),\n", | |
" nn.LogSoftmax(dim=1)\n", | |
" # nn.Sigmoid()\n", | |
"]).cuda()\n", | |
"\n", | |
"# model = Net().cuda()\n", | |
"\n", | |
"learning_rate = 0.01\n", | |
"# optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)\n", | |
"optimizer = optim.Adam(model.parameters(), lr=learning_rate)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 213, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"torch.Size([2, 10])" | |
] | |
}, | |
"execution_count": 213, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"test_data_x = torch.randn(2, 1, 28, 28).cuda()\n", | |
"model(test_data_x).shape" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 224, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"n_epochs = 3\n", | |
"train_losses = []\n", | |
"train_counter = []\n", | |
"test_losses = []\n", | |
"test_counter = [i*len(train_loader.dataset) for i in range(n_epochs + 1)]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 225, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def train(epoch):\n", | |
" model.train()\n", | |
" for batch_idx, (data, target) in enumerate(train_loader):\n", | |
" optimizer.zero_grad()\n", | |
" output = model(data.to(\"cuda\"))\n", | |
" loss = F.nll_loss(output, target.to(\"cuda\"))\n", | |
" loss.backward()\n", | |
" optimizer.step()\n", | |
" if batch_idx % log_interval == 0:\n", | |
" print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n", | |
" epoch, batch_idx * len(data), len(train_loader.dataset),\n", | |
" 100. * batch_idx / len(train_loader), loss.item()))\n", | |
" train_losses.append(loss.item())\n", | |
" train_counter.append(\n", | |
" (batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))\n", | |
" # torch.save(model.state_dict(), '/results/model.pth')\n", | |
" # torch.save(optimizer.state_dict(), '/results/optimizer.pth')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 226, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def test():\n", | |
" model.eval()\n", | |
" test_loss = 0\n", | |
" correct = 0\n", | |
" with torch.no_grad():\n", | |
" for data, target in test_loader:\n", | |
" output = model(data.to(\"cuda\"))\n", | |
" test_loss += F.nll_loss(output, target.to(\"cuda\"), reduction='sum').item()\n", | |
" pred = output.to(\"cpu\").data.max(1, keepdim=True)[1]\n", | |
" correct += pred.eq(target.to(\"cpu\").data.view_as(pred)).sum()\n", | |
" test_loss /= len(test_loader.dataset)\n", | |
" test_losses.append(test_loss)\n", | |
" print('\\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n", | |
" test_loss, correct, len(test_loader.dataset),\n", | |
" 100. * correct / len(test_loader.dataset)))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 227, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n", | |
"Test set: Avg. loss: 2.3178, Accuracy: 1009/10000 (10%)\n", | |
"\n", | |
"Train Epoch: 1 [0/60000 (0%)]\tLoss: 2.316218\n", | |
"Train Epoch: 1 [640/60000 (1%)]\tLoss: 2.310624\n", | |
"Train Epoch: 1 [1280/60000 (2%)]\tLoss: 2.328552\n", | |
"Train Epoch: 1 [1920/60000 (3%)]\tLoss: 2.326698\n", | |
"Train Epoch: 1 [2560/60000 (4%)]\tLoss: 2.312075\n", | |
"Train Epoch: 1 [3200/60000 (5%)]\tLoss: 2.296974\n", | |
"Train Epoch: 1 [3840/60000 (6%)]\tLoss: 2.296356\n", | |
"Train Epoch: 1 [4480/60000 (7%)]\tLoss: 2.314074\n", | |
"Train Epoch: 1 [5120/60000 (9%)]\tLoss: 2.321757\n", | |
"Train Epoch: 1 [5760/60000 (10%)]\tLoss: 2.315157\n", | |
"Train Epoch: 1 [6400/60000 (11%)]\tLoss: 2.302232\n", | |
"Train Epoch: 1 [7040/60000 (12%)]\tLoss: 2.327388\n", | |
"Train Epoch: 1 [7680/60000 (13%)]\tLoss: 2.309437\n", | |
"Train Epoch: 1 [8320/60000 (14%)]\tLoss: 2.110768\n", | |
"Train Epoch: 1 [8960/60000 (15%)]\tLoss: 1.898252\n", | |
"Train Epoch: 1 [9600/60000 (16%)]\tLoss: 1.730732\n", | |
"Train Epoch: 1 [10240/60000 (17%)]\tLoss: 1.632461\n", | |
"Train Epoch: 1 [10880/60000 (18%)]\tLoss: 1.576923\n", | |
"Train Epoch: 1 [11520/60000 (19%)]\tLoss: 1.463213\n", | |
"Train Epoch: 1 [12160/60000 (20%)]\tLoss: 1.206215\n", | |
"Train Epoch: 1 [12800/60000 (21%)]\tLoss: 1.138680\n", | |
"Train Epoch: 1 [13440/60000 (22%)]\tLoss: 1.048898\n", | |
"Train Epoch: 1 [14080/60000 (23%)]\tLoss: 1.000812\n", | |
"Train Epoch: 1 [14720/60000 (25%)]\tLoss: 1.026755\n", | |
"Train Epoch: 1 [15360/60000 (26%)]\tLoss: 1.016238\n", | |
"Train Epoch: 1 [16000/60000 (27%)]\tLoss: 0.921849\n", | |
"Train Epoch: 1 [16640/60000 (28%)]\tLoss: 0.952626\n", | |
"Train Epoch: 1 [17280/60000 (29%)]\tLoss: 0.953420\n", | |
"Train Epoch: 1 [17920/60000 (30%)]\tLoss: 0.779182\n", | |
"Train Epoch: 1 [18560/60000 (31%)]\tLoss: 1.045863\n", | |
"Train Epoch: 1 [19200/60000 (32%)]\tLoss: 1.215781\n", | |
"Train Epoch: 1 [19840/60000 (33%)]\tLoss: 0.896048\n", | |
"Train Epoch: 1 [20480/60000 (34%)]\tLoss: 0.817331\n", | |
"Train Epoch: 1 [21120/60000 (35%)]\tLoss: 0.758659\n", | |
"Train Epoch: 1 [21760/60000 (36%)]\tLoss: 0.765967\n", | |
"Train Epoch: 1 [22400/60000 (37%)]\tLoss: 0.682770\n", | |
"Train Epoch: 1 [23040/60000 (38%)]\tLoss: 0.883863\n", | |
"Train Epoch: 1 [23680/60000 (39%)]\tLoss: 0.823186\n", | |
"Train Epoch: 1 [24320/60000 (41%)]\tLoss: 0.854320\n", | |
"Train Epoch: 1 [24960/60000 (42%)]\tLoss: 0.687250\n", | |
"Train Epoch: 1 [25600/60000 (43%)]\tLoss: 0.663487\n", | |
"Train Epoch: 1 [26240/60000 (44%)]\tLoss: 0.676594\n", | |
"Train Epoch: 1 [26880/60000 (45%)]\tLoss: 0.728934\n", | |
"Train Epoch: 1 [27520/60000 (46%)]\tLoss: 0.531859\n", | |
"Train Epoch: 1 [28160/60000 (47%)]\tLoss: 0.530515\n", | |
"Train Epoch: 1 [28800/60000 (48%)]\tLoss: 0.441733\n", | |
"Train Epoch: 1 [29440/60000 (49%)]\tLoss: 0.791823\n", | |
"Train Epoch: 1 [30080/60000 (50%)]\tLoss: 0.892163\n", | |
"Train Epoch: 1 [30720/60000 (51%)]\tLoss: 0.600429\n", | |
"Train Epoch: 1 [31360/60000 (52%)]\tLoss: 0.614292\n", | |
"Train Epoch: 1 [32000/60000 (53%)]\tLoss: 0.640600\n", | |
"Train Epoch: 1 [32640/60000 (54%)]\tLoss: 0.486172\n", | |
"Train Epoch: 1 [33280/60000 (55%)]\tLoss: 0.687507\n", | |
"Train Epoch: 1 [33920/60000 (57%)]\tLoss: 0.629682\n", | |
"Train Epoch: 1 [34560/60000 (58%)]\tLoss: 0.666545\n", | |
"Train Epoch: 1 [35200/60000 (59%)]\tLoss: 0.503778\n", | |
"Train Epoch: 1 [35840/60000 (60%)]\tLoss: 0.806158\n", | |
"Train Epoch: 1 [36480/60000 (61%)]\tLoss: 0.832175\n", | |
"Train Epoch: 1 [37120/60000 (62%)]\tLoss: 0.469040\n", | |
"Train Epoch: 1 [37760/60000 (63%)]\tLoss: 0.519231\n", | |
"Train Epoch: 1 [38400/60000 (64%)]\tLoss: 0.506036\n", | |
"Train Epoch: 1 [39040/60000 (65%)]\tLoss: 0.568161\n", | |
"Train Epoch: 1 [39680/60000 (66%)]\tLoss: 0.529751\n", | |
"Train Epoch: 1 [40320/60000 (67%)]\tLoss: 0.459978\n", | |
"Train Epoch: 1 [40960/60000 (68%)]\tLoss: 0.524102\n", | |
"Train Epoch: 1 [41600/60000 (69%)]\tLoss: 0.512300\n", | |
"Train Epoch: 1 [42240/60000 (70%)]\tLoss: 0.437332\n", | |
"Train Epoch: 1 [42880/60000 (71%)]\tLoss: 0.394670\n", | |
"Train Epoch: 1 [43520/60000 (72%)]\tLoss: 0.223572\n", | |
"Train Epoch: 1 [44160/60000 (74%)]\tLoss: 0.342050\n", | |
"Train Epoch: 1 [44800/60000 (75%)]\tLoss: 0.267749\n", | |
"Train Epoch: 1 [45440/60000 (76%)]\tLoss: 0.532922\n", | |
"Train Epoch: 1 [46080/60000 (77%)]\tLoss: 0.360049\n", | |
"Train Epoch: 1 [46720/60000 (78%)]\tLoss: 0.538723\n", | |
"Train Epoch: 1 [47360/60000 (79%)]\tLoss: 0.703804\n", | |
"Train Epoch: 1 [48000/60000 (80%)]\tLoss: 0.834726\n", | |
"Train Epoch: 1 [48640/60000 (81%)]\tLoss: 0.676838\n", | |
"Train Epoch: 1 [49280/60000 (82%)]\tLoss: 0.439723\n", | |
"Train Epoch: 1 [49920/60000 (83%)]\tLoss: 0.457522\n", | |
"Train Epoch: 1 [50560/60000 (84%)]\tLoss: 0.512422\n", | |
"Train Epoch: 1 [51200/60000 (85%)]\tLoss: 0.596058\n", | |
"Train Epoch: 1 [51840/60000 (86%)]\tLoss: 0.614722\n", | |
"Train Epoch: 1 [52480/60000 (87%)]\tLoss: 0.590689\n", | |
"Train Epoch: 1 [53120/60000 (88%)]\tLoss: 0.268143\n", | |
"Train Epoch: 1 [53760/60000 (90%)]\tLoss: 0.577149\n", | |
"Train Epoch: 1 [54400/60000 (91%)]\tLoss: 0.595721\n", | |
"Train Epoch: 1 [55040/60000 (92%)]\tLoss: 0.506574\n", | |
"Train Epoch: 1 [55680/60000 (93%)]\tLoss: 0.480416\n", | |
"Train Epoch: 1 [56320/60000 (94%)]\tLoss: 0.608466\n", | |
"Train Epoch: 1 [56960/60000 (95%)]\tLoss: 0.511769\n", | |
"Train Epoch: 1 [57600/60000 (96%)]\tLoss: 0.410115\n", | |
"Train Epoch: 1 [58240/60000 (97%)]\tLoss: 0.354646\n", | |
"Train Epoch: 1 [58880/60000 (98%)]\tLoss: 0.355858\n", | |
"Train Epoch: 1 [59520/60000 (99%)]\tLoss: 0.349257\n", | |
"\n", | |
"Test set: Avg. loss: 0.3909, Accuracy: 8856/10000 (89%)\n", | |
"\n", | |
"Train Epoch: 2 [0/60000 (0%)]\tLoss: 0.262621\n", | |
"Train Epoch: 2 [640/60000 (1%)]\tLoss: 0.374203\n", | |
"Train Epoch: 2 [1280/60000 (2%)]\tLoss: 0.629726\n", | |
"Train Epoch: 2 [1920/60000 (3%)]\tLoss: 0.472493\n", | |
"Train Epoch: 2 [2560/60000 (4%)]\tLoss: 0.416987\n", | |
"Train Epoch: 2 [3200/60000 (5%)]\tLoss: 0.483328\n", | |
"Train Epoch: 2 [3840/60000 (6%)]\tLoss: 0.430447\n", | |
"Train Epoch: 2 [4480/60000 (7%)]\tLoss: 0.381701\n", | |
"Train Epoch: 2 [5120/60000 (9%)]\tLoss: 0.257381\n", | |
"Train Epoch: 2 [5760/60000 (10%)]\tLoss: 0.586834\n", | |
"Train Epoch: 2 [6400/60000 (11%)]\tLoss: 0.399070\n", | |
"Train Epoch: 2 [7040/60000 (12%)]\tLoss: 0.285346\n", | |
"Train Epoch: 2 [7680/60000 (13%)]\tLoss: 0.214885\n", | |
"Train Epoch: 2 [8320/60000 (14%)]\tLoss: 0.187148\n", | |
"Train Epoch: 2 [8960/60000 (15%)]\tLoss: 0.174931\n", | |
"Train Epoch: 2 [9600/60000 (16%)]\tLoss: 0.468998\n", | |
"Train Epoch: 2 [10240/60000 (17%)]\tLoss: 0.318777\n", | |
"Train Epoch: 2 [10880/60000 (18%)]\tLoss: 0.245776\n", | |
"Train Epoch: 2 [11520/60000 (19%)]\tLoss: 0.273737\n", | |
"Train Epoch: 2 [12160/60000 (20%)]\tLoss: 0.317462\n", | |
"Train Epoch: 2 [12800/60000 (21%)]\tLoss: 0.184208\n", | |
"Train Epoch: 2 [13440/60000 (22%)]\tLoss: 0.338580\n", | |
"Train Epoch: 2 [14080/60000 (23%)]\tLoss: 0.298320\n", | |
"Train Epoch: 2 [14720/60000 (25%)]\tLoss: 0.293539\n", | |
"Train Epoch: 2 [15360/60000 (26%)]\tLoss: 0.217749\n", | |
"Train Epoch: 2 [16000/60000 (27%)]\tLoss: 0.493424\n", | |
"Train Epoch: 2 [16640/60000 (28%)]\tLoss: 0.279736\n", | |
"Train Epoch: 2 [17280/60000 (29%)]\tLoss: 0.498832\n", | |
"Train Epoch: 2 [17920/60000 (30%)]\tLoss: 0.261046\n", | |
"Train Epoch: 2 [18560/60000 (31%)]\tLoss: 0.382005\n", | |
"Train Epoch: 2 [19200/60000 (32%)]\tLoss: 0.199404\n", | |
"Train Epoch: 2 [19840/60000 (33%)]\tLoss: 0.307919\n", | |
"Train Epoch: 2 [20480/60000 (34%)]\tLoss: 0.165441\n", | |
"Train Epoch: 2 [21120/60000 (35%)]\tLoss: 0.382691\n", | |
"Train Epoch: 2 [21760/60000 (36%)]\tLoss: 0.389504\n", | |
"Train Epoch: 2 [22400/60000 (37%)]\tLoss: 0.386378\n", | |
"Train Epoch: 2 [23040/60000 (38%)]\tLoss: 0.506936\n", | |
"Train Epoch: 2 [23680/60000 (39%)]\tLoss: 0.414230\n", | |
"Train Epoch: 2 [24320/60000 (41%)]\tLoss: 0.323903\n", | |
"Train Epoch: 2 [24960/60000 (42%)]\tLoss: 0.306717\n", | |
"Train Epoch: 2 [25600/60000 (43%)]\tLoss: 0.341347\n", | |
"Train Epoch: 2 [26240/60000 (44%)]\tLoss: 0.249054\n", | |
"Train Epoch: 2 [26880/60000 (45%)]\tLoss: 0.266523\n", | |
"Train Epoch: 2 [27520/60000 (46%)]\tLoss: 0.487125\n", | |
"Train Epoch: 2 [28160/60000 (47%)]\tLoss: 0.145521\n", | |
"Train Epoch: 2 [28800/60000 (48%)]\tLoss: 0.484457\n", | |
"Train Epoch: 2 [29440/60000 (49%)]\tLoss: 0.338556\n", | |
"Train Epoch: 2 [30080/60000 (50%)]\tLoss: 0.406402\n", | |
"Train Epoch: 2 [30720/60000 (51%)]\tLoss: 0.394525\n", | |
"Train Epoch: 2 [31360/60000 (52%)]\tLoss: 0.540513\n", | |
"Train Epoch: 2 [32000/60000 (53%)]\tLoss: 0.254305\n", | |
"Train Epoch: 2 [32640/60000 (54%)]\tLoss: 0.315709\n", | |
"Train Epoch: 2 [33280/60000 (55%)]\tLoss: 0.243859\n", | |
"Train Epoch: 2 [33920/60000 (57%)]\tLoss: 0.179197\n", | |
"Train Epoch: 2 [34560/60000 (58%)]\tLoss: 0.292523\n", | |
"Train Epoch: 2 [35200/60000 (59%)]\tLoss: 0.382717\n", | |
"Train Epoch: 2 [35840/60000 (60%)]\tLoss: 0.243314\n", | |
"Train Epoch: 2 [36480/60000 (61%)]\tLoss: 0.386967\n", | |
"Train Epoch: 2 [37120/60000 (62%)]\tLoss: 0.415259\n", | |
"Train Epoch: 2 [37760/60000 (63%)]\tLoss: 0.280219\n", | |
"Train Epoch: 2 [38400/60000 (64%)]\tLoss: 0.572420\n", | |
"Train Epoch: 2 [39040/60000 (65%)]\tLoss: 0.246219\n", | |
"Train Epoch: 2 [39680/60000 (66%)]\tLoss: 0.272813\n", | |
"Train Epoch: 2 [40320/60000 (67%)]\tLoss: 0.434831\n", | |
"Train Epoch: 2 [40960/60000 (68%)]\tLoss: 0.398748\n", | |
"Train Epoch: 2 [41600/60000 (69%)]\tLoss: 0.346579\n", | |
"Train Epoch: 2 [42240/60000 (70%)]\tLoss: 0.462290\n", | |
"Train Epoch: 2 [42880/60000 (71%)]\tLoss: 0.376466\n", | |
"Train Epoch: 2 [43520/60000 (72%)]\tLoss: 0.621190\n", | |
"Train Epoch: 2 [44160/60000 (74%)]\tLoss: 0.425865\n", | |
"Train Epoch: 2 [44800/60000 (75%)]\tLoss: 0.188408\n", | |
"Train Epoch: 2 [45440/60000 (76%)]\tLoss: 0.235857\n", | |
"Train Epoch: 2 [46080/60000 (77%)]\tLoss: 0.480694\n", | |
"Train Epoch: 2 [46720/60000 (78%)]\tLoss: 0.463522\n", | |
"Train Epoch: 2 [47360/60000 (79%)]\tLoss: 0.454278\n", | |
"Train Epoch: 2 [48000/60000 (80%)]\tLoss: 0.241668\n", | |
"Train Epoch: 2 [48640/60000 (81%)]\tLoss: 0.228049\n", | |
"Train Epoch: 2 [49280/60000 (82%)]\tLoss: 0.371920\n", | |
"Train Epoch: 2 [49920/60000 (83%)]\tLoss: 0.303571\n", | |
"Train Epoch: 2 [50560/60000 (84%)]\tLoss: 0.224145\n", | |
"Train Epoch: 2 [51200/60000 (85%)]\tLoss: 0.231461\n", | |
"Train Epoch: 2 [51840/60000 (86%)]\tLoss: 0.117327\n", | |
"Train Epoch: 2 [52480/60000 (87%)]\tLoss: 0.219765\n", | |
"Train Epoch: 2 [53120/60000 (88%)]\tLoss: 0.220244\n", | |
"Train Epoch: 2 [53760/60000 (90%)]\tLoss: 0.252273\n", | |
"Train Epoch: 2 [54400/60000 (91%)]\tLoss: 0.373454\n", | |
"Train Epoch: 2 [55040/60000 (92%)]\tLoss: 0.308257\n", | |
"Train Epoch: 2 [55680/60000 (93%)]\tLoss: 0.249226\n", | |
"Train Epoch: 2 [56320/60000 (94%)]\tLoss: 0.305760\n", | |
"Train Epoch: 2 [56960/60000 (95%)]\tLoss: 0.319969\n", | |
"Train Epoch: 2 [57600/60000 (96%)]\tLoss: 0.154123\n", | |
"Train Epoch: 2 [58240/60000 (97%)]\tLoss: 0.200313\n", | |
"Train Epoch: 2 [58880/60000 (98%)]\tLoss: 0.237880\n", | |
"Train Epoch: 2 [59520/60000 (99%)]\tLoss: 0.293041\n", | |
"\n", | |
"Test set: Avg. loss: 0.4096, Accuracy: 8801/10000 (88%)\n", | |
"\n", | |
"Train Epoch: 3 [0/60000 (0%)]\tLoss: 0.357464\n", | |
"Train Epoch: 3 [640/60000 (1%)]\tLoss: 0.372117\n", | |
"Train Epoch: 3 [1280/60000 (2%)]\tLoss: 0.352286\n", | |
"Train Epoch: 3 [1920/60000 (3%)]\tLoss: 0.306995\n", | |
"Train Epoch: 3 [2560/60000 (4%)]\tLoss: 0.318805\n", | |
"Train Epoch: 3 [3200/60000 (5%)]\tLoss: 0.481552\n", | |
"Train Epoch: 3 [3840/60000 (6%)]\tLoss: 0.424677\n", | |
"Train Epoch: 3 [4480/60000 (7%)]\tLoss: 0.163365\n", | |
"Train Epoch: 3 [5120/60000 (9%)]\tLoss: 0.183975\n", | |
"Train Epoch: 3 [5760/60000 (10%)]\tLoss: 0.172799\n", | |
"Train Epoch: 3 [6400/60000 (11%)]\tLoss: 0.298142\n", | |
"Train Epoch: 3 [7040/60000 (12%)]\tLoss: 0.261880\n", | |
"Train Epoch: 3 [7680/60000 (13%)]\tLoss: 0.134764\n", | |
"Train Epoch: 3 [8320/60000 (14%)]\tLoss: 0.212339\n", | |
"Train Epoch: 3 [8960/60000 (15%)]\tLoss: 0.365209\n", | |
"Train Epoch: 3 [9600/60000 (16%)]\tLoss: 0.150135\n", | |
"Train Epoch: 3 [10240/60000 (17%)]\tLoss: 0.248458\n", | |
"Train Epoch: 3 [10880/60000 (18%)]\tLoss: 0.284270\n", | |
"Train Epoch: 3 [11520/60000 (19%)]\tLoss: 0.469098\n", | |
"Train Epoch: 3 [12160/60000 (20%)]\tLoss: 0.220207\n", | |
"Train Epoch: 3 [12800/60000 (21%)]\tLoss: 0.382031\n", | |
"Train Epoch: 3 [13440/60000 (22%)]\tLoss: 0.241649\n", | |
"Train Epoch: 3 [14080/60000 (23%)]\tLoss: 0.378085\n", | |
"Train Epoch: 3 [14720/60000 (25%)]\tLoss: 0.210836\n", | |
"Train Epoch: 3 [15360/60000 (26%)]\tLoss: 0.291862\n", | |
"Train Epoch: 3 [16000/60000 (27%)]\tLoss: 0.370690\n", | |
"Train Epoch: 3 [16640/60000 (28%)]\tLoss: 0.209308\n", | |
"Train Epoch: 3 [17280/60000 (29%)]\tLoss: 0.120709\n", | |
"Train Epoch: 3 [17920/60000 (30%)]\tLoss: 0.227081\n", | |
"Train Epoch: 3 [18560/60000 (31%)]\tLoss: 0.151938\n", | |
"Train Epoch: 3 [19200/60000 (32%)]\tLoss: 0.281668\n", | |
"Train Epoch: 3 [19840/60000 (33%)]\tLoss: 0.226762\n", | |
"Train Epoch: 3 [20480/60000 (34%)]\tLoss: 0.233312\n", | |
"Train Epoch: 3 [21120/60000 (35%)]\tLoss: 0.253970\n", | |
"Train Epoch: 3 [21760/60000 (36%)]\tLoss: 0.066867\n", | |
"Train Epoch: 3 [22400/60000 (37%)]\tLoss: 0.184917\n", | |
"Train Epoch: 3 [23040/60000 (38%)]\tLoss: 0.303674\n", | |
"Train Epoch: 3 [23680/60000 (39%)]\tLoss: 0.379502\n", | |
"Train Epoch: 3 [24320/60000 (41%)]\tLoss: 0.300051\n", | |
"Train Epoch: 3 [24960/60000 (42%)]\tLoss: 0.240761\n", | |
"Train Epoch: 3 [25600/60000 (43%)]\tLoss: 0.168751\n", | |
"Train Epoch: 3 [26240/60000 (44%)]\tLoss: 0.380870\n", | |
"Train Epoch: 3 [26880/60000 (45%)]\tLoss: 0.276003\n", | |
"Train Epoch: 3 [27520/60000 (46%)]\tLoss: 0.353058\n", | |
"Train Epoch: 3 [28160/60000 (47%)]\tLoss: 0.117650\n", | |
"Train Epoch: 3 [28800/60000 (48%)]\tLoss: 0.362679\n", | |
"Train Epoch: 3 [29440/60000 (49%)]\tLoss: 0.102341\n", | |
"Train Epoch: 3 [30080/60000 (50%)]\tLoss: 0.155397\n", | |
"Train Epoch: 3 [30720/60000 (51%)]\tLoss: 0.233923\n", | |
"Train Epoch: 3 [31360/60000 (52%)]\tLoss: 0.212420\n", | |
"Train Epoch: 3 [32000/60000 (53%)]\tLoss: 0.118900\n", | |
"Train Epoch: 3 [32640/60000 (54%)]\tLoss: 0.205148\n", | |
"Train Epoch: 3 [33280/60000 (55%)]\tLoss: 0.293114\n", | |
"Train Epoch: 3 [33920/60000 (57%)]\tLoss: 0.226733\n", | |
"Train Epoch: 3 [34560/60000 (58%)]\tLoss: 0.280164\n", | |
"Train Epoch: 3 [35200/60000 (59%)]\tLoss: 0.270935\n", | |
"Train Epoch: 3 [35840/60000 (60%)]\tLoss: 0.131049\n", | |
"Train Epoch: 3 [36480/60000 (61%)]\tLoss: 0.082124\n", | |
"Train Epoch: 3 [37120/60000 (62%)]\tLoss: 0.240775\n", | |
"Train Epoch: 3 [37760/60000 (63%)]\tLoss: 0.265926\n", | |
"Train Epoch: 3 [38400/60000 (64%)]\tLoss: 0.270410\n", | |
"Train Epoch: 3 [39040/60000 (65%)]\tLoss: 0.197535\n", | |
"Train Epoch: 3 [39680/60000 (66%)]\tLoss: 0.272590\n", | |
"Train Epoch: 3 [40320/60000 (67%)]\tLoss: 0.308527\n", | |
"Train Epoch: 3 [40960/60000 (68%)]\tLoss: 0.149671\n", | |
"Train Epoch: 3 [41600/60000 (69%)]\tLoss: 0.293707\n", | |
"Train Epoch: 3 [42240/60000 (70%)]\tLoss: 0.245622\n", | |
"Train Epoch: 3 [42880/60000 (71%)]\tLoss: 0.294179\n", | |
"Train Epoch: 3 [43520/60000 (72%)]\tLoss: 0.167644\n", | |
"Train Epoch: 3 [44160/60000 (74%)]\tLoss: 0.229727\n", | |
"Train Epoch: 3 [44800/60000 (75%)]\tLoss: 0.205174\n", | |
"Train Epoch: 3 [45440/60000 (76%)]\tLoss: 0.178458\n", | |
"Train Epoch: 3 [46080/60000 (77%)]\tLoss: 0.064823\n", | |
"Train Epoch: 3 [46720/60000 (78%)]\tLoss: 0.360851\n", | |
"Train Epoch: 3 [47360/60000 (79%)]\tLoss: 0.259635\n", | |
"Train Epoch: 3 [48000/60000 (80%)]\tLoss: 0.089917\n", | |
"Train Epoch: 3 [48640/60000 (81%)]\tLoss: 0.267852\n", | |
"Train Epoch: 3 [49280/60000 (82%)]\tLoss: 0.267209\n", | |
"Train Epoch: 3 [49920/60000 (83%)]\tLoss: 0.243634\n", | |
"Train Epoch: 3 [50560/60000 (84%)]\tLoss: 0.219475\n", | |
"Train Epoch: 3 [51200/60000 (85%)]\tLoss: 0.160179\n", | |
"Train Epoch: 3 [51840/60000 (86%)]\tLoss: 0.218679\n", | |
"Train Epoch: 3 [52480/60000 (87%)]\tLoss: 0.121615\n", | |
"Train Epoch: 3 [53120/60000 (88%)]\tLoss: 0.109723\n", | |
"Train Epoch: 3 [53760/60000 (90%)]\tLoss: 0.374123\n", | |
"Train Epoch: 3 [54400/60000 (91%)]\tLoss: 0.231701\n", | |
"Train Epoch: 3 [55040/60000 (92%)]\tLoss: 0.346988\n", | |
"Train Epoch: 3 [55680/60000 (93%)]\tLoss: 0.184595\n", | |
"Train Epoch: 3 [56320/60000 (94%)]\tLoss: 0.178899\n", | |
"Train Epoch: 3 [56960/60000 (95%)]\tLoss: 0.194567\n", | |
"Train Epoch: 3 [57600/60000 (96%)]\tLoss: 0.209761\n", | |
"Train Epoch: 3 [58240/60000 (97%)]\tLoss: 0.111838\n", | |
"Train Epoch: 3 [58880/60000 (98%)]\tLoss: 0.277517\n", | |
"Train Epoch: 3 [59520/60000 (99%)]\tLoss: 0.256277\n", | |
"\n", | |
"Test set: Avg. loss: 0.2302, Accuracy: 9327/10000 (93%)\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"test()\n", | |
"for epoch in range(1, n_epochs + 1):\n", | |
" train(epoch)\n", | |
" test()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 136, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"----------------------------------------------------------------\n", | |
" Layer (type) Output Shape Param #\n", | |
"================================================================\n", | |
" CosSim2D_REMAKE-1 [-1, 10, 13, 13] 0\n", | |
" CosSim2D_REMAKE-2 [-1, 12, 11, 11] 0\n", | |
" CosSim2D_REMAKE-3 [-1, 8, 11, 11] 0\n", | |
" CosSim2D_REMAKE-4 [-1, 32, 6, 6] 0\n", | |
" CosSim2D_REMAKE-5 [-1, 10, 6, 6] 0\n", | |
" CosSim2D_REMAKE-6 [-1, 40, 3, 3] 0\n", | |
" CosSim2D_REMAKE-7 [-1, 10, 2, 2] 0\n", | |
" MaxAbsPool2D-8 [-1, 10, 1, 1] 0\n", | |
" Flatten-9 [-1, 10] 0\n", | |
" Linear-10 [-1, 10] 110\n", | |
" LogSoftmax-11 [-1, 10] 0\n", | |
"================================================================\n", | |
"Total params: 110\n", | |
"Trainable params: 110\n", | |
"Non-trainable params: 0\n", | |
"----------------------------------------------------------------\n", | |
"Input size (MB): 0.00\n", | |
"Forward/backward pass size (MB): 0.05\n", | |
"Params size (MB): 0.00\n", | |
"Estimated Total Size (MB): 0.05\n", | |
"----------------------------------------------------------------\n", | |
"7571\n" | |
] | |
} | |
], | |
"source": [ | |
"import torchsummary\n", | |
"torchsummary.summary(model, (1, 28, 28))\n", | |
"print(count_parameters(model))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 91, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"Text(0, 0.5, 'negative log likelihood loss')" | |
] | |
}, | |
"execution_count": 91, | |
"metadata": {}, | |
"output_type": "execute_result" | |
}, | |
{ | |
"data": { | |
"image/png": "", | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"fig = plt.figure()\n", | |
"plt.plot(train_counter, train_losses, color='blue')\n", | |
"plt.scatter(test_counter, test_losses, color='red')\n", | |
"plt.legend(['Train Loss', 'Test Loss'], loc='upper right')\n", | |
"plt.xlabel('number of training examples seen')\n", | |
"plt.ylabel('negative log likelihood loss')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"with torch.no_grad():\n", | |
" output = model(example_data.to(\"cuda\"))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"fig = plt.figure()\n", | |
"for i in range(6):\n", | |
" plt.subplot(2,3,i+1)\n", | |
" plt.tight_layout()\n", | |
" plt.imshow(example_data[i][0], cmap='gray', interpolation='none')\n", | |
" plt.title(\"Prediction: {}\".format(\n", | |
" output.data.max(1, keepdim=True)[1][i].item()))\n", | |
" plt.xticks([])\n", | |
" plt.yticks([])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"test_data_x = torch.randn(1,1,28,28)\n", | |
"max_abs = MaxAbsPool2D(2, True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"print(test_data_x.shape)\n", | |
"print(max_abs(test_data_x).shape)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"fig = plt.figure()\n", | |
"for i in range(2):\n", | |
" plt.subplot(1,2,i+1)\n", | |
" plt.tight_layout()\n", | |
" if i % 2 == 0: data = example_data[i+2][0]\n", | |
" # else: data = max_abs(example_data)[i-1].permute(1,2,0)\n", | |
" else: data = F.max_pool2d(example_data, 2)[i+1].permute(1,2,0)\n", | |
" plt.imshow(data, cmap='gray', interpolation='none')\n", | |
" plt.xticks([])\n", | |
" plt.yticks([])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"fig = plt.figure()\n", | |
"for i in range(2):\n", | |
" plt.subplot(1,2,i+1)\n", | |
" plt.tight_layout()\n", | |
" if i % 2 == 0: data = example_data[i+2][0]\n", | |
" else: data = max_abs(example_data)[i+1].permute(1,2,0)\n", | |
"# else: data = F.max_pool2d(example_data, 2)[i-1].permute(1,2,0)\n", | |
" plt.imshow(data, cmap='gray', interpolation='none')\n", | |
" plt.xticks([])\n", | |
" plt.yticks([])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"interpreter": { | |
"hash": "0d6e8431264d1160935736f310d9f0e1db684933f77421304180850f988bc540" | |
}, | |
"kernelspec": { | |
"display_name": "Python 3.6.15 ('deepgen')", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.15" | |
}, | |
"orig_nbformat": 4 | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment