Last active
February 17, 2022 17:39
-
-
Save pmineiro/171edfa6963b7d14e6f3d10dc38af9a4 to your computer and use it in GitHub Desktop.
IGL with action dependent feedback, mnist demo
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "287318ac", | |
"metadata": {}, | |
"source": [ | |
"# Supervised" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 70, | |
"id": "7fb42e91", | |
"metadata": { | |
"code_folding": [ | |
0, | |
6, | |
35, | |
36, | |
58, | |
63, | |
70, | |
84, | |
114, | |
121 | |
], | |
"scrolled": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"n \tmean \tsince \tacc \tsince \n", | |
"1 \t2.30257 \t2.30257 \t0.15625 \t0.15625 \n", | |
"2 \t2.40082 \t2.49907 \t0.21875 \t0.28125 \n", | |
"3 \t2.21948 \t1.85680 \t0.29167 \t0.43750 \n", | |
"5 \t1.95383 \t1.55536 \t0.40000 \t0.56250 \n", | |
"9 \t1.56551 \t1.08011 \t0.50694 \t0.64062 \n", | |
"17 \t1.18324 \t0.75318 \t0.63419 \t0.77734 \n", | |
"33 \t0.87001 \t0.53720 \t0.73059 \t0.83301 \n", | |
"65 \t0.65744 \t0.43823 \t0.79615 \t0.86377 \n", | |
"129 \t0.49248 \t0.32494 \t0.84726 \t0.89917 \n", | |
"257 \t0.39149 \t0.28972 \t0.87840 \t0.90979 \n", | |
"513 \t0.31768 \t0.24357 \t0.90244 \t0.92657 \n", | |
"938 \t0.27401 \t0.22131 \t0.91613 \t0.93265 \n", | |
"testacc 0.9558735489845276 testloss 0.14295095205307007\n" | |
] | |
} | |
], | |
"source": [ | |
"def supervisedLearn():\n", | |
" import itertools\n", | |
" import numpy\n", | |
" import torch\n", | |
" import torchvision\n", | |
" \n", | |
" class EasyAcc:\n", | |
" def __init__(self):\n", | |
" self.n = 0\n", | |
" self.sum = 0\n", | |
" self.sumsq = 0\n", | |
"\n", | |
" def __iadd__(self, other):\n", | |
" self.n += 1\n", | |
" self.sum += other\n", | |
" self.sumsq += other*other\n", | |
" return self\n", | |
"\n", | |
" def __isub__(self, other):\n", | |
" self.n += 1\n", | |
" self.sum -= other\n", | |
" self.sumsq += other*other\n", | |
" return self\n", | |
"\n", | |
" def mean(self):\n", | |
" return self.sum / max(self.n, 1)\n", | |
"\n", | |
" def var(self):\n", | |
" from math import sqrt\n", | |
" return sqrt(self.sumsq / max(self.n, 1) - self.mean()**2)\n", | |
"\n", | |
" def semean(self):\n", | |
" from math import sqrt\n", | |
" return self.var() / sqrt(max(self.n, 1))\n", | |
"\n", | |
" class RFFSoftmax(torch.nn.Module):\n", | |
" def __init__(self, hilo, naction, numrff, sigma, seed):\n", | |
" from math import pi\n", | |
" import numpy as np\n", | |
"\n", | |
" super(RFFSoftmax, self).__init__()\n", | |
"\n", | |
" torch.manual_seed(seed)\n", | |
" nobs = hilo.shape[1]\n", | |
" high = hilo[1, :]\n", | |
" low = hilo[0, :]\n", | |
" \n", | |
" self.rff = torch.nn.Linear(nobs, numrff)\n", | |
" self.rff.weight.data = torch.matmul(torch.empty(numrff, nobs).cauchy_(sigma = sigma), \n", | |
" torch.diag(torch.tensor([ 1.0/v if v > 1e-6 else 0. for v in high - low ])).float())\n", | |
" self.rff.weight.requires_grad = False\n", | |
" self.rff.bias.data = 2 * pi * torch.rand(numrff)\n", | |
" self.rff.bias.requires_grad = False\n", | |
" self.sqrtrff = np.sqrt(numrff)\n", | |
" self.final = torch.nn.Linear(numrff, naction)\n", | |
" self.final.weight.data *= 0.01\n", | |
" self.final.bias.data *= 0.01\n", | |
"\n", | |
" def logits(self, x):\n", | |
" with torch.no_grad():\n", | |
" rff = self.rff(x).cos() / self.sqrtrff\n", | |
" return self.final(rff)\n", | |
"\n", | |
" transform = torchvision.transforms.Compose([\n", | |
" torchvision.transforms.ToTensor(),\n", | |
" torchvision.transforms.Normalize((0.1307,), (0.3081,))\n", | |
" ])\n", | |
" mnist_train = torchvision.datasets.MNIST('/tmp/mnist', train=True, download=True, transform=transform)\n", | |
" \n", | |
" quantile_loader = torch.utils.data.DataLoader(mnist_train, batch_size=10000, shuffle=True)\n", | |
" for bno, (images, labels) in enumerate(quantile_loader):\n", | |
" flat = images.reshape(images.shape[0], -1)\n", | |
" hilo = numpy.quantile(flat.numpy(), [ 0.01, 0.99 ], axis=0)\n", | |
" pi = RFFSoftmax(hilo, 10, 2000, 0.01, 45)\n", | |
" break\n", | |
" \n", | |
" train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=64, shuffle=True)\n", | |
" mnist_test = torchvision.datasets.MNIST('/tmp/mnist', train=False, download=True, transform=transform)\n", | |
" test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=1000, shuffle=True)\n", | |
" \n", | |
" opt = torch.optim.Adam(( p for p in pi.parameters() if p.requires_grad ), lr=0.1)\n", | |
" loss = torch.nn.CrossEntropyLoss()\n", | |
" acc, accsincelast, avloss, avlosssincelast = [ EasyAcc() for _ in range(4) ]\n", | |
" \n", | |
" print('{:<5s}\\t{:<8s}\\t{:<8s}\\t{:<8s}\\t{:<8s}'.format(\n", | |
" 'n', 'mean', 'since',\n", | |
" 'acc', 'since',\n", | |
" ),\n", | |
" flush=True)\n", | |
" \n", | |
" for bno, (images, labels) in enumerate(itertools.chain(*[ train_loader for _ in range(1) ])):\n", | |
" flat = images.reshape(images.shape[0], -1)\n", | |
" \n", | |
" opt.zero_grad()\n", | |
" ld = pi.logits(flat)\n", | |
" output = loss(ld, labels)\n", | |
" output.backward()\n", | |
" opt.step()\n", | |
" \n", | |
" with torch.no_grad():\n", | |
" pred = ld.argmax(dim=1)\n", | |
" acc += torch.mean((labels == pred).float())\n", | |
" accsincelast += torch.mean((labels == pred).float())\n", | |
" avloss += output\n", | |
" avlosssincelast += output\n", | |
"\n", | |
" if (bno & (bno - 1) == 0):\n", | |
" print('{:<5d}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}'.format(\n", | |
" avloss.n, avloss.mean(), avlosssincelast.mean(), \n", | |
" acc.mean(), accsincelast.mean(), \n", | |
" ),\n", | |
" flush=True)\n", | |
" accsincelast, avlosssincelast = EasyAcc(), EasyAcc()\n", | |
" \n", | |
" print('{:<5d}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}'.format(\n", | |
" avloss.n, avloss.mean(), avlosssincelast.mean(), \n", | |
" acc.mean(), accsincelast.mean(), \n", | |
" ),\n", | |
" flush=True)\n", | |
" accsincelast, avlosssincelast = EasyAcc(), EasyAcc()\n", | |
" testacc, testloss = EasyAcc(), EasyAcc()\n", | |
" with torch.no_grad():\n", | |
" for ti, tl in train_loader:\n", | |
" flat = ti.reshape(ti.shape[0], -1)\n", | |
" ld = pi.logits(flat)\n", | |
" output = loss(ld, tl)\n", | |
" testloss += output\n", | |
" testpred = ld.argmax(dim=1)\n", | |
" testacc += torch.mean((tl == testpred).float())\n", | |
"\n", | |
" print(f'testacc {testacc.mean()} testloss {testloss.mean()}')\n", | |
" \n", | |
"supervisedLearn()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "55cda05e", | |
"metadata": {}, | |
"source": [ | |
"# CB" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 80, | |
"id": "de5fa0ff", | |
"metadata": { | |
"code_folding": [ | |
0, | |
6, | |
7, | |
10, | |
20, | |
49, | |
50, | |
73, | |
78 | |
] | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"n \tloss \tsince \tacc \tsince \treward \tsince \n", | |
"1 \t0.69313 \t0.69313 \t0.15625 \t0.15625 \t0.06250 \t0.06250 \n", | |
"2 \t0.55508 \t0.41703 \t0.10156 \t0.04688 \t0.05469 \t0.04688 \n", | |
"3 \t0.46854 \t0.29545 \t0.09896 \t0.09375 \t0.05729 \t0.06250 \n", | |
"5 \t0.46672 \t0.46399 \t0.11563 \t0.14062 \t0.10000 \t0.16406 \n", | |
"9 \t0.42025 \t0.36216 \t0.18229 \t0.26562 \t0.16493 \t0.24609 \n", | |
"17 \t0.37235 \t0.31846 \t0.22151 \t0.26562 \t0.20956 \t0.25977 \n", | |
"33 \t0.34175 \t0.30924 \t0.27415 \t0.33008 \t0.26089 \t0.31543 \n", | |
"65 \t0.34546 \t0.34929 \t0.38726 \t0.50391 \t0.36250 \t0.46729 \n", | |
"129 \t0.31504 \t0.28415 \t0.61216 \t0.84058 \t0.56468 \t0.77002 \n", | |
"257 \t0.26828 \t0.22116 \t0.75298 \t0.89490 \t0.69127 \t0.81885 \n", | |
"513 \t0.22252 \t0.17659 \t0.83458 \t0.91650 \t0.76459 \t0.83820 \n", | |
"938 \t0.19493 \t0.16162 \t0.87785 \t0.93007 \t0.80360 \t0.85070 \n", | |
"testacc 0.9445962309837341\n" | |
] | |
} | |
], | |
"source": [ | |
"def cbLearn():\n", | |
" import itertools\n", | |
" import numpy\n", | |
" import torch\n", | |
" import torchvision\n", | |
" \n", | |
" class FastCB:\n", | |
" def __init__(self, gamma):\n", | |
" self.gamma = gamma\n", | |
"\n", | |
" def sample(self, fhat):\n", | |
" N, K = fhat.shape\n", | |
" rando = torch.randint(high=K, size=(N, 1), device=fhat.device)\n", | |
" fhatstar, ahatstar = torch.max(fhat, dim=1, keepdim=True)\n", | |
" fhatrando = torch.gather(input=fhat, dim=1, index=rando)\n", | |
" probs = K / (K + self.gamma * (1 - fhatrando / fhatstar))\n", | |
" unif = torch.rand(size=(N, 1), device=fhat.device)\n", | |
" shouldexplore = (unif <= probs).long()\n", | |
" return (ahatstar + shouldexplore * (rando - ahatstar)).squeeze(1)\n", | |
"\n", | |
" class EasyAcc:\n", | |
" def __init__(self):\n", | |
" self.n = 0\n", | |
" self.sum = 0\n", | |
" self.sumsq = 0\n", | |
"\n", | |
" def __iadd__(self, other):\n", | |
" self.n += 1\n", | |
" self.sum += other\n", | |
" self.sumsq += other*other\n", | |
" return self\n", | |
"\n", | |
" def __isub__(self, other):\n", | |
" self.n += 1\n", | |
" self.sum -= other\n", | |
" self.sumsq += other*other\n", | |
" return self\n", | |
"\n", | |
" def mean(self):\n", | |
" return self.sum / max(self.n, 1)\n", | |
"\n", | |
" def var(self):\n", | |
" from math import sqrt\n", | |
" return sqrt(self.sumsq / max(self.n, 1) - self.mean()**2)\n", | |
"\n", | |
" def semean(self):\n", | |
" from math import sqrt\n", | |
" return self.var() / sqrt(max(self.n, 1))\n", | |
"\n", | |
" class RFFSoftmax(torch.nn.Module):\n", | |
" def __init__(self, hilo, naction, numrff, sigma, seed):\n", | |
" from math import pi\n", | |
" import numpy as np\n", | |
"\n", | |
" super(RFFSoftmax, self).__init__()\n", | |
"\n", | |
" torch.manual_seed(seed)\n", | |
" nobs = hilo.shape[1]\n", | |
" high = hilo[1, :]\n", | |
" low = hilo[0, :]\n", | |
" \n", | |
" self.rff = torch.nn.Linear(nobs, numrff)\n", | |
" self.rff.weight.data = torch.matmul(torch.empty(numrff, nobs).cauchy_(sigma = sigma), \n", | |
" torch.diag(torch.tensor([ 1.0/v if v > 1e-6 else 0. for v in high - low ])).float())\n", | |
" self.rff.weight.requires_grad = False\n", | |
" self.rff.bias.data = 2 * pi * torch.rand(numrff)\n", | |
" self.rff.bias.requires_grad = False\n", | |
" self.sqrtrff = np.sqrt(numrff)\n", | |
" self.final = torch.nn.Linear(numrff, naction)\n", | |
" self.final.weight.data *= 0.01\n", | |
" self.final.bias.data *= 0.01\n", | |
" self.sigmoid = torch.nn.Sigmoid()\n", | |
"\n", | |
" def forward(self, x):\n", | |
" with torch.no_grad():\n", | |
" rff = self.rff(x).cos() / self.sqrtrff\n", | |
" return self.final(rff)\n", | |
" \n", | |
" def density(self, logits):\n", | |
" return self.sigmoid(logits)\n", | |
"\n", | |
" transform = torchvision.transforms.Compose([\n", | |
" torchvision.transforms.ToTensor(),\n", | |
" torchvision.transforms.Normalize((0.1307,), (0.3081,))\n", | |
" ])\n", | |
" mnist_train = torchvision.datasets.MNIST('/tmp/mnist', train=True, download=True, transform=transform)\n", | |
" \n", | |
" quantile_loader = torch.utils.data.DataLoader(mnist_train, batch_size=10000, shuffle=True)\n", | |
" for bno, (images, labels) in enumerate(quantile_loader):\n", | |
" flat = images.reshape(images.shape[0], -1)\n", | |
" hilo = numpy.quantile(flat.numpy(), [ 0.01, 0.99 ], axis=0)\n", | |
" pi = RFFSoftmax(hilo, 10, 2000, 0.01, 45)\n", | |
" break\n", | |
" \n", | |
" train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=64, shuffle=True)\n", | |
" mnist_test = torchvision.datasets.MNIST('/tmp/mnist', train=False, download=True, transform=transform)\n", | |
" test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=1000, shuffle=True)\n", | |
" sampler = FastCB(gamma=100)\n", | |
" \n", | |
" opt = torch.optim.Adam(( p for p in pi.parameters() if p.requires_grad ), lr=1e-1)\n", | |
" log_loss = torch.nn.BCEWithLogitsLoss()\n", | |
" acc, accsincelast, avloss, avlosssincelast, avreward, avrewardsincelast = [ EasyAcc() for _ in range(6) ]\n", | |
" \n", | |
" print('{:<5s}\\t{:<8s}\\t{:<8s}\\t{:<8s}\\t{:<8s}\\t{:<8s}\\t{:<8s}'.format(\n", | |
" 'n', 'loss', 'since', \n", | |
" 'acc', 'since',\n", | |
" 'reward', 'since',\n", | |
" ),\n", | |
" flush=True)\n", | |
" \n", | |
" for bno, (images, labels) in enumerate(itertools.chain(*[ train_loader for _ in range(1) ])):\n", | |
" flat = images.reshape(images.shape[0], -1)\n", | |
" \n", | |
" opt.zero_grad()\n", | |
" logit = pi(flat)\n", | |
" with torch.no_grad():\n", | |
" fhat = pi.density(logit)\n", | |
" sample = sampler.sample(fhat)\n", | |
" reward = (sample == labels).unsqueeze(1).float()\n", | |
" \n", | |
" samplelogit = torch.gather(input=logit, index=sample.unsqueeze(1), dim=1)\n", | |
" loss = log_loss(samplelogit, reward)\n", | |
" loss.backward()\n", | |
" opt.step()\n", | |
" \n", | |
" with torch.no_grad():\n", | |
" pred = logit.argmax(dim=1)\n", | |
" acc += torch.mean((labels == pred).float())\n", | |
" accsincelast += torch.mean((labels == pred).float())\n", | |
" avloss += loss\n", | |
" avlosssincelast += loss\n", | |
" avreward += torch.mean(reward)\n", | |
" avrewardsincelast += torch.mean(reward)\n", | |
" \n", | |
" if (bno & (bno - 1) == 0):\n", | |
" print('{:<5d}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}'.format(\n", | |
" avloss.n, avloss.mean(), avlosssincelast.mean(), \n", | |
" acc.mean(), accsincelast.mean(), \n", | |
" avreward.mean(), avrewardsincelast.mean(),\n", | |
" ),\n", | |
" flush=True)\n", | |
" accsincelast, avlosssincelast, avrewardsincelast = EasyAcc(), EasyAcc(), EasyAcc()\n", | |
" \n", | |
" print('{:<5d}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}'.format(\n", | |
" avloss.n, avloss.mean(), avlosssincelast.mean(), \n", | |
" acc.mean(), accsincelast.mean(), \n", | |
" avreward.mean(), avrewardsincelast.mean(),\n", | |
" ),\n", | |
" flush=True)\n", | |
" accsincelast, avlosssincelast, avrewardsincelast = EasyAcc(), EasyAcc(), EasyAcc()\n", | |
" testacc = EasyAcc()\n", | |
" with torch.no_grad():\n", | |
" for ti, tl in train_loader:\n", | |
" flat = ti.reshape(ti.shape[0], -1)\n", | |
" logit = pi(flat)\n", | |
" testpred = logit.argmax(dim=1)\n", | |
" testacc += torch.mean((tl == testpred).float())\n", | |
"\n", | |
" print(f'testacc {testacc.mean()}')\n", | |
"\n", | |
"cbLearn()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "09186826", | |
"metadata": {}, | |
"source": [ | |
"# IGL ($y_a \\perp x, a|r_a$)\n", | |
"$y_a$ is a (randomly selected) \"zero\" image or a (randomly selected) \"one\" image depending only upon $r_a$." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 113, | |
"id": "67459e03", | |
"metadata": { | |
"code_folding": [ | |
0, | |
6, | |
22, | |
51, | |
83, | |
90, | |
100, | |
113, | |
139, | |
188, | |
197 | |
] | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"n \tloss \tsince \tacc \tsince \treward \tsince \tfake \tsince \n", | |
"1 \t1.38627 \t1.38627 \t0.32812 \t0.32812 \t0.09375 \t0.09375 \t0.49997 \t0.49997 \n", | |
"2 \t1.37295 \t1.35963 \t0.18750 \t0.04688 \t0.07812 \t0.06250 \t0.48236 \t0.46476 \n", | |
"3 \t1.35572 \t1.32125 \t0.17188 \t0.14062 \t0.08333 \t0.09375 \t0.46612 \t0.43364 \n", | |
"5 \t1.33542 \t1.30497 \t0.14687 \t0.10938 \t0.09375 \t0.10938 \t0.43557 \t0.38975 \n", | |
"9 \t1.28511 \t1.22223 \t0.14583 \t0.14453 \t0.11111 \t0.13281 \t0.38302 \t0.31733 \n", | |
"17 \t1.19374 \t1.09095 \t0.15533 \t0.16602 \t0.11949 \t0.12891 \t0.31242 \t0.23300 \n", | |
"33 \t1.19390 \t1.19407 \t0.16714 \t0.17969 \t0.12689 \t0.13477 \t0.26854 \t0.22192 \n", | |
"65 \t1.25434 \t1.31667 \t0.25024 \t0.33594 \t0.18438 \t0.24365 \t0.43038 \t0.59728 \n", | |
"129 \t1.12546 \t0.99457 \t0.45094 \t0.65479 \t0.33285 \t0.48364 \t0.54036 \t0.65205 \n", | |
"257 \t0.88065 \t0.63392 \t0.65394 \t0.85852 \t0.51283 \t0.69421 \t0.65974 \t0.78006 \n", | |
"513 \t0.68728 \t0.49315 \t0.77656 \t0.89966 \t0.63411 \t0.75586 \t0.74226 \t0.82510 \n", | |
"938 \t0.56352 \t0.41413 \t0.84097 \t0.91871 \t0.70281 \t0.78574 \t0.78930 \t0.84607 \n", | |
"testacc 0.928521454334259\n" | |
] | |
} | |
], | |
"source": [ | |
"def iglLearn():\n", | |
" import itertools\n", | |
" import numpy\n", | |
" import torch\n", | |
" import torchvision\n", | |
" \n", | |
" class SquareCB(object):\n", | |
" def __init__(self, gamma):\n", | |
" super(SquareCB, self).__init__()\n", | |
"\n", | |
" self.gamma = gamma\n", | |
"\n", | |
" def sample(self, fhat):\n", | |
" N, K = fhat.shape\n", | |
" rando = torch.randint(high=K, size=(N, 1), device=fhat.device)\n", | |
" fhatstar, ahatstar = torch.max(fhat, dim=1, keepdim=True)\n", | |
" fhatrando = torch.gather(input=fhat, dim=1, index=rando)\n", | |
" probs = K / (K + self.gamma * (fhatstar - fhatrando))\n", | |
" unif = torch.rand(size=(N, 1), device=fhat.device)\n", | |
" shouldexplore = (unif <= probs).long()\n", | |
" return (ahatstar + shouldexplore * (rando - ahatstar)).squeeze(1)\n", | |
" \n", | |
" class EasyAcc:\n", | |
" def __init__(self):\n", | |
" self.n = 0\n", | |
" self.sum = 0\n", | |
" self.sumsq = 0\n", | |
"\n", | |
" def __iadd__(self, other):\n", | |
" self.n += 1\n", | |
" self.sum += other\n", | |
" self.sumsq += other*other\n", | |
" return self\n", | |
"\n", | |
" def __isub__(self, other):\n", | |
" self.n += 1\n", | |
" self.sum -= other\n", | |
" self.sumsq += other*other\n", | |
" return self\n", | |
"\n", | |
" def mean(self):\n", | |
" return self.sum / max(self.n, 1)\n", | |
"\n", | |
" def var(self):\n", | |
" from math import sqrt\n", | |
" return sqrt(self.sumsq / max(self.n, 1) - self.mean()**2)\n", | |
"\n", | |
" def semean(self):\n", | |
" from math import sqrt\n", | |
" return self.var() / sqrt(max(self.n, 1))\n", | |
"\n", | |
" class RFFSoftmax(torch.nn.Module):\n", | |
" def __init__(self, hilo, naction, numrff, sigma, seed):\n", | |
" from math import pi\n", | |
" import numpy as np\n", | |
"\n", | |
" super(RFFSoftmax, self).__init__()\n", | |
"\n", | |
" torch.manual_seed(seed)\n", | |
" nobs = hilo.shape[1]\n", | |
" high = hilo[1, :]\n", | |
" low = hilo[0, :]\n", | |
" \n", | |
" self.rff = torch.nn.Linear(nobs, numrff)\n", | |
" self.rff.weight.data = torch.matmul(torch.empty(numrff, nobs).cauchy_(sigma = sigma), \n", | |
" torch.diag(torch.tensor([ 1.0/v if v > 1e-6 else 0. for v in high - low ])).float())\n", | |
" self.rff.weight.requires_grad = False\n", | |
" self.rff.bias.data = 2 * pi * torch.rand(numrff)\n", | |
" self.rff.bias.requires_grad = False\n", | |
" self.sqrtrff = np.sqrt(numrff)\n", | |
" self.final = torch.nn.Linear(numrff, naction)\n", | |
" self.final.weight.data *= 0.01\n", | |
" self.final.bias.data *= 0.01\n", | |
" self.sigmoid = torch.nn.Sigmoid()\n", | |
"\n", | |
" def forward(self, x):\n", | |
" with torch.no_grad():\n", | |
" rff = self.rff(x).cos() / self.sqrtrff\n", | |
" return self.final(rff)\n", | |
" \n", | |
" def density(self, logits):\n", | |
" return self.sigmoid(logits)\n", | |
"\n", | |
" transform = torchvision.transforms.Compose([\n", | |
" torchvision.transforms.ToTensor(),\n", | |
" torchvision.transforms.Normalize((0.1307,), (0.3081,))\n", | |
" ])\n", | |
" mnist_train = torchvision.datasets.MNIST('/tmp/mnist', train=True, download=True, transform=transform)\n", | |
" \n", | |
" quantile_loader = torch.utils.data.DataLoader(mnist_train, batch_size=10000, shuffle=True)\n", | |
" for bno, (images, labels) in enumerate(quantile_loader):\n", | |
" flat = images.reshape(images.shape[0], -1)\n", | |
" hilo = numpy.quantile(flat.numpy(), [ 0.01, 0.99 ], axis=0)\n", | |
" pi = RFFSoftmax(hilo, 10, 2000, 0.01, 45)\n", | |
" decoder = RFFSoftmax(hilo, 1, 2000, 0.01, 2112)\n", | |
" break\n", | |
" \n", | |
" zero_one_loader = torch.utils.data.DataLoader(mnist_train, batch_size=1, shuffle=True)\n", | |
" zeros = []\n", | |
" ones = []\n", | |
" for bno, (images, labels) in enumerate(zero_one_loader):\n", | |
" flat = images.reshape(images.shape[0], -1)\n", | |
" if labels[0] == 0:\n", | |
" zeros.append(flat)\n", | |
" elif labels[0] == 1:\n", | |
" ones.append(flat)\n", | |
" \n", | |
" if len(zeros) > 100 and len(ones) > 100:\n", | |
" break \n", | |
" zeros = torch.cat(zeros, dim=0)\n", | |
" ones = torch.cat(ones, dim=0)\n", | |
" \n", | |
" # pre-train to get policy \"better than random\"\n", | |
" if True:\n", | |
" preopt = torch.optim.Adam(( p for p in pi.parameters() if p.requires_grad ), lr=1e-3) # 0.1\n", | |
" preloss = torch.nn.CrossEntropyLoss()\n", | |
" pretrain_loader = torch.utils.data.DataLoader(mnist_train, batch_size=64, shuffle=True)\n", | |
" for bno, (images, labels) in enumerate(itertools.chain(*[ pretrain_loader for _ in range(1) ])):\n", | |
" flat = images.reshape(images.shape[0], -1)\n", | |
"\n", | |
" preopt.zero_grad()\n", | |
" ld = pi.forward(flat)\n", | |
" output = preloss(ld, labels)\n", | |
" output.backward()\n", | |
" preopt.step()\n", | |
"\n", | |
" if bno > 0:\n", | |
" break\n", | |
" \n", | |
" train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=64, shuffle=True)\n", | |
" mnist_test = torchvision.datasets.MNIST('/tmp/mnist', train=False, download=True, transform=transform)\n", | |
" test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=1000, shuffle=True)\n", | |
" \n", | |
" opt = torch.optim.Adam(( p for p in itertools.chain(pi.parameters(), decoder.parameters()) if p.requires_grad ), lr=1e-2)\n", | |
" log_loss = torch.nn.BCEWithLogitsLoss(reduce='none')\n", | |
" sampler = SquareCB(gamma=100)\n", | |
" acc, accsincelast, avloss, avlosssincelast = [ EasyAcc() for _ in range(4) ]\n", | |
" avreward, avrewardsincelast, avfake, avfakesincelast = [ EasyAcc() for _ in range(4) ]\n", | |
" \n", | |
" print('{:<5s}\\t{:<8s}\\t{:<8s}\\t{:<8s}\\t{:<8s}\\t{:<8s}\\t{:<8s}\\t{:<8s}\\t{:<8s}'.format(\n", | |
" 'n', 'loss', 'since', \n", | |
" 'acc', 'since',\n", | |
" 'reward', 'since',\n", | |
" 'fake', 'since',\n", | |
" ),\n", | |
" flush=True)\n", | |
" \n", | |
" for bno, (images, labels) in enumerate(itertools.chain(*[ train_loader for _ in range(1) ])):\n", | |
" flat = images.reshape(images.shape[0], -1)\n", | |
" \n", | |
" opt.zero_grad()\n", | |
" logit = pi(flat)\n", | |
" with torch.no_grad():\n", | |
" fhat = pi.density(logit)\n", | |
" sample = sampler.sample(fhat)\n", | |
" reward = (sample == labels).unsqueeze(1).float()\n", | |
" pred = logit.argmax(dim=1)\n", | |
" ispred = (sample == pred).unsqueeze(1).float()\n", | |
" antipred = logit.argmin(dim=1)\n", | |
" isantipred = (sample == antipred).unsqueeze(1).float()\n", | |
" zerossample = torch.randint(low=0, high=zeros.shape[0], size=(fhat.shape[0], 1))\n", | |
" zerofeedback = torch.gather(input=zeros, index=zerossample.expand(-1, zeros.shape[1]), dim=0)\n", | |
" onessample = torch.randint(low=0, high=ones.shape[0], size=(fhat.shape[0], 1))\n", | |
" onefeedback = torch.gather(input=ones, index=onessample.expand(-1, ones.shape[1]), dim=0)\n", | |
" feedback = zerofeedback + reward * (onefeedback - zerofeedback) \n", | |
" \n", | |
" samplelogit = torch.gather(input=logit, index=sample.unsqueeze(1), dim=1)\n", | |
" fakelogit = decoder(feedback)\n", | |
" fakereward = decoder.density(fakelogit)\n", | |
" predloss = torch.mean(log_loss(fakelogit, ispred) + log_loss(samplelogit, fakereward.detach()))\n", | |
" antipredloss = torch.mean(log_loss(1 - fakelogit, isantipred) + log_loss(1 - samplelogit, fakereward.detach()))\n", | |
" loss = torch.min(predloss, antipredloss)\n", | |
"\n", | |
" loss.backward()\n", | |
" opt.step()\n", | |
" \n", | |
" with torch.no_grad():\n", | |
" pred = logit.argmax(dim=1)\n", | |
" acc += torch.mean((labels == pred).float())\n", | |
" accsincelast += torch.mean((labels == pred).float())\n", | |
" avloss += loss\n", | |
" avlosssincelast += loss\n", | |
" avreward += torch.mean(reward)\n", | |
" avrewardsincelast += torch.mean(reward)\n", | |
" avfake += torch.mean(fakereward)\n", | |
" avfakesincelast += torch.mean(fakereward)\n", | |
" \n", | |
" if (bno & (bno - 1) == 0):\n", | |
" print('{:<5d}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}'.format(\n", | |
" avloss.n, avloss.mean(), avlosssincelast.mean(), \n", | |
" acc.mean(), accsincelast.mean(), \n", | |
" avreward.mean(), avrewardsincelast.mean(),\n", | |
" avfake.mean(), avfakesincelast.mean(),\n", | |
" ),\n", | |
" flush=True)\n", | |
" accsincelast, avlosssincelast, avrewardsincelast, avfakesincelast = [ EasyAcc() for _ in range(4) ]\n", | |
" \n", | |
" print('{:<5d}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}'.format(\n", | |
" avloss.n, avloss.mean(), avlosssincelast.mean(), \n", | |
" acc.mean(), accsincelast.mean(), \n", | |
" avreward.mean(), avrewardsincelast.mean(),\n", | |
" avfake.mean(), avfakesincelast.mean(),\n", | |
" ),\n", | |
" flush=True)\n", | |
" accsincelast, avlosssincelast, avrewardsincelast = EasyAcc(), EasyAcc(), EasyAcc()\n", | |
" testacc = EasyAcc()\n", | |
" with torch.no_grad():\n", | |
" for ti, tl in train_loader:\n", | |
" flat = ti.reshape(ti.shape[0], -1)\n", | |
" logit = pi(flat)\n", | |
" testpred = logit.argmax(dim=1)\n", | |
" testacc += torch.mean((tl == testpred).float())\n", | |
"\n", | |
" print(f'testacc {testacc.mean()}')\n", | |
"\n", | |
"iglLearn()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "90994504", | |
"metadata": {}, | |
"source": [ | |
"# IGL ($y_a \\perp x|r_a$)\n", | |
"$y_a$ is an image of the action taken if $r_a = 1$, e.g., if $a=3$, a \"three\" image; otherwise if $r_a = 0$, an image of $(9-a)$, e.g., if $a=3$, a \"six\" image." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 288, | |
"id": "4d4e6631", | |
"metadata": { | |
"code_folding": [ | |
6, | |
16, | |
26, | |
31, | |
55, | |
84, | |
85, | |
108, | |
113, | |
116, | |
117, | |
140, | |
145, | |
148, | |
155, | |
164, | |
172, | |
199, | |
213, | |
233, | |
271, | |
305, | |
314 | |
], | |
"scrolled": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"n \tloss \tsince \tacc \tsince \treward \tsince \tfake \tsince \n", | |
"1 \t1.38891 \t1.38891 \t0.28125 \t0.28125 \t0.23438 \t0.23438 \t0.49940 \t0.49940 \n", | |
"2 \t1.38223 \t1.37554 \t0.29688 \t0.31250 \t0.19531 \t0.15625 \t0.49962 \t0.49983 \n", | |
"3 \t1.38101 \t1.37858 \t0.31771 \t0.35938 \t0.18750 \t0.17188 \t0.49995 \t0.50062 \n", | |
"5 \t1.38767 \t1.39765 \t0.29063 \t0.25000 \t0.14687 \t0.08594 \t0.49908 \t0.49777 \n", | |
"9 \t1.38487 \t1.38138 \t0.19965 \t0.08594 \t0.13368 \t0.11719 \t0.50189 \t0.50540 \n", | |
"17 \t1.35851 \t1.32885 \t0.18199 \t0.16211 \t0.13787 \t0.14258 \t0.50388 \t0.50613 \n", | |
"33 \t1.34694 \t1.33465 \t0.20028 \t0.21973 \t0.14725 \t0.15723 \t0.50259 \t0.50121 \n", | |
"65 \t1.33522 \t1.32314 \t0.27428 \t0.35059 \t0.19447 \t0.24316 \t0.50141 \t0.50020 \n", | |
"129 \t1.26931 \t1.20236 \t0.35913 \t0.44531 \t0.26211 \t0.33081 \t0.50427 \t0.50716 \n", | |
"257 \t1.14736 \t1.02446 \t0.52420 \t0.69055 \t0.40096 \t0.54089 \t0.53432 \t0.56461 \n", | |
"513 \t0.99472 \t0.84148 \t0.65068 \t0.77765 \t0.52537 \t0.65027 \t0.57622 \t0.61829 \n", | |
"938 \t0.86970 \t0.71879 \t0.73391 \t0.83438 \t0.61042 \t0.71309 \t0.61330 \t0.65806 \n", | |
"testacc 0.8389525413513184\n" | |
] | |
} | |
], | |
"source": [ | |
"def iglADepLearn():\n", | |
" import itertools\n", | |
" import numpy\n", | |
" import torch\n", | |
" import torchvision\n", | |
" \n", | |
" class WeightedReservoir(object):\n", | |
" def __init__(self, n, seed):\n", | |
" import random\n", | |
" \n", | |
" super().__init__()\n", | |
" self.n = n\n", | |
" self.items = []\n", | |
" self.wsum = 0\n", | |
" self.gen = random.Random(seed) \n", | |
" \n", | |
" def insert(self, item, weight):\n", | |
" if weight > 0:\n", | |
" self.wsum += weight\n", | |
" if self.wsum * self.gen.random() < weight:\n", | |
" if len(self.items) < self.n:\n", | |
" self.items.append(item)\n", | |
" else:\n", | |
" index = self.gen.randrange(0, self.n) \n", | |
" self.items[index] = item\n", | |
" \n", | |
" def sample(self):\n", | |
" assert len(self.items) > 0\n", | |
" index = self.gen.randrange(0, len(self.items))\n", | |
" return self.items[index]\n", | |
" \n", | |
" class SquareCB(object):\n", | |
" def __init__(self, gamma):\n", | |
" super().__init__()\n", | |
"\n", | |
" self.gamma = gamma\n", | |
"\n", | |
" def sample(self, fhat, *, keepdim=False):\n", | |
" N, K = fhat.shape\n", | |
" fhatstar, ahatstar = torch.max(fhat, dim=1, keepdim=True)\n", | |
" probs = 1 / (K + self.gamma * (fhatstar - fhat))\n", | |
" psum = torch.sum(probs, dim=1, keepdim=True)\n", | |
" phatstar = psum + torch.gather(input=probs, dim=1, index=ahatstar)\n", | |
"\n", | |
" rando = torch.randint(high=K, size=(N, 1), device=fhat.device)\n", | |
" prando = torch.gather(input=probs, dim=1, index=rando)\n", | |
" unif = torch.rand(size=(N, 1), device=fhat.device)\n", | |
" shouldexplore = (unif <= K * prando).long()\n", | |
" actions = ahatstar + shouldexplore * (rando - ahatstar)\n", | |
" pactions = phatstar + shouldexplore * (prando - phatstar)\n", | |
" if not keepdim:\n", | |
" actions = actions.squeeze(1)\n", | |
" pactions = pactions.squeeze(1)\n", | |
" return actions, pactions\n", | |
" \n", | |
" class EasyAcc:\n", | |
" def __init__(self):\n", | |
" self.n = 0\n", | |
" self.sum = 0\n", | |
" self.sumsq = 0\n", | |
"\n", | |
" def __iadd__(self, other):\n", | |
" self.n += 1\n", | |
" self.sum += other\n", | |
" self.sumsq += other*other\n", | |
" return self\n", | |
"\n", | |
" def __isub__(self, other):\n", | |
" self.n += 1\n", | |
" self.sum -= other\n", | |
" self.sumsq += other*other\n", | |
" return self\n", | |
"\n", | |
" def mean(self):\n", | |
" return self.sum / max(self.n, 1)\n", | |
"\n", | |
" def var(self):\n", | |
" from math import sqrt\n", | |
" return sqrt(self.sumsq / max(self.n, 1) - self.mean()**2)\n", | |
"\n", | |
" def semean(self):\n", | |
" from math import sqrt\n", | |
" return self.var() / sqrt(max(self.n, 1))\n", | |
"\n", | |
" class RFFBilinearSoftmax(torch.nn.Module):\n", | |
" def __init__(self, hilo, naction, numrff, sigma, seed):\n", | |
" from math import pi\n", | |
" import numpy as np\n", | |
"\n", | |
" super().__init__()\n", | |
"\n", | |
" torch.manual_seed(seed)\n", | |
" nobs = hilo.shape[1]\n", | |
" high = hilo[1, :]\n", | |
" low = hilo[0, :]\n", | |
" \n", | |
" self.rff = torch.nn.Linear(nobs, numrff)\n", | |
" self.rff.weight.data = torch.matmul(torch.empty(numrff, nobs).cauchy_(sigma = sigma), \n", | |
" torch.diag(torch.tensor([ 1.0/v if v > 1e-6 else 0. for v in high - low ])).float())\n", | |
" self.rff.weight.requires_grad = False\n", | |
" self.rff.bias.data = 2 * pi * torch.rand(numrff)\n", | |
" self.rff.bias.requires_grad = False\n", | |
" self.sqrtrff = np.sqrt(numrff)\n", | |
" self.final = torch.nn.Bilinear(naction, numrff, 1)\n", | |
" self.final.weight.data *= 0.01\n", | |
" self.final.bias.data *= 0.01\n", | |
" self.sigmoid = torch.nn.Sigmoid()\n", | |
"\n", | |
" def forward(self, a, y):\n", | |
" with torch.no_grad():\n", | |
" rff = self.rff(y).cos() / self.sqrtrff\n", | |
" return self.final(a, rff)\n", | |
" \n", | |
" def density(self, logits):\n", | |
" return self.sigmoid(logits)\n", | |
"\n", | |
" class RFFSoftmax(torch.nn.Module):\n", | |
" def __init__(self, hilo, naction, numrff, sigma, seed):\n", | |
" from math import pi\n", | |
" import numpy as np\n", | |
"\n", | |
" super().__init__()\n", | |
"\n", | |
" torch.manual_seed(seed)\n", | |
" nobs = hilo.shape[1]\n", | |
" high = hilo[1, :]\n", | |
" low = hilo[0, :]\n", | |
" \n", | |
" self.rff = torch.nn.Linear(nobs, numrff)\n", | |
" self.rff.weight.data = torch.matmul(torch.empty(numrff, nobs).cauchy_(sigma = sigma), \n", | |
" torch.diag(torch.tensor([ 1.0/v if v > 1e-6 else 0. for v in high - low ])).float())\n", | |
" self.rff.weight.requires_grad = False\n", | |
" self.rff.bias.data = 2 * pi * torch.rand(numrff)\n", | |
" self.rff.bias.requires_grad = False\n", | |
" self.sqrtrff = np.sqrt(numrff)\n", | |
" self.final = torch.nn.Linear(numrff, naction)\n", | |
" self.final.weight.data *= 0.01\n", | |
" self.final.bias.data *= 0.01\n", | |
" self.sigmoid = torch.nn.Sigmoid()\n", | |
"\n", | |
" def forward(self, x):\n", | |
" with torch.no_grad():\n", | |
" rff = self.rff(x).cos() / self.sqrtrff\n", | |
" return self.final(rff)\n", | |
" \n", | |
" def preq1(self, logits):\n", | |
" return self.sigmoid(logits)\n", | |
"\n", | |
" transform = torchvision.transforms.Compose([\n", | |
" torchvision.transforms.ToTensor(),\n", | |
" torchvision.transforms.Normalize((0.1307,), (0.3081,))\n", | |
" ])\n", | |
" mnist_train = torchvision.datasets.MNIST('/tmp/mnist', train=True, download=True, transform=transform)\n", | |
" \n", | |
" quantile_loader = torch.utils.data.DataLoader(mnist_train, batch_size=10000, shuffle=True)\n", | |
" for bno, (images, labels) in enumerate(quantile_loader):\n", | |
" flat = images.reshape(images.shape[0], -1)\n", | |
" hilo = numpy.quantile(flat.numpy(), [ 0.01, 0.99 ], axis=0)\n", | |
" pi = RFFSoftmax(hilo, 10, 2000, 0.01, 45)\n", | |
" decoder = RFFBilinearSoftmax(hilo, 10, 2000, 0.01, 2112)\n", | |
" break\n", | |
" \n", | |
" feedback_loader = torch.utils.data.DataLoader(mnist_train, batch_size=1, shuffle=True)\n", | |
" feedbacks = [ [] for _ in range(10) ]\n", | |
" for bno, (images, labels) in enumerate(feedback_loader):\n", | |
" flat = images.reshape(images.shape[0], -1)\n", | |
" feedbacks[labels[0]].append(flat)\n", | |
" if all(len(x) > 100 for x in feedbacks):\n", | |
" break \n", | |
" feedbacks = torch.cat([ torch.cat(x[:100], dim=0).unsqueeze(0) for x in feedbacks ], dim=0)\n", | |
" \n", | |
" # pre-train to get policy \"better than random\"\n", | |
" if True:\n", | |
" preopt = torch.optim.Adam(( p for p in pi.parameters() if p.requires_grad ), lr=1e-2) # 0.1\n", | |
" preloss = torch.nn.CrossEntropyLoss()\n", | |
" pretrain_loader = torch.utils.data.DataLoader(mnist_train, batch_size=64, shuffle=True)\n", | |
" for bno, (images, labels) in enumerate(itertools.chain(*[ pretrain_loader for _ in range(1) ])):\n", | |
" flat = images.reshape(images.shape[0], -1)\n", | |
"\n", | |
" preopt.zero_grad()\n", | |
" ld = pi.forward(flat)\n", | |
" output = preloss(ld, labels)\n", | |
" output.backward()\n", | |
" preopt.step()\n", | |
"\n", | |
" if bno > 0:\n", | |
" break\n", | |
" \n", | |
" train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=64, shuffle=True)\n", | |
" mnist_test = torchvision.datasets.MNIST('/tmp/mnist', train=False, download=True, transform=transform)\n", | |
" test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=1000, shuffle=True)\n", | |
" \n", | |
" opt = torch.optim.Adam(( p for p in itertools.chain(pi.parameters(), decoder.parameters()) if p.requires_grad ), lr=1e-2)\n", | |
" log_loss = torch.nn.BCEWithLogitsLoss(reduce='none')\n", | |
" sampler = SquareCB(gamma=100)\n", | |
" acc, accsincelast, avloss, avlosssincelast = [ EasyAcc() for _ in range(4) ]\n", | |
" avreward, avrewardsincelast, avfake, avfakesincelast = [ EasyAcc() for _ in range(4) ]\n", | |
" reservoirs = [ WeightedReservoir(20, 1973+a) for a in range(10) ]\n", | |
" \n", | |
" print('{:<5s}\\t{:<8s}\\t{:<8s}\\t{:<8s}\\t{:<8s}\\t{:<8s}\\t{:<8s}\\t{:<8s}\\t{:<8s}'.format(\n", | |
" 'n', 'loss', 'since', \n", | |
" 'acc', 'since',\n", | |
" 'reward', 'since',\n", | |
" 'fake', 'since',\n", | |
" ),\n", | |
" flush=True)\n", | |
" \n", | |
" for bno, (images, labels) in enumerate(itertools.chain(*[ train_loader for _ in range(1) ])):\n", | |
" flatimage = images.reshape(images.shape[0], -1)\n", | |
" \n", | |
" opt.zero_grad()\n", | |
" logit = pi(flatimage)\n", | |
" \n", | |
" with torch.no_grad():\n", | |
" fhat = pi.preq1(logit)\n", | |
" sample, probs = sampler.sample(fhat, keepdim=True)\n", | |
" \n", | |
" reward = (sample == labels.unsqueeze(1)).float()\n", | |
" pred = logit.argmax(dim=1, keepdim=True)\n", | |
" ispred = (sample == pred).float()\n", | |
" antipred = logit.argmin(dim=1, keepdim=True)\n", | |
" isantipred = (sample == antipred).float()\n", | |
" \n", | |
" # this assumes a particular majorization (Torch tensors are row-major)\n", | |
" bigfeedbacks = feedbacks.unsqueeze(0).expand(fhat.shape[0], -1, -1, -1).reshape(fhat.shape[0], -1, flatimage.shape[1]) # Batch x (A x Rep) x Pixels\n", | |
" nreps = feedbacks.shape[1]\n", | |
" goodwhich = feedbacks.shape[1] * sample.squeeze(1) + torch.randint(low=0, high=feedbacks.shape[1], size=(fhat.shape[0],))\n", | |
" goodwhich = goodwhich.unsqueeze(1).unsqueeze(2).expand(-1, -1, flatimage.shape[1])\n", | |
" goodfeedbacks = torch.gather(input=bigfeedbacks, index=goodwhich, dim=1).squeeze(1)\n", | |
" badwhich = feedbacks.shape[1] * (9-sample).squeeze(1) + torch.randint(low=0, high=feedbacks.shape[1], size=(fhat.shape[0],))\n", | |
" badwhich = badwhich.unsqueeze(1).unsqueeze(2).expand(-1, -1, flatimage.shape[1])\n", | |
" badfeedbacks = torch.gather(input=bigfeedbacks, index=badwhich, dim=1).squeeze(1)\n", | |
" \n", | |
" if False:\n", | |
" import matplotlib.pyplot as plt\n", | |
"\n", | |
" fig, axs = plt.subplots(1, 10)\n", | |
" for n, (s, f) in enumerate(zip(sample, goodfeedbacks)):\n", | |
" if n > 9:\n", | |
" break\n", | |
" axs[n].imshow(f.reshape(28, 28))\n", | |
" axs[n].set_title(f'{s.item()}')\n", | |
" \n", | |
" plt.show()\n", | |
" \n", | |
" fig, axs = plt.subplots(1, 10)\n", | |
" for n, (s, f) in enumerate(zip(sample, badfeedbacks)):\n", | |
" if n > 9:\n", | |
" break\n", | |
" axs[n].imshow(f.reshape(28, 28))\n", | |
" axs[n].set_title(f'{s.item()}')\n", | |
" \n", | |
" plt.show()\n", | |
" assert False\n", | |
" \n", | |
" feedback = badfeedbacks + reward * (goodfeedbacks - badfeedbacks)\n", | |
" onehotsample = torch.nn.functional.one_hot(sample.squeeze(1), num_classes=fhat.shape[1]).float()\n", | |
" \n", | |
" # insert then sample ... means the first time we play an action there will be no update, that's ok\n", | |
" for s, p, r, f in zip(sample, probs, reward, feedback):\n", | |
" reservoirs[s.item()].insert((f, r), 1/p)\n", | |
" \n", | |
" compfeedback = []\n", | |
" compreward = []\n", | |
" for s in sample:\n", | |
" f, r = reservoirs[s.item()].sample()\n", | |
" compfeedback.append(f.unsqueeze(0))\n", | |
" compreward.append(r.unsqueeze(0))\n", | |
" compfeedback = torch.cat(compfeedback, dim=0)\n", | |
" compreward = torch.cat(compreward, dim=0)\n", | |
" \n", | |
" if False:\n", | |
" import matplotlib.pyplot as plt\n", | |
"\n", | |
" fig, axs = plt.subplots(1, 10)\n", | |
" for n, (s, f, r) in enumerate(zip(sample, compfeedback, compreward)):\n", | |
" if n > 9:\n", | |
" break\n", | |
" axs[n].imshow(f.reshape(28, 28))\n", | |
" axs[n].set_title(f'{s.item()} {r.long().item()}')\n", | |
" \n", | |
" plt.show()\n", | |
" assert False\n", | |
"\n", | |
" samplelogit = torch.gather(input=logit, index=sample, dim=1)\n", | |
" fakelogit = decoder(onehotsample, feedback)\n", | |
" fakereward = decoder.density(fakelogit)\n", | |
" fakecomplogit = decoder(onehotsample, compfeedback)\n", | |
" predloss = torch.mean(log_loss(fakelogit - fakecomplogit, ispred) + log_loss(samplelogit, fakereward.detach()))\n", | |
" antipredloss = torch.mean(log_loss(fakecomplogit - fakelogit, isantipred) + log_loss(1 - samplelogit, fakereward.detach()))\n", | |
" loss = torch.min(predloss, antipredloss)\n", | |
" loss.backward()\n", | |
" opt.step()\n", | |
" \n", | |
" with torch.no_grad():\n", | |
" acc += torch.mean((labels.unsqueeze(1) == pred).float())\n", | |
" accsincelast += torch.mean((labels.unsqueeze(1) == pred).float())\n", | |
" avloss += loss\n", | |
" avlosssincelast += loss\n", | |
" avreward += torch.mean(reward)\n", | |
" avrewardsincelast += torch.mean(reward)\n", | |
" avfake += torch.mean(fakereward)\n", | |
" avfakesincelast += torch.mean(fakereward)\n", | |
" \n", | |
" if (bno & (bno - 1) == 0):\n", | |
" print('{:<5d}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}'.format(\n", | |
" avloss.n, avloss.mean(), avlosssincelast.mean(), \n", | |
" acc.mean(), accsincelast.mean(), \n", | |
" avreward.mean(), avrewardsincelast.mean(),\n", | |
" avfake.mean(), avfakesincelast.mean(),\n", | |
" ),\n", | |
" flush=True)\n", | |
" accsincelast, avlosssincelast, avrewardsincelast, avfakesincelast = [ EasyAcc() for _ in range(4) ]\n", | |
" \n", | |
" print('{:<5d}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}\\t{:<8.5f}'.format(\n", | |
" avloss.n, avloss.mean(), avlosssincelast.mean(), \n", | |
" acc.mean(), accsincelast.mean(), \n", | |
" avreward.mean(), avrewardsincelast.mean(),\n", | |
" avfake.mean(), avfakesincelast.mean(),\n", | |
" ),\n", | |
" flush=True)\n", | |
" accsincelast, avlosssincelast, avrewardsincelast = EasyAcc(), EasyAcc(), EasyAcc()\n", | |
" testacc = EasyAcc()\n", | |
" with torch.no_grad():\n", | |
" for ti, tl in train_loader:\n", | |
" flat = ti.reshape(ti.shape[0], -1)\n", | |
" logit = pi(flat)\n", | |
" testpred = logit.argmax(dim=1)\n", | |
" testacc += torch.mean((tl == testpred).float())\n", | |
"\n", | |
" print(f'testacc {testacc.mean()}')\n", | |
"\n", | |
"iglADepLearn()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "ef8c19c9", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.7" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment