Created
January 13, 2025 06:55
-
-
Save shrijayan/c3595fdffeb37c2f964b5ceb4848cf67 to your computer and use it in GitHub Desktop.
Can We Teach AI to See Through Text? A Fun Exploration with ASCII Art and GPT-3
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "nbformat": 4, | |
| "nbformat_minor": 0, | |
| "metadata": { | |
| "colab": { | |
| "provenance": [] | |
| }, | |
| "kernelspec": { | |
| "name": "python3", | |
| "display_name": "Python 3" | |
| }, | |
| "language_info": { | |
| "name": "python" | |
| } | |
| }, | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "# Art Project - Large Language Models can be Few-shot MNIST Classifiers - [Shrijayan](https://shrijayan.cpluz.com)\n", | |
| "\n", | |
| "Can we get large language models (LLMs) like GPT-3 to do few-shot classification of images?\n", | |
| "\n", | |
| "A key challenge is how to get a large language model to understand images (channel x width x height pixel matrices). We try to bridge this gap by \"translating\" images to ASCII art.\n", | |
| "\n", | |
| "The key idea is to prompt an LLM to output a class label given an ASCII image + a few examples of (image, class label) pairs as context.\n", | |
| "\n", | |
| "### *Hypothesis* (why interesting)\n", | |
| "We've seen that LLMs are very good at few-shot learning. They can learn to complete a new task based on only a few unseen examples, all expressed as natural text!\n", | |
| "\n", | |
| "However, while these models were trained on lots of human-produced text, it may not matter as much what the actual content of the examples are (or if the content is similar to something the LLM encountered during pretraining).\n", | |
| "* Instead, to get an LLM to solve a task, the prompt we give it may only need some consistent structure or recognizable patterns across examples. \n", | |
| "\n", | |
| "One way to test this is to see if an LLM can reason over sequences containing ASCII art of real-world images.\n", | |
| "\n", | |
| "\n", | |
| "\n", | |
| "### *Motivations* (why important) and *difficulties* (why hard)\n", | |
| "While this is just an art project for fun, there might also be some related usefulness. Personally, I've been interested in \n", | |
| "1. How we can adapt benefits of foundation models to work well on individualized use-cases and local points of deployment \n", | |
| "2. How to get models to handle new settings without much data (perhaps a new failure mode or edge case surfaces).\n", | |
| "\n", | |
| "Some difficulties with realizing the above are:\n", | |
| "\n", | |
| "- Training a sufficient model from scratch may not be feasible, as there may not be enough data in the deployment setting. \n", | |
| "- Finetuning a pretrained model may be expensive (large FMs that not everyone has access to, hard to maintain such a model for each setting). \n", | |
| "- A pretrained model's out-of-the-box zero-shot behavior may not be desired for a specific deployment scenario.\n", | |
| "\n", | |
| "### *Solution?*\n", | |
| "However, recent LLMs have shown the ability to \"steer\" / change their classification based on a few examples, without having to finetune their weights. This is great because for the above settings, by providing a few examples illustrating the desired classification behavior, we could then get personalized inference. \n", | |
| "- But the world is not just all natural language! It'd also be great to extend this capability to more modalities and domains. \n", | |
| "- Given that we have these text-trained LLMs lying around, can we use text as a universal modality to do so?" | |
| ], | |
| "metadata": { | |
| "id": "qdAMTEAwR3en" | |
| } | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "### Setup" | |
| ], | |
| "metadata": { | |
| "id": "oH6-uz6qSg7r" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "!pip install numpy\n", | |
| "!pip install pytorch\n", | |
| "!pip install torchvision\n", | |
| "!pip install pillow\n", | |
| "!pip install openai # Only for using GPT-3 models" | |
| ], | |
| "metadata": { | |
| "id": "wj-CgWxsSFyO" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "import copy\n", | |
| "import torch\n", | |
| "import numpy as np\n", | |
| "\n", | |
| "from PIL import Image" | |
| ], | |
| "metadata": { | |
| "id": "hUbNPT5sStwL" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Fake args\n", | |
| "class Args():\n", | |
| " def __init__(self):\n", | |
| " pass\n", | |
| "args = Args()\n", | |
| "args.seed = 42\n", | |
| "args.device = torch.device('cuda:0')\n", | |
| "args.display_image = True\n", | |
| "args.num_workers = 2\n", | |
| "\n", | |
| "np.random.seed(args.seed)\n", | |
| "torch.manual_seed(args.seed)" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "Q6bm0782S26g", | |
| "outputId": "22d30d11-4a43-4028-c966-35c284f51c81" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "<torch._C.Generator at 0x7fb4d21ddc70>" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 159 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "### Load data\n", | |
| "\n", | |
| "We need to load images and convert them to ASCII. To ensure this fits into a propt that an LLM can process (frequently 2048 or 4096 tokens), we may also need to downsize the images.\n", | |
| "\n", | |
| "For this art project, we'll use the MNIST dataset." | |
| ], | |
| "metadata": { | |
| "id": "u_8Z7l8YS-v_" | |
| } | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "### Image transforms\n", | |
| "\n", | |
| "We convert each image to a downsized, center-cropped, gray-scaled square image." | |
| ], | |
| "metadata": { | |
| "id": "r2AcDPUsTFMo" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "import torchvision\n", | |
| "from torchvision import transforms\n", | |
| "from torchvision.transforms import Compose, Resize, CenterCrop, Grayscale, ToTensor\n", | |
| "try:\n", | |
| " from torchvision.transforms import InterpolationMode\n", | |
| " BICUBIC = InterpolationMode.BICUBIC\n", | |
| "except ImportError:\n", | |
| " BICUBIC = Image.BICUBIC" | |
| ], | |
| "metadata": { | |
| "id": "NXv9l_GYTLFd" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "def get_transform(w, grayscale=True, to_tensor=False):\n", | |
| " transforms = [\n", | |
| " Resize(w, interpolation=BICUBIC),\n", | |
| " CenterCrop(w),\n", | |
| " ]\n", | |
| " if grayscale:\n", | |
| " transforms.append(Grayscale(num_output_channels=1))\n", | |
| " if to_tensor:\n", | |
| " transforms.append(ToTensor())\n", | |
| " return Compose(transforms)" | |
| ], | |
| "metadata": { | |
| "id": "SBdlcKGpTO4h" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Specify the width of downsized images\n", | |
| "args.width = 24" | |
| ], | |
| "metadata": { | |
| "id": "SOHVeNVbS4mo" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Actual transform function\n", | |
| "transform_grayscale = get_transform(args.width, grayscale=True)\n", | |
| "\n", | |
| "# Original image (for visualization)\n", | |
| "transform_base = get_transform(28, grayscale=False)" | |
| ], | |
| "metadata": { | |
| "id": "A2W5hOkvTQtB" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "### ASCII conversion \n", | |
| "3 options to pick from: \n", | |
| "* `ASCII_68`: 68 levels of brightness \n", | |
| "* `ASCII_10`: 10 levels of brightness \n", | |
| "* `ASCII_10n`: 10 levels of brightness using integers to denote level" | |
| ], | |
| "metadata": { | |
| "id": "5CzZaOl2TV4R" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# No periods or spaces bc I thought that might interfere with LLM\n", | |
| "ASCII_68 = \"$@B%8&WM#*oahkbdpqwmZO0QLCJUYXzcvunxrjft/\\|()1{}[]?-_+~<>i!lI;:,\\\"^`'\"\n", | |
| "ASCII_10 = \"@%#*+=-:`'\"\n", | |
| "ASCII_10n = ''.join([str(i) for i in range(len(ASCII_10))[::-1]])" | |
| ], | |
| "metadata": { | |
| "id": "RyObH9RMTgfs" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "def img_to_text(img, scale=10, numerical=False):\n", | |
| " global ASCII_68, ASCII_10, ASCII_10n\n", | |
| " if scale > 10:\n", | |
| " ascii_scale = ASCII_68\n", | |
| " elif numerical:\n", | |
| " ascii_scale = ASCII_10n\n", | |
| " else:\n", | |
| " ascii_scale = ASCII_10\n", | |
| "\n", | |
| " return ''.join([ascii_scale[int(p * (scale - 1) / 255)] for p in\n", | |
| " np.array(img).flatten()])" | |
| ], | |
| "metadata": { | |
| "id": "QpCs4OgOTirt" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "def pretty_print_ascii(text, width, height):\n", | |
| " num_lines = int(len(text) / width)\n", | |
| " assert num_lines == height\n", | |
| " print_str = []\n", | |
| " for i in range(num_lines - 1):\n", | |
| " print_str.append(f'{text[i * width: (i + 1) * width]}\\n')\n", | |
| " i += 1\n", | |
| " print_str.append(f'{text[i * width: (i + 1) * width]}')\n", | |
| " print(''.join(print_str))" | |
| ], | |
| "metadata": { | |
| "id": "1GsASUR3TkHl" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "### See MNIST examples" | |
| ], | |
| "metadata": { | |
| "id": "NUWPFFBfTo5F" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Load original data\n", | |
| "train_set = torchvision.datasets.MNIST(root='.', train=True, transform=transform_base, download=True)\n", | |
| "test_set = torchvision.datasets.MNIST(root='.', train=False, transform=transform_base, download=True)\n" | |
| ], | |
| "metadata": { | |
| "id": "HJy-SZmzTnF3" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "train_set.__getitem__(0)[0]" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 45 | |
| }, | |
| "id": "hPdwx9hVUHao", | |
| "outputId": "39d34dde-c9d6-4b27-e515-f69ac302c1b0" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "<PIL.Image.Image image mode=L size=28x28 at 0x7FB4C431FB50>" | |
| ], | |
| "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABAElEQVR4nGNgGMyAWUhIqK5jvdSy/9/rGRgYGFhgEnJsVjYCwQwMDAxPJgV+vniQgYGBgREqZ7iXH8r6l/SV4dn7m8gmCt3++/fv37/Htn3/iMW+gDnZf/+e5WbQnoXNNXyMs/5GoQoxwVmf/n9kSGFiwAW49/11wynJoPzx4YIcRlyygR/+/i2XxCWru+vv32nSuGQFYv/83Y3b4p9/fzpAmSyoMnohpiwM1w5h06Q+5enfv39/bcMiJVF09+/fv39P+mFKiTtd/fv3799jgZiBJLT69t+/f/8eDuDEkDJf8+jv379/v7Ryo4qzMDAwMAQGMjBc3/y35wM2V1IfAABFF16Aa0wAOwAAAABJRU5ErkJggg==\n" | |
| }, | |
| "metadata": {}, | |
| "execution_count": 168 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Load ASCII data\n", | |
| "train_set_ascii = torchvision.datasets.MNIST(root='.', train=True,\n", | |
| " transform=transform_grayscale, download=False)\n", | |
| "test_set_ascii = torchvision.datasets.MNIST(root='.', train=False,\n", | |
| " transform=transform_grayscale, download=False)" | |
| ], | |
| "metadata": { | |
| "id": "EW7TgezzTzYG" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "pretty_print_ascii(\n", | |
| " img_to_text(train_set_ascii.__getitem__(0)[0]),\n", | |
| " width=args.width, height=args.width\n", | |
| ")" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "qzUce8kCUJey", | |
| "outputId": "f728ad4d-018b-4eb8-c083-73c2493810ee" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@*+%*-=%@@@\n", | |
| "@@@@@@@%%+=::```=``-%@@@\n", | |
| "@@@@@@#`'``''''=##%@@@@@\n", | |
| "@@@@@@%:````==:*@@@@@@@@\n", | |
| "@@@@@@@%#=`=@@%%@@@@@@@@\n", | |
| "@@@@@@@@@#`=@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@=`%@@@@@@@@@@@\n", | |
| "@@@@@@@@@@%::+%@@@@@@@@@\n", | |
| "@@@@@@@@@@@#:':#@@@@@@@@\n", | |
| "@@@@@@@@@@@@%=`:*@@@@@@@\n", | |
| "@@@@@@@@@@@@@@+'`%@@@@@@\n", | |
| "@@@@@@@@@@@@@#='`%@@@@@@\n", | |
| "@@@@@@@@@@@*-``':@@@@@@@\n", | |
| "@@@@@@@@@#-`''`=%@@@@@@@\n", | |
| "@@@@@@@%+`''`=%@@@@@@@@@\n", | |
| "@@@@%#-`''`=%@@@@@@@@@@@\n", | |
| "@@@%:''`:-*@@@@@@@@@@@@@\n", | |
| "@@@%***#%@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "# LLM Time\n", | |
| "\n", | |
| "Now we'll create prompts, initialize GPT-3, and get outputs." | |
| ], | |
| "metadata": { | |
| "id": "7ivitinuUuSe" | |
| } | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "## Prompt Creation\n", | |
| "\n", | |
| "How we format the prompt, and what samples we include in the prompt could affect the LLM output. \n", | |
| "\n", | |
| "For formatting, we'll just follow a simple few-shot template: \n", | |
| "```\n", | |
| "Input: [flattened_ascii_image] \n", | |
| "Output: [class_label] \n", | |
| "###\n", | |
| "Input: [flattened_ascii_image] \n", | |
| "Output: [class_label] \n", | |
| "###\n", | |
| "...\n", | |
| "Input: [flattened_ascii_image] \n", | |
| "Output: \n", | |
| "``` \n", | |
| "\n", | |
| "For selecting samples to include, there are many potentially fancy ways to do this. For classification with many potential groups, it'd be good to include at least one example of each group. \n", | |
| "\n", | |
| "For now, we'll just randomly sample from each group (e.g., MNIST digit class)." | |
| ], | |
| "metadata": { | |
| "id": "aesYe1OmU5EW" | |
| } | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "## Helpers" | |
| ], | |
| "metadata": { | |
| "id": "eWkwQlBuVDYf" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "def get_indices_by_class(dataset):\n", | |
| " \"\"\"\n", | |
| " Organize data indices into classes\n", | |
| " \"\"\"\n", | |
| " indices_by_class = []\n", | |
| " num_classes = len(np.unique(dataset.targets))\n", | |
| " for i in np.arange(num_classes):\n", | |
| " indices_class = np.where(dataset.targets == i)[0]\n", | |
| " indices_by_class.append(indices_class)\n", | |
| " return indices_by_class" | |
| ], | |
| "metadata": { | |
| "id": "HiDLQbWwUQX8" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "def get_random_indices_by_class(indices_by_class, classes, num_samples, sample_seed):\n", | |
| " \"\"\"\n", | |
| " Randomly sample samples from each group\n", | |
| " \"\"\"\n", | |
| " np.random.seed(sample_seed)\n", | |
| " reference_sample_ix = []\n", | |
| " reference_sample_y = []\n", | |
| " for i in classes:\n", | |
| " sample_ix = np.random.choice(indices_by_class[i], size=num_samples, replace=False)\n", | |
| " reference_sample_ix.append(sample_ix)\n", | |
| " reference_sample_y.append(np.repeat([i], num_samples))\n", | |
| " return np.concatenate(reference_sample_ix), np.concatenate(reference_sample_y)" | |
| ], | |
| "metadata": { | |
| "id": "Nrp7hWBgVHR1" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "def get_context_prompt(dataset, indices, pretty_print=False,\n", | |
| " start='[START]\\n', end='[END]\\n', sep='###\\n',\n", | |
| " shuffle=False, seed=42):\n", | |
| " \"\"\"\n", | |
| " Given dataset, dataset indices, retrieve ASCII images, ground-truth labels,\n", | |
| " and format into the context of a prompt\n", | |
| " \"\"\"\n", | |
| " prompt = []\n", | |
| " # Template: start, x, y, end, sep\n", | |
| " sample = '{}Input: {}\\nOutput: {}\\n{}{}'\n", | |
| " if pretty_print: # Support this later\n", | |
| " raise NotImplementedError\n", | |
| " else:\n", | |
| " for i in indices:\n", | |
| " data_x, data_y = dataset.__getitem__(i)\n", | |
| " prompt.append(copy.copy(sample).format(\n", | |
| " start, img_to_text(data_x), data_y, end, sep\n", | |
| " ))\n", | |
| " shuffle_ix = np.arange(len(indices))\n", | |
| " if shuffle:\n", | |
| " np.random.seed(seed)\n", | |
| " np.random.shuffle(shuffle_ix)\n", | |
| " prompt = [prompt[ix] for ix in shuffle_ix]\n", | |
| " prompt = ''.join(prompt)\n", | |
| "\n", | |
| " return prompt, indices[shuffle_ix]" | |
| ], | |
| "metadata": { | |
| "id": "xtebS4E1VJK_" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "def get_eval_prompt(dataset, ix, pretty_print=False,\n", | |
| " start='[START]\\n', end='[END]\\n', sep='###\\n'):\n", | |
| " \"\"\"\n", | |
| " Format image we want to classify into a prompt\n", | |
| " \"\"\"\n", | |
| " # Template: start, x\n", | |
| " prompt = '{}Input: {}\\nOutput: '\n", | |
| " data_x, data_y = dataset.__getitem__(ix)\n", | |
| " return prompt.format(start, img_to_text(data_x)), data_y\n", | |
| "" | |
| ], | |
| "metadata": { | |
| "id": "WeyOlXlMVK-b" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "## Easy-mode Image Classification\n", | |
| "\n", | |
| "For now, we'll stick to classifying only 3 MNIST digit classes (0s, 1s, and 2s).\n", | |
| "\n", | |
| "(Later we can try doing more, but I don't have the budget to call OpenAI's API all day for inference and evaluation)" | |
| ], | |
| "metadata": { | |
| "id": "Hnd4v61rVZZ5" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Organize train and test into classes\n", | |
| "num_classes = len(torch.unique(train_set.targets))\n", | |
| "\n", | |
| "# TRAIN\n", | |
| "indices_by_class = get_indices_by_class(train_set)\n", | |
| "\n", | |
| "# TEST\n", | |
| "test_indices_by_class = get_indices_by_class(test_set)" | |
| ], | |
| "metadata": { | |
| "id": "mv9OGYaEWLgL" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "class_indices = [0, 1, 2] # These are the classes we'll evaluate" | |
| ], | |
| "metadata": { | |
| "id": "OTgwRrO2WNi6" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "### Visualizing a formatted prompt \"context\" \n", | |
| "We can now visualize a few-shot prompt, without the sample we want to classify:" | |
| ], | |
| "metadata": { | |
| "id": "DiCRGe9UWRPB" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "args.seed_cls = 42\n", | |
| "args.num_samples = 1\n", | |
| "img_dims = {'width': args.width, 'height': args.width}\n", | |
| "classes = range(num_classes)\n", | |
| "\n", | |
| "context_ix, context_y = get_random_indices_by_class(\n", | |
| " indices_by_class,\n", | |
| " class_indices,\n", | |
| " args.num_samples,\n", | |
| " args.seed_cls\n", | |
| ")\n", | |
| "\n", | |
| "for i, ix in enumerate(context_ix):\n", | |
| " print(f'Input:')\n", | |
| " pretty_print_ascii(\n", | |
| " img_to_text(train_set_ascii.__getitem__(ix)[0]),\n", | |
| " **img_dims\n", | |
| " )\n", | |
| " print(f'Output: {train_set_ascii.__getitem__(ix)[1]}\\n###\\n')\n", | |
| " # Unrolled\n", | |
| " # print(img_to_text(train_set_ascii.__getitem__(ix)[0]))" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "y9gDbEs2WP13", | |
| "outputId": "53321ce9-f443-4b55-912d-015f60ca0b2e" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "Input:\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@%+===*@@@@@@@\n", | |
| "@@@@@@@@@%=:`'''`-@@@@@@\n", | |
| "@@@@@@@@+``''```''+@@@@@\n", | |
| "@@@@@@@=''''''''`'-@@@@@\n", | |
| "@@@@@@:```:==-:````+@@@@\n", | |
| "@@@@@@`''`#@@@%-'``=@@@@\n", | |
| "@@@@@#``'`%@@@@%:'':%@@@\n", | |
| "@@@@@:`':#@@@@@@='':%@@@\n", | |
| "@@@@+`'-@@@@@@@@='':%@@@\n", | |
| "@@@@=``=@@@@@@@@-'':%@@@\n", | |
| "@@@@-'`#@@@@@@@+`'`=@@@@\n", | |
| "@@@%:':@@@@@@@%:''`#@@@@\n", | |
| "@@@*``-@@@@@%#-```=@@@@@\n", | |
| "@@@='``-=##=:`'```*@@@@@\n", | |
| "@@@=``''```'''````@@@@@@\n", | |
| "@@@@-'```''```'`-@@@@@@@\n", | |
| "@@@@+:'''''''':=%@@@@@@@\n", | |
| "@@@@%%=======+%@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "Output: 0\n", | |
| "###\n", | |
| "\n", | |
| "Input:\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@#@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@-%@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@=%@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@=*@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@=+@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@--@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@--@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@=:@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@*`@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@#`@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@%:@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@-*@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@==@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@=-@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@*:%@@@@@@@@@\n", | |
| "@@@@@@@@@@@@%:#@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@:+@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@**@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "Output: 1\n", | |
| "###\n", | |
| "\n", | |
| "Input:\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@%*#@@@@@@\n", | |
| "@@@@@@@@@@@@%+::``-*@@@@\n", | |
| "@@@@@@@@@@@=`'''`'`+@@@@\n", | |
| "@@@@@@@@@@+''`:````*@@@@\n", | |
| "@@@@@@@@@@%::**:'':@@@@@\n", | |
| "@@@@@@@@@@@%%@*```#@@@@@\n", | |
| "@@@@@@@@@@@@@*:'':@@@@@@\n", | |
| "@@@@@%%%%%%+=`````-=%@@@\n", | |
| "@@@@*`'''''''`''''''=@@@\n", | |
| "@@@*`'``````''::::::+@@@\n", | |
| "@@@='``````'`+#%%%%%%@@@\n", | |
| "@@@='````''=%@@@@@@@@@@@\n", | |
| "@@@+`````+%@@@@@@@@@@@@@\n", | |
| "@@@@%***%@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "Output: 2\n", | |
| "###\n", | |
| "\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "### Visualizing a full prompt\n", | |
| "\n", | |
| "Below we'll see a prompt that we can actually input to an LLM. For now we linearize the ASCII inputs, but it might be good to try the formatted ASCII too." | |
| ], | |
| "metadata": { | |
| "id": "9b9CV5YiWY61" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "context_prompt, shuffled_context_ix = get_context_prompt(\n", | |
| " train_set_ascii, context_ix,\n", | |
| " start='', end='', sep='###\\n', shuffle=True, seed=42\n", | |
| ")\n", | |
| "\n", | |
| "eval_prompt, eval_target = get_eval_prompt(\n", | |
| " test_set_ascii, ix=test_indices_by_class[0][0], start=''\n", | |
| ")\n", | |
| "# Linearized (what we'll actually send to an LLM)\n", | |
| "print(context_prompt + eval_prompt)\n" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "Qj6YnnAPWYFT", | |
| "outputId": "26da1599-c2c1-4e14-a11a-6b43d4ffb0dc" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "Input: @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@%+===*@@@@@@@@@@@@@@@@%=:`'''`-@@@@@@@@@@@@@@+``''```''+@@@@@@@@@@@@=''''''''`'-@@@@@@@@@@@:```:==-:````+@@@@@@@@@@`''`#@@@%-'``=@@@@@@@@@#``'`%@@@@%:'':%@@@@@@@@:`':#@@@@@@='':%@@@@@@@+`'-@@@@@@@@='':%@@@@@@@=``=@@@@@@@@-'':%@@@@@@@-'`#@@@@@@@+`'`=@@@@@@@%:':@@@@@@@%:''`#@@@@@@@*``-@@@@@%#-```=@@@@@@@@='``-=##=:`'```*@@@@@@@@=``''```'''````@@@@@@@@@@-'```''```'`-@@@@@@@@@@@+:'''''''':=%@@@@@@@@@@@%%=======+%@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "Output: 0\n", | |
| "###\n", | |
| "Input: @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@#@@@@@@@@@@@@@@@@@@@@@@@-%@@@@@@@@@@@@@@@@@@@@@@=%@@@@@@@@@@@@@@@@@@@@@@=*@@@@@@@@@@@@@@@@@@@@@@=+@@@@@@@@@@@@@@@@@@@@@@--@@@@@@@@@@@@@@@@@@@@@@--@@@@@@@@@@@@@@@@@@@@@@=:@@@@@@@@@@@@@@@@@@@@@@*`@@@@@@@@@@@@@@@@@@@@@@#`@@@@@@@@@@@@@@@@@@@@@@%:@@@@@@@@@@@@@@@@@@@@@@@-*@@@@@@@@@@@@@@@@@@@@@@==@@@@@@@@@@@@@@@@@@@@@@=-@@@@@@@@@@@@@@@@@@@@@@*:%@@@@@@@@@@@@@@@@@@@@@%:#@@@@@@@@@@@@@@@@@@@@@@:+@@@@@@@@@@@@@@@@@@@@@@**@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "Output: 1\n", | |
| "###\n", | |
| "Input: @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@%*#@@@@@@@@@@@@@@@@@@%+::``-*@@@@@@@@@@@@@@@=`'''`'`+@@@@@@@@@@@@@@+''`:````*@@@@@@@@@@@@@@%::**:'':@@@@@@@@@@@@@@@@%%@*```#@@@@@@@@@@@@@@@@@@*:'':@@@@@@@@@@@%%%%%%+=`````-=%@@@@@@@*`'''''''`''''''=@@@@@@*`'``````''::::::+@@@@@@='``````'`+#%%%%%%@@@@@@='````''=%@@@@@@@@@@@@@@+`````+%@@@@@@@@@@@@@@@@@%***%@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "Output: 2\n", | |
| "###\n", | |
| "Input: @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@#=*@@@@@@@@@@@@@@@@@@@@@:':%@@@@@@@@@@@@@@@@@@@+`':%@@@@@@@@@@@@@@@@@%=````=*@@@@@@@@@@@@@@@@+'``'''`+@@@@@@@@@@@@@@#`'`':=-`'#@@@@@@@@@@@@%-`'`-#@@-'-%@@@@@@@@@@@#''`*%@@@#`':@@@@@@@@@@@#'`*@@@@@@='`@@@@@@@@@@@#'=@@@@@@@='`*@@@@@@@@@@-'=@@@@@@#`'`@@@@@@@@@@@:'=@@@@@*:'`=@@@@@@@@@@@:'=@@@@+`'`:@@@@@@@@@@@@:'=%+--```'#@@@@@@@@@@@@+'```'''``-@@@@@@@@@@@@@%:`'``'`:*@@@@@@@@@@@@@@@%=`'`:-%@@@@@@@@@@@@@@@@@@#+*%@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "Output: \n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "eval_target # Double check this should be 0" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "D8M-aUoqWnON", | |
| "outputId": "f3a1e852-7def-45d3-8024-605a92f72e31" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "execute_result", | |
| "data": { | |
| "text/plain": [ | |
| "0" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "execution_count": 179 | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Formatting the above (for sanity checking)\n", | |
| "for i, ix in enumerate(shuffled_context_ix):\n", | |
| " print(f'Input:')\n", | |
| " pretty_print_ascii(\n", | |
| " img_to_text(train_set_ascii.__getitem__(ix)[0]),\n", | |
| " **img_dims\n", | |
| " )\n", | |
| " print(f'Output: {train_set_ascii.__getitem__(ix)[1]}\\n###\\n')\n", | |
| "\n", | |
| "# Eval sample\n", | |
| "print(f'Input:')\n", | |
| "pretty_print_ascii(\n", | |
| " img_to_text(test_set_ascii.__getitem__(test_indices_by_class[0][0])[0]),\n", | |
| " **img_dims\n", | |
| ")\n", | |
| "print(f'Output:')" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "X22Ad8hjw7Fq", | |
| "outputId": "52ed577c-657c-49ce-bfe2-2252768998c3" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "Input:\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@%+===*@@@@@@@\n", | |
| "@@@@@@@@@%=:`'''`-@@@@@@\n", | |
| "@@@@@@@@+``''```''+@@@@@\n", | |
| "@@@@@@@=''''''''`'-@@@@@\n", | |
| "@@@@@@:```:==-:````+@@@@\n", | |
| "@@@@@@`''`#@@@%-'``=@@@@\n", | |
| "@@@@@#``'`%@@@@%:'':%@@@\n", | |
| "@@@@@:`':#@@@@@@='':%@@@\n", | |
| "@@@@+`'-@@@@@@@@='':%@@@\n", | |
| "@@@@=``=@@@@@@@@-'':%@@@\n", | |
| "@@@@-'`#@@@@@@@+`'`=@@@@\n", | |
| "@@@%:':@@@@@@@%:''`#@@@@\n", | |
| "@@@*``-@@@@@%#-```=@@@@@\n", | |
| "@@@='``-=##=:`'```*@@@@@\n", | |
| "@@@=``''```'''````@@@@@@\n", | |
| "@@@@-'```''```'`-@@@@@@@\n", | |
| "@@@@+:'''''''':=%@@@@@@@\n", | |
| "@@@@%%=======+%@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "Output: 0\n", | |
| "###\n", | |
| "\n", | |
| "Input:\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@#@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@-%@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@=%@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@=*@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@=+@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@--@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@--@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@=:@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@*`@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@#`@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@%:@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@-*@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@==@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@=-@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@*:%@@@@@@@@@\n", | |
| "@@@@@@@@@@@@%:#@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@:+@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@**@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "Output: 1\n", | |
| "###\n", | |
| "\n", | |
| "Input:\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@%*#@@@@@@\n", | |
| "@@@@@@@@@@@@%+::``-*@@@@\n", | |
| "@@@@@@@@@@@=`'''`'`+@@@@\n", | |
| "@@@@@@@@@@+''`:````*@@@@\n", | |
| "@@@@@@@@@@%::**:'':@@@@@\n", | |
| "@@@@@@@@@@@%%@*```#@@@@@\n", | |
| "@@@@@@@@@@@@@*:'':@@@@@@\n", | |
| "@@@@@%%%%%%+=`````-=%@@@\n", | |
| "@@@@*`'''''''`''''''=@@@\n", | |
| "@@@*`'``````''::::::+@@@\n", | |
| "@@@='``````'`+#%%%%%%@@@\n", | |
| "@@@='````''=%@@@@@@@@@@@\n", | |
| "@@@+`````+%@@@@@@@@@@@@@\n", | |
| "@@@@%***%@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "Output: 2\n", | |
| "###\n", | |
| "\n", | |
| "Input:\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@#=*@@@@@@@@@@\n", | |
| "@@@@@@@@@@@:':%@@@@@@@@@\n", | |
| "@@@@@@@@@@+`':%@@@@@@@@@\n", | |
| "@@@@@@@@%=````=*@@@@@@@@\n", | |
| "@@@@@@@@+'``'''`+@@@@@@@\n", | |
| "@@@@@@@#`'`':=-`'#@@@@@@\n", | |
| "@@@@@@%-`'`-#@@-'-%@@@@@\n", | |
| "@@@@@@#''`*%@@@#`':@@@@@\n", | |
| "@@@@@@#'`*@@@@@@='`@@@@@\n", | |
| "@@@@@@#'=@@@@@@@='`*@@@@\n", | |
| "@@@@@@-'=@@@@@@#`'`@@@@@\n", | |
| "@@@@@@:'=@@@@@*:'`=@@@@@\n", | |
| "@@@@@@:'=@@@@+`'`:@@@@@@\n", | |
| "@@@@@@:'=%+--```'#@@@@@@\n", | |
| "@@@@@@+'```'''``-@@@@@@@\n", | |
| "@@@@@@%:`'``'`:*@@@@@@@@\n", | |
| "@@@@@@@%=`'`:-%@@@@@@@@@\n", | |
| "@@@@@@@@@#+*%@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
| "Output:\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "# GPT-3 Few-shot Classification\n", | |
| "\n", | |
| "Time to actually classify MNIST with a language model." | |
| ], | |
| "metadata": { | |
| "id": "AGoCzDdAWp1l" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "import openai" | |
| ], | |
| "metadata": { | |
| "id": "LD9VrIw0XI9v" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Some helpers\n", | |
| "def get_gpt_output(prompt, model):\n", | |
| " response = openai.Completion.create(\n", | |
| " prompt=prompt,\n", | |
| " model=model,\n", | |
| " # Rest are defaults\n", | |
| " temperature=0,\n", | |
| " max_tokens=7,\n", | |
| " top_p=1.0,\n", | |
| " n=1,\n", | |
| " )\n", | |
| " return response\n", | |
| "\n", | |
| "def decode_response(response):\n", | |
| " result = response['choices'][0]['text']\n", | |
| " try:\n", | |
| " return int(result)\n", | |
| " except Exception as e:\n", | |
| " print(e)\n", | |
| " return -1" | |
| ], | |
| "metadata": { | |
| "id": "UkvSNAVJXRQM" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "### Evaluation\n", | |
| "\n", | |
| "For evaluation, we'll pick a couple examples from the MNIST test set. In particular, we'll sample an equal number of datapoints from each class. We'll test the performance of the LLM via overall and class-specific accuracy." | |
| ], | |
| "metadata": { | |
| "id": "T64YRAhCXVz1" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "def eval_predictions(predictions, classes):\n", | |
| " \"\"\"\n", | |
| " Report average and group-wise accuracies\n", | |
| " \"\"\"\n", | |
| " correct = np.array(predictions) == np.array(classes)\n", | |
| " class_correct = []\n", | |
| " class_total = []\n", | |
| " for ix, c in enumerate(np.unique(classes)):\n", | |
| " c_ix = np.where(np.array(classes) == c)[0]\n", | |
| " class_correct.append(correct[c_ix].sum())\n", | |
| " class_total.append(len(c_ix))\n", | |
| " avg_acc = np.mean(correct) * 100\n", | |
| " class_accs = np.array(class_correct) / np.array(class_total) * 100\n", | |
| " print(f'Average acc: {avg_acc:.1f}%')\n", | |
| " print(f'Class accs:')\n", | |
| " for ix, c in enumerate(class_accs):\n", | |
| " print(f'- {class_correct[ix]} / {class_total[ix]} = {c:.1f}%')\n", | |
| " return avg_acc, class_accs" | |
| ], | |
| "metadata": { | |
| "id": "ALRaEE-FX_9v" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "### Few-shot Prompt Template\n", | |
| "To make prompts, we'll use the same context for each sample: an equal number of samples from each group in the MNIST train set." | |
| ], | |
| "metadata": { | |
| "id": "x5861FmKX8Qw" | |
| } | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "## Actual GPT-3 Demo" | |
| ], | |
| "metadata": { | |
| "id": "-5CBGw16YEJQ" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "openai.api_key = \"sk-\" # Your OpenAI API key goes here" | |
| ], | |
| "metadata": { | |
| "id": "pHiH1Tj8YDUS" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "models = [\n", | |
| " 'text-ada-001',\n", | |
| " 'text-babbage-001',\n", | |
| " 'text-curie-001',\n", | |
| " 'text-davinci-002'\n", | |
| "]" | |
| ], | |
| "metadata": { | |
| "id": "reMB79l6YInh" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# 1) Get samples to serve as \"context\"\n", | |
| "args.seed_train = 42\n", | |
| "args.train_samples_per_class = 4\n", | |
| "img_dims = {'width': args.width, 'height': args.width}\n", | |
| "\n", | |
| "context_ix, context_y = get_random_indices_by_class(\n", | |
| " indices_by_class,\n", | |
| " class_indices,\n", | |
| " args.train_samples_per_class,\n", | |
| " args.seed_train\n", | |
| ")" | |
| ], | |
| "metadata": { | |
| "id": "z32une07YJyg" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# 2) Get examples to few-shot classify\n", | |
| "args.seed_eval = 42\n", | |
| "args.eval_samples_per_group = 10\n", | |
| "test_eval_ix, test_eval_y = get_random_indices_by_class(\n", | |
| " test_indices_by_class,\n", | |
| " class_indices,\n", | |
| " args.eval_samples_per_group,\n", | |
| " args.seed_eval\n", | |
| ")" | |
| ], | |
| "metadata": { | |
| "id": "dJ7lQKYFYLLS" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# 3) Get prompt context\n", | |
| "for seed in range(20):\n", | |
| " context_prompt, shuffled_context_ix = get_context_prompt(\n", | |
| " train_set_ascii, context_ix,\n", | |
| " start='', end='', sep='###\\n', shuffle=True,\n", | |
| " seed=seed # May need to change this to shuffle\n", | |
| " )\n", | |
| "\n", | |
| " # Inspect context order (want to avoid patterns between outputs, e.g., 0, 1, 0, 1, ...)\n", | |
| " print(seed, train_set.targets[shuffled_context_ix])" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "goTJZGxaYMsi", | |
| "outputId": "fbfc3326-c66a-40d8-bf4b-8fbf3ee19260" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "0 tensor([1, 2, 1, 2, 0, 2, 0, 1, 2, 0, 0, 1])\n", | |
| "1 tensor([0, 0, 1, 2, 0, 1, 0, 1, 2, 2, 2, 1])\n", | |
| "2 tensor([2, 1, 1, 1, 0, 2, 0, 0, 0, 1, 2, 2])\n", | |
| "3 tensor([1, 1, 0, 0, 2, 1, 1, 0, 0, 2, 2, 2])\n", | |
| "4 tensor([0, 1, 1, 2, 2, 0, 2, 0, 0, 1, 1, 2])\n", | |
| "5 tensor([1, 1, 0, 2, 2, 1, 2, 0, 0, 2, 1, 0])\n", | |
| "6 tensor([2, 2, 0, 1, 0, 1, 1, 0, 1, 0, 2, 2])\n", | |
| "7 tensor([1, 2, 0, 1, 0, 0, 2, 2, 0, 1, 2, 1])\n", | |
| "8 tensor([1, 2, 2, 2, 1, 2, 0, 0, 1, 0, 1, 0])\n", | |
| "9 tensor([1, 2, 0, 0, 0, 1, 0, 2, 2, 2, 1, 1])\n", | |
| "10 tensor([0, 1, 1, 1, 2, 2, 2, 0, 0, 0, 1, 2])\n", | |
| "11 tensor([1, 0, 2, 2, 0, 2, 1, 1, 1, 0, 0, 2])\n", | |
| "12 tensor([2, 1, 2, 2, 0, 1, 1, 0, 0, 0, 1, 2])\n", | |
| "13 tensor([1, 2, 0, 1, 2, 0, 1, 2, 1, 2, 0, 0])\n", | |
| "14 tensor([0, 2, 0, 1, 1, 0, 2, 0, 1, 1, 2, 2])\n", | |
| "15 tensor([2, 1, 0, 0, 2, 0, 1, 0, 1, 2, 1, 2])\n", | |
| "16 tensor([1, 0, 0, 2, 1, 2, 1, 0, 0, 1, 2, 2])\n", | |
| "17 tensor([0, 0, 2, 2, 1, 1, 2, 1, 0, 2, 1, 0])\n", | |
| "18 tensor([1, 1, 2, 0, 2, 0, 0, 1, 1, 2, 0, 2])\n", | |
| "19 tensor([1, 0, 2, 2, 0, 1, 2, 1, 0, 2, 0, 1])\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# 3.1) Get prompt context (cont.)\n", | |
| "context_prompt, shuffled_context_ix = get_context_prompt(\n", | |
| " train_set_ascii, context_ix,\n", | |
| " start='', end='', sep='###\\n', shuffle=True,\n", | |
| " seed=16 # Pick a random-looking one from above\n", | |
| ")" | |
| ], | |
| "metadata": { | |
| "id": "i3EzD6I3YQbp" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# 4) Get full prompts for each few-shot prompt, and get GPT outputs\n", | |
| "model_name = 'text-davinci-003'\n", | |
| "all_responses = []\n", | |
| "all_targets = []\n", | |
| "for test_ix in test_eval_ix:\n", | |
| " # Calling this line will give OpenAI money\n", | |
| " eval_prompt, target = get_eval_prompt(test_set_ascii, ix=test_ix, start='')\n", | |
| " _prompt = context_prompt + eval_prompt\n", | |
| " response = get_gpt_output(_prompt, model=model_name)\n", | |
| " all_responses.append(response)\n", | |
| " all_targets.append(target)" | |
| ], | |
| "metadata": { | |
| "id": "prLD5YMLYZUr" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# 5) Evaluate\n", | |
| "\n", | |
| "# Report accuracy metrics\n", | |
| "avg_acc, class_accs = eval_predictions(\n", | |
| " [decode_response(r) for r in all_responses],\n", | |
| " [test_set.targets[test_ix] for test_ix in test_eval_ix],\n", | |
| ")\n", | |
| "\n", | |
| "# Report predicted vs true targets\n", | |
| "print(f'GPT-3 {model_name} few-shot predictions')\n", | |
| "print([decode_response(r) for r in all_responses])\n", | |
| "\n", | |
| "print(f'Ground-truth labels')\n", | |
| "print([test_set.targets[test_ix].item() for test_ix in test_eval_ix])" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "siiYoSPGYm5w", | |
| "outputId": "91fbfbda-5e52-4090-825f-3a0a5ad457b5" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "Average acc: 66.7%\n", | |
| "Class accs:\n", | |
| "- 4 / 10 = 40.0%\n", | |
| "- 10 / 10 = 100.0%\n", | |
| "- 6 / 10 = 60.0%\n", | |
| "GPT-3 text-davinci-003 few-shot predictions\n", | |
| "[0, 2, 0, 0, 2, 0, 2, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 3, 2, 2, 0, 0, 2, 2, 2, 2]\n", | |
| "Ground-truth labels\n", | |
| "[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "### Results\n", | |
| "So... better than chance! Our toy idea *seems* to work on our toy example (of course not battle-tested across seeds, different prompts, etc...) \n", | |
| "* If you wanted to twist your arms into interpretating the above, seems plausible that the `1`s are more distinctive than the `0`s and `2`s, and hence the higher accuracy (also good precision and recall there).\n", | |
| "\n", | |
| "No further analysis here; just an art project. But some bonuses below!" | |
| ], | |
| "metadata": { | |
| "id": "CX8Nmw4x0Zxz" | |
| } | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "# Bonus 1: ChatGPT\n", | |
| "\n", | |
| "We can also try the above with ChatGPT. I'm not going to because I don't want to paste all these examples into their text interface." | |
| ], | |
| "metadata": { | |
| "id": "Mw0cFIhVY1TW" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# # Run this block to get prompts to paste for ChatGPT\n", | |
| "# str_classes = ', '.join([f\"'{c}'\" for c in class_indices])\n", | |
| "# str_classes = '[' + str_classes + ']'\n", | |
| "# prompt_prefix = f\"In the text below, what is the most likely next character? Choose from a character in {str_classes}:\"\n", | |
| "# all_prompts = []\n", | |
| "# all_targets = []\n", | |
| "# for test_ix in test_eval_ix:\n", | |
| "# eval_prompt, eval_target = get_eval_prompt(\n", | |
| "# test_set_ascii, ix=test_ix, start=''\n", | |
| "# )\n", | |
| "# _prompt = context_prompt + eval_prompt\n", | |
| "# print('\\n', '-' * 20, f'Sample {test_ix} (target = {eval_target})', '-' * 20)\n", | |
| "# print(prompt_prefix)\n", | |
| "# print(_prompt)\n", | |
| "# all_prompts.append(_prompt)\n", | |
| "# all_targets.append(eval_target)" | |
| ], | |
| "metadata": { | |
| "id": "OuG1VxwEZAJm" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "## Bonus 2: CIFAR-10 classification (CIFAR-3)\n", | |
| "\n", | |
| "So we get better than random accuracy on MNIST. But what about \"real\" images?\n", | |
| "\n", | |
| "Like CIFAR-10? (Yes `channel x height x width` dims of `3 x 32 x 32` is hardly real, but it's progress)\n", | |
| "\n", | |
| "Let's try it out with the exact pipeline above!" | |
| ], | |
| "metadata": { | |
| "id": "6URKLSmzmc3N" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Load original data\n", | |
| "train_set = torchvision.datasets.CIFAR10(root='.', train=True, transform=transform_base, download=True)\n", | |
| "test_set = torchvision.datasets.CIFAR10(root='.', train=False, transform=transform_base, download=True)\n", | |
| "\n", | |
| "train_set.targets = torch.tensor(train_set.targets) # CIFAR-10 targets are a list by default\n", | |
| "test_set.targets = torch.tensor(test_set.targets)\n" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "RJQ0mOt2oo7Z", | |
| "outputId": "9bcc8bf8-24d1-4698-ae0d-f4c050b5d1e9" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "Files already downloaded and verified\n", | |
| "Files already downloaded and verified\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Load ASCII data\n", | |
| "train_set_ascii = torchvision.datasets.CIFAR10(root='.', train=True,\n", | |
| " transform=transform_grayscale, download=False)\n", | |
| "test_set_ascii = torchvision.datasets.CIFAR10(root='.', train=False,\n", | |
| " transform=transform_grayscale, download=False)\n", | |
| "\n", | |
| "train_set_ascii.targets = torch.tensor(train_set_ascii.targets)\n", | |
| "test_set_ascii.targets = torch.tensor(test_set_ascii.targets)" | |
| ], | |
| "metadata": { | |
| "id": "TW6Xu-2Trxt5" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Also do subsampled evaluation by class\n", | |
| "class_names = ('plane', 'car', 'bird', 'cat',\n", | |
| " 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')" | |
| ], | |
| "metadata": { | |
| "id": "8zyO_YyFpf8H" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Organize train and test into classes\n", | |
| "num_classes = len(np.unique(train_set.targets)) # 10\n", | |
| "\n", | |
| "# TRAIN\n", | |
| "indices_by_class = get_indices_by_class(train_set)\n", | |
| "\n", | |
| "# TEST\n", | |
| "test_indices_by_class = get_indices_by_class(test_set)" | |
| ], | |
| "metadata": { | |
| "id": "1HE-RWKNp2jq" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "class_indices = [0, 1, 2] # These are the classes we'll evaluate -> planes vs cars vs birds\n", | |
| "print([class_names[i] for i in class_indices])" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "wjndnBc6pliP", | |
| "outputId": "42b9e68a-6cec-44f8-a815-cdaee12ed78f" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "['plane', 'car', 'bird']\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "!pip install matplotlib\n", | |
| "\n", | |
| "import matplotlib.pyplot as plt # so we can visualize the images below" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "wWA2EzbisliH", | |
| "outputId": "88fe462f-35d8-4d1e-879e-e221a04916e6" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", | |
| "Requirement already satisfied: matplotlib in /usr/local/lib/python3.8/dist-packages (3.2.2)\n", | |
| "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib) (1.4.4)\n", | |
| "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.8/dist-packages (from matplotlib) (0.11.0)\n", | |
| "Requirement already satisfied: numpy>=1.11 in /usr/local/lib/python3.8/dist-packages (from matplotlib) (1.21.6)\n", | |
| "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib) (2.8.2)\n", | |
| "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib) (3.0.9)\n", | |
| "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.8/dist-packages (from python-dateutil>=2.1->matplotlib) (1.15.0)\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Visualize original image vs ASCII\n", | |
| "args.seed_cls = 42\n", | |
| "args.num_samples = 1\n", | |
| "img_dims = {'width': args.width, 'height': args.width}\n", | |
| "classes = range(num_classes)\n", | |
| "\n", | |
| "context_ix, context_y = get_random_indices_by_class(\n", | |
| " indices_by_class,\n", | |
| " class_indices,\n", | |
| " args.num_samples,\n", | |
| " args.seed_cls\n", | |
| ")\n", | |
| "\n", | |
| "# Show samples\n", | |
| "for i, ix in enumerate(context_ix):\n", | |
| " target = train_set_ascii.__getitem__(ix)[1]\n", | |
| " print('=' * 4, f'Sample {i}, class = {class_names[target]}', '=' * 4)\n", | |
| " print(f'Original:')\n", | |
| " plt.imshow(train_set.__getitem__(ix)[0])\n", | |
| " plt.show()\n", | |
| "\n", | |
| " print(f'\\nASCII: ')\n", | |
| " pretty_print_ascii(\n", | |
| " img_to_text(train_set_ascii.__getitem__(ix)[0]),\n", | |
| " **img_dims\n", | |
| " )\n", | |
| " print('\\n')\n" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 1000 | |
| }, | |
| "id": "OOrij5e8uB1L", | |
| "outputId": "37ea0d0c-7752-4f3f-9e48-8f21e25430d2" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "==== Sample 0, class = plane ====\n", | |
| "Original:\n" | |
| ] | |
| }, | |
| { | |
| "output_type": "display_data", | |
| "data": { | |
| "text/plain": [ | |
| "<Figure size 432x288 with 1 Axes>" | |
| ], | |
| "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAS4ElEQVR4nO3dXWxd1ZUH8P+fYOcLCLFDTBIjCF9CphIBLL6pGFVUhBfoCyoPVSrQpA8gtVIfipiH8ohG01YIjSqlA2o6KlSVWkQEaKZMVBRVQgETAiQEJnwE7ODEwRbBIQRwsubBh8oEn7Uud99zzp3u/0+KbN/lfc6+23flfqyz96aZQUT+8Z3SdAdEpB5KdpFMKNlFMqFkF8mEkl0kE6fWebK+vj4bHBxsuz3JtmKdkHL8bu5bk9WYf+RKUHTfqrrvo6OjmJqamvcBkZTsJG8B8BCABQD+w8we9H5/cHAQW7ZsKY2fcor/QqO3t7c01tPT47ZNTTivb1G/o3jUt9T2ntQHZUr8xIkTbtvjx4+78WhcPNG5Ux8vMzMzbty7b1HfvDFdv359aazt0SK5AMC/A1gPYAjAnSSH2j2eiFQr5T37VQDeMrN3zOxzAH8AcFtnuiUinZaS7GsAjM75eay47StIbiQ5QnJkcnIy4XQikqLyT+PNbJOZDZvZcH9/f9WnE5ESKcm+H8A5c34eLG4TkS6UkuwvAriI5FqSvQC+D6D8o3YRaVTbpTczmyF5L4D/xmzp7VEz253SmSrLOKnlrRRVl3m89tGYRueO2qeUiaLSWuq5UzRZkky5317bpDq7mT0D4JmUY4hIPXS5rEgmlOwimVCyi2RCyS6SCSW7SCaU7CKZqHU+O+DXs6uc993kNNKobZXxqufSR+Pm1YQXLFjgtq2ylh31O3W+eXT86BoDjzdu3t9bz+wimVCyi2RCyS6SCSW7SCaU7CKZULKLZKKrSm8p0yUjVZagUss0qaW3lDGNpE6BTSkLNjnFNfXxktK3qGyn0puIuJTsIplQsotkQskukgklu0gmlOwimVCyi2Si1jo7SbdGmLobanTulGNXOcU19dxubTVxKmdUL075m1S9rXFK+yqX0AbSpnq3O+Z6ZhfJhJJdJBNKdpFMKNlFMqFkF8mEkl0kE0p2kUx0VZ09pabbzcs1p14/kFLjT62zN7kOQJN19tS58il/05S2Xiwp2UnuAzAN4DiAGTMbTjmeiFSnE8/s/2RmH3bgOCJSIb1nF8lEarIbgL+QfInkxvl+geRGkiMkRyYnJxNPJyLtSk32G8zsCgDrAdxD8tsn/4KZbTKzYTMb7u/vTzydiLQrKdnNbH/xdQLAEwCu6kSnRKTz2k52kktJnv7l9wC+C2BXpzomIp2V8mn8AIAnirreqQAeM7P/ihp5dfYq1yBPmXedeu7UWnWT20VHUurRVa5JD/h9S5lv3sq5o7iXB6lbWZdpO9nN7B0Al7XbXkTqpdKbSCaU7CKZULKLZELJLpIJJbtIJmqf4uqVNKKSQ8q0wEiTU1yrnp5bpSrPHR37+PHjbbeP/iZRSTF6rEbH7+3tLY1FpTXvfmvLZhFRsovkQskukgklu0gmlOwimVCyi2RCyS6SiVrr7IBff4zqpp6qa9Vev1OXiq5yOefUenGkyjp76rUT7dajW4lH4/bFF1+48Q8/LF+j9bTTTnPb9vT0uPEyemYXyYSSXSQTSnaRTCjZRTKhZBfJhJJdJBNKdpFM1F5n91Q5b7vKbZWPHj3qtj1w4IAb//jjj934ihUr3LhX843qvX19fW7cqwcDwAUXXODGvXH77LPP3LZvv/22G4/GZfny5W7cc+qpfmpEj5fovj311FOlsaGhIbft9ddfXxrTfHYRUbKL5ELJLpIJJbtIJpTsIplQsotkQskukona6+wp66+3e1wA+Pzzz914VCs/fPhwaWxqaspt++6777rxqJ68bNkyNz4wMFAaW7x4sdt2enrajUfz4fv7+924V8ePjv3KK6+48ajG7/VtZmbGbRv1LaqzR4+nQ4cOlcb27Nnjtl29enVpzHuch8/sJB8lOUFy15zb+kg+S3Jv8bX9qxdEpBatvIz/LYBbTrrtPgBbzewiAFuLn0Wki4XJbmbbAJz8OvU2AJuL7zcDuL3D/RKRDmv3A7oBMxsvvj8AoPRNI8mNJEdIjkxOTrZ5OhFJlfxpvM2ulli6YqKZbTKzYTMbjj7MEZHqtJvsB0muAoDi60TnuiQiVWg32bcA2FB8vwHAk53pjohUJayzk3wcwE0AVpAcA/BzAA8C+CPJuwG8B+COTnQmqm16tfSojj4x4b/4GBsbc+O7d+8ujUXz1Y8cOeLGo/nsS5cudeNezda7PgAAPv30Uzd+ySWXuPGnn37ajadcO7F///6k+Msvv1wai/YoSHksAnGdfXR0tDQWXXfhrTHgXfMRJruZ3VkS+k7UVkS6hy6XFcmEkl0kE0p2kUwo2UUyoWQXyUTtU1xTtmX2phVG5avnnnvOje/atcuNe2WeY8eOuW2jMk1U5lm4cKEb95YtTi0h7d27142PjIy4cW+6ZlT2i6aRRldkrly5sjQWbZO9ZMkSNz44OOjGI97fZXh42G3rLSX92GOPlcb0zC6SCSW7SCaU7CKZULKLZELJLpIJJbtIJpTsIpmotc4+PT2Nbdu2lcYvvvhit723LHJUZ3/jjTfc+Pbt2924N2Wxt7fXbdvT0+PGo1p3tH2wN4X2jDPOcNtG1wiMj4+78ei6CW/cvK2mgXi76UWLFrnx008/vTT25ptvum2j6xNuueXkNVi/KtqyeefOnaWxNWvWuG3PPPPM0pg3pnpmF8mEkl0kE0p2kUwo2UUyoWQXyYSSXSQTSnaRTNRaZ5+YmMDDDz9cGr/mmmvc9pdffnnb5/Zqk0Bc23z//fdLY1EtOqqjR/XmaF63V+s+77zz3LYpWwsD8ZbP3jUCZ511lts2Gpeolu1dexHNV7/wwgvd+PLl/sbF0dLmZ599dmnMe6xFvDUC9Mwukgklu0gmlOwimVCyi2RCyS6SCSW7SCaU7CKZqLXOfuzYMXcecbQ+urdVbbSt8cDAgBu/66673PiOHTtKYx988IHbtq+vz41HNdlorr23NbFXzwX8Od9AvH1wNOd89erVpbGoVh3Fb775Zje+du3a0lhUw4+uAYjmu0fXAFx55ZWlsWiNAW9ckuazk3yU5ATJXXNue4DkfpI7i3+3RscRkWa18jL+twDmW5bjV2a2rvj3TGe7JSKdFia7mW0DMFVDX0SkQikf0N1L8tXiZX7pmwiSG0mOkByJ3ueISHXaTfZfA7gAwDoA4wB+UfaLZrbJzIbNbDia0CEi1Wkr+8zsoJkdN7MTAH4D4KrOdktEOq2tZCe5as6P3wPg73csIo0L6+wkHwdwE4AVJMcA/BzATSTXATAA+wD8qJWTLVmyBOvWrSuNR+uve3N1ozXEo/24I95+3Ndee63b9txzz3XjExMTbnxsbMyNe3OzL730UrftzMyMG4/m+U9Ntf/ZbTRn/LLLLnPj1113nRv31hGIrg+I4tHjKVonwGsfjblbS3feKofJbmZ3znPzI1E7Eeku+sRMJBNKdpFMKNlFMqFkF8mEkl0kE7VOcV20aBGGhoZK49GSy950zGhr4qisFy2J7JW3Jicn3bbRdMdoO+kofsUVV5TGotJbVLLs7+9349EU2Iceeqg0FpUkzz//fDcejevo6GhpbN++fW7bqCQZbcMdXRruLf8dTe1dtmxZacybLq1ndpFMKNlFMqFkF8mEkl0kE0p2kUwo2UUyoWQXyUStdfYFCxa4Sz57tccoHtXJo1VyopqtV1c9ePBg222BuOYbLUXt1flfeOEFt623pTIQL7kcbensTUuOtrp+/vnn3Xh0Xcbhw4dLY9GYR3XyaHpudH2C93iMHstend09Z1utROT/HSW7SCaU7CKZULKLZELJLpIJJbtIJpTsIpmotc5OEosXL3bjHq/OHs0/jpYG9mqyAHDkyJG2zx2J5uJfffXVbtyrhUdjGo1LtCSy9/cEgBtvvLE0Fq0xEG2FHbX/6KOPSmPR/Y6uP4iurYi2Xfbq8N7aCYC/9Lh3v/TMLpIJJbtIJpTsIplQsotkQskukgklu0gmlOwimai1zg74NelozrlXM45qrtG87Ii3vnrKPHwgbY3xqH107KgOH83F99byB/x6depa/9F9W7lyZWnMm2cPxHXy6PqDTz75xI179fBoLX9vTJPWjSd5Dsm/knyd5G6SPy5u7yP5LMm9xVd/ZXsRaVQrL+NnAPzUzIYAXAPgHpJDAO4DsNXMLgKwtfhZRLpUmOxmNm5mO4rvpwHsAbAGwG0ANhe/thnA7VV1UkTSfaMP6EieB+ByANsBDJjZl4uIHQAwUNJmI8kRkiPR+xgRqU7LyU7yNAB/AvATM/t4bszMDIDN187MNpnZsJkNe4tNiki1Wkp2kj2YTfTfm9mfi5sPklxVxFcBKJ+KIyKNC0tvnK3NPAJgj5n9ck5oC4ANAB4svj4ZHevEiRNuaSDilaCiKYlRGScqMXntoymuUXkrEh3fK0Gl9m32RVv7ca9v3hTUVkTn9kRluygeTUONlv/2pgZHJUnvfnuP41bq7NcD+AGA10juLG67H7NJ/keSdwN4D8AdLRxLRBoSJruZ/Q1A2X//3+lsd0SkKrpcViQTSnaRTCjZRTKhZBfJhJJdJBO1T3GNpmt6vCmwUc01qu+n1pM90fTaaGpv1N6rCUfXD6ROz43iXp0/qmVHUv5mUdto3KL4woUL3bi3RXi0hLY35lpKWkSU7CK5ULKLZELJLpIJJbtIJpTsIplQsotkotY6u5m5ddeU5Z6jumnKMtWR1DnhqVLqyVE8um/ROgKeqEafOm5eHT96PERbOkfXbUxPT7vxlMdbtDZDGT2zi2RCyS6SCSW7SCaU7CKZULKLZELJLpIJJbtIJmqvs3u11ZSab1TvjeYfRzV+Lx7VTFPrySntU2vVqVtde1KvAUi5tiK6X9HjKdrSOeWxXNV1GXpmF8mEkl0kE0p2kUwo2UUyoWQXyYSSXSQTSnaRTLSyP/s5AH4HYACAAdhkZg+RfADAPwM4VPzq/Wb2THAst/aZMuc8qnWnzjlPmRuduj97xDt+6j7kUY0/Ze331Ln00bh7x4/6HZ07ZS3/6PhRjd9bk949rnvUWTMAfmpmO0ieDuAlks8WsV+Z2b+1cAwRaVgr+7OPAxgvvp8muQfAmqo7JiKd9Y3es5M8D8DlALYXN91L8lWSj5JcXtJmI8kRkiNHjx5N6qyItK/lZCd5GoA/AfiJmX0M4NcALgCwDrPP/L+Yr52ZbTKzYTMbXrJkSQe6LCLtaCnZSfZgNtF/b2Z/BgAzO2hmx83sBIDfALiqum6KSKow2Tn78d4jAPaY2S/n3L5qzq99D8CuzndPRDqllU/jrwfwAwCvkdxZ3HY/gDtJrsNsOW4fgB+1ckKvNBCVUlKkLjVdZektZRvrVo7vSSk5ttK+qrZAWrk1dcyj8lgU9+57VXnQyqfxfwMw36i5NXUR6S66gk4kE0p2kUwo2UUyoWQXyYSSXSQTSnaRTNS6lDTg1xCjrWhTtgeOapdR3Dt3So2+FSn16GgJ7ajvVV4DEB079dwp06lTjg3EU2C9rcuj+93uVG89s4tkQskukgklu0gmlOwimVCyi2RCyS6SCSW7SCZY1faw856MPATgvTk3rQDwYW0d+Ga6tW/d2i9AfWtXJ/t2rpmdNV+g1mT/2snJETMbbqwDjm7tW7f2C1Df2lVX3/QyXiQTSnaRTDSd7JsaPr+nW/vWrf0C1Ld21dK3Rt+zi0h9mn5mF5GaKNlFMtFIspO8heSbJN8ieV8TfShDch/J10juJDnScF8eJTlBctec2/pIPktyb/F13j32GurbAyT3F2O3k+StDfXtHJJ/Jfk6yd0kf1zc3ujYOf2qZdxqf89OcgGA/wVwM4AxAC8CuNPMXq+1IyVI7gMwbGaNX4BB8tsAjgD4nZl9q7jtXwFMmdmDxX+Uy83sZ13StwcAHGl6G+9it6JVc7cZB3A7gB+iwbFz+nUHahi3Jp7ZrwLwlpm9Y2afA/gDgNsa6EfXM7NtAKZOuvk2AJuL7zdj9sFSu5K+dQUzGzezHcX30wC+3Ga80bFz+lWLJpJ9DYDROT+Pobv2ezcAfyH5EsmNTXdmHgNmNl58fwDAQJOdmUe4jXedTtpmvGvGrp3tz1PpA7qvu8HMrgCwHsA9xcvVrmSz78G6qXba0jbedZlnm/G/a3Ls2t3+PFUTyb4fwDlzfh4sbusKZra/+DoB4Al031bUB7/cQbf4OtFwf/6um7bxnm+bcXTB2DW5/XkTyf4igItIriXZC+D7ALY00I+vIbm0+OAEJJcC+C66byvqLQA2FN9vAPBkg335im7Zxrtsm3E0PHaNb39uZrX/A3ArZj+RfxvAvzTRh5J+nQ/gleLf7qb7BuBxzL6s+wKzn23cDaAfwFYAewH8D4C+LurbfwJ4DcCrmE2sVQ317QbMvkR/FcDO4t+tTY+d069axk2Xy4pkQh/QiWRCyS6SCSW7SCaU7CKZULKLZELJLpIJJbtIJv4Pw1VAg23eYPsAAAAASUVORK5CYII=\n" | |
| }, | |
| "metadata": { | |
| "needs_background": "light" | |
| } | |
| }, | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "\n", | |
| "ASCII: \n", | |
| "::``````````````````````\n", | |
| "::``````````````````````\n", | |
| "::``````````````````````\n", | |
| "````````````````````````\n", | |
| "````````````````````````\n", | |
| "````````````````````````\n", | |
| "````````````````````````\n", | |
| "`````````````````````::`\n", | |
| "``````:-=::::`':-```'-=`\n", | |
| "````:-=+*+++=::-=::--=*=\n", | |
| "```-+*###*****####*****-\n", | |
| "-:`=#%%%%###%%@@%#%#=*=:\n", | |
| "#+-=*#@@%@%%@@@%%*#+-=+-\n", | |
| "%#*+-=+=+@%#@@##=--:----\n", | |
| "**++==-:+#+==*#+::---=--\n", | |
| "+++=====#%*+=*%===**+++=\n", | |
| "++++++*+*****##***#**++=\n", | |
| "+++++++++++++++++++****+\n", | |
| "+++++=+++++++++++++*++++\n", | |
| "++++++++++++++++++++++++\n", | |
| "++++++++++++++++++++****\n", | |
| "++++++=+++++++++++++***+\n", | |
| "++++++==++++++++++++++++\n", | |
| "+++++===++++++++++++++++\n", | |
| "\n", | |
| "\n", | |
| "==== Sample 1, class = car ====\n", | |
| "Original:\n" | |
| ] | |
| }, | |
| { | |
| "output_type": "display_data", | |
| "data": { | |
| "text/plain": [ | |
| "<Figure size 432x288 with 1 Axes>" | |
| ], | |
| "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAaB0lEQVR4nO2dXWxcZ5nH/8+cmfF4xo5jx4mbzza03bSltCkKpYguCwtUbfeicFPRC9SV0JYLKoHExSL2Ai7RagFxsUIKS0VZ8SFWgOguaGm3y1JAK2hoQ5s2pWkT58NN/BU79tjzPc9eZIrSNu//GNuZsfb9/yTL9jx+z3nnnPOfM57/+zyPuTuEEP//yfR6AkKI7iCxCxEJErsQkSCxCxEJErsQkZDt5s76sjnv7+sLxjPWpuPLlVow5uCuwshgP40PDg7R+PxyIxibm5+nY5OM8Xg2oXEzPr7Vaq16bBpp4zMpz63VCp+XdpuPNYSvFQBopYzvz4X3XUy58muZPI03kiLfQLNJw0kmHDfwsRkPX4vLlWXU6vXLHpg1id3M7gbwNQAJgH9x9y+xv+/v68P7b3pHMJ5Plun+fvPiK8FYs83F/rE7b6bxD/z1PTT+74cmgrF/++m/07GD/VzMw1s203iSy9H44sJCMJbN8rFIOW65HL/oCyUeL5frwVhlic8t8b00PlfhLwa3XhUWxa2j/HmPl3bT+Lmh/TTemp2l8eG+cLwvM0PHFhqTwdh//+aXwdiq38abWQLgnwHcA+AmAA+Y2U2r3Z4Q4sqylv/Zbwfwirsfd/c6gO8DuG99piWEWG/WIvadAE5f8vuZzmNvwMweMrNDZnao3gy/rRJCXFmu+Kfx7n7Q3Q+4+4F82v+PQogrxlrEPgHg0k8xdnUeE0JsQNYi9qcBXG9me80sD+BjAB5bn2kJIdabVVtv7t40s4cB/BwXrbdH3P2FtHEtC9tQJ2aW6NhyI+yr3vC2t3xc8AbGdnHr7fCRsJ0BAMdPvhiM5fPcxikvXqDxSo1bjq02X39Qq4XtrWyWn+JGg3u6aWTy/H7RXwyvX3j7vgN07NbCGI0//ew4jZ+eKQdj1vduOvaV2i00PjXFj9uQDdD4TD58Xkol/u/unrFtwVg7+7tgbE0+u7v/DMDP1rINIUR30HJZISJBYhciEiR2ISJBYhciEiR2ISJBYhciErqaz962BLVsKRg/U+bTGdn1nmBs7/X76Nhm9ioaP3H6MI1PTYd9+KyFnxMA9JW4J8vy9AFgOSUOD/vwbZLrDgDtNVYXbtVScvWT8NwbrQode811IzRebvD1Cc8fCa/pOLU8SMdWh8JeNgBYyvKE5VbY4weABqm/sIRNdGyhEH5ejUwhGNOdXYhIkNiFiASJXYhIkNiFiASJXYhIkNiFiISuWm+ZTILCQDjlsa9/mI6/+R3vC8YaTV7l9KVTPM10qcbTayuVcEmtUp5bJYNF/pq6ZZhXn23yDFdqr6WVqW4T2w4AqvVw+iwAwPhzqzXCczt98lU6dmInt792XL2FxqenwtdTNc9tvaEhXno8n+Wlx70SrvgLADj322BozMfp0NJ0OJ07aYTPl+7sQkSCxC5EJEjsQkSCxC5EJEjsQkSCxC5EJEjsQkRCV332XDbBjs1hf/I0wuWaAcDLJ4OxhTr3upHw1lMpNjvaHj5UI8PhtEIAuGpwB43nCmktnflrcqMRfm6GlJbKxlNcK6RMNQBYhl9CC0vhVM+J13hPkaNH+PWw77obaHzLSLiccznP1x9cyPDU4IalpB1nuM++tRROmb5rH7+Wl0lp8ZfJ6dCdXYhIkNiFiASJXYhIkNiFiASJXYhIkNiFiASJXYhI6KrPvqlUxAfuCLfp7atP0/GvzYbzn/tL3HMt7eHtgSdnee50fy68/cIIz43O9fHcaDR4rv3keV4yeZq0uk7aVTo2A+6jD43ytsnZlOe2MP9aMJYUttOxyzXudY+/zOs5Z0jJ5aUcL2Odab5E48Ucv0/WKzM0nrO5YGzrLt5O+sJS+HpgazLWJHYzGwewCKAFoOnuXFFCiJ6xHnf2D7g7fxkTQvQc/c8uRCSsVewO4HEz+72ZPXS5PzCzh8zskJkdYuukhRBXlrW+jb/T3SfMbBuAJ8zsJXd/6tI/cPeDAA4CwLW7r15bYzEhxKpZ053d3Sc636cA/BjA7esxKSHE+rNqsZtZycwGX/8ZwF0AjqzXxIQQ68ta3saPAfixmb2+ne+6+3/SnWUTjI6E89n3v+PtdIfbZ8If+r98nvvkW/Z/mMY3D/FD8drz4fzls1X+mjmf0r63vcQ93/MNHl+qhn3XfHORjt3azxP5k2He2vjUazxve2F5TzBWSHit/3yB++wZ5znlzUYuGGuBH5eH7+ctwHfu3kzjzz37Mo3PnCGtrOt8bcTZM8fDY0nd+FWL3d2PA7h1teOFEN1F1psQkSCxCxEJErsQkSCxCxEJErsQkdDVFFcYkITdEOQH++jwrQPhNNPDVd7u+cUJ/lRbZW4xTRW2BmMXWrztcbPKrbPEeJrp6C5uxaAVTg32+XA5ZQCog1tv5xb5vpfa3D4bGAufl9YyL+/dbvES3a2U9ZitbHj7tQZPj222eUvmv7n7Nh7/0I00vjx/dzA2Pcnzyl47fTYYe+p3jwdjurMLEQkSuxCRILELEQkSuxCRILELEQkSuxCRILELEQld9dmTJMEmkuI6Nc/LEjdz1wRjZSvRsSdOcVO2lnA/uVkMH6p6mx/Gvgb3i3cN87llZnlZ48rUoWCsUeYef73J1wh4m3vhuRz3q3PZ8PqFjPG2yfXMTho/1+DnvN0Kn5fl+hY69tEfnabx627g4+/5EC+TPZQNt2Ue2sqf17U3XheMDQyF11Xozi5EJEjsQkSCxC5EJEjsQkSCxC5EJEjsQkSCxC5EJHTXZ89mMTQcbm+czfPyvBfKYR9+Ys7o2KU8zxnPF0iiPQDUw3nfLW5Fw5opr6lzvFX13IlxGl88fz68b0/JpQf3urPOj2t9mZeSnquHy1wPlriPjhLPlV8Ejzc8/NySIl/bcGyWr7v47n+E24cDwE03cR9+745w7YY2+NqFViZ8wbmFn5fu7EJEgsQuRCRI7EJEgsQuRCRI7EJEgsQuRCRI7EJEQld99kwmi9Jg2H+sVXjd+BdfmAzGFiu8ZXNugPuqRlrdAkCCsO9aSvhhbC5yL3t8mrce9sa1NJ6MhHOnt/TxGuRDA9xHLxb4OenLzfN4PvzcswlfV9ECr3mPSe6zz86Hc/Wv2lykY5da4fUBAHD8DK8D8Ozz/LhfvSP83Awpaz4ybO3EGnx2M3vEzKbM7Mglj42Y2RNmdqzznXdoEEL0nJW8jf8WgDe3r/gcgCfd/XoAT3Z+F0JsYFLF7u5PAXjzesz7ADza+flRAB9Z53kJIdaZ1X5AN+burzecOgdgLPSHZvaQmR0ys0Ozc+E13EKIK8uaP413dwf5VMDdD7r7AXc/sIUkwQghriyrFfukmW0HgM73qfWbkhDiSrBasT8G4MHOzw8C+Mn6TEcIcaVI9dnN7HsA3g9g1MzOAPgCgC8B+IGZfQLASQD3r2Rn7o5qLexnz83zHOLZC2EPcduWlPzhq7lnO3Oae93zpJf47r/g2/7jBd4DfbbK85d3bSvT+P5rw8d0xxB/Pc8lLRpfKi/S+K5t3IcvDoRroM+V+XF7dYL7zWdq3GcfJE+t4LwIgeW4m3xulp/Tx395gsZvvzW8xmDXGP93NyF5+obwuolUsbv7A4HQB9PGCiE2DlouK0QkSOxCRILELkQkSOxCRILELkQkdDXFtdFs4tw0KXuc51ZLoRROK6xOX6Bjc9M8FdPr3IJqLIa3X53hr5n5Ok+f/csb+TLihz8+TuM5Yt09/iveBjspjtJ43xK3/aoNfgmdng7v/8g5nvo7MclbFy/U+XPbuTs8t8rEOB1brYdbTQPAQoXbfk/8z1Eav+s9Ydtx5z2307EZJzohpb91ZxciEiR2ISJBYhciEiR2ISJBYhciEiR2ISJBYhciErrqsxcK/dh3w83BeBa8tPCxE78OxvqrJ+nY8RPcF51vpbX/DZfvXT7BXzOzDV62ePcYL+f89r18DcHyTNhnr2X43F6d5Km9uwa4Dz8yxrdfbIafe/8QL+99c0rq8MsT/LieGQ+nTJeXeGpurcJTe41nJaNaOUfjp46Gr7fGB2+lY3Mp61FC6M4uRCRI7EJEgsQuRCRI7EJEgsQuRCRI7EJEgsQuRCR01WdPsgmGt4a99Lctcd/1lj3h/OXz0y/RsYsXuG+6UNlN4y2ES1U3izwvu53wfPaf/4bHTxzbSePbx8L7r/bvoWOnwcsWz1ZTagw0ec75vj1hP3m4zdsezy7wcs8zZT7+j+XwGoL5BV6/YKyYsn4ge4rGrx7mdQAGSEvoxhK/HrJ9fI1ACN3ZhYgEiV2ISJDYhYgEiV2ISJDYhYgEiV2ISJDYhYiErvrscIeTGuqW8CTh3TvCtbxvu5DyujV+moYzJ6do/PhSuIVvvniAjq01N9F4OcWrfqHCvfKXF8Oecb7KT3HNuM9ez3A/+rVnuBc+dJScb+d+cb3KW3hPTy7QeLUWviZym3n9gmwySeOl2jM0vrnEr6f58+FreX6We/SlYX49hUi9s5vZI2Y2ZWZHLnnsi2Y2YWaHO1/3rmrvQoiusZK38d8CcPdlHv+qu+/vfP1sfaclhFhvUsXu7k8B4P2JhBAbnrV8QPewmT3XeZsf/IfWzB4ys0Nmdmh6ZmYNuxNCrIXViv3rAK4FsB/AWQBfDv2hux909wPufmDrKC9eKIS4cqxK7O4+6e4td28D+AYA3nZSCNFzViV2M9t+ya8fBXAk9LdCiI1Bqs9uZt8D8H4Ao2Z2BsAXALzfzPYDcADjAD65kp15u41GJZzHO7/Ic86PvHImGJs4O03HDjj3i/f28/7tC+fDtds9xUevNd9J4yjx/OXSCP+sY3gwfBr7m+F69wAwOc394JZxz7edsv25cvi5tVJ8dBjf9mCTjx/y8HFx4/3X8+VZGu8z3qdgZp4f1989F66PcOeZOTp2597tNB4iVezu/sBlHv7mqvYmhOgZWi4rRCRI7EJEgsQuRCRI7EJEgsQuRCR0NcXVW200FsPlfWdSUvteeCVszR0/tUTHmvMU16TJ0yWHEbZ5Zmef5tsu8NbDmzfxtsmlk3+g8daF8NyWMvy4tBZ4O+gB8HLNWzbxNtuNevi4Li9xa63Qz9NQNyX8eikT26+SDbcOBwBv8eeFIrdbqwlP3z17IVwW/cJiiiXZIlath8ux684uRCRI7EJEgsQuRCRI7EJEgsQuRCRI7EJEgsQuRCR01Wc3AxLy8pJt8TTUeWLLHpst0rHFHI8PgO+71gz7zUszx+hYFLnP3p/wls8Dy7/l4y18UJNB/rwHSzytuLnIy3uPZLgP38iHU5rnqnwNQKbF20VXqzz1t7IULnOd2XQtHZsM8vLdlfxWGr8wy1Nktw2Ez1kxrSNzi5Tvls8uhJDYhYgEiV2ISJDYhYgEiV2ISJDYhYgEiV2ISOhuy2YzJPmwpzxQ5L5qf8GCsfOL3O+tD9xK462+FE83G253V+sPe5sAYEXuybac59LPkXx1AJivhcf3l7nH327zbbe5zY76NC+DXW+QnPIK33crJZc+yfHJbdmyLxhrZnfSsUvg+ez1Cn/e5fnw+gIAuG5rIRgbHQrnul/cObneSEh3diEiQWIXIhIkdiEiQWIXIhIkdiEiQWIXIhIkdiEiobs+O4AWybctlcLeIwCMDpNE39prdGy5yVsyN/qvpvFqbkcw1hrdTccOX8XjpeqzND6Uv57GR0jKeovk4QPAmTPjNN4u8vG5PK/tPpAbCsayrLgBgEx/eCwA3HDje2m8XQ/npP/qaX69zM2eo/GkzvPVM7VJGt++7bpgbCilBoGzNgNr8dnNbLeZ/cLMXjSzF8zs053HR8zsCTM71vk+nLYtIUTvWMnb+CaAz7r7TQDuAPApM7sJwOcAPOnu1wN4svO7EGKDkip2dz/r7s90fl4EcBTATgD3AXi082ePAvjIlZqkEGLt/Fkf0JnZNQBuA/BbAGPufrYTOgdgLDDmITM7ZGaHps/z/3OEEFeOFYvdzAYA/BDAZ9zfmLnh7o7ARwPuftDdD7j7ga0jW9Y0WSHE6lmR2M0sh4tC/467/6jz8KSZbe/EtwOYujJTFEKsB6nWm5kZgG8COOruX7kk9BiABwF8qfP9J2nbcne0WmErp1jiqX3bx8Jph6UCTymcXT5D4/UWTzNtJ+E0VW9yG2ehOkHjtZQS2sWd76Lx0dtuCsaaKeWYT2SP0ngzy621dkoZ7MKmUjA2MMDP9/npcFoxAJyt8tTh8mz4nE7NnqRj61XeDrrovAz20GA4HRsAtm8bDG+7xGtJW5ad0/B+V+KzvxfAxwE8b2aHO499HhdF/gMz+wSAkwDuX8G2hBA9IlXs7v5rhF8uPri+0xFCXCm0XFaISJDYhYgEiV2ISJDYhYgEiV2ISOh6imuGVcF17k0W+sLTHRkI+7kAUF9keYFAq07a4AJoZMNeerPJPfza8im+7RYvW3yyso3Gy82wF94Ef161Cj8uSUoKK+kWDQA4NxNOFZ1GSp1q5xt/9fnjNF5fDrejzvgcHZsYLxWdK3AvvLSJn9NGK5zO3UrRAVibbFPLZiGiR2IXIhIkdiEiQWIXIhIkdiEiQWIXIhIkdiEioas+u8GQtMMeord5XnelGm7xm2R5GepCkfvwtQYvmdwg+csZ8JzxbIrXncEMjVem+fgT58eDsS3D3Mveyjs6o1jg7aj7stwTzpK2y0MD/Jxli7yU9EtLfI3AqflKMFZNudZyeV4subBpE41XW/y4WzY8PpPj1xOMXQ/y2YWIHoldiEiQ2IWIBIldiEiQ2IWIBIldiEiQ2IWIhO7ms7vD22H/Mcly7zPbF86trjW5j44Sf10rJnzfrcWwX1yvcB+8tcxzp5vgfjHavN00muH1ByVaYxzYtZn7ycUsv0SShPvsrUY4L7yQcA+/1eDxhXnel6RcDeecZwo76dhC/3YaTzycKw8AQyWeD7//lmuCsYG0ls2sKISFz4fu7EJEgsQuRCRI7EJEgsQuRCRI7EJEgsQuRCRI7EJEwkr6s+8G8G0AY7iYLHvQ3b9mZl8E8HcApjt/+nl3/xnblruj0Qz77JkU33X7ti3B2OZBXqd7aZr3EU8y/FAU+8KJ32k530tl3me8UeN+MTI8N7rZCPv0Zry++eAQ99n7Sb9vAMj2cR9/cTlcB6CV8LF15+c0k0/pDV8Kzz1fSOlRkPBa/6NFfs4/eOfNNH7Hu/YEYzme5o92g92j19afvQngs+7+jJkNAvi9mT3RiX3V3f9pBdsQQvSYlfRnPwvgbOfnRTM7CoAvPxJCbDj+rP/ZzewaALcB+G3noYfN7Dkze8TMLvt+0MweMrNDZnZoZo6/nRVCXDlWLHYzGwDwQwCfcfcFAF8HcC2A/bh45//y5ca5+0F3P+DuB0aHR9ZhykKI1bAisZtZDheF/h13/xEAuPuku7fcvQ3gGwBuv3LTFEKslVSxm5kB+CaAo+7+lUsevzQt6KMAjqz/9IQQ68VKPo1/L4CPA3jezA53Hvs8gAfMbD8u2nHjAD6ZuiUzwMJ2S6bFp3Pj1eG0w3ffwtsa//QX4ZbLALBwIVx2GAAySXhu/QO8HvPIyCiNNxtbabyQ5/ZZQtJzRzbzsct1/nrfdJ5+W0j6aXxxOWwFVap829UGT+31ZtjWA4ASSR3enOPW275rwjYvALzrlhto/I7bb6TxHVsHSZTPrU3sa/ewJbiST+N/Hdg79dSFEBsLraATIhIkdiEiQWIXIhIkdiEiQWIXIhIkdiEioaulpL0NNGrhkszZDE95HN0Uzv2756+479luT9L4/z79exqfnJoNxmopS/6z+ZQy1xnemrjZSGkJXQifxoVZXiJ7fnKZxq3NSyIPDvDWxQuLF4Kx5WW+776UVM/+fLi0OADs2RZee3HDNXvp2P37rqPxa/ZcReNXDfPU4cTD57S5yFOa67XwOfW2WjYLET0SuxCRILELEQkSuxCRILELEQkSuxCRILELEQnG8l/XfWdm0wBOXvLQKICZrk3gz2Ojzm2jzgvQ3FbLes7tane/bIGEror9LTs3O+TuB3o2AcJGndtGnRegua2Wbs1Nb+OFiASJXYhI6LXYD/Z4/4yNOreNOi9Ac1stXZlbT/9nF0J0j17f2YUQXUJiFyISeiJ2M7vbzP5oZq+Y2ed6MYcQZjZuZs+b2WEzO9TjuTxiZlNmduSSx0bM7AkzO9b5zhOnuzu3L5rZROfYHTaze3s0t91m9gsze9HMXjCzT3ce7+mxI/PqynHr+v/sZpYAeBnAhwGcAfA0gAfc/cWuTiSAmY0DOODuPV+AYWbvA1AG8G13v7nz2D8COO/uX+q8UA67+99vkLl9EUC51228O92Ktl/aZhzARwD8LXp47Mi87kcXjlsv7uy3A3jF3Y+7ex3A9wHc14N5bHjc/SkAb66Dcx+ARzs/P4qLF0vXCcxtQ+DuZ939mc7PiwBebzPe02NH5tUVeiH2nQBOX/L7GWysfu8O4HEz+72ZPdTryVyGMXc/2/n5HICxXk7mMqS28e4mb2ozvmGO3Wran68VfUD3Vu5093cCuAfApzpvVzckfvF/sI3kna6ojXe3uEyb8T/Ry2O32vbna6UXYp8AsPuS33d1HtsQuPtE5/sUgB9j47Winny9g27n+1SP5/MnNlIb78u1GccGOHa9bH/eC7E/DeB6M9trZnkAHwPwWA/m8RbMrNT54ARmVgJwFzZeK+rHADzY+flBAD/p4VzewEZp4x1qM44eH7uetz93965/AbgXFz+RfxXAP/RiDoF5vQ3AHzpfL/R6bgC+h4tv6xq4+NnGJwBsAfAkgGMA/gvAyAaa278CeB7Ac7gorO09mtuduPgW/TkAhztf9/b62JF5deW4abmsEJGgD+iEiASJXYhIkNiFiASJXYhIkNiFiASJXYhIkNiFiIT/A8EtMvZ2MEM6AAAAAElFTkSuQmCC\n" | |
| }, | |
| "metadata": { | |
| "needs_background": "light" | |
| } | |
| }, | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "\n", | |
| "ASCII: \n", | |
| "#%@%+*@@%%%%%%%%@%######\n", | |
| "#@@**#@@@@@@@@%*%@%#%##*\n", | |
| "#%#**%@@@@@@@@@#*%%###%#\n", | |
| "#%*##%%%####%%%%*######*\n", | |
| "+##*#%####%%##%%%####%*+\n", | |
| "=+*+=***##%##%%%%%#==*+*\n", | |
| "==+*+****######%###-``:-\n", | |
| "-=+******%%%%%%%%%%#:'``\n", | |
| ":=*##*###%%%%%%##%##=```\n", | |
| ":+#**#*###***+++*###*:'`\n", | |
| "-+***#*#*-+++++*####%+``\n", | |
| ":=*#####=:+*+==++*#%%*``\n", | |
| "`:+##%###****=-=+#%##*:`\n", | |
| "`:=*#%###*******#####*-`\n", | |
| "`-*##%########%%%%%%##=`\n", | |
| "`=#%%%%%#%%%%#%%%%%%#%=`\n", | |
| "`-%%%%%%%%%%%%%@@%%#%%=`\n", | |
| "`-#@%%%####%%%%%**%%%%-`\n", | |
| "`:#@@@%%%*+###*#*#%@%#:`\n", | |
| "`:*@@@@@@%%%####%%%@%#-`\n", | |
| "`:*%%@@@@@%%#####@@%%#:`\n", | |
| "::*%%%@@@@@#####%@%%#=``\n", | |
| "::=%%%%%%%%##*##%%%#+:``\n", | |
| "::-+%@@%%%%%%%%%%%*+-:``\n", | |
| "\n", | |
| "\n", | |
| "==== Sample 2, class = bird ====\n", | |
| "Original:\n" | |
| ] | |
| }, | |
| { | |
| "output_type": "display_data", | |
| "data": { | |
| "text/plain": [ | |
| "<Figure size 432x288 with 1 Axes>" | |
| ], | |
| "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAXRElEQVR4nO2dXWycZ5mG72dmPGOP/xLHSWqSNCVtWpoCW1hvhQTaBaFFpSeFk4oeoK5UbTgACSQOFnUP6GG1WkAcrJDCUlFWLAgJED2oWLoVqwpphepWoU0b2rRp0iQ4dhKn8fh3fr5nDzxdmeL3fo3nd/e9L8nyeB5/3/d+P/d8M3O/z/OYu0MI8f+fXK8HIIToDhK7EIkgsQuRCBK7EIkgsQuRCIWubqxU9GJ5MBiP+gLsH8x2MqS2YB3fdgfXH3NjendYEXWKInEjcUPGl43cBqOHLbICIxdzFll3hnwwVltdRX29uuVZa0nsZnYvgG8DyAP4V3d/jP1/sTyI2z7xkWA8fnLDoYaFDwAAWK61NzFM0Pk833Yu9mIQE5TzsbuHVxA7pp7FLno+ODO+/hwJZ5GX93qjxrddr0bijWCsmK3SZQfLfGyrdR4v5YdpfMDDY1/iu41qYSwYe/2//jsY27ECzCwP4F8AfBrAMQAPmtmxna5PCNFZWrnd3QPgdXc/6+5VAD8GcH97hiWEaDetiP0AgAub/r7YfO6PMLPjZjZjZjP19cj7EyFEx+j4t/HufsLdp919ulAa6PTmhBABWhH7JQCHNv19sPmcEKIPaUXszwE4ambvNbMigM8BeLI9wxJCtJsdW2/uXjezLwH4D2xYb4+7+8t8KeNWDneBQBwm5CK+Zi7Pd7UVqzwXsfViPnw8zq09dmCyiGnrOX7Q88w7A5BFjntG7ieeRfY74/udxW5VhbC9tlbjHymrq+s07nUe37V7iMYPHZgKxi5eWaLLzi+Gd5wd0ZZ8dnd/CsBTraxDCNEdNF1WiESQ2IVIBIldiESQ2IVIBIldiESQ2IVIhK7mswMOJ76vRXI9mR+dRY3ySO5zzKcnXnqr+eyt+vDMXY1k38Ij6bNm4TTRjRVwn55FI9mxsOjYinzbhXAaaYPkhANAoRGRRiS9trZaofE8xsOxyIEZINcq05Du7EIkgsQuRCJI7EIkgsQuRCJI7EIkgsQuRCJ02XozWmk1WkqaOVCRKqit0knrLdpbs6XemzFbL7LpSLwYsfZKpIqqOy9TVo/YY3Xn8QY5ZxFXD3CeAjtY3kXjpcHIvtXC8YFYtWKWlkxsO93ZhUgEiV2IRJDYhUgEiV2IRJDYhUgEiV2IRJDYhUiE7vrsFummGuko6ix9L5om2trrGuuGGu2UGlt3ZL9zsVLSLB4z8WOtiSOjHx8p0fj+kXBJZa+v0WXnrvI00qsLdRrn8x8idctj0sjxUtHLa3zfrl25EYw1CpEOsIXw+Wa7rDu7EIkgsQuRCBK7EIkgsQuRCBK7EIkgsQuRCBK7EInQ5Xx2wFkL38iyRpbNIonXHvHh85GSyFhbDoaKkXLLFtuzXCQvGzy3uuHEb47ML7BITjjy/LhUw12RN9Y/PhKM3bJ/gi6ba1yj8WvXeblmPi+DLkqnLgBAFjkn61V+zpdXwm2ZM+PHvFgKH1N2rbUkdjM7B6ACoAGg7u7TraxPCNE52nFn/4S7X23DeoQQHUSf2YVIhFbF7gB+ZWbPm9nxrf7BzI6b2YyZzdTX+VxnIUTnaPVt/Mfc/ZKZ7QPwtJn93t2f3fwP7n4CwAkAKE+Mt1Q6UQixc1q6s7v7pebveQA/B3BPOwYlhGg/Oxa7mQ2b2eg7jwF8CsCpdg1MCNFeWnkbvx/Az5s5wwUA/+7uv6RLuPFC5NEWveF4LlYfPRaPGMaT5fC2b9kbqSGei3x6KfCc8BtLfGw3KovBmBk/xZbxbS9V+fcsy+HpBwCAC7Ph5adG99Blx4d5S+bcAPeja0a8cFZXAYBF5m1YxudWlEcHaXx8OHzc18MWPABgYCA8CSBPhr1jsbv7WQB/sdPlhRDdRdabEIkgsQuRCBK7EIkgsQuRCBK7EInQ1RRXMyBP0jmzSHnfHMlLtIi9VVsL21MAsDvPLaaDY+Phba9y/+nQzXtpvJHxsY+scett71C4/e/aCl92bKxM4/nyJI2/+NZ1Gq8shUsqX77Ej/ngMLc0LdIvOsvCl3cuVlq8HjknZb7tAzftp/HKQvi47d4zRpedGA+fs2IxbDfqzi5EIkjsQiSCxC5EIkjsQiSCxC5EIkjsQiSCxC5EInS5lLQhR1ILY62Pqc/OyikDGB3gr2u37Qv76AAw6GFPeHkp3H4XAErZKI1XSIoqAFQun6fxAVJ6eDDPUy3Xrlyi8cM3T9H4ZI3v+8qNcHy+yi+/w8d4qenhIZ6eu7zC+hfzZQ3rNL5rlI991xg/7hfOhde/1OD1Wxtr4W3Xa+E5F7qzC5EIErsQiSCxC5EIErsQiSCxC5EIErsQiSCxC5EIPWjZ3BmYfw8AExO8bPFgiedWjw2E84THh3j73uvzf6DxyuLbNB5pJg0rhfO+iyWer75yY4HGhzzs2wLA0Um+/gapi7ywxr3sbInPP5ga4fMXFlfDufSNSD57PnIbLOR5KWmv8RoHrGX0YHmYL1siufREB7qzC5EIErsQiSCxC5EIErsQiSCxC5EIErsQiSCxC5EIXfbZHYjkrDNY22XPca/78kKFxq8uX6Hxv7zjcDB2ZB/38C++coHGVyJ+cq44RONGWviWIp5trcb7AzeYIQxg326+/vJo+LhdJLnuALBc4TXpxwZ57fZdFr7WFiL56lbgsxvGh7nHX7Kwxw8AI4Nhnz6mkLm3w3Mf6qTeffTObmaPm9m8mZ3a9NyEmT1tZmeav3fH1iOE6C3beRv/fQD3vuu5rwF4xt2PAnim+bcQoo+Jit3dnwXw7jmV9wN4ovn4CQCfafO4hBBtZqdf0O1399nm48sAgo2tzOy4mc2Y2Ux9nc8/F0J0jpa/jfeNKpHBbwXc/YS7T7v7dKFUbHVzQogdslOxz5nZFAA0f8+3b0hCiE6wU7E/CeCh5uOHAPyiPcMRQnSKqM9uZj8C8HEAk2Z2EcDXATwG4Cdm9jCA8wAe6OQg34G2MXfuB2OAe9X1SB/yN6+G85OnJnnN+ZFx3me8ts5zn0uDfN/KA2HPNlvnufJjo3x+wtwi98JR4f3fB0rh+unjZZ4LP+z8XnR9jR+3MYTP+WLGffB9kXM6PjxC49fm5mi8uhqe37BS5/MHGgPh6ykjc1GiYnf3BwOhT8aWFUL0D5ouK0QiSOxCJILELkQiSOxCJILELkQidL2UNCPWsjlzknYYsd4ssqte5DbQHLHHzszxFru3D/P2wKOjPE20kfHX5MnJvcHYSqQddJ6kgQKAjXDLshFJLc6IfVZzvmy1xttwV1f49OuV62FbsJbnKax7D4zR+LV5nhL9+iwvwV1FOC26VOZJpAPkWrVc2LbTnV2IRJDYhUgEiV2IRJDYhUgEiV2IRJDYhUgEiV2IROh+y2bipcd8dlpKOrJsLtL42Iz79LVC2G++ssjTJe8aD6d5AkCBeKMA8NbsNRr3Ynj9Y8N823t28zLYeyIlkxcjKa43lsJxz/Oxrc9xLzur8hTXrBKON+rcB794jkujmudzJ+qD/LgOFEmKrPHrgV3JTjSiO7sQiSCxC5EIErsQiSCxC5EIErsQiSCxC5EIErsQidD9fPaIn925zUby3SPxnIW72aw2eN51jZQ0BoChIV62eGw3nyMwNB7Of969m6+7OMj94lzMC8cKja9kYT/70NR76LJDBX5Oxob55btvX7jk8ulLvB30W3y3sB4pDz5U4sc1UkagI+jOLkQiSOxCJILELkQiSOxCJILELkQiSOxCJILELkQidNlnN+Ry4dcXJzEg4oXnuCcby3f3mM9Ocs7X67z++YUFXt/8tokpGr9r3xEan9wTbjc9PBieHwAA16/O0/j5WR5fbvB9q1TDcwTORertr1zj7aZHyvyclQrrwdjBUX5cKlWex385x+cfNEheOQDkyfUYnxNCw0Gid3Yze9zM5s3s1KbnHjWzS2Z2svlz3842L4ToFtt5G/99APdu8fy33P3u5s9T7R2WEKLdRMXu7s8CWOjCWIQQHaSVL+i+ZGYvNt/mBydnm9lxM5sxs5n6evgzlBCis+xU7N8BcCuAuwHMAvhG6B/d/YS7T7v7dCGSHCCE6Bw7Eru7z7l7w90zAN8FcE97hyWEaDc7EruZbfaKPgvgVOh/hRD9QdRnN7MfAfg4gEkzuwjg6wA+bmZ3A3AA5wB8YVtbMyBHTMKshVz3VvPVozXrSQJyPVL3/Y3rPDm6WuP56gfHuCdcIuHF60t02doa3+98ifeOv3yee+FnL4drvy9U3qTL1qq8tvtd7+P58Icnwpf33jF+zirLDRq/vsTz4S3He6znCuGxxWTArmW2bFTs7v7gFk9/L7acEKK/0HRZIRJBYhciESR2IRJBYhciESR2IRKh+6WkW4JYDpEl49ZaxLpDuFx0luPrfjvjMwcr89wee/0tnppw03g4VXT20gW6bCTTE/f9zV/R+DLvVo3nT54Pxm5kfOOH7rydxs8u8XvVyFg4fucET1E9OsBtvxtrkZbP66QlMwAvhK+JThVb151diESQ2IVIBIldiESQ2IVIBIldiESQ2IVIBIldiET4P+WzU6884qO3vO0svH6P9N/1Ai81nQ3xssVrazwF9tqbZ4Oxcp77we+/7X00PnPqNRp/+tmXaPxqJTz2Ax/g295z6x003qhXaPzV6+H02qECP6ZHytyHP7aHp8BWFhZp/O01kupdHqPL5nJh2Tpx6XVnFyIRJHYhEkFiFyIRJHYhEkFiFyIRJHYhEkFiFyIRuuuze/OHxRkk0Te2KCthDWyjpTMN89fMnHNPN8u4Zzs0wE/T7UduDsamjx2ky164+BaN//LZ52h8fpHv2+FjHwzHPnAnXbaR5/vdMF7mutIIj+3Fq7wUNEb4/IRyxluZ7TOe6L+0GF6+WirTZXMDrD6CfHYhkkdiFyIRJHYhEkFiFyIRJHYhEkFiFyIRJHYhEqG7PruBF8VupVVtjr9uxerCx/LhGyQefcWshWvOA0BhnXu2d0yN0/iHj+wPxhpVXpN+cYX7waNT+2i8dHAXjd98ZzhnPYscuFqde92x+uoFUpt9ZZ179K9VeD764UF+TsciPnxpOZyL3xjiPnuxFJ47YWTGSfQ6NbNDZvZrM3vFzF42sy83n58ws6fN7EzzN29ILYToKdt5G18H8FV3PwbgIwC+aGbHAHwNwDPufhTAM82/hRB9SlTs7j7r7i80H1cAnAZwAMD9AJ5o/tsTAD7TqUEKIVrnz/rMbma3APgQgN8C2O/us83QZQBbfnA0s+MAjgNAsTy003EKIVpk29/Gm9kIgJ8C+Iq7/9G3F76RRbLlNwPufsLdp919ujDIGxwKITrHtsRuZgPYEPoP3f1nzafnzGyqGZ8CMN+ZIQoh2kH0bbxteFbfA3Da3b+5KfQkgIcAPNb8/YtWB5OL2Geg9lmk5XIsxbWFFFiPpKhavUrjh/dyI+PY4T003lgLp2tevMrbPS+s8xTV4sQkjQ+PTNF4VgyX0a5GrDV3fj3kY2nJ5JrwArczr0aup7EcL2N9U5FfEzeNhI/7mcXZYAwAGmVSepxci9v5zP5RAJ8H8JKZnWw+9wg2RP4TM3sYwHkAD2xjXUKIHhEVu7v/BuHb5ifbOxwhRKfQdFkhEkFiFyIRJHYhEkFiFyIRJHYhEqHrpaRjJZsZsSxVRhbZbpZxv5n5l97gaaJ7RvjMwWNHD9B4KcfTKW8shf3q8/M36LJ/WORzAIYn30PjuSJPcWVOemxuQ6w+eOycFljp8Rxvo70eKVN9nbTwBoCJEj/n5dHw9oeX+DlZXV8NxpyULdedXYhEkNiFSASJXYhEkNiFSASJXYhEkNiFSASJXYhE6EEpaWZ+RlobE2vTWvDvAaAR8U2tGvY2xwrcB3/vwZtoPF/kPv0Lr5yl8cUby8HYap6XJS7s4j55YZDkTgPInPvVLCfdjJ/vaAfvFuZdeJ7n0hec++RLGd/v8xkv4b1aeTsYGxgeo8uCxC2XD8Z0ZxciESR2IRJBYhciESR2IRJBYhciESR2IRJBYhciEbrrs8Noa+WsEfG6EfZlY7nNMdPWiY8OALtL4RXc9h5e972Y5/nJr75xhsZ//9ZVGm80SG70nhG6LAZ43nYW66scuV8YO+6tpbNvA1I33nhd96Jxn73hgzQ+T65VgNetHx7k52RgKHxOmb50ZxciESR2IRJBYhciESR2IRJBYhciESR2IRJBYhciEbbTn/0QgB8A2I8N6/OEu3/bzB4F8PcArjT/9RF3fyq2voaFX19idcQLWTgHOVaPvl5fp/GJSE769B3hPuRjZb7t06+8QePnr/Lc6lqe55yD9EBfrUd6nEfc7CwXmfsQOW4g5ztmpFukxoBH7lUZ2baxcQHwmE9ukf3Oc5++OBnuFdAotDK3Iayh7UyqqQP4qru/YGajAJ43s6ebsW+5+z9vYx1CiB6znf7sswBmm48rZnYaAG9hIoToO/6sz+xmdguADwH4bfOpL5nZi2b2uJltOWfUzI6b2YyZzdTX+VtpIUTn2LbYzWwEwE8BfMXdFwF8B8CtAO7Gxp3/G1st5+4n3H3a3acLkf5XQojOsS2xm9kANoT+Q3f/GQC4+5y7N3yjk9x3AdzTuWEKIVolKnYzMwDfA3Da3b+56fnNX09/FsCp9g9PCNEutvNt/EcBfB7AS2Z2svncIwAeNLO7sWGgnAPwhe1skJUWzjJuZ3gjHM+q/PuAcuQTxPSdR2h8366wpfHa2VfpspdmF2i8mttP44URXs45lirK8Eir6nhXZZ4qyshZuOxxO2Dpnha7z8XCNHcXyMUOHGsZnY9sPGIbhtjOt/G/wdaXU9RTF0L0D5pBJ0QiSOxCJILELkQiSOxCJILELkQiSOxCJEKXS0kDRvIa3bnPXq+uBGMjkT35yN0fpPHDe7kR/7uXnwvGTr85T5et5rlPnitF2iJHjHR2TC3i9zIvutfExp7bod8MxFNY86T1MRAfW7TdND3usfO9M/r3TAsh2orELkQiSOxCJILELkQiSOxCJILELkQiSOxCJILFSjC3dWNmVwCc3/TUJADej7h39OvY+nVcgMa2U9o5tsPuvnerQFfF/icbN5tx9+meDYDQr2Pr13EBGttO6dbY9DZeiESQ2IVIhF6L/USPt8/o17H167gAjW2ndGVsPf3MLoToHr2+swshuoTELkQi9ETsZnavmb1qZq+b2dd6MYYQZnbOzF4ys5NmNtPjsTxuZvNmdmrTcxNm9rSZnWn+3rLHXo/G9qiZXWoeu5Nmdl+PxnbIzH5tZq+Y2ctm9uXm8z09dmRcXTluXf/MbmZ5AK8B+FsAFwE8B+BBd3+lqwMJYGbnAEy7e88nYJjZXwNYAvADd39/87l/ArDg7o81Xyh3u/s/9MnYHgWw1Os23s1uRVOb24wD+AyAv0MPjx0Z1wPownHrxZ39HgCvu/tZd68C+DGA+3swjr7H3Z8F8O52MvcDeKL5+AlsXCxdJzC2vsDdZ939hebjCoB32oz39NiRcXWFXoj9AIALm/6+iP7q9+4AfmVmz5vZ8V4PZgv2u/ts8/FlALx3VPeJtvHuJu9qM943x24n7c9bRV/Q/Skfc/cPA/g0gC823672Jb7xGayfvNNttfHuFlu0Gf9fennsdtr+vFV6IfZLAA5t+vtg87m+wN0vNX/PA/g5+q8V9dw7HXSbv3m1yy7ST228t2ozjj44dr1sf94LsT8H4KiZvdfMigA+B+DJHozjTzCz4eYXJzCzYQCfQv+1on4SwEPNxw8B+EUPx/JH9Esb71CbcfT42PW8/bm7d/0HwH3Y+Eb+DQD/2IsxBMZ1BMDvmj8v93psAH6Ejbd1NWx8t/EwgD0AngFwBsB/Apjoo7H9G4CXALyIDWFN9WhsH8PGW/QXAZxs/tzX62NHxtWV46bpskIkgr6gEyIRJHYhEkFiFyIRJHYhEkFiFyIRJHYhEkFiFyIR/gcK4w4eII1bDQAAAABJRU5ErkJggg==\n" | |
| }, | |
| "metadata": { | |
| "needs_background": "light" | |
| } | |
| }, | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "\n", | |
| "ASCII: \n", | |
| "########################\n", | |
| "#####################**#\n", | |
| "#####################**#\n", | |
| "#############***######*#\n", | |
| "#############*++########\n", | |
| "######****####*+########\n", | |
| "#####**++****##**###****\n", | |
| "####***+****+##**###+*#*\n", | |
| "####**+*****++***###+*#*\n", | |
| "#####*+**********##*+*##\n", | |
| "######***********##**###\n", | |
| "######*+**++=++**#%####*\n", | |
| "######*+**+==++++*#####*\n", | |
| "#######*+===+******####*\n", | |
| "########*+**+#%%#*+*##**\n", | |
| "##########%#*%@##*++*#**\n", | |
| "##########**%@%**#***##*\n", | |
| "#########++*%%***##**##*\n", | |
| "########++*###**####**##\n", | |
| "#######*++*###**#*##**##\n", | |
| "######*+*#####**#*#**##*\n", | |
| "#####*+*#####***##***#**\n", | |
| "####*+*#######*##*******\n", | |
| "###*=*########**********\n", | |
| "\n", | |
| "\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# 1) Get samples to serve as \"context\"\n", | |
| "args.seed_train = 42\n", | |
| "args.train_samples_per_class = 4\n", | |
| "img_dims = {'width': args.width, 'height': args.width}\n", | |
| "\n", | |
| "context_ix, context_y = get_random_indices_by_class(\n", | |
| " indices_by_class,\n", | |
| " class_indices,\n", | |
| " args.train_samples_per_class,\n", | |
| " args.seed_train\n", | |
| ")" | |
| ], | |
| "metadata": { | |
| "id": "9Qrhqhy0qRbV" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# 2) Get examples to few-shot classify\n", | |
| "args.seed_eval = 42\n", | |
| "args.eval_samples_per_group = 10\n", | |
| "test_eval_ix, test_eval_y = get_random_indices_by_class(\n", | |
| " test_indices_by_class,\n", | |
| " class_indices,\n", | |
| " args.eval_samples_per_group,\n", | |
| " args.seed_eval\n", | |
| ")" | |
| ], | |
| "metadata": { | |
| "id": "QWW1ab8atP5Z" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# 3) Get prompt context\n", | |
| "for seed in range(20):\n", | |
| " context_prompt, shuffled_context_ix = get_context_prompt(\n", | |
| " train_set_ascii, context_ix,\n", | |
| " start='', end='', sep='###\\n', shuffle=True,\n", | |
| " seed=seed # May need to change this to shuffle\n", | |
| " )\n", | |
| "\n", | |
| " # Inspect context order (want to avoid patterns between outputs, e.g., 0, 1, 0, 1, ...)\n", | |
| " print(seed, train_set.targets[shuffled_context_ix])" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "V8ZfWFBMtS-E", | |
| "outputId": "ae34e695-dae9-4055-bbc6-c4844ad79ec2" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "0 tensor([1, 2, 1, 2, 0, 2, 0, 1, 2, 0, 0, 1])\n", | |
| "1 tensor([0, 0, 1, 2, 0, 1, 0, 1, 2, 2, 2, 1])\n", | |
| "2 tensor([2, 1, 1, 1, 0, 2, 0, 0, 0, 1, 2, 2])\n", | |
| "3 tensor([1, 1, 0, 0, 2, 1, 1, 0, 0, 2, 2, 2])\n", | |
| "4 tensor([0, 1, 1, 2, 2, 0, 2, 0, 0, 1, 1, 2])\n", | |
| "5 tensor([1, 1, 0, 2, 2, 1, 2, 0, 0, 2, 1, 0])\n", | |
| "6 tensor([2, 2, 0, 1, 0, 1, 1, 0, 1, 0, 2, 2])\n", | |
| "7 tensor([1, 2, 0, 1, 0, 0, 2, 2, 0, 1, 2, 1])\n", | |
| "8 tensor([1, 2, 2, 2, 1, 2, 0, 0, 1, 0, 1, 0])\n", | |
| "9 tensor([1, 2, 0, 0, 0, 1, 0, 2, 2, 2, 1, 1])\n", | |
| "10 tensor([0, 1, 1, 1, 2, 2, 2, 0, 0, 0, 1, 2])\n", | |
| "11 tensor([1, 0, 2, 2, 0, 2, 1, 1, 1, 0, 0, 2])\n", | |
| "12 tensor([2, 1, 2, 2, 0, 1, 1, 0, 0, 0, 1, 2])\n", | |
| "13 tensor([1, 2, 0, 1, 2, 0, 1, 2, 1, 2, 0, 0])\n", | |
| "14 tensor([0, 2, 0, 1, 1, 0, 2, 0, 1, 1, 2, 2])\n", | |
| "15 tensor([2, 1, 0, 0, 2, 0, 1, 0, 1, 2, 1, 2])\n", | |
| "16 tensor([1, 0, 0, 2, 1, 2, 1, 0, 0, 1, 2, 2])\n", | |
| "17 tensor([0, 0, 2, 2, 1, 1, 2, 1, 0, 2, 1, 0])\n", | |
| "18 tensor([1, 1, 2, 0, 2, 0, 0, 1, 1, 2, 0, 2])\n", | |
| "19 tensor([1, 0, 2, 2, 0, 1, 2, 1, 0, 2, 0, 1])\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# 3.1) Get prompt context (cont.)\n", | |
| "context_prompt, shuffled_context_ix = get_context_prompt(\n", | |
| " train_set_ascii, context_ix,\n", | |
| " start='', end='', sep='###\\n', shuffle=True,\n", | |
| " seed=7 # Pick a random-looking one from above\n", | |
| ")" | |
| ], | |
| "metadata": { | |
| "id": "_QXJgEJQtVCO" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# 3.2 Visualize a prompt\n", | |
| "# Formatting the above (for sanity checking)\n", | |
| "for i, ix in enumerate(shuffled_context_ix):\n", | |
| " print(f'Input:')\n", | |
| " pretty_print_ascii(\n", | |
| " img_to_text(train_set_ascii.__getitem__(ix)[0]),\n", | |
| " **img_dims\n", | |
| " )\n", | |
| " print(f'Output: {train_set_ascii.__getitem__(ix)[1]}\\n###\\n')\n", | |
| "\n", | |
| "# Eval sample\n", | |
| "print(f'Input:')\n", | |
| "pretty_print_ascii(\n", | |
| " img_to_text(test_set_ascii.__getitem__(test_indices_by_class[1][0])[0]),\n", | |
| " **img_dims\n", | |
| ")\n", | |
| "print(f'Output:')" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "R4DOPULMyKKa", | |
| "outputId": "235beff4-ccb7-482f-cfb7-5d66028d76d1" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "Input:\n", | |
| "+=+##%#%@@@%@@%@@@@%#=--\n", | |
| "###%%%#%@@@@@%%%%%%##=``\n", | |
| "%%%@@@#%@%@@%%%%%%#**-'`\n", | |
| "%%%@@%#%%%%%%%%%%%#++-'`\n", | |
| "%%%%%%#%%%%%%%%%%%#**-'`\n", | |
| "%%%%%%##%#*#%%%%%%#**:'`\n", | |
| "%@%%#*+***+**#%@%%%%%-'`\n", | |
| "@@%*+=-*###*##+#@%%%%='`\n", | |
| "@%#++=:*@%%%#%*=%@@##='`\n", | |
| "%#+===+*%@@%%%#++%@%*='`\n", | |
| "#*=-=++*%%%%%%#*++%##='`\n", | |
| "#+==+**#%%###*****###='`\n", | |
| "#=-=++*#*+--++++*+*+*-'`\n", | |
| "*=++=+++++=+**++++++=-``\n", | |
| "#***++++++**####****++-`\n", | |
| "***+****+***+==+=*###+==\n", | |
| "+=+==+***%#*=::-=#%##++=\n", | |
| "=-+===++****+===+***++*=\n", | |
| ":-+==+**#=+##****=:``=#=\n", | |
| "-+##***#*==*#*###+-=+*#=\n", | |
| ":=+#%%%#*-=*#**#%%%%%%#=\n", | |
| ":::-*%@@%--*####%%%%%@*-\n", | |
| "-----=+*%*+%@@@@@@@@@%+-\n", | |
| "---==--=+*##%%%%##**++==\n", | |
| "Output: 1\n", | |
| "###\n", | |
| "\n", | |
| "Input:\n", | |
| "`=+=:````'`''```'''''`=*\n", | |
| "-+**=:````````````````-+\n", | |
| "+###*-'`:-:::`````````=#\n", | |
| "*###+-::+*==-`:-`'''':*#\n", | |
| "####*+++*#+**++*-```:=#%\n", | |
| "##*##*###%##*+*+=---==#%\n", | |
| "#######%%###*=:``:::::*%\n", | |
| "******+**++++=:``:````+*\n", | |
| "==-=======--=----+=-::--\n", | |
| "=====---===----=====----\n", | |
| "==============--====----\n", | |
| "=======++***+=--=--=----\n", | |
| "+====+*******+=-======--\n", | |
| "++==+####*****+========-\n", | |
| "+*+=*%%%%##*****========\n", | |
| "**+=###%%%%####*+=======\n", | |
| "#*=+*++%%%%%###*+=======\n", | |
| "*+==++*##%#####*+=======\n", | |
| "=====++++*++**+=========\n", | |
| "+++==+++++++============\n", | |
| "++++++++++++++++++++++++\n", | |
| "+++++++++*+++++**++*++++\n", | |
| "++++++++**+++++++++*++++\n", | |
| "++++++++**+++++++*+++++=\n", | |
| "Output: 2\n", | |
| "###\n", | |
| "\n", | |
| "Input:\n", | |
| "********+++++++++++*****\n", | |
| "**+++++++++++++++*+++++*\n", | |
| "++++++++++++++++*#++++++\n", | |
| "+++++++++++*++++##++++++\n", | |
| "++++++++++#%+=++##++++++\n", | |
| "++++=====+%%+==+%*++=+++\n", | |
| "=========#%#===*%#+=====\n", | |
| "========+%%#**##%*======\n", | |
| "======--*%%%%%%%%*-=====\n", | |
| "-------=%%%%%%%%%*------\n", | |
| "-------*%%%%%%%%%*------\n", | |
| "-----:=%%%%%%%%%%+=++**+\n", | |
| "::::::*%%%%####%%##%%##=\n", | |
| ":::::=%%%%%####%%%%#*=-:\n", | |
| ":::::*#%%%#####%##+-::::\n", | |
| ":::`=**%%###%%#*=:```:::\n", | |
| "```:**#%####*+-````````:\n", | |
| "```=#%%%##+=:```````````\n", | |
| "``:#%%%*+:``````````````\n", | |
| "``=%#+-:````````````````\n", | |
| "`:+=:```````````````````\n", | |
| "`::`````````````````````\n", | |
| "````````````````````````\n", | |
| "````````````````````````\n", | |
| "Output: 0\n", | |
| "###\n", | |
| "\n", | |
| "Input:\n", | |
| "--===--------------==+++\n", | |
| "===++==----------===+*#-\n", | |
| "++:++==============++**:\n", | |
| "+=:=+==============++**:\n", | |
| "+=`=+===============+**:\n", | |
| "+=`-=====-===-==--==+**:\n", | |
| "+=`-=----:----------=+*:\n", | |
| "*+`---:-----:--:::---++:\n", | |
| "*+`:--=---------:::--=+:\n", | |
| "*+:::*#+=++===+--:::-=+:\n", | |
| "=::`=@@%%%#####+-=-:-=+:\n", | |
| "=```:=====++***+==:`:=+:\n", | |
| "*:::`::::::::::=:`````::\n", | |
| "*---:------------::```::\n", | |
| "=--+#+------:-----+#=:-:\n", | |
| "--=%#%+----------=%##=-:\n", | |
| "::+#+**--:-:::::-**=#=::\n", | |
| "--+*=+#--:--:-:::#+-*+::\n", | |
| "*#%*=*%+++++++==+%*=*#=:\n", | |
| "%@@%+#@@@@@@@@@@@@#+%@#:\n", | |
| "+**#**##*###**#**#***+=:\n", | |
| ":::::::::::::::::::::```\n", | |
| "::::::::::::::::::``:::`\n", | |
| ":-:::---:::::-::::::::::\n", | |
| "Output: 1\n", | |
| "###\n", | |
| "\n", | |
| "Input:\n", | |
| "::``````````````````````\n", | |
| "::``````````````````````\n", | |
| "::``````````````````````\n", | |
| "````````````````````````\n", | |
| "````````````````````````\n", | |
| "````````````````````````\n", | |
| "````````````````````````\n", | |
| "`````````````````````::`\n", | |
| "``````:-=::::`':-```'-=`\n", | |
| "````:-=+*+++=::-=::--=*=\n", | |
| "```-+*###*****####*****-\n", | |
| "-:`=#%%%%###%%@@%#%#=*=:\n", | |
| "#+-=*#@@%@%%@@@%%*#+-=+-\n", | |
| "%#*+-=+=+@%#@@##=--:----\n", | |
| "**++==-:+#+==*#+::---=--\n", | |
| "+++=====#%*+=*%===**+++=\n", | |
| "++++++*+*****##***#**++=\n", | |
| "+++++++++++++++++++****+\n", | |
| "+++++=+++++++++++++*++++\n", | |
| "++++++++++++++++++++++++\n", | |
| "++++++++++++++++++++****\n", | |
| "++++++=+++++++++++++***+\n", | |
| "++++++==++++++++++++++++\n", | |
| "+++++===++++++++++++++++\n", | |
| "Output: 0\n", | |
| "###\n", | |
| "\n", | |
| "Input:\n", | |
| "```````````````````````:\n", | |
| "```````````````````````:\n", | |
| "`````````````::::`::::::\n", | |
| "::``:++*-```:::::`::::::\n", | |
| "::-=+#%%*=-::::::```````\n", | |
| "--+##%%%%##++---+=-==-::\n", | |
| "###%%%#%#%%##*#+##*##++*\n", | |
| "%%%%%%%%%%%#=+#*+****#%%\n", | |
| "#%%%%%%%%%#+:=#*+++*#%#*\n", | |
| "*#%%%%%%%%+--=++****%@%#\n", | |
| "+=#@%%%#*==+++**###%%@@%\n", | |
| "++#%%#*=--=++*##%##%@@@%\n", | |
| "#**##=+==++#%%%####%@@@%\n", | |
| "*##%%=+++*%@@@@@@%%%%@@@\n", | |
| "#%@@@**#*#%@@@@@%##**###\n", | |
| "%@@@@##@%%%%%%#***#*****\n", | |
| "%%%%####**********#%****\n", | |
| "##%#**********##%##%####\n", | |
| "#############%%@@@@@%%%%\n", | |
| "************##%%%%%#####\n", | |
| "************************\n", | |
| "************************\n", | |
| "****++++++++*****+**++++\n", | |
| "**************++***+++++\n", | |
| "Output: 0\n", | |
| "###\n", | |
| "\n", | |
| "Input:\n", | |
| "###**********###########\n", | |
| "#***********############\n", | |
| "********#***############\n", | |
| "**#######****###########\n", | |
| "**######*++===+*###*****\n", | |
| "###**+==--==-===*##*****\n", | |
| "*++===----=--====*#***##\n", | |
| "====+==-=====+++++**####\n", | |
| "+*++++===+++*#*+****####\n", | |
| "+**#*++==++=###****#####\n", | |
| "==+**++==+++*****++*####\n", | |
| "++++**+=++*+***#**+*####\n", | |
| "++**++*+++*+*****######*\n", | |
| "+++**+++++*##+****######\n", | |
| "++++****++**++*#########\n", | |
| "+++*+**##**++*#%########\n", | |
| "**+++***#**+**#%########\n", | |
| "***+**********##########\n", | |
| "*******###****##*#******\n", | |
| "*****######***##********\n", | |
| "*****####%#**##*********\n", | |
| "**+***####+*#%#*********\n", | |
| "*****#####+*#%#*********\n", | |
| "#***###%#**###**********\n", | |
| "Output: 2\n", | |
| "###\n", | |
| "\n", | |
| "Input:\n", | |
| "########################\n", | |
| "#####################**#\n", | |
| "#####################**#\n", | |
| "#############***######*#\n", | |
| "#############*++########\n", | |
| "######****####*+########\n", | |
| "#####**++****##**###****\n", | |
| "####***+****+##**###+*#*\n", | |
| "####**+*****++***###+*#*\n", | |
| "#####*+**********##*+*##\n", | |
| "######***********##**###\n", | |
| "######*+**++=++**#%####*\n", | |
| "######*+**+==++++*#####*\n", | |
| "#######*+===+******####*\n", | |
| "########*+**+#%%#*+*##**\n", | |
| "##########%#*%@##*++*#**\n", | |
| "##########**%@%**#***##*\n", | |
| "#########++*%%***##**##*\n", | |
| "########++*###**####**##\n", | |
| "#######*++*###**#*##**##\n", | |
| "######*+*#####**#*#**##*\n", | |
| "#####*+*#####***##***#**\n", | |
| "####*+*#######*##*******\n", | |
| "###*=*########**********\n", | |
| "Output: 2\n", | |
| "###\n", | |
| "\n", | |
| "Input:\n", | |
| "%%%%%%%%%%%%%%%%%%%%#%%#\n", | |
| "%%%%%%%%%%%%%%%%%%%%%%%%\n", | |
| "%%%%%%%%%%%%%%%%%%%%%%%%\n", | |
| "%%%%%%%%%%%%%%%%%%%%%%%%\n", | |
| "%%%%%%%%%%%%%%%%%%%%%%%%\n", | |
| "%%%%%%%%%%%%%%%%%%%%%%%%\n", | |
| "%%%%%%%%%%%%%%%%%%%%%%%%\n", | |
| "%%%%%##%%%%%%%%%%%%*#%%%\n", | |
| "%%%%%#=-*#%%%%%%%%+-*%%%\n", | |
| "%#+++++==+***##**=:=#%%%\n", | |
| "%%*+=--==*#*+=-::``-#%%#\n", | |
| "%###*+++==**++====-=*+=-\n", | |
| "%+=--=++++*++==----=-:::\n", | |
| "%%#+----::::=+-:::::--::\n", | |
| "%##***++-::-=*+==-======\n", | |
| "%#+=======++**+====++***\n", | |
| "%#*+++++++*****+++=+*###\n", | |
| "%%#********####***+*####\n", | |
| "%%##*****#######****####\n", | |
| "%%##########%%####*##%%%\n", | |
| "%%%#######%%%%#######%%%\n", | |
| "%%%%%####%%%%%%#####%%%%\n", | |
| "%%%%%%%%%%%%%%%%####%%%%\n", | |
| "%%%%%%%%%%%%%%%%%###%%%%\n", | |
| "Output: 0\n", | |
| "###\n", | |
| "\n", | |
| "Input:\n", | |
| "`::::::::::`:`::::::::::\n", | |
| ":=+==++++++=---=====+===\n", | |
| ":+++=+++===--:-===++====\n", | |
| ":=======----::--=====+==\n", | |
| ":------=+****+===--=====\n", | |
| "::::::--**##**+==------=\n", | |
| "::::---:+####**=---:---:\n", | |
| ":-:-+**+*####***=--:::::\n", | |
| ":-+*++**#**+==--::```:::\n", | |
| "::+###**#+*++==---::````\n", | |
| "::+######++*++******+:``\n", | |
| "::*#######*+-=+*%%##*-``\n", | |
| "::=#%%##%%#****##***+=:`\n", | |
| ":`-#%%%%%#######*===++-`\n", | |
| ":`-%%#%%%**%%##%%***##+:\n", | |
| ":`-%*=+*#**%%##%@%%%%#+`\n", | |
| ":``-:::-=*#%%%%#*++*%#:`\n", | |
| ":`'''```:+%%+===---=#+`'\n", | |
| ":`'''''''-++----:::::``'\n", | |
| ":`''''''``````````````''\n", | |
| "::``````````````````````\n", | |
| "`::::::::::::::::::::::`\n", | |
| "`::----:````````````````\n", | |
| "`:--==-:`'''''''''''''''\n", | |
| "Output: 1\n", | |
| "###\n", | |
| "\n", | |
| "Input:\n", | |
| "-==-------------::--::::\n", | |
| "==----------::--::-:::::\n", | |
| "=====-------::::::-:::::\n", | |
| "======------::::--::::::\n", | |
| "=====-------:::=*=::::::\n", | |
| "======--------:#%=::::::\n", | |
| "======--------=%%-::::::\n", | |
| "-=========----#%+-::::::\n", | |
| "-===========-=%%=---::::\n", | |
| "-===========-+@%=---::::\n", | |
| "-============*@%=---::::\n", | |
| "---------====#@#----::::\n", | |
| "::::---::---+#@*--------\n", | |
| "::::::::::--#%%=--------\n", | |
| ":::::::::::=@%*+=-------\n", | |
| "::::::::::-%@@#=-=------\n", | |
| "::::::::::*@@#*+==------\n", | |
| ":::::::::+@@@==+-------:\n", | |
| "::--::::-#@@*---------::\n", | |
| "::--::::+%#*-:::::----::\n", | |
| ":--:::::*@*-::::::------\n", | |
| "--:::::-*#=::::::::-:---\n", | |
| "::::::::++:::::::::::---\n", | |
| "::::::::::::::::::::----\n", | |
| "Output: 2\n", | |
| "###\n", | |
| "\n", | |
| "Input:\n", | |
| "#%@%+*@@%%%%%%%%@%######\n", | |
| "#@@**#@@@@@@@@%*%@%#%##*\n", | |
| "#%#**%@@@@@@@@@#*%%###%#\n", | |
| "#%*##%%%####%%%%*######*\n", | |
| "+##*#%####%%##%%%####%*+\n", | |
| "=+*+=***##%##%%%%%#==*+*\n", | |
| "==+*+****######%###-``:-\n", | |
| "-=+******%%%%%%%%%%#:'``\n", | |
| ":=*##*###%%%%%%##%##=```\n", | |
| ":+#**#*###***+++*###*:'`\n", | |
| "-+***#*#*-+++++*####%+``\n", | |
| ":=*#####=:+*+==++*#%%*``\n", | |
| "`:+##%###****=-=+#%##*:`\n", | |
| "`:=*#%###*******#####*-`\n", | |
| "`-*##%########%%%%%%##=`\n", | |
| "`=#%%%%%#%%%%#%%%%%%#%=`\n", | |
| "`-%%%%%%%%%%%%%@@%%#%%=`\n", | |
| "`-#@%%%####%%%%%**%%%%-`\n", | |
| "`:#@@@%%%*+###*#*#%@%#:`\n", | |
| "`:*@@@@@@%%%####%%%@%#-`\n", | |
| "`:*%%@@@@@%%#####@@%%#:`\n", | |
| "::*%%%@@@@@#####%@%%#=``\n", | |
| "::=%%%%%%%%##*##%%%#+:``\n", | |
| "::-+%@@%%%%%%%%%%%*+-:``\n", | |
| "Output: 1\n", | |
| "###\n", | |
| "\n", | |
| "Input:\n", | |
| "#***++++****#**###%%%%%%\n", | |
| "++*******#*########%@%%%\n", | |
| "*****##########%#%%@@%%%\n", | |
| "***#############*#%@%#%%\n", | |
| "######%%########*#%%##%%\n", | |
| "%%%%%%%#########%@%##%%%\n", | |
| "%%%%%#####**####%@%%%%%%\n", | |
| "%%%%%#####***++*#%##%%%%\n", | |
| "%%%#%####*=-----=%%##%%%\n", | |
| "%#%#####+++**#%=`+@@#%%%\n", | |
| "####%%#*###@@@%=:-#@%%%%\n", | |
| "###########%%*=:::*%@%#%\n", | |
| "#####**####=:```:+%%@*=#\n", | |
| "#%##%##%%%*``::`-#%#*==#\n", | |
| "%%+=#%%@@%+:-::::-+====#\n", | |
| "%#:-#%%@@%+---+#+-:-==*%\n", | |
| "%**+*%@@%#+--+%%@*-=+*@@\n", | |
| "%#%#*%@@#*=:-*%%%%++*@@%\n", | |
| "%%#**%@%#*---#%%%@*=#@@%\n", | |
| "%%##*%#*==---#%*#@%*%@@@\n", | |
| "%%##*##*+++++#%##@@%@@@@\n", | |
| "%%%%@@@@@@@@@@%#%@@@@@@@\n", | |
| "%%%%%%%@%%@@@@@%%@@@@@@%\n", | |
| "%%#**#%%####%%%%@@@@@@%%\n", | |
| "Output:\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# 4) Get full prompts for each few-shot prompt, and get GPT outputs\n", | |
| "model_name = 'text-davinci-003'\n", | |
| "all_responses = []\n", | |
| "all_targets = []\n", | |
| "for test_ix in test_eval_ix:\n", | |
| " # Calling this line will give OpenAI money\n", | |
| " eval_prompt, target = get_eval_prompt(test_set_ascii, ix=test_ix, start='')\n", | |
| " _prompt = context_prompt + eval_prompt\n", | |
| " response = get_gpt_output(_prompt, model=model_name)\n", | |
| " all_responses.append(response)\n", | |
| " all_targets.append(target)" | |
| ], | |
| "metadata": { | |
| "id": "-BxHv8bZtgAW" | |
| }, | |
| "execution_count": null, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# 5) Evaluate\n", | |
| "\n", | |
| "# Report accuracy metrics\n", | |
| "avg_acc, class_accs = eval_predictions(\n", | |
| " [decode_response(r) for r in all_responses],\n", | |
| " [test_set.targets[test_ix] for test_ix in test_eval_ix],\n", | |
| ")\n", | |
| "print('')\n", | |
| "\n", | |
| "# Report predicted vs true targets\n", | |
| "print(f'GPT-3 {model_name} few-shot predictions')\n", | |
| "print([decode_response(r) for r in all_responses])\n", | |
| "print('')\n", | |
| "print(f'Ground-truth labels')\n", | |
| "print([test_set.targets[test_ix].item() for test_ix in test_eval_ix])" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "_J3VoNmIuZXH", | |
| "outputId": "40a094f4-4b9a-4063-eb1d-2154688f4100" | |
| }, | |
| "execution_count": null, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "Average acc: 50.0%\n", | |
| "Class accs:\n", | |
| "- 7 / 10 = 70.0%\n", | |
| "- 0 / 10 = 0.0%\n", | |
| "- 8 / 10 = 80.0%\n", | |
| "\n", | |
| "GPT-3 text-davinci-003 few-shot predictions\n", | |
| "[0, 2, 2, 0, 0, 0, 0, 0, 2, 0, 2, 0, 2, 0, 2, 2, 2, 2, 2, 0, 2, 2, 2, 0, 0, 2, 2, 2, 2, 2]\n", | |
| "\n", | |
| "Ground-truth labels\n", | |
| "[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "### Results: CIFAR-10\n", | |
| "So still better than chance in aggregate... but questionable if it actually works. Curious how no labels for cars (`1`'s get outputted), even though as a prior I'd suspect planes and birds to be the confusing distinction.\n", | |
| "\n", | |
| "* There's probably also brittleness and a lot to do with the sample selection and prompting.\n", | |
| "\n", | |
| "Lots of battle-testing to do!" | |
| ], | |
| "metadata": { | |
| "id": "QaLxKRhrzuVR" | |
| } | |
| } | |
| ] | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment