Skip to content

Instantly share code, notes, and snippets.

@gokulanv
Created March 27, 2020 02:10
Show Gist options
  • Save gokulanv/201adf2650444869a90f394a8932a8dd to your computer and use it in GitHub Desktop.
Save gokulanv/201adf2650444869a90f394a8932a8dd to your computer and use it in GitHub Desktop.
omniglot-1shot.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "omniglot-1shot.ipynb",
"provenance": [],
"collapsed_sections": [
"27iaosZpkPI8"
],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/gokulanv/201adf2650444869a90f394a8932a8dd/omniglot-1shot.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4F3naO_l1Sts",
"colab_type": "text"
},
"source": [
"## DataLoad\n",
"\n",
"Here, I clone the data from the original repository and store it in the folder : **/content/omniglot/images**\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "XbqzoXeQgRrd",
"colab_type": "code",
"outputId": "335b3f0d-72ac-441b-a6a0-f6a18c1118d4",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 84
}
},
"source": [
"!git clone https://github.com/brendenlake/omniglot.git"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"Cloning into 'omniglot'...\n",
"remote: Enumerating objects: 81, done.\u001b[K\n",
"remote: Total 81 (delta 0), reused 0 (delta 0), pack-reused 81\u001b[K\n",
"Unpacking objects: 100% (81/81), done.\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "XvC3ScQtV0Ii",
"colab_type": "code",
"colab": {}
},
"source": [
"!unzip /content/omniglot/python/images_background.zip -d /content/omniglot/images/"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "GNQnTLnZWB1L",
"colab_type": "code",
"colab": {}
},
"source": [
"!unzip /content/omniglot/python/images_evaluation.zip -d /content/omniglot/images/"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "IxqshMMvOGlY",
"colab_type": "code",
"colab": {}
},
"source": [
"# move all images to /content/omniglot/images/\n",
"!mv /content/omniglot/images/images_background/* /content/omniglot/images/\n",
"!mv /content/omniglot/images/images_evaluation/* /content/omniglot/images/"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "9gA2y3u8QItP",
"colab_type": "code",
"colab": {}
},
"source": [
"# Removing redundant data\n",
"!rm -r /content/omniglot/images/images_evaluation/\n",
"!rm -r /content/omniglot/images/images_background/\n",
"\n",
"# Removing unnecessary directories\n",
"!rm -r /content/omniglot/matlab/\n",
"!rm -r /content/omniglot/python/"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "L54f6T_m1XP8",
"colab_type": "text"
},
"source": [
"## Headers"
]
},
{
"cell_type": "code",
"metadata": {
"id": "aWHDnWVZWS6w",
"colab_type": "code",
"colab": {}
},
"source": [
"# System headers\n",
"import os\n",
"import sys\n",
"from datetime import datetime\n",
"from os import listdir\n",
"\n",
"# PIL headers\n",
"from PIL import Image\n",
"from PIL import ImageOps\n",
"\n",
"# Pytorch headers\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
"from torch import autograd\n",
"from torch.autograd import Variable\n",
"from torchvision import transforms, datasets\n",
"\n",
"# misc headers\n",
"from tqdm import tqdm\n",
"import pickle\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"# Torch device config setup\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "nrDDhkQLGaUB",
"colab_type": "text"
},
"source": [
"# MAML\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JaPrQw8myO66",
"colab_type": "text"
},
"source": [
"### DataLoader\n",
"\n",
"OmniglotDataLoader class reads the images from the alphabet directories and stores in the variable X_data which is again stored as pickle file."
]
},
{
"cell_type": "code",
"metadata": {
"id": "07nlKbuzGwFu",
"colab_type": "code",
"outputId": "cb357f98-b962-4b83-83b8-ab5a5d0e1b59",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 542
}
},
"source": [
"class OmniglotDataLoader(object):\n",
" def __init__(self, is_plot = True):\n",
" datapath = \"/content/omniglot/images/\"\n",
" alphabets = os.listdir(datapath)\n",
" idx = 0\n",
" X_data = np.zeros((1623, 20, 28, 28), dtype=np.uint8)\n",
" # 1623 = num_characters\n",
" # 20 = each character has 20 instances\n",
" for alphabet in tqdm(alphabets):\n",
" ap_dir = datapath + alphabet\n",
" characters = os.listdir(ap_dir)\n",
" for character in characters:\n",
" char_dir = ap_dir + \"/\" + character\n",
" for j, fname in enumerate(os.listdir(char_dir)):\n",
" fpath = char_dir + \"/\" + fname\n",
" im = Image.open(fpath)\n",
" im = im.resize((28,28))\n",
" image = np.array(im, dtype=np.uint8)\n",
" X_data[idx, j, :, :] = image\n",
" idx += 1\n",
" with open(\"/content/omniglot.nparray.pk\", \"wb\") as f:\n",
" pickle.dump(X_data, f)\n",
"\n",
" if is_plot:\n",
" self.plotImg(X_data)\n",
"\n",
" def plotImg(self, X_data, size = 6):\n",
" characters = np.random.randint(low=0, high=1623, size=size)\n",
" classes = np.random.randint(low=0, high=20, size=size)\n",
" \n",
" columns = size // 2\n",
" rows = 2\n",
"\n",
" f, axs = plt.subplots(rows, columns, figsize=(10,10))\n",
" axs = axs.flatten()\n",
" for i in range(columns*rows):\n",
" axs[i].imshow(X_data[characters[i], classes[i], :, :], cmap = 'gray')\n",
" axs[i].set_title(\"Char: \" + str(characters[i]) +\" | Y: \" + str(classes[i]))\n",
" plt.show()\n",
"\n",
"\n",
"\n",
"loader = OmniglotDataLoader()\n",
"\n"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"100%|██████████| 50/50 [00:04<00:00, 10.74it/s]\n"
],
"name": "stderr"
},
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAlAAAAH8CAYAAAAe1JFNAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAgAElEQVR4nO3de7hsdX3n+fdHwCteQI4nBMGjNJkO\nSVpMn0bipYc8qIMmHXA644jGgNEHO5GI0SfTtElHTEfbNqLpTow2DniwVYgZb8RWE8LIGO2E8UiQ\naxR0jhE8wAFFLo0xwHf+WOt0qrZnX357V+26vV/PU8+uWquq1nfVXt+9P7XW+lWlqpAkSdLaPWTS\nBUiSJM0aA5QkSVIjA5QkSVIjA5QkSVIjA5QkSVIjA5QkSVIjA5QkSVIjA9QSSc5O8oFJ1zEOSS5L\ncvyk69Bsm/Me2ZHktEnXoelmDwgWNEAleUmSnUnuSbI7yaeTPGtCtbw9yQ1J7k7yN0l+cZn7/WKS\nSvLKgWk/neSzSb6bZNcGatia5Pal4SrJ+UkuWuWxy9aQ5AlJLkzyrX7+F5I8fb11avNMU4/09Twn\nyRVJ7k1yU5IXDczbL8nv9NvZ3Un+OsnjBuY/Jckn+3m3J3nbOpb/iL5PT10y/bf67XrFv6VJfi3J\nLUnu6vvqYa01aHNNUw8keVuSb/bbzzeSvGFg3iH9NnhHkjuT/GWSZw7Mf1iSd/b98Z0kf5jkgHXU\nsO4eSPLjSf60778f+PTuJAcn+Vjf399I8pLW+iZh4QJUktcBvwe8BdgKHAH8IXDSGJa1/xrudi/w\nL4DHAqcC/zHJM5Y8z0HAG4Br9/HY84Ff30idVXUr8GvAe5M8ol/mCcDPAr+6hvqXq+FA4IvAPwUO\nBi4A/muSAzdSr8Zr2nokydHAh4DfoOuTpwJfGrjLm4BnAD8FPAZ4GfC9/rEPBS4B/m/gh4AnAs17\nDqrqPuAVwNuTbO2f+0eB1wOvqKoHV6j/fwHOAk4AngQ8pa9ZU2raegA4D/jHVfUYum39pUn+137e\nPcAvAVuAg4D/APzJwPOeBWwHfhz4EeAngd9srXMjPQD8PfDh/vH78i7g+3Sv9UuBdyf5sdYaN11V\nLcyF7o/vPcD/tsJ9zqb7Rb8fuJsutGwfmH8W8LV+3nXACwfmnQZ8AXgncAfwO+uo8WLg9UumvQf4\nFeAy4JX7eMxzgF1reO7LgONXmP9J4HeBRwA3Ai9uqHutNdwF/NNJbwtelv39TF2P0IWnf7fMvIP6\neo9cZv7pwF80rP8O4LQV5r8L+GMg/Xqctcb63zJw+wTglkn/rr0s+/uauh5YsuzDgKuB/2Mf8x5C\n94a8gCf003YOrgvwEuCbKzz/yHtg4LH/CKgl0x5FF55+ZGDafwHeOultYbXLou2B+ing4cDHVrnf\nzwEXAY+jCzR/MDDva8Cz6ZrsTcAHkhw6MP/pwNfpkvSb+93AV62luH7vzz9jYE9TkmPp3j28Zy3P\nsUH/iu6dzEXANVV1UZJnJblzFE+e5BjgoXThTNNpGnvkOIAkV/eHUj6Q5OB+3k8A9wM/3x8i+2qS\nVy957K7+8Mvt6c4D/IlV1m0l/5quRz8CPAz43SRH9IdOjljmMT8GfHng9peBrUkev4E6ND7T2AMk\nOSvJPcBNdKHjQ0vmX0W35/Vi4P+sqtsGZy+5/sQkj11l/Zaznh5YyY8A91fVVwemfZmub6baogWo\nxwO3V9X9q9zv81X1qap6gC4JP3XvjKr646r6VlU9WFV/BNwAHDvw2G9V1e9X1f1VdV9Vfaiq/ska\n63sP3Ybzp9Cd20G32/iMWnn36EhU1U3Ab9HtTfrlftrnq+pxKz5wDZI8hu61fFNVfXejz6exmcYe\neSLdYbl/CRxFt4f09wfmPZbuj/CTgZ8Hzk7y3IH5Lwb+E/DDwH8FPtEf2mtWVfcArwZeSHfY4oGq\n+tuqelxV/e0yDzsQGNzm915/9Hpq0NhNYw9QVW+l22Z+sl/ed5fM/yd0h7BfAnx+YNZngDOTbEny\nQ8Br+umPXGX9lqtjPT2wkgPpjkwM+i4z0B+LFqDuAA5ZwzHnWwau/3fg4Xsfk+5k7iv7tH0n3XHl\nQwbu/831FJbkd/vnelH1+zDpDttdVVV/tZ7nXKdrge9U1e5RPWG/Z+1PgL+qqn8/qufVWExjj9wH\nvK+qvtr/8X4L8IKBeQC/3f8juopur8Dg/M9X1aer6vvA2+n+Qf5oYw2Drl3yczX30P1j22vv9bs3\nUIPGZxp7AOiOfVXVX9Nt1z9wHl1Vfa+qLgTOSrI30L0Z+GvgSuC/AR+nOyfp1vXU0GvtgZUs7Q/6\n21PfH4sWoP4S+Dvg5PU8OMmTgPcCZwCP7/fMXMPw7tEfGGGwhud9E/B84HlVNZjETwBe2B+auIXu\n5MFzkvzBvp5nGqUbbfRxut3Or5pwOVrdNPbIVUseU0vmLZ22dH5zT47YtQzsneiv31pVd0yoHq1s\nGntgqf2BI1eYfwDdYAX6NxZnVNVhVfUUuoD4pc04qrFGXwX2T3LUwLSnMppwNlYLFaD6Q0e/Bbwr\nyclJHpnkgCTPz9qGNj+KbsPfA5Dk5XTvLNYtyb+h2+X6nH38QT2N7p3yMf1lJ927jt/oH/uQJA+n\na5Ykefh6D02s10o19ENl/y+6d0unTlHDahnT2CPA+4CXp/s4gkfSnaD7yb7erwF/AfxGuuHaP0p3\nyO6T/WM/AByX7mMQ9gNeC9wOXL/Bmlq8H3hFkqPTfbzCb9KdqKspNG090P+NfVWSg9I5lu4Q2qX9\n/OP6c1Ufmu6jBv413blVl/fzD0vyw/1jjwP+LfDG9dazznVI/39i7/+Gh/dvrqmqe4GPAr+d5FHp\nPoLhJLrDlFNtoQIUQFWdA7yO7o/YHrpdqWfQ7SVZ7bHXAefQvUO5le4E1i+s9JgkL02yUpJ+C90Q\n2RvTfd7IPek/46Oq7qyqW/Ze6EYq3DVwDtE/pwsnn+qf4z7gz1ZbjxZJnt2fuLiclWp4Bt1HITwP\nuHNg/Z49yho1WtPWI1V1Pl0IuRz4Bt3egdcM3OUUuo8HuIPuHKd/W1WX9o/9CvALdOcXfofuD/PP\n9YfzRqI/gfae5U6grarPAG8DPgv8bb8Om/oPTG2mrQfozjfaO6rvA3TnAO49D/BhdCPj7gBupjt8\n/TNV9a1+/pF0h+7upfsombOqatT/J1bsAbr+vI9/2Kt0H/CVgfm/Qndu423AhcAvV9XU74HKP5xu\no3mX5DLg7Kq6bMKlSFMpyQ7gsqraMeFSpImwB9Zu4fZASZIkbdRaPgFV82MHsGvCNUjT7OPYI1ps\n9sAaeQhPkiSp0YYO4SU5MclXktyY5KxRFSXNKntCGmZPaF6tew9UPyT4q8Bz6T7j54vAKf0IhH06\n5JBDatu2betanjRqu3bt4vbbb8/q91wbe0Kzzp6Qhq3UExs5B+pY4Maq+jpAkovohggv2xjbtm1j\n586dG1ikNDrbt28f9VPaE5pp9oQ0bKWe2MghvMMY/jj6m/pp0qKyJ6Rh9oTm1tg/xiDJ6Ul2Jtm5\nZ8+ecS9Omnr2hDTMntAs2kiAuhk4fOD2E/tpQ6rq3KraXlXbt2zZsoHFSVPPnpCG2ROaWxsJUF8E\njkry5P67z14MXDyasqSZZE9Iw+wJza11n0ReVfcnOQP4U2A/4PxZ+O4aaVzsCWmYPaF5tqFPIq+q\nT9F9iawk7AlpKXtC88rvwpMkSWpkgJIkSWpkgJIkSWpkgJIkSWpkgJIkSWpkgJIkSWpkgJIkSWpk\ngJIkSWpkgJIkSWpkgJIkSWpkgJIkSWpkgJIkSWpkgJIkSWpkgJIkSWpkgJIkSWq0/6QL0PRJss/p\nVbXJlUiSVrOvv9n+vR4/90BJkiQ1MkBJkiQ1MkBJkiQ1MkBJkiQ12tBJ5El2AXcDDwD3V9X2URQl\nzSp7QhpmT2hejWIU3k9X1e0jeB5pXtgT0jB7QnPHQ3iSJEmNNhqgCvizJF9KcvooCpJmnD0hDbMn\nNJc2egjvWVV1c5InAJck+Zuq+tzgHfqGOR3giCOO2ODipKlnT0jD7AnNpQ3tgaqqm/uftwEfA47d\nx33OrartVbV9y5YtG1mcNPXsCWmYPaF5te4AleRRSR699zrwPOCaURU2bklm4jKJ2rU+s94T0qjZ\nE5pnGzmEtxX4WP8Pd3/gQ1X1mZFUJc0me0IaZk9obq07QFXV14GnjrAWaabZE9Iwe0LzzI8xkCRJ\namSAkiRJamSAkiRJajSKr3KZaq2jyqpqTJWsbKURdxvVuk6ttUzqNdP6zEpPSNI0cw+UJElSIwOU\nJElSIwOUJElSIwOUJElSIwOUJElSo7kfhbecaRtZtFw9LSOmRrVOrbU4Om++OWpPkn6Qe6AkSZIa\nGaAkSZIaGaAkSZIaGaAkSZIaGaAkSZIaLewovGkzidF2Wkyj2n4cldluEt9tqcVkH46fe6AkSZIa\nGaAkSZIaGaAkSZIaGaAkSZIaGaAkSZIarRqgkpyf5LYk1wxMOzjJJUlu6H8eNN4ypekxrz2RpOlS\nVfu8aPnXcrnXrOW1XO65J2lee2JUWnur9TJNtUzj9jkua9kDtQM4ccm0s4BLq+oo4NL+trQodmBP\nSIN2YE9owawaoKrqc8C3l0w+Cbigv34BcPKI65Kmlj0hDbMntIjWew7U1qra3V+/Bdg6onqkWWVP\nSMPsCc21DZ9EXt2B+mVPfEhyepKdSXbu2bNno4uTpp49IQ2zJzSP1hugbk1yKED/87bl7lhV51bV\n9qravmXLlnUuTpp69oQ0zJ7QXFtvgLoYOLW/firwidGUs3mmbaTARkfmaOJmvidaTVsPTcK4X4MZ\n/7uwcD0xC1pHgo7qMo/W8jEGFwJ/CfxPSW5K8grgrcBzk9wAPKe/LS0Ee0IaZk9oEe2/2h2q6pRl\nZp0w4lqkmWBPSMPsCS0iP4lckiSpkQFKkiSpkQFKkiSp0arnQM26lb5PSvvW+trM6wiLebXc73dU\nv8flnn/cyx0ne0IbMYntwf9x4+ceKEmSpEYGKEmSpEYGKEmSpEYGKEmSpEYGKEmSpEZzPwqv1SKN\ntlmkddXqWkfJjWqUzyyPzpuFGiWNh3ugJEmSGhmgJEmSGhmgJEmSGhmgJEmSGi3sSeStJ3/O8tdT\neLK4BrWeFL5I249ffyFprdwDJUmS1MgAJUmS1MgAJUmS1MgAJUmS1MgAJUmS1GjVAJXk/CS3Jblm\nYNrZSW5OcmV/ecF4y5y8qmq6JGm6TNM6aWXz2hOj2h6maRvX5pjXnpBWspY9UDuAE/cx/Z1VdUx/\n+dRoy5Km2g7sCWnQDuwJLZhVA1RVfQ749ibUIs0Ee0IaZk9oEW3kHKgzklzV77o9aGQVSbPLnpCG\n2ROaW+sNUO8GjgSOAXYD5yx3xySnJ9mZZOeePXvWuThp6tkT0jB7QnNtXQGqqm6tqgeq6kHgvcCx\nK9z33KraXlXbt2zZst46palmT0jD7AnNu3UFqCSHDtx8IXDNcvfV9HK01OjMQ0+0bg+to/ZmeXtz\nxGq7eegJaSWrfplwkguB44FDktwEvBE4PskxQAG7gFeNsUZpqtgT0jB7Qoto1QBVVafsY/J5Y6hF\nmgn2hDTMntAi8pPIJUmSGhmgJEmSGhmgJEmSGq16DpTWZ5pG6CxXy3IjoEY1MmqaXgNtHn/v0tpM\nYhTqJL97dd64B0qSJKmRAUqSJKmRAUqSJKmRAUqSJKmRAUqSJKmRo/AW2DyOipDWYla+g0/zbZx/\ng1u3cf8ftHMPlCRJUiMDlCRJUiMDlCRJUiMDlCRJUiMDlCRJUiNH4UkC2r8zcRY4EknSuLgHSpIk\nqZEBSpIkqZEBSpIkqZEBSpIkqdGqASrJ4Uk+m+S6JNcmObOffnCSS5Lc0P88aPzlSpNnT0jD7Akt\norXsgbofeH1VHQ0cB7w6ydHAWcClVXUUcGl/W1oE9gTdCLeWyyRqWU5V7fOidbMntHBWDVBVtbuq\nruiv3w1cDxwGnARc0N/tAuDkcRUpTRN7QhpmT2gRNZ0DlWQb8DTgcmBrVe3uZ90CbB1pZdIMsCek\nYfaEFsWaA1SSA4GPAK+tqrsG51W373uf+7+TnJ5kZ5Kde/bs2VCx0jSxJ6Rh9oQWyZoCVJID6Jri\ng1X10X7yrUkO7ecfCty2r8dW1blVtb2qtm/ZsmUUNUsTZ09Iw+wJLZq1jMILcB5wfVW9Y2DWxcCp\n/fVTgU+Mvjxp+tgT0jB7QotoLd+F90zgZcDVSa7sp70BeCvw4SSvAL4BvGg8JUpTZ6F6onV02nKj\n3ybxnXqOrNs0C9UTEqwhQFXV54Hl/vKdMNpypOlnT0jD7AktIj+JXJIkqZEBSpIkqZEBSpIkqZEB\nSpIkqdFaRuFJ0po58k3SInAPlCRJUiMDlCRJUiMDlCRJUiMDlCRJUiMDlCRJUiMDlCRJUiMDlCRJ\nUiMDlCRJUiMDlCRJUiMDlCRJUiO/ykWSpDnjVyqNn3ugJEmSGhmgJEmSGhmgJEmSGhmgJEmSGhmg\nJEmSGq0aoJIcnuSzSa5Lcm2SM/vpZye5OcmV/eUF4y9Xmjx7QhpmT2gRreVjDO4HXl9VVyR5NPCl\nJJf0895ZVW8fX3nSVLInpGH2hBbOqgGqqnYDu/vrdye5Hjhs3IVJ08qekIbZE1pETedAJdkGPA24\nvJ90RpKrkpyf5KAR1yZNPXtCGmZPaFGsOUAlORD4CPDaqroLeDdwJHAM3TuPc5Z53OlJdibZuWfP\nnhGULE0He0IaZk9okawpQCU5gK4pPlhVHwWoqlur6oGqehB4L3Dsvh5bVedW1faq2r5ly5ZR1S1N\nlD0hDbMntGjWMgovwHnA9VX1joHphw7c7YXANaMvT5o+9oQ0zJ7QIlrLKLxnAi8Drk5yZT/tDcAp\nSY4BCtgFvGosFUrTx56QhtkTWjhrGYX3eSD7mPWp0ZcjTT97QhpmT2gR+UnkkiRJjQxQkiRJjQxQ\nkiRJjQxQkiRJjQxQkiRJjQxQkiRJjQxQkiRJjQxQkiRJjQxQkiRJjVJVm7ewZA/wjf7mIcDtm7bw\nyVmU9YTZW9cnVdVEv7nUnph7s7au9sRkLMp6wuyt67I9sakBamjByc6q2j6RhW+iRVlPWKx1HYdF\nef0WZT1hsdZ1HBbl9VuU9YT5WlcP4UmSJDUyQEmSJDWaZIA6d4LL3kyLsp6wWOs6Dovy+i3KesJi\nres4LMrrtyjrCXO0rhM7B0qSJGlWeQhPkiSp0aYHqCQnJvlKkhuTnLXZyx+nJOcnuS3JNQPTDk5y\nSZIb+p8HTbLGUUlyeJLPJrkuybVJzuynz+X6jpM9MfvbiP0wWvbE7G8ni9ATmxqgkuwHvAt4PnA0\ncEqSozezhjHbAZy4ZNpZwKVVdRRwaX97HtwPvL6qjgaOA17d/y7ndX3Hwp6Ym23EfhgRe2JutpO5\n74nN3gN1LHBjVX29qr4PXASctMk1jE1VfQ749pLJJwEX9NcvAE7e1KLGpKp2V9UV/fW7geuBw5jT\n9R0je2IOthH7YaTsiTnYThahJzY7QB0GfHPg9k39tHm2tap299dvAbZOsphxSLINeBpwOQuwviNm\nT8zZNmI/bJg9MWfbybz2hCeRb6LqhjzO1bDHJAcCHwFeW1V3Dc6bx/XVaM3bNmI/aKPmbTuZ557Y\n7AB1M3D4wO0n9tPm2a1JDgXof9424XpGJskBdI3xwar6aD95btd3TOyJOdlG7IeRsSfmZDuZ957Y\n7AD1ReCoJE9O8lDgxcDFm1zDZrsYOLW/firwiQnWMjJJApwHXF9V7xiYNZfrO0b2xBxsI/bDSNkT\nc7CdLEJPbPoHaSZ5AfB7wH7A+VX15k0tYIySXAgcT/dt07cCbwQ+DnwYOILuG8ZfVFVLTyCcOUme\nBfwFcDXwYD/5DXTHuOdufcfJnpj9bcR+GC17Yva3k0XoCT+JXJIkqZEnkUuSJDUyQEmSJDUyQEmS\nJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUy\nQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmS\nJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUy\nQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmS\nJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUy\nQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmS\nJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUy\nQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmS\nJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUy\nQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmSJDUyQEmS\nJDUyQEmSJDUyQEmSJDUyQEmSJDUyQK0iydlJPjDpOsYhyWlJdky6Ds2WOe+Jy5IcP+k6NFvsicVk\ngAKSvCTJziT3JNmd5NNJnjWhWg5O8kdJ7khye5IPJnnMwPzPJtmT5K4kX05y0pLH/2qS/6+fv3O9\n65HkA0net2Ta/9zXdegKjzu6X+53+sufJzl6PTVocqapJ/p6npPkiiT3JrkpyYv2cZ9fTFJJXjkw\n7af7nvlukl0bWP7Wvh+PXzL9/CQXrfLYs5P8ff9a7r08Zb21aDLmqCd+Pck1Se7u/1f8+jqXv5Ge\n+LUkX+//T30ryTuT7L+eOiZp4QNUktcBvwe8BdgKHAH8IXDSSo9b57LWsoH8DnAQ8GTgyL6mswfm\nnwkcWlWPAU4HPrA30CR5OvBW4OeBxwLnAR9Lst86yj0TeH6S5/bP/XDgvcDrq2r3Co/7Vr/8g4FD\ngIuBFZtJ02XaeqIP4B8CfoNuu34q8KUl9zkIeANw7ZKH3wucD6zrn8ReVXUr8GvAe5M8ol/mCcDP\nAr+6hqf4o6o6cODy9Y3Uo801Zz0R4Bfp/s+cCJyR5MWtdW6wJy4GfrL/P/bjff2vaa1h0hY6QCV5\nLPDbwKur6qNVdW9V/X1V/UlVDf7BfWiS9/eJ/dok2wee46wkX+vnXZfkhQPzTkvyhT5d38FwEFrO\nk4GPV9VdVfVd4GPAj+2dWVVXVdX9e28CBwCH97e3AddW1ZeqqoD304WYJ7S9MlBVd9A1wblJHgW8\nEfhaVe1Y5XF3VtWufvkBHgD+UevyNRlT2hO/Cfznqvp0Vd1fVXdU1deW3OffA/8JuH1wYlX9v1X1\nX4ANB5b+eb4C/Hb/D+M/A6+pqj0bfW5NrznsibdV1RX9474CfAJ4ZstrMvBc6+qJqvpaVd3Z3wzw\nIDP4f2KhAxTwU8DD6ULKSn6Obi/K4+iS8x8MzPsa8Gy6dwFvYmCPUO/pdH+8twJv7ncDX7XCst4F\n/GySg/p3EP8S+PTgHZJ8Msn3gMuBy4Cd/axPA/sleXq/1+mXgCuBW1ZZv32qqj8GrgAupNvbdXq/\n/KuSvGSlxya5E/ge8Pt079o0G6axJ44DSHJ1f+jkA0kO3jszybHAduA9a1i/jfpXdH11EXBNVV2U\n5Fn99r6Sf5Hk2/0/1l8ef5kaobntiSTp61q6l6rFunqiX8e76ALeU+nC12ypqoW9AC8FblnlPmcD\nfz5w+2jgvhXufyVwUn/9NOBvG2v6YeDP6RL5g8AlwEP3cb8DgOcDrxuYFrpdtn8P3E+3Yf6zFZZ1\nGrBjlXq2AvcAZ67j9X0U8CvAz0z6d+1lzb+zaeyJ7wO7gB8BDgQ+Anywn7cf3RuI4/rblwGv3Mdz\nPAfYtYZlXQYcv8p9Xk13aPDQNdZ/dN/X+wHPAHYDp0z6d+1lbZd57Yl+3puALwMPW2FZI++JJY89\nCvh3wA9N+nfdeln0PVB3AIes4Zjz4B6c/w48fO9j+pP0rkxyZ5+4f5zusNle32ys6cPAV4FHA4+h\ne+fyA6M7qtuF/GngeUl+rp/8CuDldIf8Hgr8AvDJJD/cWMPgcm6lC2LN71Cq6l66d0DvT9J8GFET\nMY09cR/wvqr6alXdQ7dH8wX9vF8Brqqqv2p8zo24FvhOrXwu4P9QVddV1beq6oGq+m/Af6Q7T1Cz\nYS57IskZdOdC/UxV/V3j8pdq6olBVXVD//g/3GANm27RA9RfAn8HnLyeByd5Et2J1WcAj6+qxwHX\n0O0J2qsan/YYumPb9/aN8R7+oTH2ZX+6k833PvaTfVM9WFWfoXu3+4zGGkbpIcAjgcMmWIPWbhp7\n4qoljxm8fgLwwiS3JLmFbls/J8ng4ZNps/f8QM2GueuJJL8EnAWcUFU3NS57HAb/j82MhQ5Q1Z2k\n/VvAu5KcnOSRSQ5I8vwkb1vDUzyKbsPdA5Dk5XTvLDbii8ArkzyiPynvdLpmIck/7mt7RF/nLwD/\nHPh/Bh77M0meks5z6XbxXrPBmtYsyXOTPC3Jfuk+fuEdwHeA6zerBq3flPbE+4CX99v1I+n+8H+y\nn3ca8KN0bx6OoTt08Sa60UkkeUi6EaQHdDfz8CQP3WA9TZKc1J/TmP7clNfQnbirGTCHPfFSuj1W\nz60JjQZN8sq9RyXSjSj8N8Clk6hlIxY6QAFU1TnA6+hGNeyh25V6BvDxNTz2OuAcuncotwI/AXxh\npcckeWmSlQ6H/RLdaLqbgJuBpwCn7n043bH22/pazwT+96q6op//froT+S4D7qIbgfGqqvqb1dal\nRX8i7EuXmf04upPOv0t3+PFI4MSq+t4oa9D4TFtPVNX5dNv25cA36PYGvKafd2dV3bL3QnduyN4R\nrNC9wbgP+BTd0PP7gD9bbbezw8YAAA3DSURBVD1aJHl2kntWuMuLgRuBu/v1+A9VdcEoa9B4zVlP\n/A7weOCL+YfPJRvpAIw19MQzgauT3EvXm5+iO393pqQ/iUsLKMlpdCcHnjbhUqSpkOQy4OyqumzC\npUhTwZ5Y3sLvgZIkSWo1cx+drpG6Eljt82ukRbKDbni4pM4O7Il98hCeJElSIw/hSZIkNdpQgEpy\nYpKvJLkxyVmjKkqaVfaENMye0Lxa9yG8dN+19lXguXRD7r9I9/UE1y33mEMOOaS2bdu2ruVJo7Zr\n1y5uv/32kX2goT2hWWdPSMNW6omNnER+LHDj3g/iSnIRcBKwbGNs27aNnTt3Ljdb2lTbt29f/U5t\n7AnNNHtCGrZST2zkEN5hDH9/z034dR1abPaENMye0Nwa+0nkSU5PsjPJzj179ox7cdLUsyekYfaE\nZtFGAtTNwOEDt5/YTxtSVedW1faq2r5ly5YNLE6aevaENMye0NzaSID6InBUkif3X875YuDi0ZQl\nzSR7QhpmT2hurfsk8qq6P8kZwJ8C+wHnV9VKX5IrzTV7QhpmT2iebeirXKpq77coS8KekJayJzSv\n/CRySZKkRgYoSZKkRgYoSZKkRgYoSZKkRgYoSZKkRgYoSZKkRgYoSZKkRgYoSZKkRgYoSZKkRgYo\nSZKkRgYoSZKkRgYoSZKkRgYoSZKkRgYoSZKkRgYoSZKkRgYoSZKkRgYoSZKkRgYoSZKkRgYoSZKk\nRgYoSZKkRvtv5MFJdgF3Aw8A91fV9lEUJc0qe0IaZk9oXm0oQPV+uqpuH8HzSPPCnpCG2ROaOx7C\nkyRJarTRAFXAnyX5UpLTR1GQNOPsCWmYPaG5tNFDeM+qqpuTPAG4JMnfVNXnBu/QN8zpAEccccQG\nFydNPXtCGmZPaC5taA9UVd3c/7wN+Bhw7D7uc25Vba+q7Vu2bNnI4qSpZ09Iw+wJzat1B6gkj0ry\n6L3XgecB14yqMGnW2BPSMHtC82wjh/C2Ah9Lsvd5PlRVnxlJVdJssiekYfaE5ta6A1RVfR146ghr\nkWaaPSENsyc0z/wYA0mSpEYGKEmSpEYGKEmSpEYGKEmSpEYGKEmSpEYGKEmSpEYGKEmSpEYGKEmS\npEYGKEmSpEYGKEmSpEYGKEmSpEYGKEmSpEYGKEmSpEYGKEmSpEYGKEmSpEb7T7oASZK0uiRrvm9V\njbESgXugJEmSmhmgJEmSGhmgJEmSGhmgJEmSGq0aoJKcn+S2JNcMTDs4ySVJbuh/HjTeMqXpYU9I\nw+wJLaK17IHaAZy4ZNpZwKVVdRRwaX9bGpKk6TJDdmBPSIN2MIc9Mcd/wzQCqwaoqvoc8O0lk08C\nLuivXwCcPOK6pKllT0jD7AktovWeA7W1qnb3128Bto6oHmlW2RPSMHtCc23DJ5FX92ldy35iV5LT\nk+xMsnPPnj0bXZw09ewJaZg9oXm03gB1a5JDAfqfty13x6o6t6q2V9X2LVu2rHNx0tSzJ6Rh9oTm\n2noD1MXAqf31U4FPjKYcaWbZE9Iwe0JzbdXvwktyIXA8cEiSm4A3Am8FPpzkFcA3gBeNs8hxGNWI\niXF/39A4R3ZM23clzcr3PM1rT0jrZU90pmkk3jTVspJp+z/UYtUAVVWnLDPrhBHXIs0Ee0IaZk9o\nEflJ5JIkSY0MUJIkSY0MUJIkSY0MUJIkSY1WPYlcm2NWRkxI0qKYthFiszJSeVG4B0qSJKmRAUqS\nJKmRAUqSJKmRAUqSJKmRAUqSJKmRo/A2Wetou32NpFjuOaZt1MVy9bTU7+jEjfH1mw3T1ruSVuce\nKEmSpEYGKEmSpEYGKEmSpEYGKEmSpEYLexJ560mb4z4Zt/WE63m0SOs6jTyRWZLWzj1QkiRJjQxQ\nkiRJjQxQkiRJjQxQkiRJjQxQkiRJjVYNUEnOT3JbkmsGpp2d5OYkV/aXF4y3TA2qqn1eZsWs129P\nSMPsCS2iteyB2gGcuI/p76yqY/rLp0ZbljTVdmBPSIN2YE9owawaoKrqc8C3N6EWaSbYE9Iwe0KL\naCPnQJ2R5Kp+1+1BI6tIml32hDTMntDcWm+AejdwJHAMsBs4Z7k7Jjk9yc4kO/fs2bPOxUlTz56Q\nhtkTmmvrClBVdWtVPVBVDwLvBY5d4b7nVtX2qtq+ZcuW9dYpTTV7QhpmT2jereu78JIcWlW7+5sv\nBK5Z6f7SvLMnpGGT7olF/27NRV//9WgdDb5qgEpyIXA8cEiSm4A3AscnOQYoYBfwqtZCpVllT0jD\n7AktolUDVFWdso/J542hFmkm2BPSMHtCi8hPIpckSWpkgJIkSWpkgJIkSWq0rlF4kiRNs1n6fk3N\nJvdASZIkNTJASZIkNTJASZIkNTJASZIkNTJASZIkNTJASZIkNTJASZIkNTJASZIkNTJASZIkNTJA\nSZIkNTJASZIkNfK78MYkyaRLWNWoavQ7pzbPOLerWdhmtTz7UNpc7oGSJElqZICSJElqZICSJElq\nZICSJElqtGqASnJ4ks8muS7JtUnO7KcfnOSSJDf0Pw8af7nS5NkT0jB7QotoLXug7gdeX1VHA8cB\nr05yNHAWcGlVHQVc2t/WKqpqn5dxSrLPy6w8/xSaWE8st/148TJh/p/Qwlk1QFXV7qq6or9+N3A9\ncBhwEnBBf7cLgJPHVaQ0TewJaZg9oUXUdA5Ukm3A04DLga1VtbufdQuwdaSVSTPAnpCG2RNaFGsO\nUEkOBD4CvLaq7hqcV93+433uQ05yepKdSXbu2bNnQ8VK08SekIbZE1okawpQSQ6ga4oPVtVH+8m3\nJjm0n38ocNu+HltV51bV9qravmXLllHULE2cPSENsye0aNYyCi/AecD1VfWOgVkXA6f2108FPjH6\n8qTpY09Iw+wJLaK1fBfeM4GXAVcnubKf9gbgrcCHk7wC+AbwovGUON0mMdpsUt9hN6rljuJ5lnuO\nTRqNZE9Iw+wJLZxVA1RVfR5Y7j/eCaMtR5p+9oQ0zJ7QIvKTyCVJkhoZoCRJkhoZoCRJkhoZoCRJ\nkhqtZRSeWH5017hH4e1rua3LHNXItNbXYBR1zvl36kmSZpR7oCRJkhoZoCRJkhoZoCRJkhoZoCRJ\nkhoZoCRJkho5Cm+DNum71ya+zJWMs55JjX6UJGkl7oGSJElqZICSJElqZICSJElqZICSJElq5Enk\nmknTdiK9JGmxuAdKkiSpkQFKkiSpkQFKkiSpkQFKkiSp0aoBKsnhST6b5Lok1yY5s59+dpKbk1zZ\nX14w/nKlybMnpGH2hBbRWkbh3Q+8vqquSPJo4EtJLunnvbOq3j6+8qSpZE9Iw+wJLZxVA1RV7QZ2\n99fvTnI9cNi4C5OmlT0hDbMntIiazoFKsg14GnB5P+mMJFclOT/JQSOuTZp69oQ0zJ7QolhzgEpy\nIPAR4LVVdRfwbuBI4Bi6dx7nLPO405PsTLJzz549IyhZmg72hDTMntAiWVOASnIAXVN8sKo+ClBV\nt1bVA1X1IPBe4Nh9Pbaqzq2q7VW1fcuWLaOqW5ooe0IaZk9o0axlFF6A84Drq+odA9MPHbjbC4Fr\nRl+eNH3sCWmYPaFFtJZReM8EXgZcneTKftobgFOSHAMUsAt41VgqlKaPPSENsye0cNYyCu/zQPYx\n61OjL0eafvaENMye0CLyk8glSZIaGaAkSZIaGaAkSZIaGaAkSZIaGaAkSZIaGaAkSZIaGaAkSZIa\nGaAkSZIaGaAkSZIaGaAkSZIapao2b2HJHuAb/c1DgNs3beGTsyjrCbO3rk+qqol+9bs9MfdmbV3t\niclYlPWE2VvXZXtiUwPU0IKTnVW1fSIL30SLsp6wWOs6Dovy+i3KesJires4LMrrtyjrCfO1rh7C\nkyRJamSAkiRJajTJAHXuBJe9mRZlPWGx1nUcFuX1W5T1hMVa13FYlNdvUdYT5mhdJ3YOlCRJ0qzy\nEJ4kSVKjTQ9QSU5M8pUkNyY5a7OXP05Jzk9yW5JrBqYdnOSSJDf0Pw+aZI2jkuTwJJ9Ncl2Sa5Oc\n2U+fy/UdJ3ti9rcR+2G07InZ304WoSc2NUAl2Q94F/B84GjglCRHb2YNY7YDOHHJtLOAS6vqKODS\n/vY8uB94fVUdDRwHvLr/Xc7r+o6FPTE324j9MCL2xNxsJ3PfE5u9B+pY4Maq+npVfR+4CDhpk2sY\nm6r6HPDtJZNPAi7or18AnLypRY1JVe2uqiv663cD1wOHMafrO0b2xBxsI/bDSNkTc7CdLEJPbHaA\nOgz45sDtm/pp82xrVe3ur98CbJ1kMeOQZBvwNOByFmB9R8yemLNtxH7YMHtizraTee0JTyLfRNUN\neZyrYY9JDgQ+Ary2qu4anDeP66vRmrdtxH7QRs3bdjLPPbHZAepm4PCB20/sp82zW5McCtD/vG3C\n9YxMkgPoGuODVfXRfvLcru+Y2BNzso3YDyNjT8zJdjLvPbHZAeqLwFFJnpzkocCLgYs3uYbNdjFw\nan/9VOATE6xlZJIEOA+4vqreMTBrLtd3jOyJOdhG7IeRsifmYDtZhJ7Y9A/STPIC4PeA/YDzq+rN\nm1rAGCW5EDie7tumbwXeCHwc+DBwBN03jL+oqpaeQDhzkjwL+AvgauDBfvIb6I5xz936jpM9Mfvb\niP0wWvbE7G8ni9ATfhK5JElSI08ilyRJamSAkiRJamSAkiRJamSAkiRJamSAkiRJamSAkiRJamSA\nkiRJamSAkiRJavT/A8DQ685C+w9vAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 720x720 with 6 Axes>"
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2IHOXGaplK6e",
"colab_type": "text"
},
"source": [
"### Meta Learner\n",
"\n",
"Meta learner is the Neural Network Architecture that learns $\\theta$ as well as $\\theta'$ (as shown in the image).\n",
"![MAML Adaptation](https://www.domluna.com/static/0661053d379cd77ae37b39e28ad470b5/d7711/maml.png)\n",
"\n",
"\n",
"The methods **init** and **forward** are used to learn the *model-agnostic parameters* $\\theta$.\n",
"\n",
"\n",
"---\n",
"\n",
"\n",
"Whereas, the method **forward_fast_weights** builds a *supervised classification(task specific)* Convnet + Dense layers architecture in order to learn $\\phi^*_1, \\phi^*_2, \\phi^*_3, ...$\n",
"\n",
"The helper method **update_fast_grad** computes the gradient with respect to the *few shot supervised classifcation(task specific) learner*."
]
},
{
"cell_type": "code",
"metadata": {
"id": "8C4vWXYQGXmj",
"colab_type": "code",
"colab": {}
},
"source": [
"class LearnerConv(nn.Module):\n",
" def __init__(self, N_way, device):\n",
" super(LearnerConv, self).__init__()\n",
" self.device = device\n",
" self.N_way = N_way\n",
" self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=2, padding=0)\n",
" self.bnorm1 = nn.BatchNorm2d(num_features=64, track_running_stats=False)\n",
"\n",
" self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=2, padding=0)\n",
" self.bnorm2 = nn.BatchNorm2d(num_features=64, track_running_stats=False)\n",
"\n",
" self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=2, padding=0)\n",
" self.bnorm3 = nn.BatchNorm2d(num_features=64, track_running_stats=False)\n",
"\n",
" self.conv4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=2, stride=1, padding=0)\n",
" self.bnorm4 = nn.BatchNorm2d(num_features=64, track_running_stats=False)\n",
"\n",
" self.fc = nn.Linear(in_features=64, out_features=N_way, bias=True)\n",
" self.logsoftmax = nn.LogSoftmax(dim=-1)\n",
"\n",
" self.to(device)\n",
"\n",
" def copy_model_weights(self):\n",
" '''\n",
" copies theta to update theta'\n",
" '''\n",
" fast_weights = []\n",
" for pr in self.parameters():\n",
" fast_weights.append(pr.clone())\n",
" return fast_weights\n",
"\n",
" def update_fast_grad(self, fast_params, grad, lr_a):\n",
" '''\n",
" updates task specific weights - theta'\n",
"\n",
" '''\n",
" if len(fast_params) != len(grad): raise ValueError(\"fast grad update error\")\n",
" num = len(grad)\n",
" updated_fast_params = [None for _ in range(num)]\n",
" for i in range(num):\n",
" updated_fast_params[i] = fast_params[i] - lr_a * grad[i]\n",
" return updated_fast_params\n",
"\n",
" def forward(self, x):\n",
" # x = [batch_size, 1, 28, 28]\n",
"\n",
" z1 = F.relu(self.bnorm1(self.conv1(x)))\n",
" z2 = F.relu(self.bnorm2(self.conv2(z1)))\n",
" z3 = F.relu(self.bnorm3(self.conv3(z2)))\n",
" z4 = F.relu(self.bnorm4(self.conv4(z3)))\n",
"\n",
" z4 = z4.view(-1, 64)\n",
" out = self.logsoftmax(self.fc(z4))\n",
"\n",
" return out\n",
"\n",
" def forward_fast_weights(self, x, fast_weights):\n",
" '''\n",
" trains task specific neural network to compute loss on f(theta')\n",
" '''\n",
" # x = [batch_size, 1, 28, 28]\n",
"\n",
" conv1_w = fast_weights[0]\n",
" conv1_b = fast_weights[1]\n",
" bnorm1_w = fast_weights[2]\n",
" bnorm1_b = fast_weights[3]\n",
"\n",
" conv2_w = fast_weights[4]\n",
" conv2_b = fast_weights[5]\n",
" bnorm2_w = fast_weights[6]\n",
" bnorm2_b = fast_weights[7]\n",
"\n",
" conv3_w = fast_weights[8]\n",
" conv3_b = fast_weights[9]\n",
" bnorm3_w = fast_weights[10]\n",
" bnorm3_b = fast_weights[11]\n",
"\n",
" conv4_w = fast_weights[12]\n",
" conv4_b = fast_weights[13]\n",
" bnorm4_w = fast_weights[14]\n",
" bnorm4_b = fast_weights[15]\n",
"\n",
" fc_w = fast_weights[16]\n",
" fc_b = fast_weights[17]\n",
"\n",
" z1 = F.conv2d(x, conv1_w, conv1_b, stride=2, padding=0)\n",
" z1 = F.batch_norm(z1, running_mean=self.bnorm1.running_mean, running_var=self.bnorm1.running_var, weight=bnorm1_w, bias=bnorm1_b, training=True) # how about training=True??\n",
" z1 = F.relu(z1)\n",
"\n",
" z2 = F.conv2d(z1, conv2_w, conv2_b, stride=2, padding=0)\n",
" z2 = F.batch_norm(z2, running_mean=self.bnorm2.running_mean, running_var=self.bnorm2.running_var, weight=bnorm2_w, bias=bnorm2_b, training=True) # how about training=True??\n",
" z2 = F.relu(z2)\n",
"\n",
" z3 = F.conv2d(z2, conv3_w, conv3_b, stride=2, padding=0)\n",
" z3 = F.batch_norm(z3, running_mean=self.bnorm3.running_mean, running_var=self.bnorm3.running_var, weight=bnorm3_w, bias=bnorm3_b, training=True) # how about training=True??\n",
" z3 = F.relu(z3)\n",
"\n",
" z4 = F.conv2d(z3, conv4_w, conv4_b, stride=1, padding=0)\n",
" z4 = F.batch_norm(z4, running_mean=self.bnorm4.running_mean, running_var=self.bnorm4.running_var, weight=bnorm4_w, bias=bnorm4_b, training=True) # how about training=True??\n",
" z4 = F.relu(z4)\n",
"\n",
" z4 = z4.view(-1, 64)\n",
" z4 = torch.matmul(z4, fc_w.T) + fc_b\n",
" out = F.log_softmax(z4, dim=-1)\n",
" return out"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "r69FJe5gEEah",
"colab_type": "text"
},
"source": [
"### Train\n",
"![MAML algorithm for few shot supervised learning](https://image.slidesharecdn.com/maml-181130144044/95/introduction-to-maml-model-agnostic-meta-learning-with-discussions-19-638.jpg)\n",
"\n",
"For every epoch:\n",
"\n",
"\n",
"* I sample $T_i$ tasks from p(T)\n",
"* Compute the train_loss for each task and perform gradient updates for $\\theta$\n",
"* Then, I sample $D'$ tasks from $T_i$ to compute meta_loss\n",
"* Meta loss is computed using $\\theta$ as initial parameter and the parameters are updated to compute $\\theta'$\n",
"* The total of all meta_loss is summed up and averaged using *nn.NLLLoss(reduction='mean')* and the mean is used to perform gradient descent to update $\\theta$ (Meta Update)\n",
"\n",
"\n",
"---\n",
"\n",
"\n",
"\n",
"### Visualization\n",
"\n",
"The final plot shows that the loss value reduces across epochs. The plot is a bit noisy due to the inherent feature of few shot learning i.e. for a new task, the gradient might be at an inappropriate starting point but eventually approaches the global minimum.\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "q4_sM_ZH8ndz",
"colab_type": "text"
},
"source": [
"**Note: The Viualization is from a model trained on 1000 epochs only to evaluate correctness. However, the final submission contains the model trained on 10000 epochs.**"
]
},
{
"cell_type": "code",
"metadata": {
"id": "BXR3tCeqG5x-",
"colab_type": "code",
"outputId": "ab1c530e-c619-4b27-e21f-d094adda5dac",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 601
}
},
"source": [
"print(\"device = {}\".format(device))\n",
"\n",
"omniglot_learner = LearnerConv(N_way=5, device=device)\n",
"\n",
"lr_b = 1e-4\n",
"print(\"lr_beta = {:.2e}\".format(lr_b))\n",
"\n",
"criterion = nn.NLLLoss(reduction='mean')\n",
"optimizer = torch.optim.Adam(omniglot_learner.parameters(), lr=lr_b, betas=(0.9,0.999), eps=1e-08, weight_decay=0)\n",
"optimizer.zero_grad()\n",
"\n",
"# --------------------- Load the Omniglot data --------------------- #\n",
"omniglot_np_path = \"/content/omniglot.nparray.pk\"\n",
"with open(omniglot_np_path, 'rb') as f:\n",
" X_data = pickle.load(f, encoding=\"bytes\")\n",
"np.random.seed(28)\n",
"np.random.shuffle(X_data)\n",
"X_train = X_data[:1200,:,:,:]\n",
"X_test = X_data[1200:,:,:,:]\n",
"\n",
"# --------------------- MAML Omniglot --------------------- \n",
"def omniglot_maml_exp():\n",
" # hyperparameters\n",
" num_iterations = 1001\n",
" n_way = 5 # N\n",
" k_shot = 1 # K\n",
" batch_size = n_way * k_shot # number of images to train and validate on\n",
" metabatch_size = 32 # 32 tasks --- the number of tasks sampled per meta-update\n",
" # each task is an N-way, K-shot classification problem\n",
" lr_a = 0.4\n",
" num_grad_update = 5\n",
"\n",
" print(\"N={}\".format(n_way))\n",
" print(\"K={}\".format(k_shot))\n",
" print(\"metabatch_size={}\".format(metabatch_size))\n",
" print(\"lr_a={}\".format(lr_a))\n",
" print(\"num_grad_update={}\".format(num_grad_update))\n",
"\n",
" loss_values = []\n",
" running_loss = 0.0\n",
"\n",
" for iter in range(num_iterations):\n",
" # 1. sample batch of tasks Ti ~ p(T)\n",
" tasks = [None for _ in range(metabatch_size)]\n",
" for _i in range(metabatch_size):\n",
" # choose 5 labels from 1200 train labels\n",
" tasks[_i] = np.random.randint(low=0, high=1200, size=n_way)\n",
"\n",
" # Loss accumulator: L_theta\n",
" meta_learning_loss = 0\n",
" # 2. for each task Ti\n",
" for task in tasks:\n",
"\n",
" # copy current model weights to fast_weights\n",
" fast_weights = omniglot_learner.copy_model_weights()\n",
"\n",
" X_batch_a = np.zeros((batch_size, 28, 28))\n",
" Y_batch_a = np.zeros((batch_size))\n",
" X_batch_b = np.zeros((batch_size, 28, 28))\n",
" Y_batch_b = np.zeros((batch_size))\n",
"\n",
" # 2.1 sample K datapoints from Ti\n",
" for j1, char_id in enumerate(task):\n",
" #choose 1 character from 20 chars\n",
" instances = np.random.randint(low=0, high=20, size=k_shot)\n",
" for j2, ins in enumerate(instances):\n",
" # store datapoints in X_batch_a, Y_batch_a\n",
" X_batch_a[j1*k_shot+j2,:,:] = X_train[char_id,ins,:,:]\n",
" Y_batch_a[j1*k_shot+j2] = j1\n",
"\n",
" X_batch_a = torch.tensor(X_batch_a, dtype=torch.float32).unsqueeze(1).to(device)\n",
" Y_batch_a = torch.tensor(Y_batch_a, dtype=torch.long).to(device)\n",
" # 2.2 compute gradient (multiple steps)\n",
" for grad_update_iter in range(num_grad_update):\n",
" Y_pred = omniglot_learner.forward_fast_weights(X_batch_a, fast_weights)\n",
" train_loss = criterion(Y_pred, Y_batch_a)\n",
"\n",
" grad = torch.autograd.grad(train_loss, fast_weights, create_graph=True)\n",
" fast_weights = omniglot_learner.update_fast_grad(fast_weights, grad, lr_a)\n",
"\n",
" # 2.3 sample K datapoints from Ti --- for meta-update step\n",
" for j1, char_id in enumerate(task):\n",
" instances = np.random.randint(low=0, high=20, size=k_shot)\n",
" for j2, ins in enumerate(instances):\n",
" X_batch_b[j1*k_shot+j2,:,:] = X_train[char_id,ins,:,:]\n",
" Y_batch_b[j1*k_shot+j2] = j1\n",
"\n",
" # 3. meta-update step\n",
" X_batch_b = torch.tensor(X_batch_b, dtype=torch.float32).unsqueeze(1).to(device)\n",
" Y_batch_b = torch.tensor(Y_batch_b, dtype=torch.long).to(device)\n",
"\n",
" Y_pred = omniglot_learner.forward_fast_weights(X_batch_b, fast_weights)\n",
" meta_loss = criterion(Y_pred, Y_batch_b)\n",
" meta_learning_loss += meta_loss\n",
"\n",
" # 4. Backpropagation to update model's parameters\n",
" meta_learning_loss /= n_way\n",
" if iter % 100 == 0:\n",
" print(\"[{}] iteration {}: meta_learning_loss = {:.5f}\".format(str(datetime.now()), iter, meta_learning_loss))\n",
" sys.stdout.flush()\n",
"\n",
" # minimize sum of all task specfic losses\n",
" meta_learning_loss.backward()\n",
" optimizer.step()\n",
" optimizer.zero_grad()\n",
"\n",
" if iter % 100 == 0:\n",
" loss_values.append(meta_learning_loss.item())\n",
"\n",
" if iter % 5000 == 0 and iter != 0:\n",
" print(iter)\n",
" savepath = \"/content/omniglot_4marchv2_n{}_k{}_iter{}.pt\".format(n_way, k_shot, iter)\n",
" print(\"saving a model at\", savepath)\n",
" torch.save(omniglot_learner.state_dict(), savepath)\n",
"\n",
" savepath = \"/content/omniglot_4marchv2_n{}_k{}_final.pt\".format(n_way, k_shot)\n",
" print(\"saving a model at\", savepath)\n",
" torch.save(omniglot_learner.state_dict(), savepath)\n",
"\n",
" plt.plot(loss_values)\n",
" plt.show()\n",
"\n",
" print(\"finished maml training\")\n",
"\n",
"\n",
"omniglot_maml_exp()"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"device = cuda\n",
"lr_beta = 1.00e-04\n",
"N=5\n",
"K=1\n",
"metabatch_size=32\n",
"lr_a=0.4\n",
"num_grad_update=5\n",
"[2020-03-27 01:13:27.597378] iteration 0: meta_learning_loss = 8.63651\n",
"[2020-03-27 01:16:27.049208] iteration 100: meta_learning_loss = 7.71411\n",
"[2020-03-27 01:19:27.210273] iteration 200: meta_learning_loss = 6.97912\n",
"[2020-03-27 01:22:26.908446] iteration 300: meta_learning_loss = 7.37316\n",
"[2020-03-27 01:25:25.661675] iteration 400: meta_learning_loss = 6.37255\n",
"[2020-03-27 01:28:24.455490] iteration 500: meta_learning_loss = 7.13989\n",
"[2020-03-27 01:31:23.396171] iteration 600: meta_learning_loss = 7.47399\n",
"[2020-03-27 01:34:21.541504] iteration 700: meta_learning_loss = 6.64808\n",
"[2020-03-27 01:37:20.181869] iteration 800: meta_learning_loss = 7.32519\n",
"[2020-03-27 01:40:18.981792] iteration 900: meta_learning_loss = 6.38697\n",
"[2020-03-27 01:43:18.563801] iteration 1000: meta_learning_loss = 6.94127\n",
"saving a model at /content/omniglot_4marchv2_n5_k1_final.pt\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0\ndHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAAgAElEQVR4nO3dd3xU153//9cZjXpFDWk0QgLRm4Qk\nug3YuNvImBLjhluCHWfjxEm+uynfTfsm2c1vs9nE2bjgBi5xAYONCbaxjQ3YVCGa6EJCHVUkoV7m\n/P6QSDAWZiTNzJ0ZfZ6Phx6RNFf3fuSIt67OPZ9zlNYaIYQQns9kdAFCCCEcQwJdCCG8hAS6EEJ4\nCQl0IYTwEhLoQgjhJcxGXTg6OlonJycbdXkhhPBI+/btq9Zax/T2mmGBnpycTHZ2tlGXF0IIj6SU\nKrzcazLkIoQQXkICXQghvIQEuhBCeAkJdCGE8BIS6EII4SUk0IUQwktIoAshhJfwuEDPq2zk1+8d\npb3TZnQpQgjhVjwu0Itrm3nxiwK2HK80uhQhhHArHhfoV4+KJjbUn7X7io0uRQgh3IrHBbrZx8Si\ndCufnqii8nyr0eUIIYTb8LhAB1iSYaXLpnl3f5nRpQghhNvwyEAfGRvClGERrNlXjOyJKoQQ3Twy\n0AGWZiRysqKRQyX1RpcihBBuwWMD/bbUePzNJtbIw1EhhAA8ONDDAny5aWIcGw6U0drRZXQ5Qghh\nOI8NdOgedmlo7eSjoxVGlyKEEIazK9CVUk8opY4opXKVUq8rpQIuef0BpVSVUupAz9s3nVPul81K\nicISHsCafSWuuJwQQri1Kwa6UioBeBzI1FpPBHyAZb0c+qbWOq3n7XkH19krk0mxOMPK9lNVlNe3\nuOKSQgjhtuwdcjEDgUopMxAEuM0E8CUZVrSGdTmlRpcihBCGumKga61LgT8ARUA5UK+13tzLoYuV\nUoeUUmuVUom9nUsptUIpla2Uyq6qqhpQ4RckRQUzbXgka/eVyJx0IcSgZs+QyxDgdmA4YAGClVL3\nXnLYe0Cy1noy8BGwurdzaa1Xaq0ztdaZMTExA6v8IkszrBRUN7Gv8JzDzimEEJ7GniGX64ACrXWV\n1roDWAfMuvgArXWN1rqt58PngQzHlvn1bpkUT5CfD2uy5eGoEGLwsifQi4AZSqkgpZQC5gPHLj5A\nKRV/0YdZl77ubMH+Zm6ZFM/fD5fT3N7pyksLIYTbsGcMfTewFsgBDvd8zUql1K+VUlk9hz3eM63x\nIN0zYh5wUr2XtTTDSmNbJx/knnX1pYUQwi0oox4kZmZm6uzsbIedT2vN3P/6jISIQF5fMcNh5xVC\nCHeilNqntc7s7TWP7hS9mFKKJRlWdubXUFzbbHQ5Qgjhcl4T6ACLM6woBW/nyMNRIcTg41WBnhAR\nyOyUaNbuK8FmkznpQojBxasCHWBpppWScy3sKqgxuhQhhHAprwv0G8bHEepvZq3MSRdCDDJeF+iB\nfj7clmphU24551s7jC5HCCFcxusCHbqHXVo7bGw6XG50KUII4TJeGehTEiNIiQmWpQCEEIOKVwZ6\n95z0RLILz5Ff1Wh0OUII4RJeGegAi9ITMMmcdCHEIOK1gT40LIC5o2N4e18pXTInXQgxCHhtoAMs\nzUzkbEMrn+dVG12KEEI4nVcH+vxxsUQE+bImu9joUoQQwum8OtD9zT7cnmph89EK6ptlTroQwrt5\ndaBD97BLe6eNDYfcZl9rIYRwCq8P9AmWMMbGhbJWhl2EEF7O6wP9wjrpB0vqOVlx3uhyhBDCabw+\n0AHumJKA2aRYu0/mpAshvNegCPSoEH+uHRvLupxSOrpsRpcjhBBOMSgCHbofjlY3trH1RJXRpQgh\nhFMMmkCfNyaG6BA/1uyTh6NCCO80aALd18fEwrQEPjlWSU1jm9HlCCGEww2aQIfuYZdOm+bdAzIn\nXQjhfQZVoI+JC2WyNZw1MttFCOGFBlWgAyzJsHKsvIHc0nqjSxFCCIcadIGelWrBz8ckc9KFEF5n\n0AV6RJAf108YyrsHSmnvlDnpQgjvMegCHWBphpVzzR18cqzC6FKEEMJhBmWgXz0qhqFh/vJwVAjh\nVQZloPuYFIvSrWw9WUVlQ6vR5QghhEMMykCH7tkuXTbN+v2lRpcihBAOMWgDPSUmhPRhEazZV4LW\nsom0EMLzDdpAh+7O0bzKRg4U1xldihBCDNigDvTbJscT4Ctz0oUQ3mFQB3pogC83T4xnw8EyWju6\njC5HCCEGxK5AV0o9oZQ6opTKVUq9rpQKuOR1f6XUm0qpPKXUbqVUsjOKdYYlGVbOt3by4ZGzRpci\nhBADcsVAV0olAI8DmVrriYAPsOySwx4GzmmtRwL/A/ze0YU6y8wRUSREBMqwixDC49k75GIGApVS\nZiAIuHT92duB1T3vrwXmK6WUY0p0LpNJsTjDyud51ZTVtRhdjhBC9NsVA11rXQr8ASgCyoF6rfXm\nSw5LAIp7ju8E6oGoS8+llFqhlMpWSmVXVbnPVnBL0q1oDety5C5dCOG57BlyGUL3HfhwwAIEK6Xu\n7c/FtNYrtdaZWuvMmJiY/pzCKYZFBTF9eCRrZU66EMKD2TPkch1QoLWu0lp3AOuAWZccUwokAvQM\ny4QDNY4s1NmWZiZypqaZvWfOGV2KEEL0iz2BXgTMUEoF9YyLzweOXXLMBuD+nveXAFu0h93q3jIp\njmA/H9bKJtJCCA9lzxj6brofdOYAh3u+ZqVS6tdKqayew14AopRSecAPgB87qV6nCfIzc+vkeP5+\nqJzm9k6jyxFCiD6za5aL1voXWuuxWuuJWuv7tNZtWuufa6039LzeqrVeqrUeqbWeprXOd27ZzrEk\nI5Gm9i42HZY56UIIzzOoO0UvNTV5CMlRQazJlmEXIYTnkUC/iFKKJRlWdhfUUlTTbHQ5QgjRJxLo\nl1iUbkUpWCtz0oUQHkYC/RKWiECuGhnN2/tKsNk8aqKOEGKQk0DvxZIMK6V1LezM96ip9EKIQU4C\nvRc3TogjNMAsC3YJITyKBHovAnx9yEq18H5uOQ2tHUaXI4QQdpFAv4wlGVZaO2z8/VC50aUIIYRd\nJNAvIy0xgpGxITInXQjhMSTQL0MpxdIMKzlFdZyuajS6HCGEuCIJ9K9xx5QEfExKHo4KITyCBPrX\niA0LYO7oGNbllNAlc9KFEG5OAv0KlmZYqWhoY9sp99lhSQgheiOBfgXzxw1lSJCvDLsIIdyeBPoV\n+JlN3J6WwEdHKqhrbje6HCGEuCwJdDssybDS3mVjw8Eyo0sRQojLkkC3w8SEcMbFh7EmW4ZdhBDu\nSwLdTkszrBwuref42QajSxFCiF5JoNtp4ZQEfH0Ub+yRzlEhhHuSQLdTZLAft0228FZ2Meea5OGo\nEML9SKD3wbfnpdDc3sWqHWeMLkUIIb5CAr0PRg8N5frxQ1m14wyNbZ1GlyOEEF8igd5Hj81Lob6l\ng9d3FxldihBCfIkEeh9NGTaEWSlRPLc9n7bOLqPLEUKIf5BA74fH5o2k8nwbb+8rNboUIYT4Bwn0\nfpg9MorJ1nCe3Xaazi6b0eUIIQQggd4vSikemzeSwppmNuWeNbocIYQAJND77YbxQxkZG8JTn+ah\ntayVLoQwngR6P5lMim/PTeH42fN8eqLS6HKEEEICfSCy0iwkRATy109Py126EMJwEugD4OtjYsWc\nEewrPMeeglqjyxFCDHIS6AN059REokP8eOqz00aXIoQY5CTQByjA14cHZw9n68kqckvrjS5HCDGI\nSaA7wH0zkwj1N/O03KULIQx0xUBXSo1RSh246K1BKfX9S46Zp5Sqv+iYnzuvZPcTFuDLfTOT2JRb\nzumqRqPLEUIMUlcMdK31Ca11mtY6DcgAmoH1vRy6/cJxWutfO7pQd/fQVcPx8zHx7Fa5SxdCGKOv\nQy7zgdNa60JnFOPJokP8WTY1kfX7SymrazG6HCHEINTXQF8GvH6Z12YqpQ4qpd5XSk3o7QCl1Aql\nVLZSKruqqqqPl3Z/35ozAq3hue35RpcihBiE7A50pZQfkAWs6eXlHCBJa50K/AV4p7dzaK1Xaq0z\ntdaZMTEx/anXrVmHBHF7WgJv7CmmprHN6HKEEINMX+7QbwZytNYVl76gtW7QWjf2vL8J8FVKRTuo\nRo/y7XkjaO2UbepE3zS2dfLQqr38y99yjC5FeLC+BPpdXGa4RSkVp5RSPe9P6zlvzcDL8zwjY0O5\ncXwcq3ec4Xxrh9HlCA9Q39LBfS/sZsvxSjYeKudkxXmjSxIeyq5AV0oFA9cD6y763KNKqUd7PlwC\n5CqlDgJPAsv0IF7c5LFrUmho7eRvsk2duILapnbufm4XuaX1/OeiSfibTby884zRZQkPZVega62b\ntNZRWuv6iz73jNb6mZ73/1drPUFrnaq1nqG13uGsgj3BZGsEV42M5vnPC2jtkG3qRO8qG1q589md\n5FU2snJ5JsumDSMr1cK6nFIa5K870Q/SKeokj12TQtX5NtbuKzG6FOGGSuta+MazOymta+GlB6dy\nzZhYAO6flUxzexdvy8+N6AcJdCeZOSKKtMQI2aZOfEVhTRPfeGYnNU3tvPLwNGal/HP+wMSEcNKH\nRfDKzkJstkE7ain6SQLdSZRSfOeakRTXtrDxULnR5Qg3kVd5nqXP7KS5vZPXvzWDjKTIrxxz/6xk\n8qub+Dyv2oAKhSeTQHei+WNjGT00hKc/Oy13W4KjZQ3c+ewubBreWDGTiQnhvR5388R4okP8WS1T\nX0UfSaA7kcmk+Pa8FE5UnOeT4565TV1dc7vsxuQAB4rrWLZyJ35mE289MoMxcaGXPdbPbOLuaYls\nOVFJUU2zC6sUnk4C3ckWTLZgHRLIXz1wM+kdedVM++0nPLNVljIYiD0Ftdz7/G4igvx465GZjIgJ\nueLX3D09CZNSvLpblk0S9pNAdzKzj4lH5qZwoLiOnfme02t1rLyBR17ZR3uXjbeyiz3ul5G72H6q\niuUv7mZomD9vPTKTxMggu74uLjyAmybE8ebeYlraZeqrsI8EugsszbASHeLvMRtglNW18OBLewny\n9+Hx+aMoqG4it7TB6LI8zsdHK3h4VTbJUcG8+chM4sID+vT1y2cmUd/SwYaDpU6qUHgbCXQXCPD1\n4ZtXD2f7qWoOldQZXc7Xqm/p4IGX9tDU1smqB6fx8Ozh+PooCZU+2niojEdf3ce4+FDeWDGD6BD/\nPp9j2vBIxsaFsnpHofyFJOwige4i90wfRliAmac+dd+79LbOLh55JZuC6iaevS+DcfFhhAf5Mnd0\nDBsPlctMHTut3VfC46/vZ8qwCF795nQigvz6dR6lFMtnJnO0vIF9heccXKXxbDbNspU7+c3Go0aX\n4jUk0F0kNMCX+2cl8+HRs+RVut/iSzab5kdrDrErv5b/WpLKrJH/bHbJSkugvL6VvWdqDazQM7yy\nq5AfrTnI7JHRrH5oGqEBvgM638IpFkIDzKze6X0PRz84cpZd+bW8vqeI5vZOo8vxChLoLvTArGT8\nzSae/sz9Zo38/oPjvHewjH+7aSwLpyR86bXrxsUS6OvDhoNlBlXnGZ7bls+/v5PLdeNieW55JkF+\n5gGfM8jPzDcyE3n/cDmVDa0OqNI9dNk0f/zoJOGBvjS1d/FB7lmjS/IKEuguFBXiz13ThvHugVJK\n3WibulVfFPDstnzum5HEo3NHfOX1ID8z148fyqbD5XTIMgZfobXmyU9O8dtNx7h1UjxP35tBgK+P\nw85/34wkurTmb3u8Z/XODQdLyats5Ld3TCQxMlDWPHIQCXQX+9bVI1Cq+27OHXyQW86vNh7l+vFD\n+WXWBHqWtf+KrFQL55o7+PyUtKNfTGvN7z84wR8/Osmi9AT+vCwNXx/H/rNKjg5m3ugYXttdRHun\n5/9C7eiy8aePTzE+PoxbJsazaIqVnfk1bnWT46kk0F3MEhHIHVMSeH1PEdUGb1OXfaaW771xgLTE\nCJ5cNgUfU+9hDjBndAzhgb4y7HIRm03zq/eO8szW09wzfRh/WJKK2cFhfsHyWclUnW/jwyOePzSx\nLqeEwppmfnD9aEwmxeJ0K1rD+hy5Sx8oCXQDPDI3hfYuGy99UWBYDXmVjXzz5WwsEYG8cP9UAv2+\nfojAz2zi5olxbD5yVhpd6B4D/sm6w6zacYZvXjWc3yyciOlrfiEO1NxRMSRFBXn85hdtnV08+Uke\nqYkRzB/XvWTwsKggpg2P5O2cUpmeOUAS6AZIiQnh5olxvLyj0JCNDCrPt/LAS3swmxSrH5xGZLB9\n0+qy0iw0tXexxUPXpXGUji4bP3jrAG9mF/P4tSP52a3jLjtU5Sgmk+K+GUnsPXOOI2X1V/4CN/Xm\n3mJK61r40Q2jv/TfbEm6lYLqJnKK3LtPw91JoBvksXkjOd/Wyau7XDsd7cJmxDWN7bz4wFSGRdnX\nig4wfXgUsaH+vHtg8DYZtXV28S9/y+HdA2X8601j+MENY5we5hcszUgk0NeHVzx0CmNLexd/2ZLH\ntOGRXDXyy3vI3zwpjgBfE2/LsMuASKAbZGJCOHNGx/CiC7ep6+iy8dhrORwrP89T96Qz2RrRp6/3\nMSlum2zhsxNV1LcMvi3SWtq7WPHyPj48UsEvF4znsXkjXXr98CBfFk5J4J0DpdQ1t7v02o7w6q5C\nqs638cPrR3/ll2BogC83TYhj48Ey2bZxACTQDfSdeSlUN7bzVnax06+ltean6w6z7WQVv104kWvG\nxvbrPFlpFtq7bF7xcK4vGts6eXDVHradquL3iyfxwOzhhtSxfGYSrR021mR71p1sU1snT289zdWj\nopk+IqrXYxZnWGlo7eTjYxUurs57SKAbaNrwSDKShvDs1nynz+/+n49PsWZfCY/PH8WyacP6fZ5U\nazhJUUG8N4hmu9S3dHDfC7vZe+Ycf7ozjTun9v+/30CNiw9j2vBIXtlVSJcHLcWwascZapva+cH1\noy97zKyUaOLCAmQ/1QGQQDeQUorH5qVQWtfChgPOC8g39hTx5CenWJph5YnrRg3oXEopslItfJFX\nTdV5Y6ddukJtUzt3P7eL3NJ6/np3OrenJVz5i5zs/pnJFNU2s/WkZzycrm/p4Nmtp7luXCxThg25\n7HE+JsUd6QlsO1XtVV2xriSBbrBrx8YyNi6Up7c6Z5u6T49X8rN3cpk7OobfLZrkkAd4WakWbBo2\nHfbuvVIrG1pZtnIneZWNrFyeyU0T44wuCYAbJgxlaJg/q3Z4xsPRF7bn09DayRNfc3d+weJ0K102\nzTuD+MH7QEigG0yp7m3q8iob+cjBY4eHSup47LUcxsWH8tQ96Q7rYBw1NJSxcaFePdulrK6FO1fu\nouRcCy89OJVrxvTvmYMz+PqYuGd6EttOVpFf1Wh0OV+rtqmdFz4v4JZJcUyw9L6H6sVGxoaQlhjB\n2/tkTnp/SKC7gVsnxZMUFcRTDtymrqimmYdW7SUqxI8XH5hKsP/AF4q6WFaahZyiOoprvW/Py8Ka\nJpY+s5Pq82288vA0ZqVEX/mLXGzZtER8fRSvuHjaa189u+00zR1dPHHdle/OL1icYeVExXmOlMmm\nKn0lge4GzD4mHpmTwsGSenacHvg2dbVN7dz/0h46bZrVD00jNrRvO+XYY8FkCwDvHfKuh6Pnmtr5\nxrM7aW7v5PUVM8hIijS6pF7FhgZwy6R41maX0NTmnkvPVp5vZfWOMyxMS2DU0Mtvin2pBZPj8fMx\nyYJd/SCB7iYWZyQQG+rPXz/NG9B5Wtq7eHj1XsrqWnh+eSYpdmxI3B+JkUFkJA1x6sNcI7yyq5CK\nhjZWPzSNiQlXHiIw0vKZyZxv62T9fvcc+nrq09N0dGm+N79vD+Ijgvy4bnwsGw6WecViZK4kge4m\n/M0+fOvqEew4XcP+ov7tTtNl03zvjf0cKK7jz8vSyEx27t1lVqqF42fPc7LC/Tbs6I/Wji5W7zjD\ntWNj+9x0ZYT0YRFMTAjj5Z1n3G68uayuhb/tLmJphpXk6OA+f/3idCu1Te18dsIzZvK4Cwl0N3LX\n9GGEB/ryVD82k9Za88sNR9h8tIJf3DaemybGO6HCL7tlUjwmhdfcpa/LKaWmqZ1vXf3VNeHd0YUt\n6k5WNLIr3712k/rLljw0mn+5tn/dtHNGxxAd4idLAfSRBLobCfE388CsZD46WtHnu95ntubzyq5C\nVswZ4bIuxphQf2aPjGbDwTK3u0PsK5tN8/z2fCZbw5kxwj3HzXuTlWohIsjXrVZhLKppZk12MXdN\nG4Z1iP1rBV3M18fE7WkJbDleybkmz1vmwCgS6G7mgVnJBPn58Ewf7tLf2V/K7z84zoJUCz++aawT\nq/uqBakWimqbOVjiuSsAAnxyvJL86qaeDUhcs9iWIwT4+nDn1EQ2H62gzE02iPjzJ6fwMSm+c83A\n1rpZnG6lo0vLGvx9IIHuZoYE+3H3tGG8e7DMrimBX+RV83/WHmTGiEj+sHSyU9fk7s2NE+Lw8zF5\n/LDLym2nSYgI5GY3aR7qi3unJ2HTmr/tNn6LurzKRtbvL2H5zCSGhg1sdtV4Sxjj4sNktksfSKC7\noW9ePQKTgpVX2KbuWHkDj76yj+HRwTx7Xyb+ZsftY2mv8EBfrhkbw8ZDZR61tsjFcorOsffMOR6+\narjTdhxypsTIIOaPHcrre4po6zR2pcI/fXySAF8fHp2b4pDzLcmwcri03msevDub5/30DgJx4QEs\nTrfyZnYxled7X9OirK6FB1/aS7C/mVUPTiM80NfFVf5TVmoClefb2F0w8Dn0Rnh+ez5hAWbunJpo\ndCn9dv+sJGqa2g1djuFYeQMbD5Xz4OxkokL8HXLO29MsmE1KFuyykwS6m3pkbgqdXTZe/PzMV16r\nb+nggZf20NTWyaqHpmKJCHR9gReZPy6WYD8fjxx2Kaxp4oPcs9w7I8nh3bSuNDslmhExwaw2cH2X\n//noJKEBZlZc7Zi7c4DoEH/mjYlh/f5SOp28Iqk3uGKgK6XGKKUOXPTWoJT6/iXHKKXUk0qpPKXU\nIaVUuvNKHhyGRwdzy6R4Xt1V+KXNJNo6u1jxcjYF1U08e18GY+PCDKyyW4CvDzdMiOP93LMe1wjy\nwucFmE0mHpiVbHQpA2IyKe6fmcyB4joOFrt+G7dDJXVsPlrBt64eQXiQY/9aXJxupfJ8G5/nVTv0\nvN7oioGutT6htU7TWqcBGUAzsP6Sw24GRvW8rQCednShg9Fj80bS2NbJKzvPAN1T63605hC7C2r5\nw9JUZo10nzVGslIt1Ld0sO1kldGl2O1cU/fmIgunWIgd4AM8d7AoPYFgPx9eNmCLuv/efJIhQb48\nODvZ4ee+dlws4YG+vJ3jnh2x7qSvQy7zgdNa60t/Ym4HXtbddgERSinnd7Z4ufGWMK4ZE8OLX5yh\npb2L339wnPcOlvFvN411i3W5L3bVqGiGBPl61BSzV3YV0tph45se0kh0JaEBvizOsPLeoTJqGl23\nVv3eM7VsPVnFI3NTCA1w/LMcf7MPWakWNh85a8im6p6kr4G+DHi9l88nABfvo1bS87kvUUqtUEpl\nK6Wyq6o8507OSI9dM5LapnYeeGkPz27LZ/nMJB6d634B5Otj4pZJ8Xx0tILmdvdcLOpiF9r8rxkT\nw+g+LBzl7pbPTKK908abLtjW8IL/3nyC6BB/ls9Mcto1FmdYaeu08fdD3r0G/0DZHehKKT8gC1jT\n34tprVdqrTO11pkxMTH9Pc2gMjU5kmnJkewuqOWG8UP5xYIJbtv4kpVqoaWji4+Ouv+ekOv397T5\nz3G/X44DMTI2lNkjo3htV5FLHiLuyKtmV34t37kmhSA/5z1UTrWGkxITLLNdrqAvd+g3Azla697+\ntZYCF8/5svZ8TjjAL7Mm8NDs4Tx51xR8XNw41BdTkyOJDw9w+/1GbTbNc9vzmZQQzszLbFjsyZbP\nTKa0roVPjjt3YSutNX/YfIL48ADuGsA+tfZQSrE4w0p24TkKqpucei1P1pdAv4veh1sANgDLe2a7\nzADqtdbyt5GDjLeE8fMF4wnwdX3jUF+YTIrbJsez9WQVdc3uu/7GJ8crya9q4ltzPKvN317zx8aS\nEBHo9PVdPjtRRU5RHd+9dpRLfjYXTbFiUrBOFuy6LLsCXSkVDFwPrLvoc48qpR7t+XATkA/kAc8B\njzm4TuEhbk9LoKNL80HuWaNLuazntuWTEBHILR7Y5m8Ps4+Je2YM44u8Gk45qcPywt15YmQgSzOt\nTrnGpeLCA5g9Mpp1OaVO2X/XVf76aR65pc5Z+8iuQNdaN2mto7TW9Rd97hmt9TM972ut9Xe01ila\n60la62ynVCvc3gRLGCOig912tsv+onPsOVPrsW3+9rozMxE/s8lpUxg/PHKWI2UNfG/+aIftVWuP\nJRlWSuta2OWhXclf5FXzXx+ecNoNj/f+RAtDKKVYkGphZ34NFQ29L1tgpOd62vy/4cFt/vaICvFn\nwWQLb+eUOHyqX5dN88ePTjIiJpiFaRaHnvtKbhgfR4i/mbf3ed4jutaOLn62/jDJUUH9Xif+SiTQ\nhcNlpVnQGja62RSzC23+98xIIsSD2/ztdf+sJJrbu1jn4JkhGw+VcbKikSeuG+3yv3IC/Xy4dVI8\n7+eWu+1eqpfzv1vyOFPTzG/vmOS0Zw4S6MLhUmJCmGAJc7thlxc/L8DHpHjQw9v87TXZGkFaYgQv\n7yx02JhzZ5eNP318irFxodw6yZjewcUZVprbu9z6Oc2lTlac55mtp1mUnsBsJ3Z4S6ALp8hKtXCw\nuI7CGveYYtbd5l/CwrQEr2jzt9f9s5LIr27ii9OOWQdl3f5SCqqbeOL60S5fe/+CqclDGBYZ5DHb\n09lsmp+sO0xogJn/e+t4p15LAl04xYLU7rFVd5mT/uquQlo6uryukehKbpkUT1Swn0NWYWzvtPHn\nj08xKSGcG8YPdUB1/aOUYlF6Ajvzayg5d+VNYIz2+t4i9hWe42e3jicy2M+p15JAF05hiQhkWnIk\n7x4wfr/R1o4uVu88wzwva/O3h7/Zh7umDeOT4xV27YD1dd7KLqa0roUf3jDa8Pn7i9OtaA3r3XzB\nrsqGVv7z/ePMSolicbrz19yFw6MAABCbSURBVF+SQBdOsyDNwqnKRo6fNXa3mfX7S6lubGfFILs7\nv+Du6cMwKcWru/t/l97a0cVftpwiM2kIc0cbv2xHYmQQ04dHsm5/qeE3DF/nVxuP0tZp47d3THLJ\nL0EJdOE0t0yMw8ekDH04eqHNf2JCmFe2+dvDEhHIDeOH8ubeYlo7+rdF3Wu7i6hoaOMHbnB3fsHi\nDCsF1U3kFJ0zupRefXq8kr8fKue714xkeHSwS64pgS6cJirEn6tGRvPeQeOGXbZcaPO/2jvb/O21\nfGYydc0d/frl2tTWydOf5TErJYpZKe6zBv8tk+IJ9PVhrRvOSW9u7+T/vpPLqNgQHnHQ/qr2kEAX\nTnV7moWScy3kFLl+Fx3o3mg7ISLQsCl27mLGiEhGDw1h9Y4zff7lunrnGaob2/nhDaOdU1w/hfib\nuWliHBsPlfX7Lw9n+Z+PTlJa18LvFk3Cz+y6mJVAF051w4Q4/M0mQ2a7XGjzf8jL2/ztoZRi+cxk\njpQ19OmXa0NrB89uzeeaMTFkJEU6scL+WZxu5Xxrp1st2ZxbWs+LX5zhrmnDmJrs2v9mg/unXDhd\niL+Z+eNi2XiozOWb/D6/vYDQADN3enmbv73umJJAqL+5T6swvvh5AfUtHfzg+jFOq2sgZqZEER8e\n4DZz0rtsmp+uP8yQID9+fNNYl19fAl04XVaqherGdnbmu25BpaKaZt7PLeee6YOjzd8ewf5mlmRa\n2XS4nMrzV15n51xTOy9sL+DGCUOZZA13QYV952NS3DElgW0nq6h0g7WDVu84w6GSen6xYLzDN8u2\nhwS6cLp5Y2IJ9Tez4YDrhl1e+Dy/u83fCZsWe7LlM5Pp6NK8sefKW9St3J5PY3snT1zvXmPnl1qc\nYcWm4Z0Dxj4cLatr4b83n2DemBhum2zMMxsJdOF0Ab4+3Dgxjg+OnKWt0/kPry60+d+elsDQQdTm\nb4/h0cHMHR3Da7sL6fiaIbDqxjZWfXGGBZMtjI0Lc2GFfZcSE8KUYRGs3Vdi2GwqrTU/f/cINg3/\n7/aJhs2okkAXLpGVauF8ayefnXD+5uCv7e5p8796cDYSXcn9s5KoaGhj85HLP0h8+rPTtHV28f3r\nRrmwsv5bnG7lZEUjuaUNhlz/wyNn+fhYBU9cP4rEyCBDagAJdOEis1KiiA7xc3qTUWtHF6t2FDJv\nTAxj4gZXm7+95o6OZVhkEKt3nun19bP1rbyyq5BF6VZGxIS4tLb+WjDZgp/ZZMjD0YbWDn6x4Qjj\n48N4aPZwl1//YhLowiXMPiZumRTPx0craHTiOtbv7C+lurGNFXJ3flk+JsV9M5LYU1DLsfKv3tH+\n76ensNk035vvGXfnAOFBvlw/bijvHiilvdO1s6n+8OEJqs638R+LJhk+PVYCXbhMVqqFtk4bHx11\nzjrWF9r8J1jCmJkyONv87bU000qAr+krUxiLa5t5c28xd05NNHTooD8WZyRwrrmDT09UuuyaOUXn\neGVXIctnJpOaGOGy616OBLpwmfRhQ0iICHTabJctxys5XdXEijmDu83fHhFBfixMS2D9/lLqm/+5\nRd1ftpxCKeW0LdKcac6oGKJD/HnbwTs0XU5Hl42frjtMXFgAP7rRPebpS6ALlzGZuvcb3X6qmnNN\n7Q4//8rt3W3+twzyNn973TczidYOG2v2dU9hzK9q5O2cUu6dnkR8eKDB1fWd2cfEwjQLn56opNYJ\nP1+Xen57AcfPnudXWRPcptdBAl24VFaqhU6bZlOuY/cbPVBcx56CWh6cnezSXeg92QRLOFOTh/xj\ni7o/f3IKPx8T357nusWkHG1xhpWOLs0GJ89JL6pp5s+fnOTGCUO5YUKcU6/VF/KTL1xqXHwoI2ND\neNfBwy7PbcsnNMDMsmnDHHpeb7d8ZjJFtc2s3J7PhoNl3D8rmZhQf6PL6rdx8WGMjw/jbSdufKG1\n5mfvHMZsMvGrrIlOu05/SKALl1JKkZVqYe+ZWsrrWxxyTmnz778bJ8QRG+rPf75/nGA/M494wSYg\nSzKsHC6t54STNlbZcLCM7aeq+T83jiEu3L0a1yTQhctlpVrQGjYedMywy4tfFOBjUjwwK9kh5xtM\n/Mwm7p7e/VfNw1cNZ4iT97x0hdvTLJhNyilz0uua2/n1e0dJS4zg3hlJDj//QEmgC5dLjg4m1Rru\nkCajuuZ23txbTFZqgtvdLXmKB2cN57vXjvSaDbSjQvyZNyaW9ftLHb7C539sOk5dSwf/sWgSPib3\nm0klgS4MsSDVwuHSevKrGgd0nld39bT5zzG2Q8+ThQf58sMbxnjVcNWSjASqzrexPa/aYefcnV/D\nm9nFfPPq4YyLd8/1bSTQhSEWpFpQigHdpV9o8587OsbtF5ASrnXN2FgignwdNie9rbOLn6w/TGJk\nIN+f776rT0qgC0MMDQtg+vBINgxgv9F3D/S0+XvJUIFwHH+zD1mpFjYfraC+pePKX3AFT392mvyq\nJn6zcBKBfj4OqNA5JNCFYbJSE8ivauJIWd9XyLPZNCu35TM+PoxZ0uYverE43Up7p42/HxrYw/e8\nykae+vQ0WakW5o6OcVB1ziGBLgxz88Q4fH1Uv/Yb/fREd5v/I3OlzV/0brI1nJGxIQOa7aK15mfr\nDxPga+LfbxvvwOqcQwJdGGZIsB9zRsXw3sEybLa+Dbus3JaPJTxA2vzFZSmlWJxuZV/hOQqqm/p1\njjXZJewuqOWnt4zziIYrCXRhqKw0C2X1rewrOmf31xwsrmN3QS0PXTVc2vzF17pjSgImRb8ejlY3\ntvHbTceYlhzJNzI9Y6Nx+dcgDHXduKEE+Jp4tw9rb6zcLm3+wj5x4QFcNSqG9ftL+/xX4G82HqW5\nvZPfLZqIyQ3nnPdGAl0YKtjfzHXjhrLp8Nmv3ePyguLaZt4/XM7d04d51bxp4TyL0xMorWthV36N\n3V+z/VQV7xwo49vzRjIy1nN2vrIr0JVSEUqptUqp40qpY0qpmZe8Pk8pVa+UOtDz9nPnlCu8UVaq\nhdqmdr6wownkhc8LMCnFg7OkkUjY58YJcYT6m1lr58PRlvYufrY+lxHRwTzmYStP2nuH/mfgA631\nWCAVONbLMdu11mk9b792WIXC680dE0NYgPmKTUb/aPNPs0ibv7BbgK8Pt06O54PcszTZsf3hk1tO\nUVTbzG/vmESAr/vOOe/NFQNdKRUOzAFeANBat2ut65xdmBg8/M0+3Dwxns1HKmjt6Lrsca/tLqKl\no0saiUSfLc6w0tzexfu5X7/94fGzDTy3LZ+lGVaP3MbQnjv04UAV8JJSar9S6nmlVHAvx81USh1U\nSr2vlJrQ24mUUiuUUtlKqeyqqqqB1C28TFaahca2TrYc730/yLbOLl764gxzpM1f9ENm0hCSooK+\ndraLzab5ybrDhAX68tNbxrmwOsexJ9DNQDrwtNZ6CtAE/PiSY3KAJK11KvAX4J3eTqS1Xqm1ztRa\nZ8bEuHfHlXCtGSOiiAn1v+x+o+/s72nzv1ruzkXfKaVYNMXKzvwaSs4193rMa7sL2V9Ux7/fNs5j\nlxG2J9BLgBKt9e6ej9fSHfD/oLVu0Fo39ry/CfBVSkU7tFLh1XxMilsnxbPlRCUNrV9ee8Nm0zy3\nvYDx8WHMHul5fwYL97AoPQGA9b3sZlTR0Mr/98EJrh4VzcK0BFeX5jBXDHSt9VmgWCl1YVvr+cDR\ni49RSsWpnv5rpdS0nvPaP0dICLo3JmjvtLH5SMWXPv/ZyUryKhtZMUfa/EX/JUYGMWNEJG/nlHxl\nQbhfbjhCe5eN3yyc6NE/Y/bOcvku8JpS6hCQBvxOKfWoUurRnteXALlKqYPAk8Ay3d8l9MSglZYY\nwbDIoK/Mdnl2az7x4QHcOlna/MXALE63cqammX2F/+xM/vhoBe/nnuXx+aNIiurt8aDnsCvQtdYH\nesa+J2utF2qtz2mtn9FaP9Pz+v9qrSdorVO11jO01jucW7bwRkopFqTG80VeNdWNbcBFbf6zpc1f\nDNzNk+IJ9PX5x4JdTW2d/PzdXMYMDfWK2VPyL0S4lazUBLpsmk2Hu5c8fW57PqH+ZpZN84y1NIR7\nC/E3c/PEODYeLKe1o4v/3nyS8oZWfrdoklfcMHj+dyC8ypi4UMYMDWXDgTKKa5vZ1NPmHxrga3Rp\nwksszrByvq2TP350klU7Crhn+jAykoYYXZZDSKALt5OVZiG78Bz/b+NRTErxwOxko0sSXmTmiCgs\n4QGs3JZPdIg//3rTWKNLchgJdOF2slItAGw+WkFWmoX48ECDKxLexGRSLEq3AvDLrAmEedFff7Jc\nnXA7iZFBTBkWwf6iOr4ljUTCCb49L4WMpCHMG+NdDY4S6MIt/dtNYzlcUs+4eGnzF44X7G/mmrGx\nRpfhcBLowi3NGBHFjBHSFSpEX8gYuhBCeAkJdCGE8BIS6EII4SUk0IUQwktIoAshhJeQQBdCCC8h\ngS6EEF5CAl0IIbyEMmofCqVUFVDYzy+PBqodWI4nkO95cJDveXAYyPecpLXudc0CwwJ9IJRS2Vrr\nTKPrcCX5ngcH+Z4HB2d9zzLkIoQQXkICXQghvISnBvpKowswgHzPg4N8z4ODU75njxxDF0II8VWe\neocuhBDiEhLoQgjhJTwu0JVSNymlTiil8pRSPza6HmdTSiUqpT5VSh1VSh1RSn3P6JpcQSnlo5Ta\nr5TaaHQtrqKUilBKrVVKHVdKHVNKzTS6JmdSSj3R8zOdq5R6XSkVYHRNzqCUelEpVamUyr3oc5FK\nqY+UUqd6/neII67lUYGulPIB/grcDIwH7lJKjTe2KqfrBH6otR4PzAC+Mwi+Z4DvAceMLsLF/gx8\noLUeC6Tixd+/UioBeBzI1FpPBHyAZcZW5TSrgJsu+dyPgU+01qOAT3o+HjCPCnRgGpCntc7XWrcD\nbwC3G1yTU2mty7XWOT3vn6f7H3mCsVU5l1LKCtwKPG90La6ilAoH5gAvAGit27XWdcZW5XRmIFAp\nZQaCgDKD63EKrfU2oPaST98OrO55fzWw0BHX8rRATwCKL/q4BC8Pt4sppZKBKcBuYytxuj8B/wrY\njC7EhYYDVcBLPUNNzyulgo0uylm01qXAH4AioByo11pvNrYqlxqqtS7vef8sMNQRJ/W0QB+0lFIh\nwNvA97XWDUbX4yxKqduASq31PqNrcTEzkA48rbWeAjThoD/D3VHPmPHtdP8iswDBSql7ja3KGLp7\n7rhD5o97WqCXAokXfWzt+ZxXU0r50h3mr2mt1xldj5PNBrKUUmfoHlK7Vin1qrEluUQJUKK1vvDX\n11q6A95bXQcUaK2rtNYdwDpglsE1uVKFUioeoOd/Kx1xUk8L9L3AKKXUcKWUH90PUTYYXJNTKaUU\n3eOqx7TWfzS6HmfTWv9Ea23VWifT/f/vFq2119+5aa3PAsVKqTE9n5oPHDWwJGcrAmYopYJ6fsbn\n48UPgXuxAbi/5/37gXcdcVKzI07iKlrrTqXUvwAf0v1U/EWt9RGDy3K22cB9wGGl1IGez/1Ua73J\nwJqEc3wXeK3nZiUfeNDgepxGa71bKbUWyKF7Jtd+vHQJAKXU68A8IFopVQL8AvhP4C2l1MN0LyP+\nDYdcS1r/hRDCO3jakIsQQojLkEAXQggvIYEuhBBeQgJdCCG8hAS6EEJ4CQl0IYTwEhLoQgjhJf5/\n4SC+3KX2RX4AAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"finished maml training\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "27iaosZpkPI8",
"colab_type": "text"
},
"source": [
"### Evalutating on model trained with 10000 epochs\n",
"\n",
"To evaluate the 5 way 1 shot model, I load the model and run the backprop the loss in the 5way-1shot manner.\n",
"\n",
"Then, I evaluate the model on the remaining 19 instances of the image(eval_k_shot) and compute the test_loss as the average of all the 19 instance losses.\n",
"\n",
"The final accuracy that I get is 87.6 which is around 10% lesser than the final accuracy of the original implementation(98.7 %). I believe with higher computing power and better hyper-parameter tuning, this is achievable."
]
},
{
"cell_type": "code",
"metadata": {
"id": "VshjF5QFHP_g",
"colab_type": "code",
"outputId": "0d8a713c-c560-4b20-f974-0a1f1e5b5574",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
}
},
"source": [
"import pickle\n",
"from sklearn.utils import shuffle\n",
"\n",
"\n",
"omniglot_learner = LearnerConv(N_way=5, device=device)\n",
"\n",
"loadpath = \"/content/omniglot_4marchv2_n5_k1_final.pt\"\n",
"print(\"loadpath =\", loadpath)\n",
"omniglot_learner.load_state_dict(torch.load(loadpath))\n",
"omniglot_learner.eval()\n",
"\n",
"# --------------------- Load the Omniglot data --------------------- #\n",
"omniglot_np_path = \"/content/omniglot.nparray.pk\"\n",
"with open(omniglot_np_path, 'rb') as f:\n",
" X_data = pickle.load(f, encoding=\"bytes\")\n",
"np.random.seed(28)\n",
"np.random.shuffle(X_data)\n",
"X_test = X_data[1200:,:,:,:]\n",
"X_test = np.transpose(X_test, (1, 0, 2, 3))\n",
"X_test = shuffle(X_test, random_state=0)\n",
"X_test = np.transpose(X_test, (1, 0, 2, 3))\n",
"\n",
"# --------------------- MAML Omniglot experiment --------------------- #\n",
"def omniglot_maml_exp_eval():\n",
" # hyperparameters\n",
" n_way = 5 # N\n",
" k_shot = 1 # K\n",
" num_grad_update = 5 # for evaluation\n",
" batch_size = n_way*k_shot\n",
" lr_a = 0.1\n",
"\n",
" eval_k_shot = 20-k_shot # each char has 5 instaces\n",
" eval_batch_size = n_way * eval_k_shot\n",
"\n",
"\n",
" num_eval_char = X_test.shape[0]\n",
" num_iterations = int(num_eval_char/n_way)\n",
"\n",
" criterion = nn.NLLLoss(reduction='mean')\n",
" optimizer = torch.optim.SGD(omniglot_learner.parameters(), lr=lr_a, momentum=0.0)\n",
" optimizer.zero_grad()\n",
"\n",
" idx = 0\n",
" count_correct_pred = 0\n",
" count_total_pred = 0\n",
"\n",
"\n",
" for iter in range(num_iterations):\n",
" # 1. for task_i consisting of characters of [idx, idx+n_way]\n",
" omniglot_learner.load_state_dict(torch.load(loadpath))\n",
"\n",
" # 2. update the gradient 'num_grad_update' times\n",
" X_batch = np.zeros((batch_size, 28, 28))\n",
" Y_batch = np.zeros((batch_size))\n",
"\n",
" for k in range(n_way):\n",
" X_batch[k*k_shot:(k+1)*k_shot,:,:] = X_test[idx+k,:k_shot,:,:]\n",
" Y_batch[k*k_shot:(k+1)*k_shot] = k\n",
"\n",
" X_batch = torch.tensor(X_batch, dtype=torch.float32).unsqueeze(1).to(device)\n",
" Y_batch = torch.tensor(Y_batch, dtype=torch.long).to(device)\n",
"\n",
" for j in range(num_grad_update):\n",
" # 2.2 compute gradient\n",
" Y_pred = omniglot_learner(X_batch)\n",
" loss = criterion(Y_pred, Y_batch)\n",
" loss.backward()\n",
"\n",
" optimizer.step()\n",
" optimizer.zero_grad()\n",
"\n",
" # 3. evaluation\n",
" X_batch_eval = np.zeros((eval_batch_size, 28, 28))\n",
" Y_batch_eval = np.zeros((eval_batch_size))\n",
" for k in range(n_way):\n",
" X_batch_eval[k*eval_k_shot:(k+1)*eval_k_shot,:,:] = X_test[idx+k,k_shot:,:,:]\n",
" Y_batch_eval[k*eval_k_shot:(k+1)*eval_k_shot] = k\n",
"\n",
" X_batch_eval = torch.tensor(X_batch_eval, dtype=torch.float32).unsqueeze(1).to(device)\n",
" Y_batch_eval = torch.tensor(Y_batch_eval, dtype=torch.long).to(device)\n",
"\n",
" Y_pred_eval = omniglot_learner(X_batch_eval)\n",
" Y_pred_eval = Y_pred_eval.argmax(dim=-1)\n",
" # print(Y_pred_eval)\n",
"\n",
" corr_pred = (Y_batch_eval == Y_pred_eval).int().sum().item()\n",
" total_pred = len(Y_batch_eval)\n",
" count_correct_pred += corr_pred\n",
" count_total_pred += total_pred\n",
"\n",
" print(\"[{}] iteration {}/{}: Accuray = {:.3f}\".format(str(datetime.now()), iter, num_iterations, corr_pred/total_pred))\n",
"\n",
" idx += n_way\n",
"\n",
" print(\"PREDICTION ACCURACY = {}\".format(count_correct_pred/count_total_pred))\n",
"\n",
"omniglot_maml_exp_eval()"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"loadpath = /content/omniglot_4marchv2_n5_k1_final.pt\n",
"[2020-03-21 04:56:22.152125] iteration 0/84: Accuray = 0.863\n",
"[2020-03-21 04:56:22.167178] iteration 1/84: Accuray = 0.979\n",
"[2020-03-21 04:56:22.183969] iteration 2/84: Accuray = 0.947\n",
"[2020-03-21 04:56:22.198658] iteration 3/84: Accuray = 0.979\n",
"[2020-03-21 04:56:22.212921] iteration 4/84: Accuray = 0.947\n",
"[2020-03-21 04:56:22.226658] iteration 5/84: Accuray = 0.821\n",
"[2020-03-21 04:56:22.241054] iteration 6/84: Accuray = 0.916\n",
"[2020-03-21 04:56:22.255583] iteration 7/84: Accuray = 0.937\n",
"[2020-03-21 04:56:22.269766] iteration 8/84: Accuray = 0.537\n",
"[2020-03-21 04:56:22.285029] iteration 9/84: Accuray = 0.863\n",
"[2020-03-21 04:56:22.301741] iteration 10/84: Accuray = 0.937\n",
"[2020-03-21 04:56:22.316293] iteration 11/84: Accuray = 0.884\n",
"[2020-03-21 04:56:22.332724] iteration 12/84: Accuray = 0.811\n",
"[2020-03-21 04:56:22.346803] iteration 13/84: Accuray = 0.821\n",
"[2020-03-21 04:56:22.360980] iteration 14/84: Accuray = 0.895\n",
"[2020-03-21 04:56:22.378460] iteration 15/84: Accuray = 0.811\n",
"[2020-03-21 04:56:22.394817] iteration 16/84: Accuray = 0.758\n",
"[2020-03-21 04:56:22.408856] iteration 17/84: Accuray = 0.853\n",
"[2020-03-21 04:56:22.423103] iteration 18/84: Accuray = 0.705\n",
"[2020-03-21 04:56:22.437264] iteration 19/84: Accuray = 0.832\n",
"[2020-03-21 04:56:22.451183] iteration 20/84: Accuray = 0.926\n",
"[2020-03-21 04:56:22.465837] iteration 21/84: Accuray = 0.863\n",
"[2020-03-21 04:56:22.480214] iteration 22/84: Accuray = 0.832\n",
"[2020-03-21 04:56:22.496632] iteration 23/84: Accuray = 0.821\n",
"[2020-03-21 04:56:22.510711] iteration 24/84: Accuray = 0.726\n",
"[2020-03-21 04:56:22.524934] iteration 25/84: Accuray = 0.916\n",
"[2020-03-21 04:56:22.539173] iteration 26/84: Accuray = 0.937\n",
"[2020-03-21 04:56:22.553825] iteration 27/84: Accuray = 0.979\n",
"[2020-03-21 04:56:22.567758] iteration 28/84: Accuray = 0.937\n",
"[2020-03-21 04:56:22.582242] iteration 29/84: Accuray = 0.958\n",
"[2020-03-21 04:56:22.597375] iteration 30/84: Accuray = 0.884\n",
"[2020-03-21 04:56:22.612305] iteration 31/84: Accuray = 0.768\n",
"[2020-03-21 04:56:22.626415] iteration 32/84: Accuray = 0.989\n",
"[2020-03-21 04:56:22.640580] iteration 33/84: Accuray = 0.937\n",
"[2020-03-21 04:56:22.655268] iteration 34/84: Accuray = 0.926\n",
"[2020-03-21 04:56:22.669356] iteration 35/84: Accuray = 0.811\n",
"[2020-03-21 04:56:22.683469] iteration 36/84: Accuray = 0.968\n",
"[2020-03-21 04:56:22.698618] iteration 37/84: Accuray = 0.779\n",
"[2020-03-21 04:56:22.718118] iteration 38/84: Accuray = 0.968\n",
"[2020-03-21 04:56:22.735132] iteration 39/84: Accuray = 0.874\n",
"[2020-03-21 04:56:22.758156] iteration 40/84: Accuray = 0.863\n",
"[2020-03-21 04:56:22.789952] iteration 41/84: Accuray = 0.968\n",
"[2020-03-21 04:56:22.812144] iteration 42/84: Accuray = 0.768\n",
"[2020-03-21 04:56:22.832461] iteration 43/84: Accuray = 0.895\n",
"[2020-03-21 04:56:22.849356] iteration 44/84: Accuray = 0.789\n",
"[2020-03-21 04:56:22.864052] iteration 45/84: Accuray = 0.947\n",
"[2020-03-21 04:56:22.880311] iteration 46/84: Accuray = 0.895\n",
"[2020-03-21 04:56:22.894158] iteration 47/84: Accuray = 0.663\n",
"[2020-03-21 04:56:22.910659] iteration 48/84: Accuray = 0.947\n",
"[2020-03-21 04:56:22.924638] iteration 49/84: Accuray = 0.800\n",
"[2020-03-21 04:56:22.938653] iteration 50/84: Accuray = 0.895\n",
"[2020-03-21 04:56:22.952456] iteration 51/84: Accuray = 0.811\n",
"[2020-03-21 04:56:22.966400] iteration 52/84: Accuray = 0.842\n",
"[2020-03-21 04:56:22.980642] iteration 53/84: Accuray = 0.853\n",
"[2020-03-21 04:56:22.994466] iteration 54/84: Accuray = 0.821\n",
"[2020-03-21 04:56:23.008251] iteration 55/84: Accuray = 0.926\n",
"[2020-03-21 04:56:23.024771] iteration 56/84: Accuray = 0.895\n",
"[2020-03-21 04:56:23.039101] iteration 57/84: Accuray = 0.832\n",
"[2020-03-21 04:56:23.052930] iteration 58/84: Accuray = 0.968\n",
"[2020-03-21 04:56:23.068045] iteration 59/84: Accuray = 0.958\n",
"[2020-03-21 04:56:23.082404] iteration 60/84: Accuray = 0.853\n",
"[2020-03-21 04:56:23.096182] iteration 61/84: Accuray = 0.811\n",
"[2020-03-21 04:56:23.109746] iteration 62/84: Accuray = 0.926\n",
"[2020-03-21 04:56:23.126473] iteration 63/84: Accuray = 0.937\n",
"[2020-03-21 04:56:23.140801] iteration 64/84: Accuray = 0.958\n",
"[2020-03-21 04:56:23.158024] iteration 65/84: Accuray = 0.832\n",
"[2020-03-21 04:56:23.174244] iteration 66/84: Accuray = 0.800\n",
"[2020-03-21 04:56:23.190729] iteration 67/84: Accuray = 0.905\n",
"[2020-03-21 04:56:23.204892] iteration 68/84: Accuray = 0.926\n",
"[2020-03-21 04:56:23.219237] iteration 69/84: Accuray = 0.989\n",
"[2020-03-21 04:56:23.235719] iteration 70/84: Accuray = 0.916\n",
"[2020-03-21 04:56:23.249319] iteration 71/84: Accuray = 0.884\n",
"[2020-03-21 04:56:23.263829] iteration 72/84: Accuray = 0.937\n",
"[2020-03-21 04:56:23.278183] iteration 73/84: Accuray = 0.800\n",
"[2020-03-21 04:56:23.292406] iteration 74/84: Accuray = 0.916\n",
"[2020-03-21 04:56:23.306824] iteration 75/84: Accuray = 0.989\n",
"[2020-03-21 04:56:23.320968] iteration 76/84: Accuray = 0.789\n",
"[2020-03-21 04:56:23.337754] iteration 77/84: Accuray = 0.905\n",
"[2020-03-21 04:56:23.351696] iteration 78/84: Accuray = 0.874\n",
"[2020-03-21 04:56:23.365844] iteration 79/84: Accuray = 0.800\n",
"[2020-03-21 04:56:23.380236] iteration 80/84: Accuray = 0.916\n",
"[2020-03-21 04:56:23.394627] iteration 81/84: Accuray = 0.989\n",
"[2020-03-21 04:56:23.408363] iteration 82/84: Accuray = 0.937\n",
"[2020-03-21 04:56:23.422395] iteration 83/84: Accuray = 0.800\n",
"PREDICTION ACCURACY = 0.8764411027568922\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Qgdu5nvjkWO8",
"colab_type": "text"
},
"source": [
"### Evalutating on model trained with 5000 epochs"
]
},
{
"cell_type": "code",
"metadata": {
"id": "ru7UkScRS2Q9",
"colab_type": "code",
"outputId": "66834329-1ee6-4589-ae33-f306981f464b",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
}
},
"source": [
"import pickle\n",
"from sklearn.utils import shuffle\n",
"\n",
"\n",
"omniglot_learner = LearnerConv(N_way=5, device=device)\n",
"\n",
"loadpath = \"/content/omniglot_4marchv2_n5_k1_iter5000.pt\"\n",
"print(\"loadpath =\", loadpath)\n",
"omniglot_learner.load_state_dict(torch.load(loadpath))\n",
"omniglot_learner.eval()\n",
"\n",
"# --------------------- Load the Omniglot data --------------------- #\n",
"omniglot_np_path = \"/content/omniglot.nparray.pk\"\n",
"with open(omniglot_np_path, 'rb') as f:\n",
" X_data = pickle.load(f, encoding=\"bytes\")\n",
"np.random.seed(28)\n",
"np.random.shuffle(X_data)\n",
"# X_train = X_data[:1200,:,:,:]\n",
"X_test = X_data[1200:,:,:,:]\n",
"X_test = np.transpose(X_test, (1, 0, 2, 3))\n",
"X_test = shuffle(X_test, random_state=0)\n",
"X_test = np.transpose(X_test, (1, 0, 2, 3))\n",
"\n",
"# --------------------- MAML Omniglot experiment --------------------- #\n",
"def omniglot_maml_exp_eval():\n",
" # hyperparameters\n",
" n_way = 5 # N\n",
" k_shot = 1 # K\n",
" num_grad_update = 5 # for evaluation\n",
" batch_size = n_way*k_shot\n",
" lr_a = 0.1\n",
"\n",
" eval_k_shot = 20-k_shot # each char has 5 instaces\n",
" eval_batch_size = n_way * eval_k_shot\n",
"\n",
"\n",
" num_eval_char = X_test.shape[0]\n",
" num_iterations = int(num_eval_char/n_way)\n",
"\n",
" criterion = nn.NLLLoss(reduction='mean')\n",
" optimizer = torch.optim.SGD(omniglot_learner.parameters(), lr=lr_a, momentum=0.0)\n",
" optimizer.zero_grad()\n",
"\n",
" idx = 0\n",
" count_correct_pred = 0\n",
" count_total_pred = 0\n",
"\n",
"\n",
" for iter in range(num_iterations):\n",
" # 1. for task_i consisting of characters of [idx, idx+n_way)\n",
" omniglot_learner.load_state_dict(torch.load(loadpath))\n",
"\n",
" # 2. update the gradient 'num_grad_update' times\n",
" X_batch = np.zeros((batch_size, 28, 28))\n",
" Y_batch = np.zeros((batch_size))\n",
"\n",
" for k in range(n_way):\n",
" X_batch[k*k_shot:(k+1)*k_shot,:,:] = X_test[idx+k,:k_shot,:,:]\n",
" Y_batch[k*k_shot:(k+1)*k_shot] = k\n",
"\n",
" X_batch = torch.tensor(X_batch, dtype=torch.float32).unsqueeze(1).to(device)\n",
" Y_batch = torch.tensor(Y_batch, dtype=torch.long).to(device)\n",
"\n",
" for j in range(num_grad_update):\n",
" # 2.2 compute gradient\n",
" Y_pred = omniglot_learner(X_batch)\n",
" loss = criterion(Y_pred, Y_batch)\n",
" loss.backward()\n",
"\n",
" optimizer.step()\n",
" optimizer.zero_grad()\n",
"\n",
" # 3. evaluation\n",
" X_batch_eval = np.zeros((eval_batch_size, 28, 28))\n",
" Y_batch_eval = np.zeros((eval_batch_size))\n",
" for k in range(n_way):\n",
" X_batch_eval[k*eval_k_shot:(k+1)*eval_k_shot,:,:] = X_test[idx+k,k_shot:,:,:]\n",
" Y_batch_eval[k*eval_k_shot:(k+1)*eval_k_shot] = k\n",
"\n",
" # X_batch_eval, Y_batch_eval = shuffle(X_batch_eval, Y_batch_eval, random_state=0)\n",
"\n",
" X_batch_eval = torch.tensor(X_batch_eval, dtype=torch.float32).unsqueeze(1).to(device)\n",
" Y_batch_eval = torch.tensor(Y_batch_eval, dtype=torch.long).to(device)\n",
"\n",
" Y_pred_eval = omniglot_learner(X_batch_eval)\n",
" Y_pred_eval = Y_pred_eval.argmax(dim=-1)\n",
" # print(Y_pred_eval)\n",
"\n",
" corr_pred = (Y_batch_eval == Y_pred_eval).int().sum().item()\n",
" total_pred = len(Y_batch_eval)\n",
" count_correct_pred += corr_pred\n",
" count_total_pred += total_pred\n",
"\n",
" print(\"[{}] iteration {}/{}: Accuray = {:.3f}\".format(str(datetime.now()), iter, num_iterations, corr_pred/total_pred))\n",
"\n",
" idx += n_way\n",
"\n",
" print(\"PREDICTION ACCURACY = {}\".format(count_correct_pred/count_total_pred))\n",
"\n",
"omniglot_maml_exp_eval()"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"loadpath = /content/omniglot_4marchv2_n5_k1_iter5000.pt\n",
"[2020-03-21 04:56:23.577508] iteration 0/84: Accuray = 0.632\n",
"[2020-03-21 04:56:23.592082] iteration 1/84: Accuray = 0.779\n",
"[2020-03-21 04:56:23.606431] iteration 2/84: Accuray = 0.779\n",
"[2020-03-21 04:56:23.620816] iteration 3/84: Accuray = 0.653\n",
"[2020-03-21 04:56:23.635386] iteration 4/84: Accuray = 0.821\n",
"[2020-03-21 04:56:23.649648] iteration 5/84: Accuray = 0.811\n",
"[2020-03-21 04:56:23.664244] iteration 6/84: Accuray = 0.589\n",
"[2020-03-21 04:56:23.678570] iteration 7/84: Accuray = 0.853\n",
"[2020-03-21 04:56:23.692901] iteration 8/84: Accuray = 0.453\n",
"[2020-03-21 04:56:23.707212] iteration 9/84: Accuray = 0.621\n",
"[2020-03-21 04:56:23.721607] iteration 10/84: Accuray = 0.611\n",
"[2020-03-21 04:56:23.737924] iteration 11/84: Accuray = 0.632\n",
"[2020-03-21 04:56:23.752156] iteration 12/84: Accuray = 0.621\n",
"[2020-03-21 04:56:23.768135] iteration 13/84: Accuray = 0.674\n",
"[2020-03-21 04:56:23.785060] iteration 14/84: Accuray = 0.800\n",
"[2020-03-21 04:56:23.805714] iteration 15/84: Accuray = 0.642\n",
"[2020-03-21 04:56:23.821493] iteration 16/84: Accuray = 0.600\n",
"[2020-03-21 04:56:23.835671] iteration 17/84: Accuray = 0.474\n",
"[2020-03-21 04:56:23.849563] iteration 18/84: Accuray = 0.737\n",
"[2020-03-21 04:56:23.863939] iteration 19/84: Accuray = 0.663\n",
"[2020-03-21 04:56:23.878258] iteration 20/84: Accuray = 0.663\n",
"[2020-03-21 04:56:23.892141] iteration 21/84: Accuray = 0.747\n",
"[2020-03-21 04:56:23.906052] iteration 22/84: Accuray = 0.663\n",
"[2020-03-21 04:56:23.920039] iteration 23/84: Accuray = 0.684\n",
"[2020-03-21 04:56:23.934124] iteration 24/84: Accuray = 0.589\n",
"[2020-03-21 04:56:23.949719] iteration 25/84: Accuray = 0.684\n",
"[2020-03-21 04:56:23.963905] iteration 26/84: Accuray = 0.842\n",
"[2020-03-21 04:56:23.978297] iteration 27/84: Accuray = 0.779\n",
"[2020-03-21 04:56:23.994581] iteration 28/84: Accuray = 0.832\n",
"[2020-03-21 04:56:24.008735] iteration 29/84: Accuray = 0.611\n",
"[2020-03-21 04:56:24.022924] iteration 30/84: Accuray = 0.684\n",
"[2020-03-21 04:56:24.037037] iteration 31/84: Accuray = 0.663\n",
"[2020-03-21 04:56:24.051033] iteration 32/84: Accuray = 0.842\n",
"[2020-03-21 04:56:24.066167] iteration 33/84: Accuray = 0.800\n",
"[2020-03-21 04:56:24.080397] iteration 34/84: Accuray = 0.653\n",
"[2020-03-21 04:56:24.094563] iteration 35/84: Accuray = 0.611\n",
"[2020-03-21 04:56:24.108383] iteration 36/84: Accuray = 0.632\n",
"[2020-03-21 04:56:24.122456] iteration 37/84: Accuray = 0.579\n",
"[2020-03-21 04:56:24.136752] iteration 38/84: Accuray = 0.811\n",
"[2020-03-21 04:56:24.152167] iteration 39/84: Accuray = 0.642\n",
"[2020-03-21 04:56:24.167462] iteration 40/84: Accuray = 0.642\n",
"[2020-03-21 04:56:24.181683] iteration 41/84: Accuray = 0.716\n",
"[2020-03-21 04:56:24.197923] iteration 42/84: Accuray = 0.558\n",
"[2020-03-21 04:56:24.213408] iteration 43/84: Accuray = 0.726\n",
"[2020-03-21 04:56:24.228301] iteration 44/84: Accuray = 0.600\n",
"[2020-03-21 04:56:24.246652] iteration 45/84: Accuray = 0.779\n",
"[2020-03-21 04:56:24.264016] iteration 46/84: Accuray = 0.611\n",
"[2020-03-21 04:56:24.281475] iteration 47/84: Accuray = 0.558\n",
"[2020-03-21 04:56:24.300421] iteration 48/84: Accuray = 0.579\n",
"[2020-03-21 04:56:24.317846] iteration 49/84: Accuray = 0.632\n",
"[2020-03-21 04:56:24.334940] iteration 50/84: Accuray = 0.758\n",
"[2020-03-21 04:56:24.352340] iteration 51/84: Accuray = 0.674\n",
"[2020-03-21 04:56:24.375235] iteration 52/84: Accuray = 0.432\n",
"[2020-03-21 04:56:24.394698] iteration 53/84: Accuray = 0.589\n",
"[2020-03-21 04:56:24.416745] iteration 54/84: Accuray = 0.442\n",
"[2020-03-21 04:56:24.435855] iteration 55/84: Accuray = 0.632\n",
"[2020-03-21 04:56:24.453252] iteration 56/84: Accuray = 0.621\n",
"[2020-03-21 04:56:24.471008] iteration 57/84: Accuray = 0.484\n",
"[2020-03-21 04:56:24.488195] iteration 58/84: Accuray = 0.811\n",
"[2020-03-21 04:56:24.505547] iteration 59/84: Accuray = 0.737\n",
"[2020-03-21 04:56:24.522586] iteration 60/84: Accuray = 0.558\n",
"[2020-03-21 04:56:24.539325] iteration 61/84: Accuray = 0.579\n",
"[2020-03-21 04:56:24.556411] iteration 62/84: Accuray = 0.684\n",
"[2020-03-21 04:56:24.571436] iteration 63/84: Accuray = 0.800\n",
"[2020-03-21 04:56:24.587515] iteration 64/84: Accuray = 0.863\n",
"[2020-03-21 04:56:24.601388] iteration 65/84: Accuray = 0.642\n",
"[2020-03-21 04:56:24.615190] iteration 66/84: Accuray = 0.558\n",
"[2020-03-21 04:56:24.631125] iteration 67/84: Accuray = 0.695\n",
"[2020-03-21 04:56:24.645176] iteration 68/84: Accuray = 0.884\n",
"[2020-03-21 04:56:24.659592] iteration 69/84: Accuray = 0.842\n",
"[2020-03-21 04:56:24.675903] iteration 70/84: Accuray = 0.705\n",
"[2020-03-21 04:56:24.689908] iteration 71/84: Accuray = 0.663\n",
"[2020-03-21 04:56:24.703759] iteration 72/84: Accuray = 0.747\n",
"[2020-03-21 04:56:24.717557] iteration 73/84: Accuray = 0.684\n",
"[2020-03-21 04:56:24.731441] iteration 74/84: Accuray = 0.653\n",
"[2020-03-21 04:56:24.744922] iteration 75/84: Accuray = 0.916\n",
"[2020-03-21 04:56:24.758482] iteration 76/84: Accuray = 0.516\n",
"[2020-03-21 04:56:24.772546] iteration 77/84: Accuray = 0.621\n",
"[2020-03-21 04:56:24.791888] iteration 78/84: Accuray = 0.747\n",
"[2020-03-21 04:56:24.805865] iteration 79/84: Accuray = 0.663\n",
"[2020-03-21 04:56:24.826479] iteration 80/84: Accuray = 0.779\n",
"[2020-03-21 04:56:24.849276] iteration 81/84: Accuray = 0.884\n",
"[2020-03-21 04:56:24.867714] iteration 82/84: Accuray = 0.789\n",
"[2020-03-21 04:56:24.886986] iteration 83/84: Accuray = 0.674\n",
"PREDICTION ACCURACY = 0.6807017543859649\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7cjBZj-2sk2X",
"colab_type": "text"
},
"source": [
"### Conclusion:\n",
"\n",
"I have understood the usage of Few shot leaarning and how useful it is to build a model with less data.\n",
"\n",
"### Future Scope:\n",
"To use Augmix to augment the data and improve on the few shot learning"
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment