Last active
June 7, 2019 09:47
-
-
Save sharanry/16897f4808d74af29fad776ca4557161 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"colab": {}, | |
"colab_type": "code", | |
"id": "5theVSlqFEMW" | |
}, | |
"outputs": [], | |
"source": [ | |
"import math\n", | |
"import numpy as np\n", | |
"import scipy\n", | |
"import pandas as pd\n", | |
"from scipy.stats import multivariate_normal\n", | |
"from sklearn.metrics import accuracy_score\n", | |
"from sklearn.utils import shuffle\n", | |
"from tqdm import tqdm\n", | |
"import matplotlib as mpl\n", | |
"mpl.use('Agg')\n", | |
"import matplotlib.pyplot as plt\n", | |
"\n", | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"\n", | |
"# Ignore warnings\n", | |
"import warnings\n", | |
"warnings.filterwarnings(\"ignore\")\n", | |
"\n", | |
"import pyro\n", | |
"import torch.nn.functional as F" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"colab": {}, | |
"colab_type": "code", | |
"id": "Zm5rduc9FKmh" | |
}, | |
"outputs": [], | |
"source": [ | |
"es = list(range(4, 10))\n", | |
"ns = [int(math.e**i) for i in es]\n", | |
"\n", | |
"N = max(ns)\n", | |
"class_0 = np.random.multivariate_normal([-1,0], [[1,0], [0,1]], int(0.9*N))\n", | |
"class_1 = np.random.multivariate_normal([1,0], [[1,0], [0,1]], int(0.1*N))\n", | |
"\n", | |
"\n", | |
"orig_data = pd.DataFrame(class_0, columns=list(range(class_0.shape[1])))\n", | |
"orig_data['y']=0\n", | |
"temp = pd.DataFrame(class_1, columns=list(range(class_1.shape[1])))\n", | |
"temp['y']=1\n", | |
"orig_data = orig_data.append(temp);\n", | |
"orig_data = orig_data[list(orig_data.columns[-1:]) + list(orig_data.columns[:-1] )]\n", | |
"orig_data = shuffle(orig_data)\n", | |
"\n", | |
"\n", | |
"# Create Test Data\n", | |
"class_0_test = np.random.multivariate_normal([-1,0], [[1,0], [0,1]], int(0.9*1000))\n", | |
"class_1_test = np.random.multivariate_normal([1,0], [[1,0], [0,1]], int(0.1*1000))\n", | |
"\n", | |
"test_data = pd.DataFrame(class_0_test, columns=list(range(class_0_test.shape[1])))\n", | |
"test_data['y']=0\n", | |
"temp = pd.DataFrame(class_1_test, columns=list(range(class_1_test.shape[1])))\n", | |
"temp['y']=1\n", | |
"test_data = test_data.append(temp)\n", | |
"test_data = test_data[list(test_data.columns[-1:]) + list(test_data.columns[:-1] )]\n", | |
"test_data = shuffle(test_data)\n", | |
"\n", | |
"#flipping\n", | |
"flipped_data = orig_data.copy()\n", | |
"def random_flip(row):\n", | |
" if(np.random.uniform()<0.3):\n", | |
" return int(not(int(row['y'])))\n", | |
" else:\n", | |
" return row['y']\n", | |
"\n", | |
"flipped_data['y'] = flipped_data.apply(random_flip, axis=1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"colab": {}, | |
"colab_type": "code", | |
"id": "4Zw5Ybwpc5fF" | |
}, | |
"outputs": [], | |
"source": [ | |
"from torch.utils import data\n", | |
"\n", | |
"class Dataset(data.Dataset):\n", | |
" 'Characterizes a dataset for PyTorch'\n", | |
" def __init__(self, dataframe):\n", | |
" 'Initialization'\n", | |
" self.dataframe = dataframe\n", | |
"\n", | |
" def __len__(self):\n", | |
" 'Denotes the total number of samples'\n", | |
" return len(self.dataframe)\n", | |
"\n", | |
" def __getitem__(self, index):\n", | |
" 'Generates one sample of data'\n", | |
" # Load data and get label\n", | |
" X = orig_data[[0,1]].iloc[index].values\n", | |
" y = orig_data['y'].iloc[index]\n", | |
"\n", | |
" return X, y" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"colab": {}, | |
"colab_type": "code", | |
"id": "ZZZWcjRCFcb3" | |
}, | |
"outputs": [], | |
"source": [ | |
"from torch.utils import data\n", | |
"orig_dataloader = data.DataLoader(Dataset(orig_data), batch_size=64)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"colab": {}, | |
"colab_type": "code", | |
"id": "Pr9nMK4bFjOr" | |
}, | |
"outputs": [], | |
"source": [ | |
"class NN(nn.Module):\n", | |
" \n", | |
" def __init__(self, input_size, hidden_size, output_size):\n", | |
" super(NN, self).__init__()\n", | |
" self.fc1 = nn.Linear(input_size, hidden_size)\n", | |
" self.out = nn.Linear(hidden_size, output_size)\n", | |
" \n", | |
" def forward(self, x):\n", | |
" output = self.fc1(x)\n", | |
" output = F.relu(output)\n", | |
" output = self.out(output)\n", | |
" return output\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"net = NN(2, 10, 2)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import pyro\n", | |
"from pyro.distributions import Normal, Categorical\n", | |
"from pyro.infer import SVI, Trace_ELBO\n", | |
"from pyro.optim import Adam" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"log_softmax = nn.LogSoftmax(dim=1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def model(x_data, y_data):\n", | |
" \n", | |
" fc1w_prior = Normal(loc=torch.zeros_like(net.fc1.weight), scale=torch.ones_like(net.fc1.weight))\n", | |
" fc1b_prior = Normal(loc=torch.zeros_like(net.fc1.bias), scale=torch.ones_like(net.fc1.bias))\n", | |
" \n", | |
" outw_prior = Normal(loc=torch.zeros_like(net.out.weight), scale=torch.ones_like(net.out.weight))\n", | |
" outb_prior = Normal(loc=torch.zeros_like(net.out.bias), scale=torch.ones_like(net.out.bias))\n", | |
" \n", | |
" priors = {'fc1.weight': fc1w_prior, 'fc1.bias': fc1b_prior, 'out.weight': outw_prior, 'out.bias': outb_prior}\n", | |
" # lift module parameters to random variables sampled from the priors\n", | |
" lifted_module = pyro.random_module(\"module\", net, priors)\n", | |
" # sample a regressor (which also samples w and b)\n", | |
" lifted_reg_model = lifted_module()\n", | |
" \n", | |
" lhat = log_softmax(lifted_reg_model(x_data))\n", | |
" \n", | |
" pyro.sample(\"obs\", Categorical(logits=lhat), obs=y_data)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"softplus = torch.nn.Softplus()\n", | |
"\n", | |
"def guide(x_data, y_data):\n", | |
" \n", | |
" # First layer weight distribution priors\n", | |
" fc1w_mu = torch.randn_like(net.fc1.weight)\n", | |
" fc1w_sigma = torch.randn_like(net.fc1.weight)\n", | |
" fc1w_mu_param = pyro.param(\"fc1w_mu\", fc1w_mu)\n", | |
" fc1w_sigma_param = softplus(pyro.param(\"fc1w_sigma\", fc1w_sigma))\n", | |
" fc1w_prior = Normal(loc=fc1w_mu_param, scale=fc1w_sigma_param)\n", | |
" # First layer bias distribution priors\n", | |
" fc1b_mu = torch.randn_like(net.fc1.bias)\n", | |
" fc1b_sigma = torch.randn_like(net.fc1.bias)\n", | |
" fc1b_mu_param = pyro.param(\"fc1b_mu\", fc1b_mu)\n", | |
" fc1b_sigma_param = softplus(pyro.param(\"fc1b_sigma\", fc1b_sigma))\n", | |
" fc1b_prior = Normal(loc=fc1b_mu_param, scale=fc1b_sigma_param)\n", | |
" # Output layer weight distribution priors\n", | |
" outw_mu = torch.randn_like(net.out.weight)\n", | |
" outw_sigma = torch.randn_like(net.out.weight)\n", | |
" outw_mu_param = pyro.param(\"outw_mu\", outw_mu)\n", | |
" outw_sigma_param = softplus(pyro.param(\"outw_sigma\", outw_sigma))\n", | |
" outw_prior = Normal(loc=outw_mu_param, scale=outw_sigma_param).independent(1)\n", | |
" # Output layer bias distribution priors\n", | |
" outb_mu = torch.randn_like(net.out.bias)\n", | |
" outb_sigma = torch.randn_like(net.out.bias)\n", | |
" outb_mu_param = pyro.param(\"outb_mu\", outb_mu)\n", | |
" outb_sigma_param = softplus(pyro.param(\"outb_sigma\", outb_sigma))\n", | |
" outb_prior = Normal(loc=outb_mu_param, scale=outb_sigma_param)\n", | |
" priors = {'fc1.weight': fc1w_prior, 'fc1.bias': fc1b_prior, 'out.weight': outw_prior, 'out.bias': outb_prior}\n", | |
" \n", | |
" lifted_module = pyro.random_module(\"module\", net, priors)\n", | |
" \n", | |
" return lifted_module()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"optim = Adam({\"lr\": 0.01})\n", | |
"svi = SVI(model, guide, optim, loss=Trace_ELBO())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"num_iterations = 20" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
" 5%|▌ | 1/20 [00:09<03:06, 9.80s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 0 Loss 2.688241181771451\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 10%|█ | 2/20 [00:19<02:55, 9.72s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 1 Loss 1.2437119732312703\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 15%|█▌ | 3/20 [00:28<02:44, 9.68s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 2 Loss 0.9694359728184138\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 20%|██ | 4/20 [00:38<02:34, 9.65s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 3 Loss 0.8697841923285696\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 25%|██▌ | 5/20 [00:48<02:24, 9.62s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 4 Loss 0.7640265142202554\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 30%|███ | 6/20 [00:57<02:14, 9.60s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 5 Loss 0.6920861386128692\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 35%|███▌ | 7/20 [01:07<02:06, 9.77s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 6 Loss 0.5986790982581044\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 40%|████ | 8/20 [01:18<01:59, 9.96s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 7 Loss 0.56430517981977\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 45%|████▌ | 9/20 [01:27<01:48, 9.85s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 8 Loss 0.5645537788162877\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 50%|█████ | 10/20 [01:37<01:37, 9.78s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 9 Loss 0.582734988232184\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 55%|█████▌ | 11/20 [01:48<01:32, 10.25s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 10 Loss 0.560835290524554\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 60%|██████ | 12/20 [01:58<01:21, 10.15s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 11 Loss 0.5368737962636734\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 65%|██████▌ | 13/20 [02:09<01:11, 10.22s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 12 Loss 0.5437979029670406\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 70%|███████ | 14/20 [02:19<01:01, 10.23s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 13 Loss 0.5032721681746104\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 75%|███████▌ | 15/20 [02:29<00:51, 10.30s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 14 Loss 0.5194673257473221\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 80%|████████ | 16/20 [02:40<00:41, 10.45s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 15 Loss 0.4960786971346121\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 85%|████████▌ | 17/20 [02:52<00:32, 10.92s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 16 Loss 0.524999288924386\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 90%|█████████ | 18/20 [03:02<00:21, 10.66s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 17 Loss 0.5436136195304158\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
" 95%|█████████▌| 19/20 [03:12<00:10, 10.41s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 18 Loss 0.5141672538742846\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
"100%|██████████| 20/20 [03:22<00:00, 10.25s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Epoch 19 Loss 0.5097796323122257\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"for j in tqdm(range(num_iterations)):\n", | |
" loss = 0\n", | |
" for batch_id, data in enumerate(orig_dataloader):\n", | |
" # calculate the loss and take a gradient step\n", | |
" loss += svi.step(data[0].float(), data[1])\n", | |
" normalizer_train = len(orig_dataloader.dataset)\n", | |
" total_epoch_loss_train = loss / normalizer_train\n", | |
" \n", | |
" print(\"Epoch \", j, \" Loss \", total_epoch_loss_train)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 26, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from torch.utils import data\n", | |
"test_dataloader = data.DataLoader(Dataset(test_data), batch_size=64)\n", | |
"flipped_dataloader = data.DataLoader(Dataset(flipped_data), batch_size=64)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 27, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Prediction when network is forced to predict\n", | |
"accuracy: 89 %\n" | |
] | |
} | |
], | |
"source": [ | |
"num_samples = 10\n", | |
"def predict(x):\n", | |
" sampled_models = [guide(None, None) for _ in range(num_samples)]\n", | |
" yhats = [model(x).data for model in sampled_models]\n", | |
" mean = torch.mean(torch.stack(yhats), 0)\n", | |
" return np.argmax(mean.numpy(), axis=1)\n", | |
"\n", | |
"print('Prediction when network is forced to predict')\n", | |
"correct = 0\n", | |
"total = 0\n", | |
"for j, data in enumerate(test_dataloader):\n", | |
" feats, labels = data\n", | |
" predicted = predict(feats.float())\n", | |
" total += labels.size(0)\n", | |
"# print(torch.tensor(predicted), labels)\n", | |
"# print(predicted == labels.values)\n", | |
" correct += (torch.tensor(predicted) == labels).sum().item()\n", | |
"# print(correct)\n", | |
"# break\n", | |
"print(\"accuracy: %d %%\" % (100 * correct / total))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 28, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"num_samples = 100\n", | |
"def give_uncertainities(x):\n", | |
" sampled_models = [guide(None, None) for _ in range(num_samples)]\n", | |
" yhats = [F.log_softmax(model(x.float()).data, 1).detach().numpy() for model in sampled_models]\n", | |
" return np.asarray(yhats)\n", | |
" #mean = torch.mean(torch.stack(yhats), 0)\n", | |
" #return np.argmax(mean, axis=1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 29, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"classes = ['0','1']" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 30, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import matplotlib.pyplot as plt\n", | |
"from matplotlib import style\n", | |
"style.use('fivethirtyeight')\n", | |
"import numpy as np\n", | |
"from scipy.stats import multivariate_normal\n", | |
"%matplotlib inline\n", | |
"\n", | |
"def imshow(img):\n", | |
" \n", | |
" x = np.linspace(-5,5,500)\n", | |
" y = np.linspace(-5,5,500)\n", | |
" X,Y = np.meshgrid(x,y)\n", | |
"\n", | |
" pos = np.array([X.flatten(),Y.flatten()]).T\n", | |
" rv1 = multivariate_normal([-1,0], [[1, 0], [0, 1]])\n", | |
" rv2 = multivariate_normal([1,0], [[1, 0], [0, 1]])\n", | |
" fig = plt.figure(figsize=(3,3))\n", | |
" ax0 = fig.add_subplot(111)\n", | |
"\n", | |
" ax0.contour(X, Y, rv1.pdf(pos).reshape(500,500), 6, linewidths=np.arange(0.3 , 2, .5), alpha = 0.9, colors='blue')\n", | |
" ax0.contour(X, Y, rv2.pdf(pos).reshape(500,500), 6, linewidths=np.arange(.5, 2, .5), alpha = 0.3, colors='red')\n", | |
" \n", | |
" ax0.scatter(img[0], img[1], s=10, c=\"black\", alpha=1)\n", | |
" plt.show()\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 41, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from matplotlib import colors\n", | |
"def test_batch(images, labels, plot=True):\n", | |
" y = give_uncertainities(images)\n", | |
" predicted_for_images = 0\n", | |
" correct_predictions=0\n", | |
"\n", | |
" for i in range(len(labels)):\n", | |
" \n", | |
" if(plot):\n", | |
" \n", | |
" fig, axs = plt.subplots(1, 10, sharey=True,figsize=(20,2))\n", | |
" \n", | |
" all_digits_prob = []\n", | |
" \n", | |
" highted_something = False\n", | |
" \n", | |
" for j in range(len(classes)):\n", | |
" \n", | |
" highlight=False\n", | |
" \n", | |
" histo = []\n", | |
" histo_exp = []\n", | |
" \n", | |
" for z in range(y.shape[0]):\n", | |
" histo.append(y[z][i][j])\n", | |
" histo_exp.append(np.exp(y[z][i][j]))\n", | |
" \n", | |
" prob = np.percentile(histo_exp, 50) #sampling median probability\n", | |
" \n", | |
" if(prob>0.2): #select if network thinks this sample is 20% chance of this being a label\n", | |
" highlight = True #possibly an answer\n", | |
" \n", | |
" all_digits_prob.append(prob)\n", | |
" \n", | |
" if(plot):\n", | |
" \n", | |
" N, bins, patches = axs[j].hist(histo, bins=8, color = \"lightgray\", lw=0, weights=np.ones(len(histo)) / len(histo), density=False)\n", | |
" axs[j].set_title(str(j)+\" (\"+str(round(prob,2))+\")\") \n", | |
" \n", | |
" if(highlight):\n", | |
" \n", | |
" highted_something = True\n", | |
" \n", | |
" if(plot):\n", | |
"\n", | |
" # We'll color code by height, but you could use any scalar\n", | |
" fracs = N / N.max()\n", | |
"\n", | |
" # we need to normalize the data to 0..1 for the full range of the colormap\n", | |
" norm = colors.Normalize(fracs.min(), fracs.max())\n", | |
"\n", | |
" # Now, we'll loop through our objects and set the color of each accordingly\n", | |
" for thisfrac, thispatch in zip(fracs, patches):\n", | |
" color = plt.cm.viridis(norm(thisfrac))\n", | |
" thispatch.set_facecolor(color)\n", | |
"\n", | |
" predicted = np.argmax(all_digits_prob)\n", | |
" \n", | |
" if(labels[i].item()!=predicted):\n", | |
" if(plot):\n", | |
" print(\"Real: \",labels[i].item())\n", | |
" plt.show()\n", | |
"\n", | |
"\n", | |
"\n", | |
" if(highted_something):\n", | |
" predicted_for_images+=1\n", | |
" if(labels[i].item()==predicted):\n", | |
" if(plot):\n", | |
" print(\"Correct\")\n", | |
" correct_predictions +=1.0\n", | |
" else:\n", | |
" if(plot):\n", | |
" print(\"Incorrect :()\")\n", | |
" else:\n", | |
" if(plot):\n", | |
" print(\"Undecided.\")\n", | |
"\n", | |
" if(plot):\n", | |
" imshow(images[i])\n", | |
" else:\n", | |
" plt.clf()\n", | |
" plt.cla()\n", | |
" plt.close()\n", | |
"\n", | |
"\n", | |
" if(plot):\n", | |
" print(\"Summary\")\n", | |
" print(\"Total images: \",len(labels))\n", | |
" print(\"Predicted for: \",predicted_for_images)\n", | |
" print(\"Accuracy when predicted: \",correct_predictions/predicted_for_images)\n", | |
" \n", | |
" return len(labels), correct_predictions, predicted_for_images\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 32, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Prediction when network can refuse\n", | |
"Total images: 1000\n", | |
"Skipped: 899\n", | |
"Accuracy when made predictions: 0 %\n" | |
] | |
} | |
], | |
"source": [ | |
"# Prediction when network can decide not to predict\n", | |
"\n", | |
"print('Prediction when network can refuse')\n", | |
"correct = 0\n", | |
"total = 0\n", | |
"total_predicted_for = 0\n", | |
"for j, data in enumerate(test_dataloader):\n", | |
" images, labels = data\n", | |
" \n", | |
" total_minibatch, correct_minibatch, predictions_minibatch = test_batch(images, labels, plot=False)\n", | |
" total += total_minibatch\n", | |
" correct += correct_minibatch\n", | |
" total_predicted_for += predictions_minibatch\n", | |
"\n", | |
"print(\"Total images: \", total)\n", | |
"print(\"Skipped: \", total-total_predicted_for)\n", | |
"print(\"Accuracy when made predictions: %d %%\" % (100 * correct / total_predicted_for))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 33, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# preparing for evaluation\n", | |
"\n", | |
"dataiter = iter(test_dataloader)\n", | |
"images, labels = dataiter.next()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Showing incorrect classfications with their posterior distributions" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 43, | |
"metadata": { | |
"scrolled": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Real: 1\n" | |
] | |
}, | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAABQ4AAACZCAYAAABufKdoAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAFrhJREFUeJzt3X+wZGV5J/DvAyNolMIoapBBGWVYgm4MBEmiWX+siAOVZcxKBKv8VWs2sZQkqzG7WLqWq1aMWinXraXWRHSNGgPormZioWAUN9lEEGMERYN3RJQJuqhxJyaCSvLuH93DNvfcH90z3X3vPf35VHVN9+n3vP2c8z09NfXMOX2qtRYAAAAAgFGHbXQBAAAAAMDmo3EIAAAAAHRoHAIAAAAAHRqHAAAAAECHxiEAAAAA0KFxCAAAAAB0aBwCAAAAAB0ahyOqantVfbuqtm90LUlSVQ+vqm9V1bEbXcsiqKrTq+rrVXXfja4lSarqZ6vqa1X1IxtdCwAAALB4etk4rKpzquqzVfX9qrqlql465qqvT/Ke1tq+kbmOqqq3DRuK/1BVH66qR45Rw+FVdVFV3TSs4/aqetvI+8dX1VVVddvw/duq6l2jTcvW2leTXJ7kteNv/WKrqidU1R9V1VerqlXVKydY/c1J3tha+4eR+Y6tqsur6u+Gj0ur6sHr1PC0qvrksOl7Z1V9uapeV1VHjIx557C+5Y9/OjB/a+2TST6fZNzjFwAAAGBqetc4rKrTk/xRko8k+ckkr07yW1X1wnXWOzbJ+Ul+b9lb707ylCTnJfm5JJXko1V1n3VK+e9JfinJy5P8eJKnJrli5P27kvyPJP8qyc4kz0xyUpI/XjbPJUmeXVXHrPN5DNwvyReS/Psk3xh3peFx89gkvz+y7LAkH0qyI4P8zsogow9WVa0x3d8leUuSJyX5Z0l+I8kvJ3nDyJhfT3Lssse1Sa5urd0+Mu6SJC+uqnuNuy0AAAAA01CttY2uYaqq6r1JTmitPW5k2ZuSnNda27HGev8uyYtaayeNLDspyU1JntZau2q47EczaEj9SmvtnavM9eQkH03ymNbajRPUvjvJB5Pcv7W2f2T5V5O8vrX21nHnIqmqW5Jc0lp73Rhj/3OSU1prZ40sOyvJlUlObq3dNFz2qAzOAnxya+0TE9Ty5iRPaq2dusr7B461Z7bW3jey/N5J9ifZ3Vr7yLifBwAAAHCoenfGYZLHZ3C24aiPJDlhnd8ufGKST60w1w+TfOzAgtbad4bjfm6NuZ6R5OYkZ1bV3qq6dXi568NWW2F4RuFzknxmtGk4dG2SJ6/xeRy61fL/yoGmYZIMG8H7snb+91BVJyc5O8nVawz7lSS3Z9A4vltr7c4k10f+AAAAwJz1sXF4bLqXqH5j5L3V7EjyNyvM9a3W2j+uMN9acz0yycOSPDuDy5XPT/KgJB8fnkF2t6r6w6r6XpJvJnlokl0rzLcvySPW+DwO3Wr5r3S583r5J0mqal9VfT/JF5P8ryS/ucq4I5M8L8k7Wms/XGGI/AEAAIC562PjcC1rXZd9nyR3Tmmuw5McmeR5rbVPtNb+IoPm4Y4k5ywb+5Ikp2ZwRlqSXFpVhy8bc+ewPmZnmvkf8C+SnJbBmaQ/n+RVq4w7L8kDkrxtlfflDwAAAMzdto0uYAa+nuTHli17yPDPtW6W8c0MmjfL5zqmqg5fdtbhQ5J8aY25bsugsfTXBxa01m6vqm8lefjowNbaN4Z13VRV1w/XfWruebn1A4b1MTur5X/mCmMfkjFuvNJa+8rw6Y1V9Y9J3lNV97hr89ALk1zVWrt5lakeMKwFAAAAYG76eMbhnyd52rJlu5J8tbW2b431PpPkUSvMda8k//LAgqq6f5KfTvK/15jrzzK4+/LOkfUemOSYJLessd6BPI5ctvyfJ/n0Gutx6FbLf0dVjeb440mOz9r5r+Sw4eMed0euqlMy+L3E311jXfkDAAAAc9fHuyo/NslfJHljkncnOSODpsxL1rorcVXtSvL+JA9qrd0xsvyDSR6d5AUZ3N32t5KcnORRB8ZV1buSpLX23OHrH8ngzru3Jfn1DG6w8oYMfqfuMa21O6vqGUnum0HD6rsZNBn/Uwa/jXhKa+27w7mOyuBsuF2T3MV3UVXV/ZKcOHx5RZL/meSSJH/fWtu7xnovTPKy1tqJI8sOS3JdkruS/GoGzeCLM8jzcW345amqjyX5VGvt5cPXv5HB2aZfyuDM09OT/E6ST7fWdi/73Lck+cUkD2ut3bVCXTszuNvyiWuckQgAAAAwdb0747C1dl2Sp2fwm3LXJ3ltkles1TQcuirJ/0nyC8uWPyfJJ5J8IIOG5GFJzhptLmbQ7Lv7jsmtte9lcInr3w7X/ViSv09y5vAuuUny/SQvyuDMtZsyaG7ekEFD6rsjc5+X5BZNw7GdnuSvho9jk7x4+PySddb7gyQPrqrHHVjQWvunDI6jr2WQ4UeTfDnJ7nbPjvsjc8+bpdwryZuSfHb42a/MoOH4rNEPrKr7JHlukrev1DQcenaSj2oaAgAAAPPWuzMOD0VVPSfJS5Oc1jbBjhme8XZ9kte11i7b6Hr6rqr+Y5Kfaq09faNrSe4+e3Jvkqe31q7Z6HoAAACAxdK7Mw4P0XuSvDfJQze6kKHjkrxT03Bu3pTkL6vqvhtdyNCOJK/UNAQAAAA2gjMOAQAAAIAOZxwCAAAAAB3b5vVB+/fvd2rjFnP00UfXNOaR/dYj+8Um/8Ul+8U1rewT+W9FvvuLS/aLTf6LS/aL62Cyd8YhAAAAANChcQgAAAAAdGz5xuHS0tJGl7Am9W0di7YvFm17p8m+W2zyX1yyX2zyX1yyX1yyX2zyX1yyv6ct3zgEAAAAAKZP4xAAAAAA6JjbXZWZj6PueOyEa7x3JnUAAAAAsLU54xAAAAAA6NA4BAAAAAA6NA4BAAAAgI6xGodVtauqbqqqvVV10SpjnllVX6iqG6vKD+cBAAAAwBa27s1RqurwJBcneWqSfUmuq6o9rbUvjIzZmeTlSR7fWvtOVT14VgUDAAAAALM3zhmHZyTZ21q7ubX2gySXJtm9bMy/TXJxa+07SdJau326ZQIAAAAA8zRO4/C4JLeOvN43XDbqpCQnVdWfV9U1VbVrWgUCAAAAAPO37qXKSWqFZW2FeXYmeVKS7Un+rKoe3Vr7vytNuLS0NEmN65r2fNM2z/pO2z75OqP17dy5c4rVrP1Zm9Fmr2/aZH/w+rY98p9Mn7ZH9pPp0/bMOvukX/sr6df2+O5Ppk/bI/vJ9G175D+ZPm2P7CfTp+051OzHaRzuS3L8yOvtSW5bYcw1rbUfJvlKVd2UQSPxupUmnOYBu7S0NJd/+B6sudd3x+SrzLM+WW0e897ePu3bRTtWpqFP+0v+k+nTvpL95Pq0v+Q/mT7tK9lPpk/7SvaT69P+kv9k+rSvZH9P41yqfF2SnVW1o6qOSHJBkj3LxnwwyZOTpKqOyeDS5ZunWSgAAAAAMD/rNg5ba3cluTDJlUm+mOTy1tqNVfWaqjp3OOzKJN+uqi8kuTrJb7bWvj2rogEAAACA2RrnUuW01q5IcsWyZa8aed6SvHT4AAAAAAC2uHEuVQYAAAAAFozGIQAAAADQoXEIAAAAAHRoHAIAAAAAHRqHAAAAAECHxiEAAAAA0KFxCAAAAAB0bNvoAoDpOOqOx064xntnUgcAAADQD844BAAAAAA6NA4BAAAAgA6NQwAAAACgQ+MQAAAAAOjQOAQAAAAAOjQOAQAAAIAOjUMAAAAAoEPjEAAAAADo0DgEAAAAADo0DgEAAACADo1DAAAAAKBD4xAAAAAA6NA4BAAAAAA6NA4BAAAAgA6NQwAAAACgY6zGYVXtqqqbqmpvVV20xrjzqqpV1enTKxEAAAAAmLd1G4dVdXiSi5OcneSUJM+qqlNWGHdUkl9Lcu20iwQAAAAA5mucMw7PSLK3tXZza+0HSS5NsnuFca9N8sYkd06xPgAAAABgA4zTODwuya0jr/cNl92tqk5Ncnxr7UNTrA0AAAAA2CDbxhhTKyxrd79ZdViSNyd5/rgfurS0NO7QDZlv2uZZ32nbJ19ntL6dO3dOsZq1P2sz2uz1rUX289W37ZH/ZPq0PbKfTJ+2Z9bZJ/3aX0m/tsd3fzJ92h7ZT6Zv2yP/yfRpe2Q/mT5tz6FmP07jcF+S40deb09y28jro5I8OsknqipJfizJnqo6t7X26ZUmnOYBu7S0NJd/+B6sudd3x+SrzLM+Wc2Q7Odmyx8rG6BP+0v+k+nTvpL95Pq0v+Q/mT7tK9lPpk/7SvaT69P+kv9k+rSvZH9P41yqfF2SnVW1o6qOSHJBkj0H3myt7W+tHdNaO6G1dkKSa5Ks2jQEAAAAADa/dRuHrbW7klyY5MokX0xyeWvtxqp6TVWdO+sCAQAAAID5G+dS5bTWrkhyxbJlr1pl7JMOvSwAAAAAYCONc6kyAAAAALBgNA4BAAAAgA6NQwAAAACgQ+MQAAAAAOjQOAQAAAAAOjQOAQAAAIAOjUMAAAAAoEPjEAAAAADo0DgEAAAAADo0DgEAAACADo1DAAAAAKBD4xAAAAAA6NA4BAAAAAA6NA4BAAAAgI5tG10AsLLHXPZ7E42/+dwZFQIAAAAsJGccAgAAAAAdGocAAAAAQIfGIQAAAADQoXEIAAAAAHRoHAIAAAAAHRqHAAAAAECHxiEAAAAA0KFxCAAAAAB0jNU4rKpdVXVTVe2tqotWeP+lVfWFqrqhqj5WVQ+ffqkAAAAAwLys2zisqsOTXJzk7CSnJHlWVZ2ybNhfJTm9tfYTSd6f5I3TLhQAAAAAmJ9xzjg8I8ne1trNrbUfJLk0ye7RAa21q1tr3xu+vCbJ9umWCQAAAADM07YxxhyX5NaR1/uS/PQa41+Q5MNrTbi0tDTGx45v2vNN2zzrO+0gWraj9e3cuXOK1az9WZvRZq9v2mR/8Pq2PfKfTJ+2R/aT6dP2zDr7pF/7K+nX9vjuT6ZP2yP7yfRte+Q/mT5tj+wn06ftOdTsx2kc1grL2ooDq56d5PQkT1xrwmkesEtLS3P5h+/Bmnt9d0y+yjzrk9UEPnP1zD9C9gdn0x0rW0Cf9pf8J9OnfSX7yfVpf8l/Mn3aV7KfTJ/2lewn16f9Jf/J9Glfyf6exmkc7kty/Mjr7UluWz6oqs5M8ookT2ytfX865XHWr759ovGf9OuSAAAAAEzBOL9xeF2SnVW1o6qOSHJBkj2jA6rq1CS/m+Tc1trt0y8TAAAAAJindRuHrbW7klyY5MokX0xyeWvtxqp6TVWdOxz2piT3S/K+qvpsVe1ZZToAAAAAYAsY51LltNauSHLFsmWvGnl+5pTrAgAAAAA20DiXKgMAAAAAC0bjEAAAAADo0DgEAAAAADo0DgEAAACADo1DAAAAAKBD4xAAAAAA6NA4BAAAAAA6NA4BAAAAgA6NQwAAAACgY9tGFwCLYvf5b51shX+trw8AAABsHI3DOfvFp75lshVOvt9sCgEAAACANTilCQAAAADo0DgEAAAAADo0DgEAAACADo1DAAAAAKBD4xAAAAAA6NA4BAAAAAA6NA4BAAAAgA6NQwAAAACgQ+MQAAAAAOjQOAQAAAAAOjQOAQAAAICObRtdwFZ3/o6XTbbCicfPphDm7oKTLppshVNPmEkdAAAAALMwVuOwqnYleUuSw5Nc0lr77WXvH5nkXUl+Ksm3k5zfWrtluqXOxzOP+7WJxtcRR8yokv47/2EvmWj8ZV9780znn1Td+8iZzg8AAACwkda9VLmqDk9ycZKzk5yS5FlVdcqyYS9I8p3W2olJ3pzkDdMuFAAAAACYn2qtrT2g6meTvLq19rTh65cnSWvt9SNjrhyO+WRVbUvyjSQPaiOT79+/f+0PYtM5+uijaxrzyH7rkf1ik//ikv3imlb2ify3It/9xSX7xSb/xSX7xXUw2Y9zc5Tjktw68nrfcNmKY1prdyXZn+SBkxYDAAAAAGwO4zQOV+pGLu8qjzMGAAAAANgixrk5yr4ko7cC3p7ktlXG7Bteqnx0kr8dHTDNS2DYWmS/uGS/2OS/uGS/2OS/uGS/uGS/2OS/uGS/GMY54/C6JDurakdVHZHkgiR7lo3Zk+R5w+fnJfl4W+/HEwEAAACATWvdxuHwNwsvTHJlki8muby1dmNVvaaqzh0Oe3uSB1bV3iQvTXLRrAquqldX1d9U1WeHj3NWGberqm6qqr1VNbN6VvjcN1XVX1fVDVX1gaq6/yrjbqmqzw234dNzqGvN/VFVR1bVZcP3r62qE2Zd02Y07vG11c3i+9HHY2yMbXp+VX1z5Hj5pY2ocxJV9Y6qur2qPr/K+1VV/2W4zTdU1Wljziv/TZ6/7MfXt+yT2eQv+8XNfrie/Dd5/rIfX9+yT+Q/ib7lL/vx9S37ZEb5t9a21CPJq5O8bJ0xhyf5cpJHJDkiyfVJTplTfWcl2TZ8/oYkb1hl3C1JjplTTevujyQvSvLW4fMLkly20Vlv1uNrqz9m8f3o4zE25jY9P8l/3ehaJ9yuJyQ5LcnnV3n/nCQfzuC3a38mybXy70f+sl/c7GeRv+w3vt6Nyl7+Wyd/2S9u9vJf7Pxlv7jZzyr/cS5V3orOSLK3tXZza+0HSS5NsnseH9xau6oNztJMkmsy+E3IjTbO/tid5PeHz9+f5ClV5fcK+mkW348+HmMb9vfILLXW/jTLfoN2md1J3tUGrkly/6o6dp1p5b8FyH5svcs+mUn+st8ifPfH1rv8ZT+23mWfyH8Cvctf9mPrXfbJbPLfqo3DC4enVL6jqn50hfePS3LryOt9w2Xz9m8y6OSupCW5qqr+sqp+ecZ1jLM/7h4zbHzuT/LAGde1Wa13fG11s/h+9PEYG3c/PWN4vLy/qo5f4f2t5mCOD/n3I3/ZDyxi9snk+ct+cbMfdx35b36yH1jE7BP5H7CI+ct+YBGzTw4i/03ZOKyqP6mqz6/w2J3kvyV5ZJKfTPL1JL+z0hQrLJvazVrWqe/AmFckuSvJH6wyzeNba6clOTvJi6vqCdOqb6WSV1i2fH/MdJ9tJlM4vra6WWTdx2NsnHr/OMkJrbWfSPIn+f//w7aVHUxO8u9H/rIfWMTsk8lzkv3iZj/uOvLf/GQ/sIjZJ/I/YBHzl/3AImafHERO22ZUyCFprZ05zriqeluSD63w1r4ko53g7Ulum0JpSdavr6qel+TnkzylDS8iX2GO24Z/3l5VH8jgNNk/nVaNy4yzPw6M2VdV25IcnbVPb92ypnB8bXWz+H708Rhbd5taa98eefm2DH7XdKs7mOND/v3IX/YDi5h9Mnn+sl/c7MddR/6bn+wHFjH7RP4HLGL+sh9YxOyTg8h/U55xuJZl117/QpKV7hRzXZKdVbWjqo7I4Ic598ypvl1J/kOSc1tr31tlzH2r6qgDzzO4ocqKd7yZknH2x54kzxs+Py/Jx1drevbZmMfXVjeL70cfj7F1t2nZ8XJuBnee3+r2JHnu8G5bP5Nkf2vt6+usI/9+5C/7gUXMPpk8f9kvbvaJ/JN+5C/7gUXMPpH/AYuYv+wHFjH75GDyb5vgri+TPJK8O8nnktww3OBjh8sfmuSKkXHnJPlSBnfJecUc69ubwfXinx0+3rq8vgzu2nP98HHjPOpbaX8keU0GDc4kuXeS9w3r/1SSR2x01pvp+OrbYxbfjz4eY2Ns0+uH3+Hrk1yd5OSNrnmMbfrDDC7D/2EG/9v0giQvTPLC4fuV5OLhNn8uyeny70f+sl/c7GeVv+wXN3v5b438Zb+42ct/sfOX/eJmP6v8a7giAAAAAMDdttylygAAAADA7GkcAgAAAAAdGocAAAAAQIfGIQAAAADQoXEIAAAAAHRoHAIAAAAAHRqHAAAAAECHxiEAAAAA0PH/AByR0VdUWStgAAAAAElFTkSuQmCC\n", | |
"text/plain": [ | |
"<Figure size 1440x144 with 10 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Incorrect :()\n" | |
] | |
}, | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 216x216 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Real: 1\n" | |
] | |
}, | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 1440x144 with 10 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Incorrect :()\n" | |
] | |
}, | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 216x216 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Real: 1\n" | |
] | |
}, | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 1440x144 with 10 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Incorrect :()\n" | |
] | |
}, | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 216x216 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Real: 1\n" | |
] | |
}, | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAABQ4AAACZCAYAAABufKdoAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAFhFJREFUeJzt3X20bHdZH/DvQ0KMQFYQKBKTW3KBmxUDZWEIAYsFLAghbRPaUl5aBFpwydLoKtRWFKFpQBGopVpTaKEUgdLw0gq3GogFQRANBMpbAo3nCoFzDRSl9HaVtxB5+sfMjcOZc8+ZOffMzLkzn89aZ+XMnt/e8/x+333WuuvJ3rOruwMAAAAAMOp2iy4AAAAAANh7NA4BAAAAgDEahwAAAADAGI1DAAAAAGCMxiEAAAAAMEbjEAAAAAAYo3EIAAAAAIzROBxRVWdV1Zer6qxF15IkVXXPqvqzqjpj0bWsgqq6oKq+UFV3XHQtSVJVP1hVn6+qOyy6FgAAAGD1LGXjsKourqqPVdU3q+qmqnrOhLu+OMkbuvvwyLFOq6pXDRuKX62qd1TVvSeo4e9V1Ueq6v9V1Zeq6r9W1X1G3j+1qv5jVX20qm6pqkMbj9Hdn0vy5iQvnLD+lVdVD6uqt1fV56qqq+oXptj95Ule2t1fHTneGVX15qr6v8Ofq6rq7tvUcN+qektVrVXVt6vq1ccY98yq+mRVfW3YILy8qm77m+zuP0xyfZJJz18AAACAXbN0jcOquiDJ25O8M8kDklye5Jeq6lnb7HdGkicm+fcb3np9kkcmeXySH0pSSf57VX33Fsd6cJKrkvyXJPdPcnGSuyb57ZFhJyW5Zfh5V21R2quTPKWq7rZV/dzmTkk+leSfJfnipDsNz5sHJfmNkW23S/JbSfYn+ZEkj05yTpK3VVVtcbg7JPl8kiuSfPwYn/djSf5Nkn+Z5H5JLkvyrIw3iV+d5Cer6vaTzgUAAABgN1R3L7qGXVVVb0xydnf/1ZFtL0vy+O7ev8V+/zjJT3T3OSPbzklyY5LHdPfvDLd9TwYNqR/v7tducaznd/ddR7b9rSQHk9y5u49sGH95kqd0932yiar6XJIXd/crt5o736mqbkry6u5+0QRj/3WS87r70SPbHp3kmiTndveNw233zeAqwB/u7vdOcNz3JjnU3c/csP33k1zf3c8a2fbsJC9KcvejVz1W1alJjiS5tLvfud3nAQAAAOyWpbviMMlDM7jacNQ7k5y9zXcXPjzJhzY51reSvPvohu7+ynDcD21xrD9IcueqekJV3a6q7pzkR5N8YGPTcEIfTPLDO9iPyR0r/88ebRomSXffkORwts5/Eqcm+caGbV/P4GrFC0Y+7xsZXLUofwAAAGCulrFxeEbGb1H94sh7x7I/yZ9scqw/6+4/3+R4xzxWd38oyeOSvDLJN5N8JclfTnLplpUf2+Ek99rhvkzmWPlvdrvzlvlP6B1JnlpVD62B70/y7OF737dhrPwBAACAuVvGxuFWtrov+7szfgXYjo5VVecmeUUGD9t4UJK/nsGVi79ZVSdN8RlHfWNYH7Oza/lP6EUZPPjmPRmcG+9P8obhexsb1fIHAAAA5m4ZG4dfSHKPDdu+d/jfrR6W8adJ7rLJse62SbPve7c51s8n+WR3v7C7P9bd70ny95P8tezsltO7DOtjdo6V/8ZzKdk+/2119zeH3294hyRnZ3AF40eGb//xhuHyBwAAAOZuGRuHH0jymA3bLkryue4+vMV+/yPJfTc51u0zuGIwSTL8vsIHJ/n9LY51xyTf3rDt6FVkWz2N91j+SpIP72A/Jnes/PdX1YGjG4a3FO/L1vlPrLtv7e7D3f2tDJrLn03y0Q3D5A8AAADM3TI2Dl+e5MKq+sWqOreqnprkp5L88jb7XT3c77ZbQrv7j5K8PckrqurhVfWAJG/M4Lvw3nR0XFW9rqpeN3KstyW5qKqeXVX3rqoLkrw2yc0ZPOjk6H7nDY95jySnVNUDhj+njIw5LckDk/z29EuxeqrqTkfXMckpSe4xfL3pE6tHXJ3kYRu2vSuDhuIbqurCqnpwktcnuTbJ74185rur6sUjr08ZqeFOSe4yfH3eyJj7VNXTquqcqnpgVf3bJE/M4Mne3x4ZdyCDqxHfMf1qAAAAAOxcdR/vV7XtPVX1N5L8UpJzM7il9Fe7+19ts8/tkqwleX53v3Fk+2kZNCP/TgZPwn1fksu6+9DImPcmSXc/YmTbjyf5yST3TvLVDJpNP9/d14+MuSnJPTcpZ3933zQc8w+T/Gx3nzvR5FdcVT0ig+8N3Oj3RvPZZL/TMmgIX9TdfzCy/Ywkv5bBVaudQQPvp7r7SyNjbkry3u5++vD12RlcObjR57r77OGYczJoQn//8LjXJfnn3f2+DXX9iyQP6e6NV9ECAAAAzNRSNg53qqp+NMlzkpzfe2Bhhs3Mjyd5UXe/abvxHJ+qen6SB3b34xZdSzK4ejLJoSSP6+5rF10PAAAAsFqW8Vbl4/GGDK4C+75FFzJ0ZpLXahrOzcuSfKSq7rjoQob2J/kFTUMAAABgEVxxCAAAAACMccUhAAAAADDm5Hl90JEjR1zaeII5/fTTazeOI/sTj+xXm/xXl+xX125ln8j/RORvf3XJfrXJf3XJfnXtJHtXHAIAAAAAYzQOAQAAAIAxK9k4XFtbW3QJe8aqr4X5r/b8d8q6rTb5ry7Zrzb5ry7Zry7Zrzb5ry7Zf6eVbBwCAAAAAFvTOAQAAAAAxsztqcrMx2lff9CUe7xxJnUAAAAAcGJzxSEAAAAAMEbjEAAAAAAYo3EIAAAAAIzROAQAAAAAxmgcAgAAAABjNA4BAAAAgDEahwAAAADAGI1DAAAAAGCMxiEAAAAAMEbjEAAAAAAYM1HjsKouqqobq+pQVT33GGOeUFWfqqobquqNu1smAAAAADBPJ283oKpOSnJlkh9JcjjJdVV1sLs/NTLmQJKfS/LQ7v5KVd19VgUDAAAAALM3yRWHFyY51N2f6e5bklyV5NINY34syZXd/ZUk6e4v7W6ZAAAAAMA8TdI4PDPJ+sjrw8Nto85Jck5VfaCqrq2qi3arQAAAAABg/ra9VTlJbbKtNznOgSSPSHJWkvdX1f26+/9sdsC1tbVpapyJvVDDLJx/1vT7jK7FgQMHdrGarT9rr9iLNc3T0fmvYvbHY9nmI//pLNN8ZD+dZZrPrLNPlmu9kuWaj7/96SzTfGQ/nWWbj/yns0zzkf10lmk+x5v9JI3Dw0n2jbw+K8nNm4y5tru/leSzVXVjBo3E6zY74Dz+obqVtbW1hdcwM1+ffpd5rsVeW/elPhcmMM/5L9M6r/p5sxPLtF7yn84yrZXsp7dM6yX/6SzTWsl+Osu0VrKf3jKtl/yns0xrJfvvNMmtytclOVBV+6vqlCRPSnJww5i3JfnhJKmqu2Vw6/JndrNQAAAAAGB+tm0cdvetSS5Lck2STyd5c3ffUFVXVNUlw2HXJPlyVX0qyXuS/NPu/vKsigYAAAAAZmuSW5XT3VcnuXrDtheM/N5JnjP8AQAAAABOcJPcqgwAAAAArBiNQwAAAABgjMYhAAAAADBG4xAAAAAAGKNxCAAAAACM0TgEAAAAAMZoHAIAAAAAYzQOAQAAAIAxGocAAAAAwBiNQwAAAABgjMYhAAAAADBG4xAAAAAAGKNxCAAAAACM0TgEAAAAAMZoHAIAAAAAYzQOAQAAAIAxGocAAAAAwBiNQwAAAABgjMYhAAAAADBG4xAAAAAAGKNxCAAAAACMmahxWFUXVdWNVXWoqp67xbjHV1VX1QW7VyIAAAAAMG/bNg6r6qQkVyZ5bJLzkjy5qs7bZNxpSX46yQd3u0gAAAAAYL4mueLwwiSHuvsz3X1LkquSXLrJuBcmeWmSb+xifQAAAADAApw8wZgzk6yPvD6c5MGjA6rqB5Ls6+7fqqqf2e6Aa2trUxU5C3uhhlk4/6zp9xldiwMHDuxiNVt/1l6xF2uap6PzX8Xsj8eyzUf+01mm+ch+Oss0n1lnnyzXeiXLNR9/+9NZpvnIfjrLNh/5T2eZ5iP76SzTfI43+0kah7XJtr7tzarbJXl5kqdP+qHz+IfqVtbW1hZew8x8ffpd5rkWe23dl/pcmMA8579M67zq581OLNN6yX86y7RWsp/eMq2X/KezTGsl++ks01rJfnrLtF7yn84yrZXsv9MktyofTrJv5PVZSW4eeX1akvsleW9V3ZTkIUkOekAKAAAAAJy4JmkcXpfkQFXtr6pTkjwpycGjb3b3ke6+W3ef3d1nJ7k2ySXd/eGZVAwAAAAAzNy2jcPuvjXJZUmuSfLpJG/u7huq6oqqumTWBQIAAAAA8zfJdxymu69OcvWGbS84xthHHH9ZAAAAAMAiTdQ4BOZvfX19+0Ej9u3bt/0gAAAAgAlN8h2HAAAAAMCK0TgEAAAAAMZoHAIAAAAAYzQOAQAAAIAxGocAAAAAwBiNQwAAAABgjMYhAAAAADBG4xAAAAAAGKNxCAAAAACM0TgEAAAAAMZoHAIAAAAAYzQOAQAAAIAxGocAAAAAwBiNQwAAAABgjMYhAAAAADDm5EUXAOyO9fX1qcbv27dvRpUAAAAAy8AVhwAAAADAGI1DAAAAAGCMxiEAAAAAMGaixmFVXVRVN1bVoap67ibvP6eqPlVVn6iqd1fVPXe/VAAAAABgXrZtHFbVSUmuTPLYJOcleXJVnbdh2EeTXNDd90/y1iQv3e1CAQAAAID5meSKwwuTHOruz3T3LUmuSnLp6IDufk93f2348tokZ+1umQAAAADAPE3SODwzyfrI68PDbcfyjCTvOJ6iAAAAAIDFOnmCMbXJtt50YNVTklyQ5OFbHXBtbW2Cj52tvVDDLJy/g2s9R9fiwIEDu1jN1p+1V+zFmpLk1FNPnenxj8776H9XMfvjsWzzkf90lmk+sp/OMs1n1tkny7VeyXLNx9/+dJZpPrKfzrLNR/7TWab5yH46yzSf481+ksbh4ST7Rl6fleTmjYOq6lFJnpfk4d39za0OOI9/qG5lbW1t4TXMzNen32Wea7HX1n0vnwvr6+vbDzoOBw4cmOv89+o678RePm/2qmVaL/lPZ5nWSvbTW6b1kv90lmmtZD+dZVor2U9vmdZL/tNZprWS/Xea5Fbl65IcqKr9VXVKkiclOTg6oKp+IMm/S3JJd39p98sEAAAAAOZp2ysOu/vWqrosyTVJTkrymu6+oaquSPLh7j6Y5GVJ7pTkLVWVJJ/v7ktmWPfKuPgZr5pq/Pt/fUaFAAAAALBSJrlVOd19dZKrN2x7wcjvj9rlugAAAACABZrkVmUAAAAAYMVoHAIAAAAAYzQOAQAAAIAxGocAAAAAwJiJHo4CHL/19fVFlwAAAAAwMY1D2CGNQAAAAGCZuVUZAAAAABijcQgAAAAAjNE4BAAAAADGaBwCAAAAAGM0DgEAAACAMRqHAAAAAMCYkxddAOwV6+vriy4BAAAAYM9wxSEAAAAAMEbjEAAAAAAYo3EIAAAAAIzROAQAAAAAxmgcAgAAAABjNA4BAAAAgDEahwAAAADAmJMXXcA0nnj2P5l6nzfd9CszqOQv7KSmqTzy3NkeHwAAAAA2MVHjsKouSvKrSU5K8uru/uUN739XktcleWCSLyd5YnfftLulzscTzvzpqcbX7W8/o0o4Xuvr69uOOfXUUycaBwAAALBqtr1VuapOSnJlkscmOS/Jk6vqvA3DnpHkK919nyQvT/KS3S4UAAAAAJif6u6tB1T9YJLLu/sxw9c/lyTd/eKRMdcMx/xhVZ2c5ItJ/lKPHPzIkSNbfxB7zumnn167cRzZn3hkv9rkv7pkv7p2K/tE/icif/urS/arTf6rS/arayfZT/JwlDOTjN7LeXi4bdMx3X1rkiNJ7jptMQAAAADA3jBJ43CzbuTGrvIkYwAAAACAE8QkD0c5nGTfyOuzktx8jDGHh7cqn57kf48O2M1bYDixyH51yX61yX91yX61yX91yX51yX61yX91yX41THLF4XVJDlTV/qo6JcmTkhzcMOZgkqcNf398kt/t7b48EQAAAADYs7ZtHA6/s/CyJNck+XSSN3f3DVV1RVVdMhz2H5LctaoOJXlOkufOquDdUlUvq6r/WVWfqKrfrKo7L7qmeaqqi6rqxqo6VFV7Pq9ZqqrLq+pPqupjw5+LF13TPOz2ObDd8arqu6rqTcP3P1hVZx/vZ87aBHN6elX96ci588xF1DmNqnpNVX2pqq4/xvtVVb82nPMnqur8CY8r/z2ev+wnt2zZJ7PJX/arm/1wP/nv8fxlP7llyz6R/zSWLX/ZT27Zsk9mlH93r+RPkkcnOXn4+0uSvGTRNc1x7icl+eMk90pySpKPJzlv0XUtcD0uT/Izi67jRD4HJjlekp9I8srh709K8qZFr8MuzOnpSX590bVOOa+HJTk/yfXHeP/iJO/I4LtrH5Lkg/Jfjvxlv7rZzyJ/2S++3kVlL/8TJ3/Zr2728l/t/GW/utnPKv9JblVeSt39Oz24mjJJrs3guxtXxYVJDnX3Z7r7liRXJbl0wTUxX7t9DkxyvEuT/Mbw97cmeWRV7eXvxFjKv5Pufl82fAftBpcmeV0PXJvkzlV1xjaHlf8JQPYTW7rsk5nkL/sThL/9iS1d/rKf2NJln8h/CkuXv+wntnTZJ7PJf2Ubhxv8oww6rqvizCTrI68PD7etssuGl+m+pqq+Z9HFzMFunwOTHO+2McOm/ZEkdz2Oz5y1Sdfo7w7PnbdW1b5N3j/R7OTckP9y5C/7gVXMPpk+f9mvbvaT7iP/vU/2A6uYfSL/o1Yxf9kPrGL2yQ7yX+rGYVW9q6qu3+Tn0pExz0tya5L/tLhK526zrv9SP8xmm3PhFUnuneQBSb6Q5FcWWux87PY5MMnxTrTzbpJ6/1uSs7v7/knelb/4P2wnsp3kJP/lyF/2A6uYfTJ9TrJf3ewn3Uf+e5/sB1Yx+0T+R61i/rIfWMXskx3ktNSNw+5+VHffb5OftydJVT0tyd9M8g96eLP3ijicZLRTflaSmxdUy1xsdS509//q7j/v7m8neVUGlywvu90+ByY53m1jqurkJKdn60uoF23bOXX3l7v7m8OXr0rywDnVNks7OTfkvxz5y35gFbNPps9f9qub/aT7yH/vk/3AKmafyP+oVcxf9gOrmH2yg/yXunG4laq6KMnPJrmku7+26Hrm7LokB6pqf1WdksEXlx5ccE0Ls+F+/r+dZNOnDy2Z3T4HJjnewSRPG/7++CS/u8cb9tvOacO5c0kGT54/0R1M8tTh07YekuRId39hm33kvxz5y35gFbNPps9f9qubfSL/ZDnyl/3AKmafyP+oVcxf9gOrmH2yk/x7Dzz1ZRE/SQ5lcF/3x4Y/r1x0TXOe/8VJ/iiDpwg9b9H1LHgtXp/kk0k+MfwjOmPRNZ2I58Bmx0tyRQbN+SQ5Nclbhn97H0pyr0WvwS7M6cVJbsjgCVzvSXLuomueYE7/OYNb8r+Vwf9tekaSZyV51vD9SnLlcM6fTHKB/Jcjf9mvbvazyl/2q5u9/E+M/GW/utnLf7Xzl/3qZj+r/Gu4IwAAAADAbVb2VmUAAAAA4Ng0DgEAAACAMRqHAAAAAMAYjUMAAAAAYIzGIQAAAAAwRuMQAAAAABijcQgAAAAAjNE4BAAAAADG/H+83exDRVUXFgAAAABJRU5ErkJggg==\n", | |
"text/plain": [ | |
"<Figure size 1440x144 with 10 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Incorrect :()\n" | |
] | |
}, | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 216x216 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Real: 1\n" | |
] | |
}, | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAABQ4AAACZCAYAAABufKdoAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAFalJREFUeJzt3XvUbGV9H/DvDwiSKIWIjVLOaUF9XRRtlkEE06ReGi+EtmBbGkmaRBvtikbSZe0lRqOLkgteVqtJZWmrzTJqLWDS6GmCYkw0SU1AktYLaPE9Uck5Qeul9HTVO/HpHzOHDu9+LzPnzLzzzt6fz1qzeGfPs5/5Pfu7h8X6sfdMtdYCAAAAADDphGUXAAAAAADsPRqHAAAAAECHxiEAAAAA0KFxCAAAAAB0aBwCAAAAAB0ahwAAAABAh8YhAAAAANChcTihqvZV1Rerat+ya0mSqvorVfWFqjpz2bUMQVVdUFWfqar7L7uWJKmq766qP62qb1t2LQAAAMDw9LJxWFWXVNWHquprVfXpqnrhlLtek+StrbXDE3OdWlVvGDcUv1RV76qqh01Rwz+oqj+uqv9bVZ+rqv9cVQ+feP1vVNWvVdXhqvpKVa1X1VVVdb+jY1prdya5IcnPTr/6Yauqx1fVO6vqzqpqVfUzM+z+6iSvbK19aWK+M6vqhqr6P+PHdVX1HTvU8Miqevs4029W1Ru3GLftOdJa+8MktyWZ9vwFAAAAmJveNQ6r6oIk70zy7iSPTnJVkl+oqufusN+ZSZ6R5N9veOktSb4vyeVJvjdJJfmtqvrWbea6KMl1SX4tyXcmuSTJGUl+c2LY9yT5kyQ/lOS8JC9K8vwkr9kw3RuT/HBVPWi7+rnXA5J8LMm/TPLZaXcanzePTfIrE9tOSPIbSc5J8pQkT03yiCTvqKraZrpvS/KnSa5O8uEt3m+acyQZ5f/8qvqWadcCAAAAMA/VWlt2DXNVVW9LcnZr7a9PbHtVkstba+dss98LkvxEa+0RE9sekeSOJE9rrb1nvO3bM2pI/Xhr7U3bzPXS1toZE9v+TpIDSU5vrR3ZYr9/luTFk/uNt9+Z5JrW2uu3XTz3UVWfTvLG1trPTTH2NUnOa609dWLbU5PclOTc1tod422PzOgqwCe11t4/xbzvT3KwtfacDdunOkeq6pQkR5Jc1lp7907vBwAAADAvvbviMKMr+TY2WN6d5OwdvrvwCUk+uMlc30jy20c3tNbuHo/73m3m+oMkp1fVD1TVCVV1epIfSfKBrZqGY6cl+cIm229J8qRt9uP4bZX/p442DZOktXZ7ksPZPv9pTHWOtNa+mtFVi/IHAAAAdlUfG4dnpnuL6mcnXtvKOUn+bJO5vtBa+/NN5ttyrtbaB5M8Pcnrk3wtyd1J/nKSy7bap6r+apIXJPm3m7x8OMlDt6md47dV/pvd7rxt/tOY8RyRPwAAALDr+tg43M5292V/a5KvzmOuqjo3yesy+rGNxyb5mxldufjrVXXiJuPXkrwnyXWttdduMuVXx/WxOHPLfxozniPyBwAAAHbdScsuYAE+k+QhG7Y9ePzP7X4s4/NJHrjJXA+qqhM3XHX44CSf2GauFyf5aGvt3l9DrqofyugHM56U5L0T2x+V5Lcy+kGX520x3wPH9bE4W+X/5E3GPjgz/PDKFqY+R8Z1feY43w8AAABgJn284vADSZ62YdvFSe5srR3eZr//luSRm8z1LRldDZYkGX8X3UVJ/us2c90/yTc3bDvaeLz313ir6rFJfjfJDUme17b+pZq/luSPtnk/jt9W+Z8zviI0yb23lO/P9vlPY6pzZEz+AAAAwK7rY+Pw1UkurKqfr6pzq+pHk/xkkpfvsN+N4/3uvSW0tfaJjK4EfF1VPaGqHp3kbRl9F971R8dV1Zur6s0Tc70jycVV9U+r6mFVdUGSNyW5K6MfOklVPT6jH115Z5Jrkjy4qh5SVfe5WrKqTk3ymCS/OeuBGKKqekBVPXqc1clJHjJ+/vAddr0xyeM3bHtvRg3Ft1bVhVV1UZK3JLk5o4bv0ff87aq6ZuL5yRM1PCDJA8fPz5uYe8dzZDzXWkbfp/iuWY4DAAAAwPGqrS9yW11V9beS/EKSczO6pfQXW2v/Zod9TkiynuSlrbW3TWw/NaNm5N9LckqS30tyZWvt4MSY9ydJa+2JE9t+PMnzkzwsyZcyaja9uLV22/j1NyV55ma1tNYmr0r8R0l+qrV27lSLH7iqemKS923y0u9O5rPJfqdm1BC+uLX2BxPbz0zySxldtdoyauD9ZGvtcxNjPp3k/a21Z42fn53kU5u8zZ2ttbMn9tv2HBmP+VdJHtda23gVLQAAAMBC9bJxeKyq6keSvDDJ+dvcNryb9ZyQ5MNJfq61dv1O4zk+VfXSJI9prT192bUko6snkxxM8vTW2s3LrgcAAAAYlj7eqnw83prRrch/admFjJ2V5E2ahrvmVUn+uKruv+xCxs5J8jOahgAAAMAyuOIQAAAAAOhwxSEAAAAA0HHSbr3RkSNHXNq4Yk477bTaedTOZL96ZD9s8h8u2Q/XvLJP5L+KfPaHS/bDJv/hkv1wHUv2rjgEAAAAADo0DgEAAACADo3DHayvry+7hIXq+/qOxZCOyZDWOk+O27DJf7hkP2zyHy7ZD5fsh03+wyX7+9I4BAAAAAA6NA4BAAAAgI5d+1VldsepX3nsjHu8bSF1AAAAALDaXHEIAAAAAHRoHAIAAAAAHVM1Dqvq4qq6o6oOVtWLthjzA1X1saq6varc/woAAAAAK2zH7zisqhOTXJvkKUkOJ7m1qg601j42MWYtyU8n+Z7W2t1V9R2LKhgAAAAAWLxprji8MMnB1tonW2tfT3Jdkss2jPnHSa5trd2dJK21z823TAAAAABgN03zq8pnJTk08fxwkos2jHlEklTVB5KcmOSq1tq7t5pwfX19xjKXa5XqPX/f7PtMrm9tbW2O1Wz/XnvZqtQ5D0fXKvvZ9G098p9Nn9Yj+9n0aT2Lzj7p1/FK+rUen/3Z9Gk9sp9N39Yj/9n0aT2yn02f1nO82U/TOKxNtrVN5llL8sQk+5L8flU9qrX2vzebcDf+Q3Ve1tfXV6refGX2XXZzfatwLFcu8+Owm2vt0zEd0jkyL306XvKfTZ+Olexn16fjJf/Z9OlYyX42fTpWsp9dn46X/GfTp2Ml+/ua5lblw0n2Tzzfl+SuTca8s7X2jdbap5LckVEjEQAAAABYQdM0Dm9NslZV51TVyUmuSHJgw5h3JHlSklTVgzK6dfmT8ywUAAAAANg9OzYOW2v3JLkyyU1JPp7khtba7VV1dVVdOh52U5IvVtXHkrwvyb9orX1xUUUDAAAAAIs1zXccprV2Y5IbN2x72cTfLckLxw8AAAAAYMVNc6syAAAAADAwGocAAAAAQIfGIQAAAADQoXEIAAAAAHRoHAIAAAAAHRqHAAAAAECHxiEAAAAA0KFxCAAAAAB0aBwCAAAAAB0ahwAAAABAh8YhAAAAANChcQgAAAAAdGgcAgAAAAAdGocAAAAAQIfGIQAAAADQoXEIAAAAAHRoHAIAAAAAHRqHAAAAAECHxiEAAAAA0KFxCAAAAAB0aBwCAAAAAB0ahwAAAABAh8YhAAAAANAxVeOwqi6uqjuq6mBVvWibcZdXVauqC+ZXIgAAAACw23ZsHFbViUmuTfL9Sc5L8oNVdd4m405N8k+S3DLvIgEAAACA3TXNFYcXJjnYWvtka+3rSa5Lctkm4342ySuTfHWO9QEAAAAASzBN4/CsJIcmnh8eb7tXVX1Xkv2ttd+YY20AAAAAwJKcNMWY2mRbu/fFqhOSvDrJs6Z90/X19WmH7gmrVO/5+2bfZ3J9a2trc6xm+/fay1alznk4ulbZz6Zv65H/bPq0HtnPpk/rWXT2Sb+OV9Kv9fjsz6ZP65H9bPq2HvnPpk/rkf1s+rSe481+msbh4ST7J57vS3LXxPNTkzwqyfurKkkekuRAVV3aWvujzSbcjf9QnZf19fWVqjdfmX2X3VzfKhzLlcv8OOzmWvt0TId0jsxLn46X/GfTp2Ml+9n16XjJfzZ9Olayn02fjpXsZ9en4yX/2fTpWMn+vqa5VfnWJGtVdU5VnZzkiiQHjr7YWjvSWntQa+3s1trZSW5OsmXTEAAAAADY+3ZsHLbW7klyZZKbknw8yQ2ttdur6uqqunTRBQIAAAAAu2+aW5XTWrsxyY0btr1si7FPPP6ygFkdOnRo50ET9u/fv/MgAAAAYLCmuVUZAAAAABgYjUMAAAAAoEPjEAAAAADo0DgEAAAAADo0DgEAAACADo1DAAAAAKBD4xAAAAAA6NA4BAAAAAA6NA4BAAAAgA6NQwAAAACgQ+MQAAAAAOjQOAQAAAAAOjQOAQAAAIAOjUMAAAAAoEPjEAAAAADo0DgEAAAAADo0DgEAAACADo1DAAAAAKBD4xAAAAAA6NA4BAAAAAA6NA4BAAAAgA6NQwAAAACgQ+MQAAAAAOjQOAQAAAAAOk6aZlBVXZzkF5OcmOSNrbWXb3j9hUmek+SeJJ9P8mOttTvnXCsMyqFDh5ZdAgAAADBgOzYOq+rEJNcmeUqSw0luraoDrbWPTQz770kuaK19uaqel+SVSZ6xiIJhVWkEAgAAAKtkmluVL0xysLX2ydba15Ncl+SyyQGttfe11r48fnpzkn3zLRMAAAAA2E3T3Kp8VpLJS6UOJ7lom/HPTvKu7SZcX1+f4m33jlWq9/xjaNlOrm9tbW2O1Wz/XnvZIuo85ZRT5j7n8Ti6xqP/lP1s+rYe+c+mT+uR/Wz6tJ5FZ5/063gl/VqPz/5s+rQe2c+mb+uR/2z6tB7Zz6ZP6zne7KdpHNYm29qmA6t+OMkFSZ6w3YS78R+q87K+vr5S9eYrs++ym+tbhWO5qMz32q3Ka2tru3p+r0L201q5fy/sAX06XvKfTZ+Olexn16fjJf/Z9OlYyX42fTpWsp9dn46X/GfTp2Ml+/uapnF4OMn+ief7kty1cVBVPTnJS5I8obX2tfmUBwAAAAAswzTfcXhrkrWqOqeqTk5yRZIDkwOq6ruS/Lskl7bWPjf/MgEAAACA3bRj47C1dk+SK5PclOTjSW5ord1eVVdX1aXjYa9K8oAkb6+qD1XVgS2mAwAAAABWwDS3Kqe1dmOSGzdse9nE30+ec12MXfLsN8w0/vdfu6BCAAAAABiUaW5VBgAAAAAGRuMQAAAAAOjQOAQAAAAAOjQOAQAAAIAOjUMAAAAAoEPjEAAAAADo0DgEAAAAADo0DgEAAACADo1DAAAAAKBD4xAAAAAA6NA4BAAAAAA6NA4BAAAAgA6NQwAAAACg46RlFzA0Vzzu5bPt8MgzFlMIAAAAAGzDFYcAAAAAQIfGIQAAAADQoXEIAAAAAHT4jkM4RocOHVp2CQAAAAALo3F4nK4498Wz7XD6X1hMIXQca2PvlFNO0RQEAAAABs+tygAAAABAh8YhAAAAANDR+1uVn7H/BTONv/7QaxZUCRvNejvw/v37F1QJAAAAABu54hAAAAAA6JiqcVhVF1fVHVV1sKpetMnr96uq68ev31JVZ8+7UAAAAABg9+zYOKyqE5Ncm+T7k5yX5Aer6rwNw56d5O7W2sOTvDrJK+ZdKAAAAACwe6q1tv2Aqu9OclVr7Wnj5z+dJK21aybG3DQe84dVdVKSzyb5i21i8iNHjmz/Ruw5p512Ws1jHtmvHtkPm/yHS/bDNa/sE/mvIp/94ZL9sMl/uGQ/XMeS/TS3Kp+VZPJXLA6Pt206prV2T5IjSc6YtRgAAAAAYG+YpnG4WTdyY1d5mjEAAAAAwIo4aYoxh5Psn3i+L8ldW4w5PL5V+bQk/2tywDxvgWG1yH64ZD9s8h8u2Q+b/IdL9sMl+2GT/3DJfhimueLw1iRrVXVOVZ2c5IokBzaMOZDkmeO/L0/yO22nL08EAAAAAPasHRuH4+8svDLJTUk+nuSG1trtVXV1VV06HvYfkpxRVQeTvDDJixZV8G6rqldV1f+oqo9U1a9X1enLrmkequriqrqjqg5WVW/ympequqqq/qyqPjR+XLLsmuZpEfnvNGdV3a+qrh+/fktVnT2P912kKdb0rKr6/MR58pxl1DmLqvrlqvpcVd22xetVVb80XvNHqur8KeeV/x7PX/bT61v2yWLyl/1wsx/vJ/89nr/sp9e37BP5z6Jv+ct+en3LPllQ/q01j20eSZ6a5KTx369I8opl1zSHNZ2Y5E+SPDTJyUk+nOS8Zde1lx5Jrkryz5ddx6rkP82cSX4iyevHf1+R5PplH4s5rOlZSV677FpnXNfjk5yf5LYtXr8kybsy+u7axyW5Rf79yF/2w81+EfnLfvn1Lit7+a9O/rIfbvbyH3b+sh9u9ovKf5pblQettfaeNrrqMkluzug7HlfdhUkOttY+2Vr7epLrkly25JrYPYvIf5o5L0vyK+O/fzXJ91XVXv5OjF5+Tlprv5cN30G7wWVJ3txGbk5yelWducO08l8Bsp9a77JPFpK/7FeEz/7Uepe/7KfWu+wT+c+gd/nLfmq9yz5ZTP4ah7P5sYw6s6vurCSHJp4fHm/jvq4cX7r7y1X17csuZo4Wkf80c947ZtyMP5LkjON830Wa9jj9/fF58qtVtX+T11fNsZwf8u9H/rIfGWL2yez5y3642U+7j/z3PtmPDDH7RP5HDTF/2Y8MMfvkGPLXOExSVe+tqts2eVw2MeYlSe5J8h+XV+ncbNb1H9yP2eyQ++uSPCzJo5N8Jsm/Xmqx87WI/KeZc9XOu2nq/S9Jzm6tfWeS9+b//x+2VXYsOcm/H/nLfmSI2Sez5yT74WY/7T7y3/tkPzLE7BP5HzXE/GU/MsTsk2PISeMwSWvtya21R23yeGeSVNUzk/ztJP+wjW8KX3GHk0x2yvcluWtJtSzNdrm31v5na+3PW2vfTPKGjC5j7otF5D/NnPeOqaqTkpyW7S+hXrYd19Ra+2Jr7Wvjp29I8phdqm2RjuX8kH8/8pf9yBCzT2bPX/bDzX7afeS/98l+ZIjZJ/I/aoj5y35kiNknx5C/xuEOquriJD+V5NLW2peXXc+c3JpkrarOqaqTM/ri0gNLrmlP2XCP/99NsukvEq2oReQ/zZwHkjxz/PflSX5njzfid1zThvPk0ox+eX7VHUjyo+Nf23pckiOttc/ssI/8+5G/7EeGmH0ye/6yH272ifyTfuQv+5EhZp/I/6gh5i/7kSFmnxxL/m0P/OrLXn4kOZjR/d8fGj9ev+ya5rSuS5J8IqNfEXrJsuvZa48kb0ny0SQfGX+wzlx2TXs9/83mTHJ1Rk33JDklydvHn6kPJnnoso/DHNZ0TZLbM/oFrvclOXfZNU+xpv+U0e3338jo/zY9O8lzkzx3/HoluXa85o8muUD+/chf9sPNflH5y3642ct/NfKX/XCzl/+w85f9cLNfVP413hEAAAAA4F5uVQYAAAAAOjQOAQAAAIAOjUMAAAAAoEPjEAAAAADo0DgEAAAAADo0DgEAAACADo1DAAAAAKBD4xAAAAAA6Ph/LY36lK4V0M8AAAAASUVORK5CYII=\n", | |
"text/plain": [ | |
"<Figure size 1440x144 with 10 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Incorrect :()\n" | |
] | |
}, | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 216x216 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Summary\n", | |
"Total images: 64\n", | |
"Predicted for: 5\n", | |
"Accuracy when predicted: 0.0\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"(64, 0, 5)" | |
] | |
}, | |
"execution_count": 43, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"test_batch(images[:100], labels[:100])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"accelerator": "GPU", | |
"colab": { | |
"collapsed_sections": [], | |
"name": "Simple NeuralNet.ipynb", | |
"provenance": [], | |
"version": "0.3.2" | |
}, | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.3" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 1 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment