Created
January 13, 2025 06:55
-
-
Save shrijayan/c3595fdffeb37c2f964b5ceb4848cf67 to your computer and use it in GitHub Desktop.
Can We Teach AI to See Through Text? A Fun Exploration with ASCII Art and GPT-3
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [] | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Art Project - Large Language Models can be Few-shot MNIST Classifiers - [Shrijayan](https://shrijayan.cpluz.com)\n", | |
"\n", | |
"Can we get large language models (LLMs) like GPT-3 to do few-shot classification of images?\n", | |
"\n", | |
"A key challenge is how to get a large language model to understand images (channel x width x height pixel matrices). We try to bridge this gap by \"translating\" images to ASCII art.\n", | |
"\n", | |
"The key idea is to prompt an LLM to output a class label given an ASCII image + a few examples of (image, class label) pairs as context.\n", | |
"\n", | |
"### *Hypothesis* (why interesting)\n", | |
"We've seen that LLMs are very good at few-shot learning. They can learn to complete a new task based on only a few unseen examples, all expressed as natural text!\n", | |
"\n", | |
"However, while these models were trained on lots of human-produced text, it may not matter as much what the actual content of the examples are (or if the content is similar to something the LLM encountered during pretraining).\n", | |
"* Instead, to get an LLM to solve a task, the prompt we give it may only need some consistent structure or recognizable patterns across examples. \n", | |
"\n", | |
"One way to test this is to see if an LLM can reason over sequences containing ASCII art of real-world images.\n", | |
"\n", | |
"\n", | |
"\n", | |
"### *Motivations* (why important) and *difficulties* (why hard)\n", | |
"While this is just an art project for fun, there might also be some related usefulness. Personally, I've been interested in \n", | |
"1. How we can adapt benefits of foundation models to work well on individualized use-cases and local points of deployment \n", | |
"2. How to get models to handle new settings without much data (perhaps a new failure mode or edge case surfaces).\n", | |
"\n", | |
"Some difficulties with realizing the above are:\n", | |
"\n", | |
"- Training a sufficient model from scratch may not be feasible, as there may not be enough data in the deployment setting. \n", | |
"- Finetuning a pretrained model may be expensive (large FMs that not everyone has access to, hard to maintain such a model for each setting). \n", | |
"- A pretrained model's out-of-the-box zero-shot behavior may not be desired for a specific deployment scenario.\n", | |
"\n", | |
"### *Solution?*\n", | |
"However, recent LLMs have shown the ability to \"steer\" / change their classification based on a few examples, without having to finetune their weights. This is great because for the above settings, by providing a few examples illustrating the desired classification behavior, we could then get personalized inference. \n", | |
"- But the world is not just all natural language! It'd also be great to extend this capability to more modalities and domains. \n", | |
"- Given that we have these text-trained LLMs lying around, can we use text as a universal modality to do so?" | |
], | |
"metadata": { | |
"id": "qdAMTEAwR3en" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### Setup" | |
], | |
"metadata": { | |
"id": "oH6-uz6qSg7r" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install numpy\n", | |
"!pip install pytorch\n", | |
"!pip install torchvision\n", | |
"!pip install pillow\n", | |
"!pip install openai # Only for using GPT-3 models" | |
], | |
"metadata": { | |
"id": "wj-CgWxsSFyO" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import copy\n", | |
"import torch\n", | |
"import numpy as np\n", | |
"\n", | |
"from PIL import Image" | |
], | |
"metadata": { | |
"id": "hUbNPT5sStwL" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Fake args\n", | |
"class Args():\n", | |
" def __init__(self):\n", | |
" pass\n", | |
"args = Args()\n", | |
"args.seed = 42\n", | |
"args.device = torch.device('cuda:0')\n", | |
"args.display_image = True\n", | |
"args.num_workers = 2\n", | |
"\n", | |
"np.random.seed(args.seed)\n", | |
"torch.manual_seed(args.seed)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "Q6bm0782S26g", | |
"outputId": "22d30d11-4a43-4028-c966-35c284f51c81" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"<torch._C.Generator at 0x7fb4d21ddc70>" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 159 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### Load data\n", | |
"\n", | |
"We need to load images and convert them to ASCII. To ensure this fits into a propt that an LLM can process (frequently 2048 or 4096 tokens), we may also need to downsize the images.\n", | |
"\n", | |
"For this art project, we'll use the MNIST dataset." | |
], | |
"metadata": { | |
"id": "u_8Z7l8YS-v_" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### Image transforms\n", | |
"\n", | |
"We convert each image to a downsized, center-cropped, gray-scaled square image." | |
], | |
"metadata": { | |
"id": "r2AcDPUsTFMo" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import torchvision\n", | |
"from torchvision import transforms\n", | |
"from torchvision.transforms import Compose, Resize, CenterCrop, Grayscale, ToTensor\n", | |
"try:\n", | |
" from torchvision.transforms import InterpolationMode\n", | |
" BICUBIC = InterpolationMode.BICUBIC\n", | |
"except ImportError:\n", | |
" BICUBIC = Image.BICUBIC" | |
], | |
"metadata": { | |
"id": "NXv9l_GYTLFd" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def get_transform(w, grayscale=True, to_tensor=False):\n", | |
" transforms = [\n", | |
" Resize(w, interpolation=BICUBIC),\n", | |
" CenterCrop(w),\n", | |
" ]\n", | |
" if grayscale:\n", | |
" transforms.append(Grayscale(num_output_channels=1))\n", | |
" if to_tensor:\n", | |
" transforms.append(ToTensor())\n", | |
" return Compose(transforms)" | |
], | |
"metadata": { | |
"id": "SBdlcKGpTO4h" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Specify the width of downsized images\n", | |
"args.width = 24" | |
], | |
"metadata": { | |
"id": "SOHVeNVbS4mo" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Actual transform function\n", | |
"transform_grayscale = get_transform(args.width, grayscale=True)\n", | |
"\n", | |
"# Original image (for visualization)\n", | |
"transform_base = get_transform(28, grayscale=False)" | |
], | |
"metadata": { | |
"id": "A2W5hOkvTQtB" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### ASCII conversion \n", | |
"3 options to pick from: \n", | |
"* `ASCII_68`: 68 levels of brightness \n", | |
"* `ASCII_10`: 10 levels of brightness \n", | |
"* `ASCII_10n`: 10 levels of brightness using integers to denote level" | |
], | |
"metadata": { | |
"id": "5CzZaOl2TV4R" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# No periods or spaces bc I thought that might interfere with LLM\n", | |
"ASCII_68 = \"$@B%8&WM#*oahkbdpqwmZO0QLCJUYXzcvunxrjft/\\|()1{}[]?-_+~<>i!lI;:,\\\"^`'\"\n", | |
"ASCII_10 = \"@%#*+=-:`'\"\n", | |
"ASCII_10n = ''.join([str(i) for i in range(len(ASCII_10))[::-1]])" | |
], | |
"metadata": { | |
"id": "RyObH9RMTgfs" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def img_to_text(img, scale=10, numerical=False):\n", | |
" global ASCII_68, ASCII_10, ASCII_10n\n", | |
" if scale > 10:\n", | |
" ascii_scale = ASCII_68\n", | |
" elif numerical:\n", | |
" ascii_scale = ASCII_10n\n", | |
" else:\n", | |
" ascii_scale = ASCII_10\n", | |
"\n", | |
" return ''.join([ascii_scale[int(p * (scale - 1) / 255)] for p in\n", | |
" np.array(img).flatten()])" | |
], | |
"metadata": { | |
"id": "QpCs4OgOTirt" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def pretty_print_ascii(text, width, height):\n", | |
" num_lines = int(len(text) / width)\n", | |
" assert num_lines == height\n", | |
" print_str = []\n", | |
" for i in range(num_lines - 1):\n", | |
" print_str.append(f'{text[i * width: (i + 1) * width]}\\n')\n", | |
" i += 1\n", | |
" print_str.append(f'{text[i * width: (i + 1) * width]}')\n", | |
" print(''.join(print_str))" | |
], | |
"metadata": { | |
"id": "1GsASUR3TkHl" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### See MNIST examples" | |
], | |
"metadata": { | |
"id": "NUWPFFBfTo5F" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Load original data\n", | |
"train_set = torchvision.datasets.MNIST(root='.', train=True, transform=transform_base, download=True)\n", | |
"test_set = torchvision.datasets.MNIST(root='.', train=False, transform=transform_base, download=True)\n" | |
], | |
"metadata": { | |
"id": "HJy-SZmzTnF3" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"train_set.__getitem__(0)[0]" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 45 | |
}, | |
"id": "hPdwx9hVUHao", | |
"outputId": "39d34dde-c9d6-4b27-e515-f69ac302c1b0" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"<PIL.Image.Image image mode=L size=28x28 at 0x7FB4C431FB50>" | |
], | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABAElEQVR4nGNgGMyAWUhIqK5jvdSy/9/rGRgYGFhgEnJsVjYCwQwMDAxPJgV+vniQgYGBgREqZ7iXH8r6l/SV4dn7m8gmCt3++/fv37/Htn3/iMW+gDnZf/+e5WbQnoXNNXyMs/5GoQoxwVmf/n9kSGFiwAW49/11wynJoPzx4YIcRlyygR/+/i2XxCWru+vv32nSuGQFYv/83Y3b4p9/fzpAmSyoMnohpiwM1w5h06Q+5enfv39/bcMiJVF09+/fv39P+mFKiTtd/fv3799jgZiBJLT69t+/f/8eDuDEkDJf8+jv379/v7Ryo4qzMDAwMAQGMjBc3/y35wM2V1IfAABFF16Aa0wAOwAAAABJRU5ErkJggg==\n" | |
}, | |
"metadata": {}, | |
"execution_count": 168 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Load ASCII data\n", | |
"train_set_ascii = torchvision.datasets.MNIST(root='.', train=True,\n", | |
" transform=transform_grayscale, download=False)\n", | |
"test_set_ascii = torchvision.datasets.MNIST(root='.', train=False,\n", | |
" transform=transform_grayscale, download=False)" | |
], | |
"metadata": { | |
"id": "EW7TgezzTzYG" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"pretty_print_ascii(\n", | |
" img_to_text(train_set_ascii.__getitem__(0)[0]),\n", | |
" width=args.width, height=args.width\n", | |
")" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "qzUce8kCUJey", | |
"outputId": "f728ad4d-018b-4eb8-c083-73c2493810ee" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@*+%*-=%@@@\n", | |
"@@@@@@@%%+=::```=``-%@@@\n", | |
"@@@@@@#`'``''''=##%@@@@@\n", | |
"@@@@@@%:````==:*@@@@@@@@\n", | |
"@@@@@@@%#=`=@@%%@@@@@@@@\n", | |
"@@@@@@@@@#`=@@@@@@@@@@@@\n", | |
"@@@@@@@@@@=`%@@@@@@@@@@@\n", | |
"@@@@@@@@@@%::+%@@@@@@@@@\n", | |
"@@@@@@@@@@@#:':#@@@@@@@@\n", | |
"@@@@@@@@@@@@%=`:*@@@@@@@\n", | |
"@@@@@@@@@@@@@@+'`%@@@@@@\n", | |
"@@@@@@@@@@@@@#='`%@@@@@@\n", | |
"@@@@@@@@@@@*-``':@@@@@@@\n", | |
"@@@@@@@@@#-`''`=%@@@@@@@\n", | |
"@@@@@@@%+`''`=%@@@@@@@@@\n", | |
"@@@@%#-`''`=%@@@@@@@@@@@\n", | |
"@@@%:''`:-*@@@@@@@@@@@@@\n", | |
"@@@%***#%@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# LLM Time\n", | |
"\n", | |
"Now we'll create prompts, initialize GPT-3, and get outputs." | |
], | |
"metadata": { | |
"id": "7ivitinuUuSe" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Prompt Creation\n", | |
"\n", | |
"How we format the prompt, and what samples we include in the prompt could affect the LLM output. \n", | |
"\n", | |
"For formatting, we'll just follow a simple few-shot template: \n", | |
"```\n", | |
"Input: [flattened_ascii_image] \n", | |
"Output: [class_label] \n", | |
"###\n", | |
"Input: [flattened_ascii_image] \n", | |
"Output: [class_label] \n", | |
"###\n", | |
"...\n", | |
"Input: [flattened_ascii_image] \n", | |
"Output: \n", | |
"``` \n", | |
"\n", | |
"For selecting samples to include, there are many potentially fancy ways to do this. For classification with many potential groups, it'd be good to include at least one example of each group. \n", | |
"\n", | |
"For now, we'll just randomly sample from each group (e.g., MNIST digit class)." | |
], | |
"metadata": { | |
"id": "aesYe1OmU5EW" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Helpers" | |
], | |
"metadata": { | |
"id": "eWkwQlBuVDYf" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def get_indices_by_class(dataset):\n", | |
" \"\"\"\n", | |
" Organize data indices into classes\n", | |
" \"\"\"\n", | |
" indices_by_class = []\n", | |
" num_classes = len(np.unique(dataset.targets))\n", | |
" for i in np.arange(num_classes):\n", | |
" indices_class = np.where(dataset.targets == i)[0]\n", | |
" indices_by_class.append(indices_class)\n", | |
" return indices_by_class" | |
], | |
"metadata": { | |
"id": "HiDLQbWwUQX8" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def get_random_indices_by_class(indices_by_class, classes, num_samples, sample_seed):\n", | |
" \"\"\"\n", | |
" Randomly sample samples from each group\n", | |
" \"\"\"\n", | |
" np.random.seed(sample_seed)\n", | |
" reference_sample_ix = []\n", | |
" reference_sample_y = []\n", | |
" for i in classes:\n", | |
" sample_ix = np.random.choice(indices_by_class[i], size=num_samples, replace=False)\n", | |
" reference_sample_ix.append(sample_ix)\n", | |
" reference_sample_y.append(np.repeat([i], num_samples))\n", | |
" return np.concatenate(reference_sample_ix), np.concatenate(reference_sample_y)" | |
], | |
"metadata": { | |
"id": "Nrp7hWBgVHR1" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def get_context_prompt(dataset, indices, pretty_print=False,\n", | |
" start='[START]\\n', end='[END]\\n', sep='###\\n',\n", | |
" shuffle=False, seed=42):\n", | |
" \"\"\"\n", | |
" Given dataset, dataset indices, retrieve ASCII images, ground-truth labels,\n", | |
" and format into the context of a prompt\n", | |
" \"\"\"\n", | |
" prompt = []\n", | |
" # Template: start, x, y, end, sep\n", | |
" sample = '{}Input: {}\\nOutput: {}\\n{}{}'\n", | |
" if pretty_print: # Support this later\n", | |
" raise NotImplementedError\n", | |
" else:\n", | |
" for i in indices:\n", | |
" data_x, data_y = dataset.__getitem__(i)\n", | |
" prompt.append(copy.copy(sample).format(\n", | |
" start, img_to_text(data_x), data_y, end, sep\n", | |
" ))\n", | |
" shuffle_ix = np.arange(len(indices))\n", | |
" if shuffle:\n", | |
" np.random.seed(seed)\n", | |
" np.random.shuffle(shuffle_ix)\n", | |
" prompt = [prompt[ix] for ix in shuffle_ix]\n", | |
" prompt = ''.join(prompt)\n", | |
"\n", | |
" return prompt, indices[shuffle_ix]" | |
], | |
"metadata": { | |
"id": "xtebS4E1VJK_" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def get_eval_prompt(dataset, ix, pretty_print=False,\n", | |
" start='[START]\\n', end='[END]\\n', sep='###\\n'):\n", | |
" \"\"\"\n", | |
" Format image we want to classify into a prompt\n", | |
" \"\"\"\n", | |
" # Template: start, x\n", | |
" prompt = '{}Input: {}\\nOutput: '\n", | |
" data_x, data_y = dataset.__getitem__(ix)\n", | |
" return prompt.format(start, img_to_text(data_x)), data_y\n", | |
"" | |
], | |
"metadata": { | |
"id": "WeyOlXlMVK-b" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Easy-mode Image Classification\n", | |
"\n", | |
"For now, we'll stick to classifying only 3 MNIST digit classes (0s, 1s, and 2s).\n", | |
"\n", | |
"(Later we can try doing more, but I don't have the budget to call OpenAI's API all day for inference and evaluation)" | |
], | |
"metadata": { | |
"id": "Hnd4v61rVZZ5" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Organize train and test into classes\n", | |
"num_classes = len(torch.unique(train_set.targets))\n", | |
"\n", | |
"# TRAIN\n", | |
"indices_by_class = get_indices_by_class(train_set)\n", | |
"\n", | |
"# TEST\n", | |
"test_indices_by_class = get_indices_by_class(test_set)" | |
], | |
"metadata": { | |
"id": "mv9OGYaEWLgL" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class_indices = [0, 1, 2] # These are the classes we'll evaluate" | |
], | |
"metadata": { | |
"id": "OTgwRrO2WNi6" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### Visualizing a formatted prompt \"context\" \n", | |
"We can now visualize a few-shot prompt, without the sample we want to classify:" | |
], | |
"metadata": { | |
"id": "DiCRGe9UWRPB" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"args.seed_cls = 42\n", | |
"args.num_samples = 1\n", | |
"img_dims = {'width': args.width, 'height': args.width}\n", | |
"classes = range(num_classes)\n", | |
"\n", | |
"context_ix, context_y = get_random_indices_by_class(\n", | |
" indices_by_class,\n", | |
" class_indices,\n", | |
" args.num_samples,\n", | |
" args.seed_cls\n", | |
")\n", | |
"\n", | |
"for i, ix in enumerate(context_ix):\n", | |
" print(f'Input:')\n", | |
" pretty_print_ascii(\n", | |
" img_to_text(train_set_ascii.__getitem__(ix)[0]),\n", | |
" **img_dims\n", | |
" )\n", | |
" print(f'Output: {train_set_ascii.__getitem__(ix)[1]}\\n###\\n')\n", | |
" # Unrolled\n", | |
" # print(img_to_text(train_set_ascii.__getitem__(ix)[0]))" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "y9gDbEs2WP13", | |
"outputId": "53321ce9-f443-4b55-912d-015f60ca0b2e" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Input:\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@%+===*@@@@@@@\n", | |
"@@@@@@@@@%=:`'''`-@@@@@@\n", | |
"@@@@@@@@+``''```''+@@@@@\n", | |
"@@@@@@@=''''''''`'-@@@@@\n", | |
"@@@@@@:```:==-:````+@@@@\n", | |
"@@@@@@`''`#@@@%-'``=@@@@\n", | |
"@@@@@#``'`%@@@@%:'':%@@@\n", | |
"@@@@@:`':#@@@@@@='':%@@@\n", | |
"@@@@+`'-@@@@@@@@='':%@@@\n", | |
"@@@@=``=@@@@@@@@-'':%@@@\n", | |
"@@@@-'`#@@@@@@@+`'`=@@@@\n", | |
"@@@%:':@@@@@@@%:''`#@@@@\n", | |
"@@@*``-@@@@@%#-```=@@@@@\n", | |
"@@@='``-=##=:`'```*@@@@@\n", | |
"@@@=``''```'''````@@@@@@\n", | |
"@@@@-'```''```'`-@@@@@@@\n", | |
"@@@@+:'''''''':=%@@@@@@@\n", | |
"@@@@%%=======+%@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"Output: 0\n", | |
"###\n", | |
"\n", | |
"Input:\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@#@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@-%@@@@@@@@@@@\n", | |
"@@@@@@@@@@@=%@@@@@@@@@@@\n", | |
"@@@@@@@@@@@=*@@@@@@@@@@@\n", | |
"@@@@@@@@@@@=+@@@@@@@@@@@\n", | |
"@@@@@@@@@@@--@@@@@@@@@@@\n", | |
"@@@@@@@@@@@--@@@@@@@@@@@\n", | |
"@@@@@@@@@@@=:@@@@@@@@@@@\n", | |
"@@@@@@@@@@@*`@@@@@@@@@@@\n", | |
"@@@@@@@@@@@#`@@@@@@@@@@@\n", | |
"@@@@@@@@@@@%:@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@-*@@@@@@@@@@\n", | |
"@@@@@@@@@@@@==@@@@@@@@@@\n", | |
"@@@@@@@@@@@@=-@@@@@@@@@@\n", | |
"@@@@@@@@@@@@*:%@@@@@@@@@\n", | |
"@@@@@@@@@@@@%:#@@@@@@@@@\n", | |
"@@@@@@@@@@@@@:+@@@@@@@@@\n", | |
"@@@@@@@@@@@@@**@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"Output: 1\n", | |
"###\n", | |
"\n", | |
"Input:\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@%*#@@@@@@\n", | |
"@@@@@@@@@@@@%+::``-*@@@@\n", | |
"@@@@@@@@@@@=`'''`'`+@@@@\n", | |
"@@@@@@@@@@+''`:````*@@@@\n", | |
"@@@@@@@@@@%::**:'':@@@@@\n", | |
"@@@@@@@@@@@%%@*```#@@@@@\n", | |
"@@@@@@@@@@@@@*:'':@@@@@@\n", | |
"@@@@@%%%%%%+=`````-=%@@@\n", | |
"@@@@*`'''''''`''''''=@@@\n", | |
"@@@*`'``````''::::::+@@@\n", | |
"@@@='``````'`+#%%%%%%@@@\n", | |
"@@@='````''=%@@@@@@@@@@@\n", | |
"@@@+`````+%@@@@@@@@@@@@@\n", | |
"@@@@%***%@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"Output: 2\n", | |
"###\n", | |
"\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### Visualizing a full prompt\n", | |
"\n", | |
"Below we'll see a prompt that we can actually input to an LLM. For now we linearize the ASCII inputs, but it might be good to try the formatted ASCII too." | |
], | |
"metadata": { | |
"id": "9b9CV5YiWY61" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"context_prompt, shuffled_context_ix = get_context_prompt(\n", | |
" train_set_ascii, context_ix,\n", | |
" start='', end='', sep='###\\n', shuffle=True, seed=42\n", | |
")\n", | |
"\n", | |
"eval_prompt, eval_target = get_eval_prompt(\n", | |
" test_set_ascii, ix=test_indices_by_class[0][0], start=''\n", | |
")\n", | |
"# Linearized (what we'll actually send to an LLM)\n", | |
"print(context_prompt + eval_prompt)\n" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "Qj6YnnAPWYFT", | |
"outputId": "26da1599-c2c1-4e14-a11a-6b43d4ffb0dc" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Input: @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@%+===*@@@@@@@@@@@@@@@@%=:`'''`-@@@@@@@@@@@@@@+``''```''+@@@@@@@@@@@@=''''''''`'-@@@@@@@@@@@:```:==-:````+@@@@@@@@@@`''`#@@@%-'``=@@@@@@@@@#``'`%@@@@%:'':%@@@@@@@@:`':#@@@@@@='':%@@@@@@@+`'-@@@@@@@@='':%@@@@@@@=``=@@@@@@@@-'':%@@@@@@@-'`#@@@@@@@+`'`=@@@@@@@%:':@@@@@@@%:''`#@@@@@@@*``-@@@@@%#-```=@@@@@@@@='``-=##=:`'```*@@@@@@@@=``''```'''````@@@@@@@@@@-'```''```'`-@@@@@@@@@@@+:'''''''':=%@@@@@@@@@@@%%=======+%@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"Output: 0\n", | |
"###\n", | |
"Input: @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@#@@@@@@@@@@@@@@@@@@@@@@@-%@@@@@@@@@@@@@@@@@@@@@@=%@@@@@@@@@@@@@@@@@@@@@@=*@@@@@@@@@@@@@@@@@@@@@@=+@@@@@@@@@@@@@@@@@@@@@@--@@@@@@@@@@@@@@@@@@@@@@--@@@@@@@@@@@@@@@@@@@@@@=:@@@@@@@@@@@@@@@@@@@@@@*`@@@@@@@@@@@@@@@@@@@@@@#`@@@@@@@@@@@@@@@@@@@@@@%:@@@@@@@@@@@@@@@@@@@@@@@-*@@@@@@@@@@@@@@@@@@@@@@==@@@@@@@@@@@@@@@@@@@@@@=-@@@@@@@@@@@@@@@@@@@@@@*:%@@@@@@@@@@@@@@@@@@@@@%:#@@@@@@@@@@@@@@@@@@@@@@:+@@@@@@@@@@@@@@@@@@@@@@**@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"Output: 1\n", | |
"###\n", | |
"Input: @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@%*#@@@@@@@@@@@@@@@@@@%+::``-*@@@@@@@@@@@@@@@=`'''`'`+@@@@@@@@@@@@@@+''`:````*@@@@@@@@@@@@@@%::**:'':@@@@@@@@@@@@@@@@%%@*```#@@@@@@@@@@@@@@@@@@*:'':@@@@@@@@@@@%%%%%%+=`````-=%@@@@@@@*`'''''''`''''''=@@@@@@*`'``````''::::::+@@@@@@='``````'`+#%%%%%%@@@@@@='````''=%@@@@@@@@@@@@@@+`````+%@@@@@@@@@@@@@@@@@%***%@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"Output: 2\n", | |
"###\n", | |
"Input: @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@#=*@@@@@@@@@@@@@@@@@@@@@:':%@@@@@@@@@@@@@@@@@@@+`':%@@@@@@@@@@@@@@@@@%=````=*@@@@@@@@@@@@@@@@+'``'''`+@@@@@@@@@@@@@@#`'`':=-`'#@@@@@@@@@@@@%-`'`-#@@-'-%@@@@@@@@@@@#''`*%@@@#`':@@@@@@@@@@@#'`*@@@@@@='`@@@@@@@@@@@#'=@@@@@@@='`*@@@@@@@@@@-'=@@@@@@#`'`@@@@@@@@@@@:'=@@@@@*:'`=@@@@@@@@@@@:'=@@@@+`'`:@@@@@@@@@@@@:'=%+--```'#@@@@@@@@@@@@+'```'''``-@@@@@@@@@@@@@%:`'``'`:*@@@@@@@@@@@@@@@%=`'`:-%@@@@@@@@@@@@@@@@@@#+*%@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"Output: \n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"eval_target # Double check this should be 0" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "D8M-aUoqWnON", | |
"outputId": "f3a1e852-7def-45d3-8024-605a92f72e31" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"0" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 179 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Formatting the above (for sanity checking)\n", | |
"for i, ix in enumerate(shuffled_context_ix):\n", | |
" print(f'Input:')\n", | |
" pretty_print_ascii(\n", | |
" img_to_text(train_set_ascii.__getitem__(ix)[0]),\n", | |
" **img_dims\n", | |
" )\n", | |
" print(f'Output: {train_set_ascii.__getitem__(ix)[1]}\\n###\\n')\n", | |
"\n", | |
"# Eval sample\n", | |
"print(f'Input:')\n", | |
"pretty_print_ascii(\n", | |
" img_to_text(test_set_ascii.__getitem__(test_indices_by_class[0][0])[0]),\n", | |
" **img_dims\n", | |
")\n", | |
"print(f'Output:')" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "X22Ad8hjw7Fq", | |
"outputId": "52ed577c-657c-49ce-bfe2-2252768998c3" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Input:\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@%+===*@@@@@@@\n", | |
"@@@@@@@@@%=:`'''`-@@@@@@\n", | |
"@@@@@@@@+``''```''+@@@@@\n", | |
"@@@@@@@=''''''''`'-@@@@@\n", | |
"@@@@@@:```:==-:````+@@@@\n", | |
"@@@@@@`''`#@@@%-'``=@@@@\n", | |
"@@@@@#``'`%@@@@%:'':%@@@\n", | |
"@@@@@:`':#@@@@@@='':%@@@\n", | |
"@@@@+`'-@@@@@@@@='':%@@@\n", | |
"@@@@=``=@@@@@@@@-'':%@@@\n", | |
"@@@@-'`#@@@@@@@+`'`=@@@@\n", | |
"@@@%:':@@@@@@@%:''`#@@@@\n", | |
"@@@*``-@@@@@%#-```=@@@@@\n", | |
"@@@='``-=##=:`'```*@@@@@\n", | |
"@@@=``''```'''````@@@@@@\n", | |
"@@@@-'```''```'`-@@@@@@@\n", | |
"@@@@+:'''''''':=%@@@@@@@\n", | |
"@@@@%%=======+%@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"Output: 0\n", | |
"###\n", | |
"\n", | |
"Input:\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@#@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@-%@@@@@@@@@@@\n", | |
"@@@@@@@@@@@=%@@@@@@@@@@@\n", | |
"@@@@@@@@@@@=*@@@@@@@@@@@\n", | |
"@@@@@@@@@@@=+@@@@@@@@@@@\n", | |
"@@@@@@@@@@@--@@@@@@@@@@@\n", | |
"@@@@@@@@@@@--@@@@@@@@@@@\n", | |
"@@@@@@@@@@@=:@@@@@@@@@@@\n", | |
"@@@@@@@@@@@*`@@@@@@@@@@@\n", | |
"@@@@@@@@@@@#`@@@@@@@@@@@\n", | |
"@@@@@@@@@@@%:@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@-*@@@@@@@@@@\n", | |
"@@@@@@@@@@@@==@@@@@@@@@@\n", | |
"@@@@@@@@@@@@=-@@@@@@@@@@\n", | |
"@@@@@@@@@@@@*:%@@@@@@@@@\n", | |
"@@@@@@@@@@@@%:#@@@@@@@@@\n", | |
"@@@@@@@@@@@@@:+@@@@@@@@@\n", | |
"@@@@@@@@@@@@@**@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"Output: 1\n", | |
"###\n", | |
"\n", | |
"Input:\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@%*#@@@@@@\n", | |
"@@@@@@@@@@@@%+::``-*@@@@\n", | |
"@@@@@@@@@@@=`'''`'`+@@@@\n", | |
"@@@@@@@@@@+''`:````*@@@@\n", | |
"@@@@@@@@@@%::**:'':@@@@@\n", | |
"@@@@@@@@@@@%%@*```#@@@@@\n", | |
"@@@@@@@@@@@@@*:'':@@@@@@\n", | |
"@@@@@%%%%%%+=`````-=%@@@\n", | |
"@@@@*`'''''''`''''''=@@@\n", | |
"@@@*`'``````''::::::+@@@\n", | |
"@@@='``````'`+#%%%%%%@@@\n", | |
"@@@='````''=%@@@@@@@@@@@\n", | |
"@@@+`````+%@@@@@@@@@@@@@\n", | |
"@@@@%***%@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"Output: 2\n", | |
"###\n", | |
"\n", | |
"Input:\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@#=*@@@@@@@@@@\n", | |
"@@@@@@@@@@@:':%@@@@@@@@@\n", | |
"@@@@@@@@@@+`':%@@@@@@@@@\n", | |
"@@@@@@@@%=````=*@@@@@@@@\n", | |
"@@@@@@@@+'``'''`+@@@@@@@\n", | |
"@@@@@@@#`'`':=-`'#@@@@@@\n", | |
"@@@@@@%-`'`-#@@-'-%@@@@@\n", | |
"@@@@@@#''`*%@@@#`':@@@@@\n", | |
"@@@@@@#'`*@@@@@@='`@@@@@\n", | |
"@@@@@@#'=@@@@@@@='`*@@@@\n", | |
"@@@@@@-'=@@@@@@#`'`@@@@@\n", | |
"@@@@@@:'=@@@@@*:'`=@@@@@\n", | |
"@@@@@@:'=@@@@+`'`:@@@@@@\n", | |
"@@@@@@:'=%+--```'#@@@@@@\n", | |
"@@@@@@+'```'''``-@@@@@@@\n", | |
"@@@@@@%:`'``'`:*@@@@@@@@\n", | |
"@@@@@@@%=`'`:-%@@@@@@@@@\n", | |
"@@@@@@@@@#+*%@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"@@@@@@@@@@@@@@@@@@@@@@@@\n", | |
"Output:\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# GPT-3 Few-shot Classification\n", | |
"\n", | |
"Time to actually classify MNIST with a language model." | |
], | |
"metadata": { | |
"id": "AGoCzDdAWp1l" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import openai" | |
], | |
"metadata": { | |
"id": "LD9VrIw0XI9v" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Some helpers\n", | |
"def get_gpt_output(prompt, model):\n", | |
" response = openai.Completion.create(\n", | |
" prompt=prompt,\n", | |
" model=model,\n", | |
" # Rest are defaults\n", | |
" temperature=0,\n", | |
" max_tokens=7,\n", | |
" top_p=1.0,\n", | |
" n=1,\n", | |
" )\n", | |
" return response\n", | |
"\n", | |
"def decode_response(response):\n", | |
" result = response['choices'][0]['text']\n", | |
" try:\n", | |
" return int(result)\n", | |
" except Exception as e:\n", | |
" print(e)\n", | |
" return -1" | |
], | |
"metadata": { | |
"id": "UkvSNAVJXRQM" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### Evaluation\n", | |
"\n", | |
"For evaluation, we'll pick a couple examples from the MNIST test set. In particular, we'll sample an equal number of datapoints from each class. We'll test the performance of the LLM via overall and class-specific accuracy." | |
], | |
"metadata": { | |
"id": "T64YRAhCXVz1" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def eval_predictions(predictions, classes):\n", | |
" \"\"\"\n", | |
" Report average and group-wise accuracies\n", | |
" \"\"\"\n", | |
" correct = np.array(predictions) == np.array(classes)\n", | |
" class_correct = []\n", | |
" class_total = []\n", | |
" for ix, c in enumerate(np.unique(classes)):\n", | |
" c_ix = np.where(np.array(classes) == c)[0]\n", | |
" class_correct.append(correct[c_ix].sum())\n", | |
" class_total.append(len(c_ix))\n", | |
" avg_acc = np.mean(correct) * 100\n", | |
" class_accs = np.array(class_correct) / np.array(class_total) * 100\n", | |
" print(f'Average acc: {avg_acc:.1f}%')\n", | |
" print(f'Class accs:')\n", | |
" for ix, c in enumerate(class_accs):\n", | |
" print(f'- {class_correct[ix]} / {class_total[ix]} = {c:.1f}%')\n", | |
" return avg_acc, class_accs" | |
], | |
"metadata": { | |
"id": "ALRaEE-FX_9v" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### Few-shot Prompt Template\n", | |
"To make prompts, we'll use the same context for each sample: an equal number of samples from each group in the MNIST train set." | |
], | |
"metadata": { | |
"id": "x5861FmKX8Qw" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Actual GPT-3 Demo" | |
], | |
"metadata": { | |
"id": "-5CBGw16YEJQ" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"openai.api_key = \"sk-\" # Your OpenAI API key goes here" | |
], | |
"metadata": { | |
"id": "pHiH1Tj8YDUS" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"models = [\n", | |
" 'text-ada-001',\n", | |
" 'text-babbage-001',\n", | |
" 'text-curie-001',\n", | |
" 'text-davinci-002'\n", | |
"]" | |
], | |
"metadata": { | |
"id": "reMB79l6YInh" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# 1) Get samples to serve as \"context\"\n", | |
"args.seed_train = 42\n", | |
"args.train_samples_per_class = 4\n", | |
"img_dims = {'width': args.width, 'height': args.width}\n", | |
"\n", | |
"context_ix, context_y = get_random_indices_by_class(\n", | |
" indices_by_class,\n", | |
" class_indices,\n", | |
" args.train_samples_per_class,\n", | |
" args.seed_train\n", | |
")" | |
], | |
"metadata": { | |
"id": "z32une07YJyg" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# 2) Get examples to few-shot classify\n", | |
"args.seed_eval = 42\n", | |
"args.eval_samples_per_group = 10\n", | |
"test_eval_ix, test_eval_y = get_random_indices_by_class(\n", | |
" test_indices_by_class,\n", | |
" class_indices,\n", | |
" args.eval_samples_per_group,\n", | |
" args.seed_eval\n", | |
")" | |
], | |
"metadata": { | |
"id": "dJ7lQKYFYLLS" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# 3) Get prompt context\n", | |
"for seed in range(20):\n", | |
" context_prompt, shuffled_context_ix = get_context_prompt(\n", | |
" train_set_ascii, context_ix,\n", | |
" start='', end='', sep='###\\n', shuffle=True,\n", | |
" seed=seed # May need to change this to shuffle\n", | |
" )\n", | |
"\n", | |
" # Inspect context order (want to avoid patterns between outputs, e.g., 0, 1, 0, 1, ...)\n", | |
" print(seed, train_set.targets[shuffled_context_ix])" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "goTJZGxaYMsi", | |
"outputId": "fbfc3326-c66a-40d8-bf4b-8fbf3ee19260" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"0 tensor([1, 2, 1, 2, 0, 2, 0, 1, 2, 0, 0, 1])\n", | |
"1 tensor([0, 0, 1, 2, 0, 1, 0, 1, 2, 2, 2, 1])\n", | |
"2 tensor([2, 1, 1, 1, 0, 2, 0, 0, 0, 1, 2, 2])\n", | |
"3 tensor([1, 1, 0, 0, 2, 1, 1, 0, 0, 2, 2, 2])\n", | |
"4 tensor([0, 1, 1, 2, 2, 0, 2, 0, 0, 1, 1, 2])\n", | |
"5 tensor([1, 1, 0, 2, 2, 1, 2, 0, 0, 2, 1, 0])\n", | |
"6 tensor([2, 2, 0, 1, 0, 1, 1, 0, 1, 0, 2, 2])\n", | |
"7 tensor([1, 2, 0, 1, 0, 0, 2, 2, 0, 1, 2, 1])\n", | |
"8 tensor([1, 2, 2, 2, 1, 2, 0, 0, 1, 0, 1, 0])\n", | |
"9 tensor([1, 2, 0, 0, 0, 1, 0, 2, 2, 2, 1, 1])\n", | |
"10 tensor([0, 1, 1, 1, 2, 2, 2, 0, 0, 0, 1, 2])\n", | |
"11 tensor([1, 0, 2, 2, 0, 2, 1, 1, 1, 0, 0, 2])\n", | |
"12 tensor([2, 1, 2, 2, 0, 1, 1, 0, 0, 0, 1, 2])\n", | |
"13 tensor([1, 2, 0, 1, 2, 0, 1, 2, 1, 2, 0, 0])\n", | |
"14 tensor([0, 2, 0, 1, 1, 0, 2, 0, 1, 1, 2, 2])\n", | |
"15 tensor([2, 1, 0, 0, 2, 0, 1, 0, 1, 2, 1, 2])\n", | |
"16 tensor([1, 0, 0, 2, 1, 2, 1, 0, 0, 1, 2, 2])\n", | |
"17 tensor([0, 0, 2, 2, 1, 1, 2, 1, 0, 2, 1, 0])\n", | |
"18 tensor([1, 1, 2, 0, 2, 0, 0, 1, 1, 2, 0, 2])\n", | |
"19 tensor([1, 0, 2, 2, 0, 1, 2, 1, 0, 2, 0, 1])\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# 3.1) Get prompt context (cont.)\n", | |
"context_prompt, shuffled_context_ix = get_context_prompt(\n", | |
" train_set_ascii, context_ix,\n", | |
" start='', end='', sep='###\\n', shuffle=True,\n", | |
" seed=16 # Pick a random-looking one from above\n", | |
")" | |
], | |
"metadata": { | |
"id": "i3EzD6I3YQbp" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# 4) Get full prompts for each few-shot prompt, and get GPT outputs\n", | |
"model_name = 'text-davinci-003'\n", | |
"all_responses = []\n", | |
"all_targets = []\n", | |
"for test_ix in test_eval_ix:\n", | |
" # Calling this line will give OpenAI money\n", | |
" eval_prompt, target = get_eval_prompt(test_set_ascii, ix=test_ix, start='')\n", | |
" _prompt = context_prompt + eval_prompt\n", | |
" response = get_gpt_output(_prompt, model=model_name)\n", | |
" all_responses.append(response)\n", | |
" all_targets.append(target)" | |
], | |
"metadata": { | |
"id": "prLD5YMLYZUr" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# 5) Evaluate\n", | |
"\n", | |
"# Report accuracy metrics\n", | |
"avg_acc, class_accs = eval_predictions(\n", | |
" [decode_response(r) for r in all_responses],\n", | |
" [test_set.targets[test_ix] for test_ix in test_eval_ix],\n", | |
")\n", | |
"\n", | |
"# Report predicted vs true targets\n", | |
"print(f'GPT-3 {model_name} few-shot predictions')\n", | |
"print([decode_response(r) for r in all_responses])\n", | |
"\n", | |
"print(f'Ground-truth labels')\n", | |
"print([test_set.targets[test_ix].item() for test_ix in test_eval_ix])" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "siiYoSPGYm5w", | |
"outputId": "91fbfbda-5e52-4090-825f-3a0a5ad457b5" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Average acc: 66.7%\n", | |
"Class accs:\n", | |
"- 4 / 10 = 40.0%\n", | |
"- 10 / 10 = 100.0%\n", | |
"- 6 / 10 = 60.0%\n", | |
"GPT-3 text-davinci-003 few-shot predictions\n", | |
"[0, 2, 0, 0, 2, 0, 2, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 3, 2, 2, 0, 0, 2, 2, 2, 2]\n", | |
"Ground-truth labels\n", | |
"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### Results\n", | |
"So... better than chance! Our toy idea *seems* to work on our toy example (of course not battle-tested across seeds, different prompts, etc...) \n", | |
"* If you wanted to twist your arms into interpretating the above, seems plausible that the `1`s are more distinctive than the `0`s and `2`s, and hence the higher accuracy (also good precision and recall there).\n", | |
"\n", | |
"No further analysis here; just an art project. But some bonuses below!" | |
], | |
"metadata": { | |
"id": "CX8Nmw4x0Zxz" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Bonus 1: ChatGPT\n", | |
"\n", | |
"We can also try the above with ChatGPT. I'm not going to because I don't want to paste all these examples into their text interface." | |
], | |
"metadata": { | |
"id": "Mw0cFIhVY1TW" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# # Run this block to get prompts to paste for ChatGPT\n", | |
"# str_classes = ', '.join([f\"'{c}'\" for c in class_indices])\n", | |
"# str_classes = '[' + str_classes + ']'\n", | |
"# prompt_prefix = f\"In the text below, what is the most likely next character? Choose from a character in {str_classes}:\"\n", | |
"# all_prompts = []\n", | |
"# all_targets = []\n", | |
"# for test_ix in test_eval_ix:\n", | |
"# eval_prompt, eval_target = get_eval_prompt(\n", | |
"# test_set_ascii, ix=test_ix, start=''\n", | |
"# )\n", | |
"# _prompt = context_prompt + eval_prompt\n", | |
"# print('\\n', '-' * 20, f'Sample {test_ix} (target = {eval_target})', '-' * 20)\n", | |
"# print(prompt_prefix)\n", | |
"# print(_prompt)\n", | |
"# all_prompts.append(_prompt)\n", | |
"# all_targets.append(eval_target)" | |
], | |
"metadata": { | |
"id": "OuG1VxwEZAJm" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Bonus 2: CIFAR-10 classification (CIFAR-3)\n", | |
"\n", | |
"So we get better than random accuracy on MNIST. But what about \"real\" images?\n", | |
"\n", | |
"Like CIFAR-10? (Yes `channel x height x width` dims of `3 x 32 x 32` is hardly real, but it's progress)\n", | |
"\n", | |
"Let's try it out with the exact pipeline above!" | |
], | |
"metadata": { | |
"id": "6URKLSmzmc3N" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Load original data\n", | |
"train_set = torchvision.datasets.CIFAR10(root='.', train=True, transform=transform_base, download=True)\n", | |
"test_set = torchvision.datasets.CIFAR10(root='.', train=False, transform=transform_base, download=True)\n", | |
"\n", | |
"train_set.targets = torch.tensor(train_set.targets) # CIFAR-10 targets are a list by default\n", | |
"test_set.targets = torch.tensor(test_set.targets)\n" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "RJQ0mOt2oo7Z", | |
"outputId": "9bcc8bf8-24d1-4698-ae0d-f4c050b5d1e9" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Files already downloaded and verified\n", | |
"Files already downloaded and verified\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Load ASCII data\n", | |
"train_set_ascii = torchvision.datasets.CIFAR10(root='.', train=True,\n", | |
" transform=transform_grayscale, download=False)\n", | |
"test_set_ascii = torchvision.datasets.CIFAR10(root='.', train=False,\n", | |
" transform=transform_grayscale, download=False)\n", | |
"\n", | |
"train_set_ascii.targets = torch.tensor(train_set_ascii.targets)\n", | |
"test_set_ascii.targets = torch.tensor(test_set_ascii.targets)" | |
], | |
"metadata": { | |
"id": "TW6Xu-2Trxt5" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Also do subsampled evaluation by class\n", | |
"class_names = ('plane', 'car', 'bird', 'cat',\n", | |
" 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')" | |
], | |
"metadata": { | |
"id": "8zyO_YyFpf8H" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Organize train and test into classes\n", | |
"num_classes = len(np.unique(train_set.targets)) # 10\n", | |
"\n", | |
"# TRAIN\n", | |
"indices_by_class = get_indices_by_class(train_set)\n", | |
"\n", | |
"# TEST\n", | |
"test_indices_by_class = get_indices_by_class(test_set)" | |
], | |
"metadata": { | |
"id": "1HE-RWKNp2jq" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class_indices = [0, 1, 2] # These are the classes we'll evaluate -> planes vs cars vs birds\n", | |
"print([class_names[i] for i in class_indices])" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "wjndnBc6pliP", | |
"outputId": "42b9e68a-6cec-44f8-a815-cdaee12ed78f" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"['plane', 'car', 'bird']\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install matplotlib\n", | |
"\n", | |
"import matplotlib.pyplot as plt # so we can visualize the images below" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "wWA2EzbisliH", | |
"outputId": "88fe462f-35d8-4d1e-879e-e221a04916e6" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", | |
"Requirement already satisfied: matplotlib in /usr/local/lib/python3.8/dist-packages (3.2.2)\n", | |
"Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib) (1.4.4)\n", | |
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.8/dist-packages (from matplotlib) (0.11.0)\n", | |
"Requirement already satisfied: numpy>=1.11 in /usr/local/lib/python3.8/dist-packages (from matplotlib) (1.21.6)\n", | |
"Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib) (2.8.2)\n", | |
"Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib) (3.0.9)\n", | |
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.8/dist-packages (from python-dateutil>=2.1->matplotlib) (1.15.0)\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Visualize original image vs ASCII\n", | |
"args.seed_cls = 42\n", | |
"args.num_samples = 1\n", | |
"img_dims = {'width': args.width, 'height': args.width}\n", | |
"classes = range(num_classes)\n", | |
"\n", | |
"context_ix, context_y = get_random_indices_by_class(\n", | |
" indices_by_class,\n", | |
" class_indices,\n", | |
" args.num_samples,\n", | |
" args.seed_cls\n", | |
")\n", | |
"\n", | |
"# Show samples\n", | |
"for i, ix in enumerate(context_ix):\n", | |
" target = train_set_ascii.__getitem__(ix)[1]\n", | |
" print('=' * 4, f'Sample {i}, class = {class_names[target]}', '=' * 4)\n", | |
" print(f'Original:')\n", | |
" plt.imshow(train_set.__getitem__(ix)[0])\n", | |
" plt.show()\n", | |
"\n", | |
" print(f'\\nASCII: ')\n", | |
" pretty_print_ascii(\n", | |
" img_to_text(train_set_ascii.__getitem__(ix)[0]),\n", | |
" **img_dims\n", | |
" )\n", | |
" print('\\n')\n" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 1000 | |
}, | |
"id": "OOrij5e8uB1L", | |
"outputId": "37ea0d0c-7752-4f3f-9e48-8f21e25430d2" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"==== Sample 0, class = plane ====\n", | |
"Original:\n" | |
] | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
], | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAS4ElEQVR4nO3dXWxd1ZUH8P+fYOcLCLFDTBIjCF9CphIBLL6pGFVUhBfoCyoPVSrQpA8gtVIfipiH8ohG01YIjSqlA2o6KlSVWkQEaKZMVBRVQgETAiQEJnwE7ODEwRbBIQRwsubBh8oEn7Uud99zzp3u/0+KbN/lfc6+23flfqyz96aZQUT+8Z3SdAdEpB5KdpFMKNlFMqFkF8mEkl0kE6fWebK+vj4bHBxsuz3JtmKdkHL8bu5bk9WYf+RKUHTfqrrvo6OjmJqamvcBkZTsJG8B8BCABQD+w8we9H5/cHAQW7ZsKY2fcor/QqO3t7c01tPT47ZNTTivb1G/o3jUt9T2ntQHZUr8xIkTbtvjx4+78WhcPNG5Ux8vMzMzbty7b1HfvDFdv359aazt0SK5AMC/A1gPYAjAnSSH2j2eiFQr5T37VQDeMrN3zOxzAH8AcFtnuiUinZaS7GsAjM75eay47StIbiQ5QnJkcnIy4XQikqLyT+PNbJOZDZvZcH9/f9WnE5ESKcm+H8A5c34eLG4TkS6UkuwvAriI5FqSvQC+D6D8o3YRaVTbpTczmyF5L4D/xmzp7VEz253SmSrLOKnlrRRVl3m89tGYRueO2qeUiaLSWuq5UzRZkky5317bpDq7mT0D4JmUY4hIPXS5rEgmlOwimVCyi2RCyS6SCSW7SCaU7CKZqHU+O+DXs6uc993kNNKobZXxqufSR+Pm1YQXLFjgtq2ylh31O3W+eXT86BoDjzdu3t9bz+wimVCyi2RCyS6SCSW7SCaU7CKZULKLZKKrSm8p0yUjVZagUss0qaW3lDGNpE6BTSkLNjnFNfXxktK3qGyn0puIuJTsIplQsotkQskukgklu0gmlOwimVCyi2Si1jo7SbdGmLobanTulGNXOcU19dxubTVxKmdUL075m1S9rXFK+yqX0AbSpnq3O+Z6ZhfJhJJdJBNKdpFMKNlFMqFkF8mEkl0kE0p2kUx0VZ09pabbzcs1p14/kFLjT62zN7kOQJN19tS58il/05S2Xiwp2UnuAzAN4DiAGTMbTjmeiFSnE8/s/2RmH3bgOCJSIb1nF8lEarIbgL+QfInkxvl+geRGkiMkRyYnJxNPJyLtSk32G8zsCgDrAdxD8tsn/4KZbTKzYTMb7u/vTzydiLQrKdnNbH/xdQLAEwCu6kSnRKTz2k52kktJnv7l9wC+C2BXpzomIp2V8mn8AIAnirreqQAeM7P/ihp5dfYq1yBPmXedeu7UWnWT20VHUurRVa5JD/h9S5lv3sq5o7iXB6lbWZdpO9nN7B0Al7XbXkTqpdKbSCaU7CKZULKLZELJLpIJJbtIJmqf4uqVNKKSQ8q0wEiTU1yrnp5bpSrPHR37+PHjbbeP/iZRSTF6rEbH7+3tLY1FpTXvfmvLZhFRsovkQskukgklu0gmlOwimVCyi2RCyS6SiVrr7IBff4zqpp6qa9Vev1OXiq5yOefUenGkyjp76rUT7dajW4lH4/bFF1+48Q8/LF+j9bTTTnPb9vT0uPEyemYXyYSSXSQTSnaRTCjZRTKhZBfJhJJdJBNKdpFM1F5n91Q5b7vKbZWPHj3qtj1w4IAb//jjj934ihUr3LhX843qvX19fW7cqwcDwAUXXODGvXH77LPP3LZvv/22G4/GZfny5W7cc+qpfmpEj5fovj311FOlsaGhIbft9ddfXxrTfHYRUbKL5ELJLpIJJbtIJpTsIplQsotkQskukona6+wp66+3e1wA+Pzzz914VCs/fPhwaWxqaspt++6777rxqJ68bNkyNz4wMFAaW7x4sdt2enrajUfz4fv7+924V8ePjv3KK6+48ajG7/VtZmbGbRv1LaqzR4+nQ4cOlcb27Nnjtl29enVpzHuch8/sJB8lOUFy15zb+kg+S3Jv8bX9qxdEpBatvIz/LYBbTrrtPgBbzewiAFuLn0Wki4XJbmbbAJz8OvU2AJuL7zcDuL3D/RKRDmv3A7oBMxsvvj8AoPRNI8mNJEdIjkxOTrZ5OhFJlfxpvM2ulli6YqKZbTKzYTMbjj7MEZHqtJvsB0muAoDi60TnuiQiVWg32bcA2FB8vwHAk53pjohUJayzk3wcwE0AVpAcA/BzAA8C+CPJuwG8B+COTnQmqm16tfSojj4x4b/4GBsbc+O7d+8ujUXz1Y8cOeLGo/nsS5cudeNezda7PgAAPv30Uzd+ySWXuPGnn37ajadcO7F///6k+Msvv1wai/YoSHksAnGdfXR0tDQWXXfhrTHgXfMRJruZ3VkS+k7UVkS6hy6XFcmEkl0kE0p2kUwo2UUyoWQXyUTtU1xTtmX2phVG5avnnnvOje/atcuNe2WeY8eOuW2jMk1U5lm4cKEb95YtTi0h7d27142PjIy4cW+6ZlT2i6aRRldkrly5sjQWbZO9ZMkSNz44OOjGI97fZXh42G3rLSX92GOPlcb0zC6SCSW7SCaU7CKZULKLZELJLpIJJbtIJpTsIpmotc4+PT2Nbdu2lcYvvvhit723LHJUZ3/jjTfc+Pbt2924N2Wxt7fXbdvT0+PGo1p3tH2wN4X2jDPOcNtG1wiMj4+78ei6CW/cvK2mgXi76UWLFrnx008/vTT25ptvum2j6xNuueXkNVi/KtqyeefOnaWxNWvWuG3PPPPM0pg3pnpmF8mEkl0kE0p2kUwo2UUyoWQXyYSSXSQTSnaRTNRaZ5+YmMDDDz9cGr/mmmvc9pdffnnb5/Zqk0Bc23z//fdLY1EtOqqjR/XmaF63V+s+77zz3LYpWwsD8ZbP3jUCZ511lts2Gpeolu1dexHNV7/wwgvd+PLl/sbF0dLmZ599dmnMe6xFvDUC9Mwukgklu0gmlOwimVCyi2RCyS6SCSW7SCaU7CKZqLXOfuzYMXcecbQ+urdVbbSt8cDAgBu/66673PiOHTtKYx988IHbtq+vz41HNdlorr23NbFXzwX8Od9AvH1wNOd89erVpbGoVh3Fb775Zje+du3a0lhUw4+uAYjmu0fXAFx55ZWlsWiNAW9ckuazk3yU5ATJXXNue4DkfpI7i3+3RscRkWa18jL+twDmW5bjV2a2rvj3TGe7JSKdFia7mW0DMFVDX0SkQikf0N1L8tXiZX7pmwiSG0mOkByJ3ueISHXaTfZfA7gAwDoA4wB+UfaLZrbJzIbNbDia0CEi1Wkr+8zsoJkdN7MTAH4D4KrOdktEOq2tZCe5as6P3wPg73csIo0L6+wkHwdwE4AVJMcA/BzATSTXATAA+wD8qJWTLVmyBOvWrSuNR+uve3N1ozXEo/24I95+3Ndee63b9txzz3XjExMTbnxsbMyNe3OzL730UrftzMyMG4/m+U9Ntf/ZbTRn/LLLLnPj1113nRv31hGIrg+I4tHjKVonwGsfjblbS3feKofJbmZ3znPzI1E7Eeku+sRMJBNKdpFMKNlFMqFkF8mEkl0kE7VOcV20aBGGhoZK49GSy950zGhr4qisFy2J7JW3Jicn3bbRdMdoO+kofsUVV5TGotJbVLLs7+9349EU2Iceeqg0FpUkzz//fDcejevo6GhpbN++fW7bqCQZbcMdXRruLf8dTe1dtmxZacybLq1ndpFMKNlFMqFkF8mEkl0kE0p2kUwo2UUyoWQXyUStdfYFCxa4Sz57tccoHtXJo1VyopqtV1c9ePBg222BuOYbLUXt1flfeOEFt623pTIQL7kcbensTUuOtrp+/vnn3Xh0Xcbhw4dLY9GYR3XyaHpudH2C93iMHstend09Z1utROT/HSW7SCaU7CKZULKLZELJLpIJJbtIJpTsIpmotc5OEosXL3bjHq/OHs0/jpYG9mqyAHDkyJG2zx2J5uJfffXVbtyrhUdjGo1LtCSy9/cEgBtvvLE0Fq0xEG2FHbX/6KOPSmPR/Y6uP4iurYi2Xfbq8N7aCYC/9Lh3v/TMLpIJJbtIJpTsIplQsotkQskukgklu0gmlOwimai1zg74NelozrlXM45qrtG87Ii3vnrKPHwgbY3xqH107KgOH83F99byB/x6depa/9F9W7lyZWnMm2cPxHXy6PqDTz75xI179fBoLX9vTJPWjSd5Dsm/knyd5G6SPy5u7yP5LMm9xVd/ZXsRaVQrL+NnAPzUzIYAXAPgHpJDAO4DsNXMLgKwtfhZRLpUmOxmNm5mO4rvpwHsAbAGwG0ANhe/thnA7VV1UkTSfaMP6EieB+ByANsBDJjZl4uIHQAwUNJmI8kRkiPR+xgRqU7LyU7yNAB/AvATM/t4bszMDIDN187MNpnZsJkNe4tNiki1Wkp2kj2YTfTfm9mfi5sPklxVxFcBKJ+KIyKNC0tvnK3NPAJgj5n9ck5oC4ANAB4svj4ZHevEiRNuaSDilaCiKYlRGScqMXntoymuUXkrEh3fK0Gl9m32RVv7ca9v3hTUVkTn9kRluygeTUONlv/2pgZHJUnvfnuP41bq7NcD+AGA10juLG67H7NJ/keSdwN4D8AdLRxLRBoSJruZ/Q1A2X//3+lsd0SkKrpcViQTSnaRTCjZRTKhZBfJhJJdJBO1T3GNpmt6vCmwUc01qu+n1pM90fTaaGpv1N6rCUfXD6ROz43iXp0/qmVHUv5mUdto3KL4woUL3bi3RXi0hLY35lpKWkSU7CK5ULKLZELJLpIJJbtIJpTsIplQsotkotY6u5m5ddeU5Z6jumnKMtWR1DnhqVLqyVE8um/ROgKeqEafOm5eHT96PERbOkfXbUxPT7vxlMdbtDZDGT2zi2RCyS6SCSW7SCaU7CKZULKLZELJLpIJJbtIJmqvs3u11ZSab1TvjeYfRzV+Lx7VTFPrySntU2vVqVtde1KvAUi5tiK6X9HjKdrSOeWxXNV1GXpmF8mEkl0kE0p2kUwo2UUyoWQXyYSSXSQTSnaRTLSyP/s5AH4HYACAAdhkZg+RfADAPwM4VPzq/Wb2THAst/aZMuc8qnWnzjlPmRuduj97xDt+6j7kUY0/Ze331Ln00bh7x4/6HZ07ZS3/6PhRjd9bk949rnvUWTMAfmpmO0ieDuAlks8WsV+Z2b+1cAwRaVgr+7OPAxgvvp8muQfAmqo7JiKd9Y3es5M8D8DlALYXN91L8lWSj5JcXtJmI8kRkiNHjx5N6qyItK/lZCd5GoA/AfiJmX0M4NcALgCwDrPP/L+Yr52ZbTKzYTMbXrJkSQe6LCLtaCnZSfZgNtF/b2Z/BgAzO2hmx83sBIDfALiqum6KSKow2Tn78d4jAPaY2S/n3L5qzq99D8CuzndPRDqllU/jrwfwAwCvkdxZ3HY/gDtJrsNsOW4fgB+1ckKvNBCVUlKkLjVdZektZRvrVo7vSSk5ttK+qrZAWrk1dcyj8lgU9+57VXnQyqfxfwMw36i5NXUR6S66gk4kE0p2kUwo2UUyoWQXyYSSXSQTSnaRTNS6lDTg1xCjrWhTtgeOapdR3Dt3So2+FSn16GgJ7ajvVV4DEB079dwp06lTjg3EU2C9rcuj+93uVG89s4tkQskukgklu0gmlOwimVCyi2RCyS6SCSW7SCZY1faw856MPATgvTk3rQDwYW0d+Ga6tW/d2i9AfWtXJ/t2rpmdNV+g1mT/2snJETMbbqwDjm7tW7f2C1Df2lVX3/QyXiQTSnaRTDSd7JsaPr+nW/vWrf0C1Ld21dK3Rt+zi0h9mn5mF5GaKNlFMtFIspO8heSbJN8ieV8TfShDch/J10juJDnScF8eJTlBctec2/pIPktyb/F13j32GurbAyT3F2O3k+StDfXtHJJ/Jfk6yd0kf1zc3ujYOf2qZdxqf89OcgGA/wVwM4AxAC8CuNPMXq+1IyVI7gMwbGaNX4BB8tsAjgD4nZl9q7jtXwFMmdmDxX+Uy83sZ13StwcAHGl6G+9it6JVc7cZB3A7gB+iwbFz+nUHahi3Jp7ZrwLwlpm9Y2afA/gDgNsa6EfXM7NtAKZOuvk2AJuL7zdj9sFSu5K+dQUzGzezHcX30wC+3Ga80bFz+lWLJpJ9DYDROT+Pobv2ezcAfyH5EsmNTXdmHgNmNl58fwDAQJOdmUe4jXedTtpmvGvGrp3tz1PpA7qvu8HMrgCwHsA9xcvVrmSz78G6qXba0jbedZlnm/G/a3Ls2t3+PFUTyb4fwDlzfh4sbusKZra/+DoB4Al031bUB7/cQbf4OtFwf/6um7bxnm+bcXTB2DW5/XkTyf4igItIriXZC+D7ALY00I+vIbm0+OAEJJcC+C66byvqLQA2FN9vAPBkg335im7Zxrtsm3E0PHaNb39uZrX/A3ArZj+RfxvAvzTRh5J+nQ/gleLf7qb7BuBxzL6s+wKzn23cDaAfwFYAewH8D4C+LurbfwJ4DcCrmE2sVQ317QbMvkR/FcDO4t+tTY+d069axk2Xy4pkQh/QiWRCyS6SCSW7SCaU7CKZULKLZELJLpIJJbtIJv4Pw1VAg23eYPsAAAAASUVORK5CYII=\n" | |
}, | |
"metadata": { | |
"needs_background": "light" | |
} | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\n", | |
"ASCII: \n", | |
"::``````````````````````\n", | |
"::``````````````````````\n", | |
"::``````````````````````\n", | |
"````````````````````````\n", | |
"````````````````````````\n", | |
"````````````````````````\n", | |
"````````````````````````\n", | |
"`````````````````````::`\n", | |
"``````:-=::::`':-```'-=`\n", | |
"````:-=+*+++=::-=::--=*=\n", | |
"```-+*###*****####*****-\n", | |
"-:`=#%%%%###%%@@%#%#=*=:\n", | |
"#+-=*#@@%@%%@@@%%*#+-=+-\n", | |
"%#*+-=+=+@%#@@##=--:----\n", | |
"**++==-:+#+==*#+::---=--\n", | |
"+++=====#%*+=*%===**+++=\n", | |
"++++++*+*****##***#**++=\n", | |
"+++++++++++++++++++****+\n", | |
"+++++=+++++++++++++*++++\n", | |
"++++++++++++++++++++++++\n", | |
"++++++++++++++++++++****\n", | |
"++++++=+++++++++++++***+\n", | |
"++++++==++++++++++++++++\n", | |
"+++++===++++++++++++++++\n", | |
"\n", | |
"\n", | |
"==== Sample 1, class = car ====\n", | |
"Original:\n" | |
] | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
], | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAaB0lEQVR4nO2dXWxcZ5nH/8+cmfF4xo5jx4mbzza03bSltCkKpYguCwtUbfeicFPRC9SV0JYLKoHExSL2Ai7RagFxsUIKS0VZ8SFWgOguaGm3y1JAK2hoQ5s2pWkT58NN/BU79tjzPc9eZIrSNu//GNuZsfb9/yTL9jx+z3nnnPOfM57/+zyPuTuEEP//yfR6AkKI7iCxCxEJErsQkSCxCxEJErsQkZDt5s76sjnv7+sLxjPWpuPLlVow5uCuwshgP40PDg7R+PxyIxibm5+nY5OM8Xg2oXEzPr7Vaq16bBpp4zMpz63VCp+XdpuPNYSvFQBopYzvz4X3XUy58muZPI03kiLfQLNJw0kmHDfwsRkPX4vLlWXU6vXLHpg1id3M7gbwNQAJgH9x9y+xv+/v68P7b3pHMJ5Plun+fvPiK8FYs83F/rE7b6bxD/z1PTT+74cmgrF/++m/07GD/VzMw1s203iSy9H44sJCMJbN8rFIOW65HL/oCyUeL5frwVhlic8t8b00PlfhLwa3XhUWxa2j/HmPl3bT+Lmh/TTemp2l8eG+cLwvM0PHFhqTwdh//+aXwdiq38abWQLgnwHcA+AmAA+Y2U2r3Z4Q4sqylv/Zbwfwirsfd/c6gO8DuG99piWEWG/WIvadAE5f8vuZzmNvwMweMrNDZnao3gy/rRJCXFmu+Kfx7n7Q3Q+4+4F82v+PQogrxlrEPgHg0k8xdnUeE0JsQNYi9qcBXG9me80sD+BjAB5bn2kJIdabVVtv7t40s4cB/BwXrbdH3P2FtHEtC9tQJ2aW6NhyI+yr3vC2t3xc8AbGdnHr7fCRsJ0BAMdPvhiM5fPcxikvXqDxSo1bjq02X39Qq4XtrWyWn+JGg3u6aWTy/H7RXwyvX3j7vgN07NbCGI0//ew4jZ+eKQdj1vduOvaV2i00PjXFj9uQDdD4TD58Xkol/u/unrFtwVg7+7tgbE0+u7v/DMDP1rINIUR30HJZISJBYhciEiR2ISJBYhciEiR2ISJBYhciErqaz962BLVsKRg/U+bTGdn1nmBs7/X76Nhm9ioaP3H6MI1PTYd9+KyFnxMA9JW4J8vy9AFgOSUOD/vwbZLrDgDtNVYXbtVScvWT8NwbrQode811IzRebvD1Cc8fCa/pOLU8SMdWh8JeNgBYyvKE5VbY4weABqm/sIRNdGyhEH5ejUwhGNOdXYhIkNiFiASJXYhIkNiFiASJXYhIkNiFiISuWm+ZTILCQDjlsa9/mI6/+R3vC8YaTV7l9KVTPM10qcbTayuVcEmtUp5bJYNF/pq6ZZhXn23yDFdqr6WVqW4T2w4AqvVw+iwAwPhzqzXCczt98lU6dmInt792XL2FxqenwtdTNc9tvaEhXno8n+Wlx70SrvgLADj322BozMfp0NJ0OJ07aYTPl+7sQkSCxC5EJEjsQkSCxC5EJEjsQkSCxC5EJEjsQkRCV332XDbBjs1hf/I0wuWaAcDLJ4OxhTr3upHw1lMpNjvaHj5UI8PhtEIAuGpwB43nCmktnflrcqMRfm6GlJbKxlNcK6RMNQBYhl9CC0vhVM+J13hPkaNH+PWw77obaHzLSLiccznP1x9cyPDU4IalpB1nuM++tRROmb5rH7+Wl0lp8ZfJ6dCdXYhIkNiFiASJXYhIkNiFiASJXYhIkNiFiASJXYhI6KrPvqlUxAfuCLfp7atP0/GvzYbzn/tL3HMt7eHtgSdnee50fy68/cIIz43O9fHcaDR4rv3keV4yeZq0uk7aVTo2A+6jD43ytsnZlOe2MP9aMJYUttOxyzXudY+/zOs5Z0jJ5aUcL2Odab5E48Ucv0/WKzM0nrO5YGzrLt5O+sJS+HpgazLWJHYzGwewCKAFoOnuXFFCiJ6xHnf2D7g7fxkTQvQc/c8uRCSsVewO4HEz+72ZPXS5PzCzh8zskJkdYuukhRBXlrW+jb/T3SfMbBuAJ8zsJXd/6tI/cPeDAA4CwLW7r15bYzEhxKpZ053d3Sc636cA/BjA7esxKSHE+rNqsZtZycwGX/8ZwF0AjqzXxIQQ68ta3saPAfixmb2+ne+6+3/SnWUTjI6E89n3v+PtdIfbZ8If+r98nvvkW/Z/mMY3D/FD8drz4fzls1X+mjmf0r63vcQ93/MNHl+qhn3XfHORjt3azxP5k2He2vjUazxve2F5TzBWSHit/3yB++wZ5znlzUYuGGuBH5eH7+ctwHfu3kzjzz37Mo3PnCGtrOt8bcTZM8fDY0nd+FWL3d2PA7h1teOFEN1F1psQkSCxCxEJErsQkSCxCxEJErsQkdDVFFcYkITdEOQH++jwrQPhNNPDVd7u+cUJ/lRbZW4xTRW2BmMXWrztcbPKrbPEeJrp6C5uxaAVTg32+XA5ZQCog1tv5xb5vpfa3D4bGAufl9YyL+/dbvES3a2U9ZitbHj7tQZPj222eUvmv7n7Nh7/0I00vjx/dzA2Pcnzyl47fTYYe+p3jwdjurMLEQkSuxCRILELEQkSuxCRILELEQkSuxCRILELEQld9dmTJMEmkuI6Nc/LEjdz1wRjZSvRsSdOcVO2lnA/uVkMH6p6mx/Gvgb3i3cN87llZnlZ48rUoWCsUeYef73J1wh4m3vhuRz3q3PZ8PqFjPG2yfXMTho/1+DnvN0Kn5fl+hY69tEfnabx627g4+/5EC+TPZQNt2Ue2sqf17U3XheMDQyF11Xozi5EJEjsQkSCxC5EJEjsQkSCxC5EJEjsQkSCxC5EJHTXZ89mMTQcbm+czfPyvBfKYR9+Ys7o2KU8zxnPF0iiPQDUw3nfLW5Fw5opr6lzvFX13IlxGl88fz68b0/JpQf3urPOj2t9mZeSnquHy1wPlriPjhLPlV8Ejzc8/NySIl/bcGyWr7v47n+E24cDwE03cR9+745w7YY2+NqFViZ8wbmFn5fu7EJEgsQuRCRI7EJEgsQuRCRI7EJEgsQuRCRI7EJEQld99kwmi9Jg2H+sVXjd+BdfmAzGFiu8ZXNugPuqRlrdAkCCsO9aSvhhbC5yL3t8mrce9sa1NJ6MhHOnt/TxGuRDA9xHLxb4OenLzfN4PvzcswlfV9ECr3mPSe6zz86Hc/Wv2lykY5da4fUBAHD8DK8D8Ozz/LhfvSP83Awpaz4ybO3EGnx2M3vEzKbM7Mglj42Y2RNmdqzznXdoEEL0nJW8jf8WgDe3r/gcgCfd/XoAT3Z+F0JsYFLF7u5PAXjzesz7ADza+flRAB9Z53kJIdaZ1X5AN+burzecOgdgLPSHZvaQmR0ys0Ozc+E13EKIK8uaP413dwf5VMDdD7r7AXc/sIUkwQghriyrFfukmW0HgM73qfWbkhDiSrBasT8G4MHOzw8C+Mn6TEcIcaVI9dnN7HsA3g9g1MzOAPgCgC8B+IGZfQLASQD3r2Rn7o5qLexnz83zHOLZC2EPcduWlPzhq7lnO3Oae93zpJf47r/g2/7jBd4DfbbK85d3bSvT+P5rw8d0xxB/Pc8lLRpfKi/S+K5t3IcvDoRroM+V+XF7dYL7zWdq3GcfJE+t4LwIgeW4m3xulp/Tx395gsZvvzW8xmDXGP93NyF5+obwuolUsbv7A4HQB9PGCiE2DlouK0QkSOxCRILELkQkSOxCRILELkQkdDXFtdFs4tw0KXuc51ZLoRROK6xOX6Bjc9M8FdPr3IJqLIa3X53hr5n5Ok+f/csb+TLihz8+TuM5Yt09/iveBjspjtJ43xK3/aoNfgmdng7v/8g5nvo7MclbFy/U+XPbuTs8t8rEOB1brYdbTQPAQoXbfk/8z1Eav+s9Ydtx5z2307EZJzohpb91ZxciEiR2ISJBYhciEiR2ISJBYhciEiR2ISJBYhciErrqsxcK/dh3w83BeBa8tPCxE78OxvqrJ+nY8RPcF51vpbX/DZfvXT7BXzOzDV62ePcYL+f89r18DcHyTNhnr2X43F6d5Km9uwa4Dz8yxrdfbIafe/8QL+99c0rq8MsT/LieGQ+nTJeXeGpurcJTe41nJaNaOUfjp46Gr7fGB2+lY3Mp61FC6M4uRCRI7EJEgsQuRCRI7EJEgsQuRCRI7EJEgsQuRCR01WdPsgmGt4a99Lctcd/1lj3h/OXz0y/RsYsXuG+6UNlN4y2ES1U3izwvu53wfPaf/4bHTxzbSePbx8L7r/bvoWOnwcsWz1ZTagw0ec75vj1hP3m4zdsezy7wcs8zZT7+j+XwGoL5BV6/YKyYsn4ge4rGrx7mdQAGSEvoxhK/HrJ9fI1ACN3ZhYgEiV2ISJDYhYgEiV2ISJDYhYgEiV2ISJDYhYiErvrscIeTGuqW8CTh3TvCtbxvu5DyujV+moYzJ6do/PhSuIVvvniAjq01N9F4OcWrfqHCvfKXF8Oecb7KT3HNuM9ez3A/+rVnuBc+dJScb+d+cb3KW3hPTy7QeLUWviZym3n9gmwySeOl2jM0vrnEr6f58+FreX6We/SlYX49hUi9s5vZI2Y2ZWZHLnnsi2Y2YWaHO1/3rmrvQoiusZK38d8CcPdlHv+qu+/vfP1sfaclhFhvUsXu7k8B4P2JhBAbnrV8QPewmT3XeZsf/IfWzB4ys0Nmdmh6ZmYNuxNCrIXViv3rAK4FsB/AWQBfDv2hux909wPufmDrKC9eKIS4cqxK7O4+6e4td28D+AYA3nZSCNFzViV2M9t+ya8fBXAk9LdCiI1Bqs9uZt8D8H4Ao2Z2BsAXALzfzPYDcADjAD65kp15u41GJZzHO7/Ic86PvHImGJs4O03HDjj3i/f28/7tC+fDtds9xUevNd9J4yjx/OXSCP+sY3gwfBr7m+F69wAwOc394JZxz7edsv25cvi5tVJ8dBjf9mCTjx/y8HFx4/3X8+VZGu8z3qdgZp4f1989F66PcOeZOTp2597tNB4iVezu/sBlHv7mqvYmhOgZWi4rRCRI7EJEgsQuRCRI7EJEgsQuRCR0NcXVW200FsPlfWdSUvteeCVszR0/tUTHmvMU16TJ0yWHEbZ5Zmef5tsu8NbDmzfxtsmlk3+g8daF8NyWMvy4tBZ4O+gB8HLNWzbxNtuNevi4Li9xa63Qz9NQNyX8eikT26+SDbcOBwBv8eeFIrdbqwlP3z17IVwW/cJiiiXZIlath8ux684uRCRI7EJEgsQuRCRI7EJEgsQuRCRI7EJEgsQuRCR01Wc3AxLy8pJt8TTUeWLLHpst0rHFHI8PgO+71gz7zUszx+hYFLnP3p/wls8Dy7/l4y18UJNB/rwHSzytuLnIy3uPZLgP38iHU5rnqnwNQKbF20VXqzz1t7IULnOd2XQtHZsM8vLdlfxWGr8wy1Nktw2Ez1kxrSNzi5Tvls8uhJDYhYgEiV2ISJDYhYgEiV2ISJDYhYgEiV2ISOhuy2YzJPmwpzxQ5L5qf8GCsfOL3O+tD9xK462+FE83G253V+sPe5sAYEXuybac59LPkXx1AJivhcf3l7nH327zbbe5zY76NC+DXW+QnPIK33crJZc+yfHJbdmyLxhrZnfSsUvg+ez1Cn/e5fnw+gIAuG5rIRgbHQrnul/cObneSEh3diEiQWIXIhIkdiEiQWIXIhIkdiEiQWIXIhIkdiEiobs+O4AWybctlcLeIwCMDpNE39prdGy5yVsyN/qvpvFqbkcw1hrdTccOX8XjpeqzND6Uv57GR0jKeovk4QPAmTPjNN4u8vG5PK/tPpAbCsayrLgBgEx/eCwA3HDje2m8XQ/npP/qaX69zM2eo/GkzvPVM7VJGt++7bpgbCilBoGzNgNr8dnNbLeZ/cLMXjSzF8zs053HR8zsCTM71vk+nLYtIUTvWMnb+CaAz7r7TQDuAPApM7sJwOcAPOnu1wN4svO7EGKDkip2dz/r7s90fl4EcBTATgD3AXi082ePAvjIlZqkEGLt/Fkf0JnZNQBuA/BbAGPufrYTOgdgLDDmITM7ZGaHps/z/3OEEFeOFYvdzAYA/BDAZ9zfmLnh7o7ARwPuftDdD7j7ga0jW9Y0WSHE6lmR2M0sh4tC/467/6jz8KSZbe/EtwOYujJTFEKsB6nWm5kZgG8COOruX7kk9BiABwF8qfP9J2nbcne0WmErp1jiqX3bx8Jph6UCTymcXT5D4/UWTzNtJ+E0VW9yG2ehOkHjtZQS2sWd76Lx0dtuCsaaKeWYT2SP0ngzy621dkoZ7MKmUjA2MMDP9/npcFoxAJyt8tTh8mz4nE7NnqRj61XeDrrovAz20GA4HRsAtm8bDG+7xGtJW5ad0/B+V+KzvxfAxwE8b2aHO499HhdF/gMz+wSAkwDuX8G2hBA9IlXs7v5rhF8uPri+0xFCXCm0XFaISJDYhYgEiV2ISJDYhYgEiV2ISOh6imuGVcF17k0W+sLTHRkI+7kAUF9keYFAq07a4AJoZMNeerPJPfza8im+7RYvW3yyso3Gy82wF94Ef161Cj8uSUoKK+kWDQA4NxNOFZ1GSp1q5xt/9fnjNF5fDrejzvgcHZsYLxWdK3AvvLSJn9NGK5zO3UrRAVibbFPLZiGiR2IXIhIkdiEiQWIXIhIkdiEiQWIXIhIkdiEioas+u8GQtMMeord5XnelGm7xm2R5GepCkfvwtQYvmdwg+csZ8JzxbIrXncEMjVem+fgT58eDsS3D3Mveyjs6o1jg7aj7stwTzpK2y0MD/Jxli7yU9EtLfI3AqflKMFZNudZyeV4subBpE41XW/y4WzY8PpPj1xOMXQ/y2YWIHoldiEiQ2IWIBIldiEiQ2IWIBIldiEiQ2IWIhO7ms7vD22H/Mcly7zPbF86trjW5j44Sf10rJnzfrcWwX1yvcB+8tcxzp5vgfjHavN00muH1ByVaYxzYtZn7ycUsv0SShPvsrUY4L7yQcA+/1eDxhXnel6RcDeecZwo76dhC/3YaTzycKw8AQyWeD7//lmuCsYG0ls2sKISFz4fu7EJEgsQuRCRI7EJEgsQuRCRI7EJEgsQuRCRI7EJEwkr6s+8G8G0AY7iYLHvQ3b9mZl8E8HcApjt/+nl3/xnblruj0Qz77JkU33X7ti3B2OZBXqd7aZr3EU8y/FAU+8KJ32k530tl3me8UeN+MTI8N7rZCPv0Zry++eAQ99n7Sb9vAMj2cR9/cTlcB6CV8LF15+c0k0/pDV8Kzz1fSOlRkPBa/6NFfs4/eOfNNH7Hu/YEYzme5o92g92j19afvQngs+7+jJkNAvi9mT3RiX3V3f9pBdsQQvSYlfRnPwvgbOfnRTM7CoAvPxJCbDj+rP/ZzewaALcB+G3noYfN7Dkze8TMLvt+0MweMrNDZnZoZo6/nRVCXDlWLHYzGwDwQwCfcfcFAF8HcC2A/bh45//y5ca5+0F3P+DuB0aHR9ZhykKI1bAisZtZDheF/h13/xEAuPuku7fcvQ3gGwBuv3LTFEKslVSxm5kB+CaAo+7+lUsevzQt6KMAjqz/9IQQ68VKPo1/L4CPA3jezA53Hvs8gAfMbD8u2nHjAD6ZuiUzwMJ2S6bFp3Pj1eG0w3ffwtsa//QX4ZbLALBwIVx2GAAySXhu/QO8HvPIyCiNNxtbabyQ5/ZZQtJzRzbzsct1/nrfdJ5+W0j6aXxxOWwFVap829UGT+31ZtjWA4ASSR3enOPW275rwjYvALzrlhto/I7bb6TxHVsHSZTPrU3sa/ewJbiST+N/Hdg79dSFEBsLraATIhIkdiEiQWIXIhIkdiEiQWIXIhIkdiEioaulpL0NNGrhkszZDE95HN0Uzv2756+479luT9L4/z79exqfnJoNxmopS/6z+ZQy1xnemrjZSGkJXQifxoVZXiJ7fnKZxq3NSyIPDvDWxQuLF4Kx5WW+776UVM/+fLi0OADs2RZee3HDNXvp2P37rqPxa/ZcReNXDfPU4cTD57S5yFOa67XwOfW2WjYLET0SuxCRILELEQkSuxCRILELEQkSuxCRILELEQnG8l/XfWdm0wBOXvLQKICZrk3gz2Ojzm2jzgvQ3FbLes7tane/bIGEror9LTs3O+TuB3o2AcJGndtGnRegua2Wbs1Nb+OFiASJXYhI6LXYD/Z4/4yNOreNOi9Ac1stXZlbT/9nF0J0j17f2YUQXUJiFyISeiJ2M7vbzP5oZq+Y2ed6MYcQZjZuZs+b2WEzO9TjuTxiZlNmduSSx0bM7AkzO9b5zhOnuzu3L5rZROfYHTaze3s0t91m9gsze9HMXjCzT3ce7+mxI/PqynHr+v/sZpYAeBnAhwGcAfA0gAfc/cWuTiSAmY0DOODuPV+AYWbvA1AG8G13v7nz2D8COO/uX+q8UA67+99vkLl9EUC51228O92Ktl/aZhzARwD8LXp47Mi87kcXjlsv7uy3A3jF3Y+7ex3A9wHc14N5bHjc/SkAb66Dcx+ARzs/P4qLF0vXCcxtQ+DuZ939mc7PiwBebzPe02NH5tUVeiH2nQBOX/L7GWysfu8O4HEz+72ZPdTryVyGMXc/2/n5HICxXk7mMqS28e4mb2ozvmGO3Wran68VfUD3Vu5093cCuAfApzpvVzckfvF/sI3kna6ojXe3uEyb8T/Ry2O32vbna6UXYp8AsPuS33d1HtsQuPtE5/sUgB9j47Winny9g27n+1SP5/MnNlIb78u1GccGOHa9bH/eC7E/DeB6M9trZnkAHwPwWA/m8RbMrNT54ARmVgJwFzZeK+rHADzY+flBAD/p4VzewEZp4x1qM44eH7uetz93965/AbgXFz+RfxXAP/RiDoF5vQ3AHzpfL/R6bgC+h4tv6xq4+NnGJwBsAfAkgGMA/gvAyAaa278CeB7Ac7gorO09mtuduPgW/TkAhztf9/b62JF5deW4abmsEJGgD+iEiASJXYhIkNiFiASJXYhIkNiFiASJXYhIkNiFiIT/A8EtMvZ2MEM6AAAAAElFTkSuQmCC\n" | |
}, | |
"metadata": { | |
"needs_background": "light" | |
} | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\n", | |
"ASCII: \n", | |
"#%@%+*@@%%%%%%%%@%######\n", | |
"#@@**#@@@@@@@@%*%@%#%##*\n", | |
"#%#**%@@@@@@@@@#*%%###%#\n", | |
"#%*##%%%####%%%%*######*\n", | |
"+##*#%####%%##%%%####%*+\n", | |
"=+*+=***##%##%%%%%#==*+*\n", | |
"==+*+****######%###-``:-\n", | |
"-=+******%%%%%%%%%%#:'``\n", | |
":=*##*###%%%%%%##%##=```\n", | |
":+#**#*###***+++*###*:'`\n", | |
"-+***#*#*-+++++*####%+``\n", | |
":=*#####=:+*+==++*#%%*``\n", | |
"`:+##%###****=-=+#%##*:`\n", | |
"`:=*#%###*******#####*-`\n", | |
"`-*##%########%%%%%%##=`\n", | |
"`=#%%%%%#%%%%#%%%%%%#%=`\n", | |
"`-%%%%%%%%%%%%%@@%%#%%=`\n", | |
"`-#@%%%####%%%%%**%%%%-`\n", | |
"`:#@@@%%%*+###*#*#%@%#:`\n", | |
"`:*@@@@@@%%%####%%%@%#-`\n", | |
"`:*%%@@@@@%%#####@@%%#:`\n", | |
"::*%%%@@@@@#####%@%%#=``\n", | |
"::=%%%%%%%%##*##%%%#+:``\n", | |
"::-+%@@%%%%%%%%%%%*+-:``\n", | |
"\n", | |
"\n", | |
"==== Sample 2, class = bird ====\n", | |
"Original:\n" | |
] | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
], | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAXRElEQVR4nO2dXWycZ5mG72dmPGOP/xLHSWqSNCVtWpoCW1hvhQTaBaFFpSeFk4oeoK5UbTgACSQOFnUP6GG1WkAcrJDCUlFWLAgJED2oWLoVqwpphepWoU0b2rRp0iQ4dhKn8fh3fr5nDzxdmeL3fo3nd/e9L8nyeB5/3/d+P/d8M3O/z/OYu0MI8f+fXK8HIIToDhK7EIkgsQuRCBK7EIkgsQuRCIWubqxU9GJ5MBiP+gLsH8x2MqS2YB3fdgfXH3NjendYEXWKInEjcUPGl43cBqOHLbICIxdzFll3hnwwVltdRX29uuVZa0nsZnYvgG8DyAP4V3d/jP1/sTyI2z7xkWA8fnLDoYaFDwAAWK61NzFM0Pk833Yu9mIQE5TzsbuHVxA7pp7FLno+ODO+/hwJZ5GX93qjxrddr0bijWCsmK3SZQfLfGyrdR4v5YdpfMDDY1/iu41qYSwYe/2//jsY27ECzCwP4F8AfBrAMQAPmtmxna5PCNFZWrnd3QPgdXc/6+5VAD8GcH97hiWEaDetiP0AgAub/r7YfO6PMLPjZjZjZjP19cj7EyFEx+j4t/HufsLdp919ulAa6PTmhBABWhH7JQCHNv19sPmcEKIPaUXszwE4ambvNbMigM8BeLI9wxJCtJsdW2/uXjezLwH4D2xYb4+7+8t8KeNWDneBQBwm5CK+Zi7Pd7UVqzwXsfViPnw8zq09dmCyiGnrOX7Q88w7A5BFjntG7ieeRfY74/udxW5VhbC9tlbjHymrq+s07nUe37V7iMYPHZgKxi5eWaLLzi+Gd5wd0ZZ8dnd/CsBTraxDCNEdNF1WiESQ2IVIBIldiESQ2IVIBIldiESQ2IVIhK7mswMOJ76vRXI9mR+dRY3ySO5zzKcnXnqr+eyt+vDMXY1k38Ij6bNm4TTRjRVwn55FI9mxsOjYinzbhXAaaYPkhANAoRGRRiS9trZaofE8xsOxyIEZINcq05Du7EIkgsQuRCJI7EIkgsQuRCJI7EIkgsQuRCJ02XozWmk1WkqaOVCRKqit0knrLdpbs6XemzFbL7LpSLwYsfZKpIqqOy9TVo/YY3Xn8QY5ZxFXD3CeAjtY3kXjpcHIvtXC8YFYtWKWlkxsO93ZhUgEiV2IRJDYhUgEiV2IRJDYhUgEiV2IRJDYhUiE7vrsFummGuko6ix9L5om2trrGuuGGu2UGlt3ZL9zsVLSLB4z8WOtiSOjHx8p0fj+kXBJZa+v0WXnrvI00qsLdRrn8x8idctj0sjxUtHLa3zfrl25EYw1CpEOsIXw+Wa7rDu7EIkgsQuRCBK7EIkgsQuRCBK7EIkgsQuRCBK7EInQ5Xx2wFkL38iyRpbNIonXHvHh85GSyFhbDoaKkXLLFtuzXCQvGzy3uuHEb47ML7BITjjy/LhUw12RN9Y/PhKM3bJ/gi6ba1yj8WvXeblmPi+DLkqnLgBAFjkn61V+zpdXwm2ZM+PHvFgKH1N2rbUkdjM7B6ACoAGg7u7TraxPCNE52nFn/4S7X23DeoQQHUSf2YVIhFbF7gB+ZWbPm9nxrf7BzI6b2YyZzdTX+VxnIUTnaPVt/Mfc/ZKZ7QPwtJn93t2f3fwP7n4CwAkAKE+Mt1Q6UQixc1q6s7v7pebveQA/B3BPOwYlhGg/Oxa7mQ2b2eg7jwF8CsCpdg1MCNFeWnkbvx/Az5s5wwUA/+7uv6RLuPFC5NEWveF4LlYfPRaPGMaT5fC2b9kbqSGei3x6KfCc8BtLfGw3KovBmBk/xZbxbS9V+fcsy+HpBwCAC7Ph5adG99Blx4d5S+bcAPeja0a8cFZXAYBF5m1YxudWlEcHaXx8OHzc18MWPABgYCA8CSBPhr1jsbv7WQB/sdPlhRDdRdabEIkgsQuRCBK7EIkgsQuRCBK7EInQ1RRXMyBP0jmzSHnfHMlLtIi9VVsL21MAsDvPLaaDY+Phba9y/+nQzXtpvJHxsY+scett71C4/e/aCl92bKxM4/nyJI2/+NZ1Gq8shUsqX77Ej/ngMLc0LdIvOsvCl3cuVlq8HjknZb7tAzftp/HKQvi47d4zRpedGA+fs2IxbDfqzi5EIkjsQiSCxC5EIkjsQiSCxC5EIkjsQiSCxC5EInS5lLQhR1ILY62Pqc/OyikDGB3gr2u37Qv76AAw6GFPeHkp3H4XAErZKI1XSIoqAFQun6fxAVJ6eDDPUy3Xrlyi8cM3T9H4ZI3v+8qNcHy+yi+/w8d4qenhIZ6eu7zC+hfzZQ3rNL5rlI991xg/7hfOhde/1OD1Wxtr4W3Xa+E5F7qzC5EIErsQiSCxC5EIErsQiSCxC5EIErsQiSCxC5EIPWjZ3BmYfw8AExO8bPFgiedWjw2E84THh3j73uvzf6DxyuLbNB5pJg0rhfO+iyWer75yY4HGhzzs2wLA0Um+/gapi7ywxr3sbInPP5ga4fMXFlfDufSNSD57PnIbLOR5KWmv8RoHrGX0YHmYL1siufREB7qzC5EIErsQiSCxC5EIErsQiSCxC5EIErsQiSCxC5EIXfbZHYjkrDNY22XPca/78kKFxq8uX6Hxv7zjcDB2ZB/38C++coHGVyJ+cq44RONGWviWIp5trcb7AzeYIQxg326+/vJo+LhdJLnuALBc4TXpxwZ57fZdFr7WFiL56lbgsxvGh7nHX7Kwxw8AI4Nhnz6mkLm3w3Mf6qTeffTObmaPm9m8mZ3a9NyEmT1tZmeav3fH1iOE6C3beRv/fQD3vuu5rwF4xt2PAnim+bcQoo+Jit3dnwXw7jmV9wN4ovn4CQCfafO4hBBtZqdf0O1399nm48sAgo2tzOy4mc2Y2Ux9nc8/F0J0jpa/jfeNKpHBbwXc/YS7T7v7dKFUbHVzQogdslOxz5nZFAA0f8+3b0hCiE6wU7E/CeCh5uOHAPyiPcMRQnSKqM9uZj8C8HEAk2Z2EcDXATwG4Cdm9jCA8wAe6OQg34G2MXfuB2OAe9X1SB/yN6+G85OnJnnN+ZFx3me8ts5zn0uDfN/KA2HPNlvnufJjo3x+wtwi98JR4f3fB0rh+unjZZ4LP+z8XnR9jR+3MYTP+WLGffB9kXM6PjxC49fm5mi8uhqe37BS5/MHGgPh6ykjc1GiYnf3BwOhT8aWFUL0D5ouK0QiSOxCJILELkQiSOxCJILELkQidL2UNCPWsjlzknYYsd4ssqte5DbQHLHHzszxFru3D/P2wKOjPE20kfHX5MnJvcHYSqQddJ6kgQKAjXDLshFJLc6IfVZzvmy1xttwV1f49OuV62FbsJbnKax7D4zR+LV5nhL9+iwvwV1FOC26VOZJpAPkWrVc2LbTnV2IRJDYhUgEiV2IRJDYhUgEiV2IRJDYhUgEiV2IROh+y2bipcd8dlpKOrJsLtL42Iz79LVC2G++ssjTJe8aD6d5AkCBeKMA8NbsNRr3Ynj9Y8N823t28zLYeyIlkxcjKa43lsJxz/Oxrc9xLzur8hTXrBKON+rcB794jkujmudzJ+qD/LgOFEmKrPHrgV3JTjSiO7sQiSCxC5EIErsQiSCxC5EIErsQiSCxC5EIErsQidD9fPaIn925zUby3SPxnIW72aw2eN51jZQ0BoChIV62eGw3nyMwNB7Of969m6+7OMj94lzMC8cKja9kYT/70NR76LJDBX5Oxob55btvX7jk8ulLvB30W3y3sB4pDz5U4sc1UkagI+jOLkQiSOxCJILELkQiSOxCJILELkQiSOxCJILELkQidNlnN+Ry4dcXJzEg4oXnuCcby3f3mM9Ocs7X67z++YUFXt/8tokpGr9r3xEan9wTbjc9PBieHwAA16/O0/j5WR5fbvB9q1TDcwTORertr1zj7aZHyvyclQrrwdjBUX5cKlWex385x+cfNEheOQDkyfUYnxNCw0Gid3Yze9zM5s3s1KbnHjWzS2Z2svlz3842L4ToFtt5G/99APdu8fy33P3u5s9T7R2WEKLdRMXu7s8CWOjCWIQQHaSVL+i+ZGYvNt/mBydnm9lxM5sxs5n6evgzlBCis+xU7N8BcCuAuwHMAvhG6B/d/YS7T7v7dCGSHCCE6Bw7Eru7z7l7w90zAN8FcE97hyWEaDc7EruZbfaKPgvgVOh/hRD9QdRnN7MfAfg4gEkzuwjg6wA+bmZ3A3AA5wB8YVtbMyBHTMKshVz3VvPVozXrSQJyPVL3/Y3rPDm6WuP56gfHuCdcIuHF60t02doa3+98ifeOv3yee+FnL4drvy9U3qTL1qq8tvtd7+P58Icnwpf33jF+zirLDRq/vsTz4S3He6znCuGxxWTArmW2bFTs7v7gFk9/L7acEKK/0HRZIRJBYhciESR2IRJBYhciESR2IRKh+6WkW4JYDpEl49ZaxLpDuFx0luPrfjvjMwcr89wee/0tnppw03g4VXT20gW6bCTTE/f9zV/R+DLvVo3nT54Pxm5kfOOH7rydxs8u8XvVyFg4fucET1E9OsBtvxtrkZbP66QlMwAvhK+JThVb151diESQ2IVIBIldiESQ2IVIBIldiESQ2IVIBIldiET4P+WzU6884qO3vO0svH6P9N/1Ai81nQ3xssVrazwF9tqbZ4Oxcp77we+/7X00PnPqNRp/+tmXaPxqJTz2Ax/g295z6x003qhXaPzV6+H02qECP6ZHytyHP7aHp8BWFhZp/O01kupdHqPL5nJh2Tpx6XVnFyIRJHYhEkFiFyIRJHYhEkFiFyIRJHYhEkFiFyIRuuuze/OHxRkk0Te2KCthDWyjpTMN89fMnHNPN8u4Zzs0wE/T7UduDsamjx2ky164+BaN//LZ52h8fpHv2+FjHwzHPnAnXbaR5/vdMF7mutIIj+3Fq7wUNEb4/IRyxluZ7TOe6L+0GF6+WirTZXMDrD6CfHYhkkdiFyIRJHYhEkFiFyIRJHYhEkFiFyIRJHYhEqG7PruBF8VupVVtjr9uxerCx/LhGyQefcWshWvOA0BhnXu2d0yN0/iHj+wPxhpVXpN+cYX7waNT+2i8dHAXjd98ZzhnPYscuFqde92x+uoFUpt9ZZ179K9VeD764UF+TsciPnxpOZyL3xjiPnuxFJ47YWTGSfQ6NbNDZvZrM3vFzF42sy83n58ws6fN7EzzN29ILYToKdt5G18H8FV3PwbgIwC+aGbHAHwNwDPufhTAM82/hRB9SlTs7j7r7i80H1cAnAZwAMD9AJ5o/tsTAD7TqUEKIVrnz/rMbma3APgQgN8C2O/us83QZQBbfnA0s+MAjgNAsTy003EKIVpk29/Gm9kIgJ8C+Iq7/9G3F76RRbLlNwPufsLdp919ujDIGxwKITrHtsRuZgPYEPoP3f1nzafnzGyqGZ8CMN+ZIQoh2kH0bbxteFbfA3Da3b+5KfQkgIcAPNb8/YtWB5OL2Geg9lmk5XIsxbWFFFiPpKhavUrjh/dyI+PY4T003lgLp2tevMrbPS+s8xTV4sQkjQ+PTNF4VgyX0a5GrDV3fj3kY2nJ5JrwArczr0aup7EcL2N9U5FfEzeNhI/7mcXZYAwAGmVSepxci9v5zP5RAJ8H8JKZnWw+9wg2RP4TM3sYwHkAD2xjXUKIHhEVu7v/BuHb5ifbOxwhRKfQdFkhEkFiFyIRJHYhEkFiFyIRJHYhEqHrpaRjJZsZsSxVRhbZbpZxv5n5l97gaaJ7RvjMwWNHD9B4KcfTKW8shf3q8/M36LJ/WORzAIYn30PjuSJPcWVOemxuQ6w+eOycFljp8Rxvo70eKVN9nbTwBoCJEj/n5dHw9oeX+DlZXV8NxpyULdedXYhEkNiFSASJXYhEkNiFSASJXYhEkNiFSASJXYhE6EEpaWZ+RlobE2vTWvDvAaAR8U2tGvY2xwrcB3/vwZtoPF/kPv0Lr5yl8cUby8HYap6XJS7s4j55YZDkTgPInPvVLCfdjJ/vaAfvFuZdeJ7n0hec++RLGd/v8xkv4b1aeTsYGxgeo8uCxC2XD8Z0ZxciESR2IRJBYhciESR2IRJBYhciESR2IRJBYhciEbrrs8Noa+WsEfG6EfZlY7nNMdPWiY8OALtL4RXc9h5e972Y5/nJr75xhsZ//9ZVGm80SG70nhG6LAZ43nYW66scuV8YO+6tpbNvA1I33nhd96Jxn73hgzQ+T65VgNetHx7k52RgKHxOmb50ZxciESR2IRJBYhciESR2IRJBYhciESR2IRJBYhciEbbTn/0QgB8A2I8N6/OEu3/bzB4F8PcArjT/9RF3fyq2voaFX19idcQLWTgHOVaPvl5fp/GJSE769B3hPuRjZb7t06+8QePnr/Lc6lqe55yD9EBfrUd6nEfc7CwXmfsQOW4g5ztmpFukxoBH7lUZ2baxcQHwmE9ukf3Oc5++OBnuFdAotDK3Iayh7UyqqQP4qru/YGajAJ43s6ebsW+5+z9vYx1CiB6znf7sswBmm48rZnYaAG9hIoToO/6sz+xmdguADwH4bfOpL5nZi2b2uJltOWfUzI6b2YyZzdTX+VtpIUTn2LbYzWwEwE8BfMXdFwF8B8CtAO7Gxp3/G1st5+4n3H3a3acLkf5XQojOsS2xm9kANoT+Q3f/GQC4+5y7N3yjk9x3AdzTuWEKIVolKnYzMwDfA3Da3b+56fnNX09/FsCp9g9PCNEutvNt/EcBfB7AS2Z2svncIwAeNLO7sWGgnAPwhe1skJUWzjJuZ3gjHM+q/PuAcuQTxPSdR2h8366wpfHa2VfpspdmF2i8mttP44URXs45lirK8Eir6nhXZZ4qyshZuOxxO2Dpnha7z8XCNHcXyMUOHGsZnY9sPGIbhtjOt/G/wdaXU9RTF0L0D5pBJ0QiSOxCJILELkQiSOxCJILELkQiSOxCJEKXS0kDRvIa3bnPXq+uBGMjkT35yN0fpPHDe7kR/7uXnwvGTr85T5et5rlPnitF2iJHjHR2TC3i9zIvutfExp7bod8MxFNY86T1MRAfW7TdND3usfO9M/r3TAsh2orELkQiSOxCJILELkQiSOxCJILELkQiSOxCJILFSjC3dWNmVwCc3/TUJADej7h39OvY+nVcgMa2U9o5tsPuvnerQFfF/icbN5tx9+meDYDQr2Pr13EBGttO6dbY9DZeiESQ2IVIhF6L/USPt8/o17H167gAjW2ndGVsPf3MLoToHr2+swshuoTELkQi9ETsZnavmb1qZq+b2dd6MYYQZnbOzF4ys5NmNtPjsTxuZvNmdmrTcxNm9rSZnWn+3rLHXo/G9qiZXWoeu5Nmdl+PxnbIzH5tZq+Y2ctm9uXm8z09dmRcXTluXf/MbmZ5AK8B+FsAFwE8B+BBd3+lqwMJYGbnAEy7e88nYJjZXwNYAvADd39/87l/ArDg7o81Xyh3u/s/9MnYHgWw1Os23s1uRVOb24wD+AyAv0MPjx0Z1wPownHrxZ39HgCvu/tZd68C+DGA+3swjr7H3Z8F8O52MvcDeKL5+AlsXCxdJzC2vsDdZ939hebjCoB32oz39NiRcXWFXoj9AIALm/6+iP7q9+4AfmVmz5vZ8V4PZgv2u/ts8/FlALx3VPeJtvHuJu9qM943x24n7c9bRV/Q/Skfc/cPA/g0gC823672Jb7xGayfvNNttfHuFlu0Gf9fennsdtr+vFV6IfZLAA5t+vtg87m+wN0vNX/PA/g5+q8V9dw7HXSbv3m1yy7ST228t2ozjj44dr1sf94LsT8H4KiZvdfMigA+B+DJHozjTzCz4eYXJzCzYQCfQv+1on4SwEPNxw8B+EUPx/JH9Esb71CbcfT42PW8/bm7d/0HwH3Y+Eb+DQD/2IsxBMZ1BMDvmj8v93psAH6Ejbd1NWx8t/EwgD0AngFwBsB/Apjoo7H9G4CXALyIDWFN9WhsH8PGW/QXAZxs/tzX62NHxtWV46bpskIkgr6gEyIRJHYhEkFiFyIRJHYhEkFiFyIRJHYhEkFiFyIR/gcK4w4eII1bDQAAAABJRU5ErkJggg==\n" | |
}, | |
"metadata": { | |
"needs_background": "light" | |
} | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\n", | |
"ASCII: \n", | |
"########################\n", | |
"#####################**#\n", | |
"#####################**#\n", | |
"#############***######*#\n", | |
"#############*++########\n", | |
"######****####*+########\n", | |
"#####**++****##**###****\n", | |
"####***+****+##**###+*#*\n", | |
"####**+*****++***###+*#*\n", | |
"#####*+**********##*+*##\n", | |
"######***********##**###\n", | |
"######*+**++=++**#%####*\n", | |
"######*+**+==++++*#####*\n", | |
"#######*+===+******####*\n", | |
"########*+**+#%%#*+*##**\n", | |
"##########%#*%@##*++*#**\n", | |
"##########**%@%**#***##*\n", | |
"#########++*%%***##**##*\n", | |
"########++*###**####**##\n", | |
"#######*++*###**#*##**##\n", | |
"######*+*#####**#*#**##*\n", | |
"#####*+*#####***##***#**\n", | |
"####*+*#######*##*******\n", | |
"###*=*########**********\n", | |
"\n", | |
"\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# 1) Get samples to serve as \"context\"\n", | |
"args.seed_train = 42\n", | |
"args.train_samples_per_class = 4\n", | |
"img_dims = {'width': args.width, 'height': args.width}\n", | |
"\n", | |
"context_ix, context_y = get_random_indices_by_class(\n", | |
" indices_by_class,\n", | |
" class_indices,\n", | |
" args.train_samples_per_class,\n", | |
" args.seed_train\n", | |
")" | |
], | |
"metadata": { | |
"id": "9Qrhqhy0qRbV" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# 2) Get examples to few-shot classify\n", | |
"args.seed_eval = 42\n", | |
"args.eval_samples_per_group = 10\n", | |
"test_eval_ix, test_eval_y = get_random_indices_by_class(\n", | |
" test_indices_by_class,\n", | |
" class_indices,\n", | |
" args.eval_samples_per_group,\n", | |
" args.seed_eval\n", | |
")" | |
], | |
"metadata": { | |
"id": "QWW1ab8atP5Z" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# 3) Get prompt context\n", | |
"for seed in range(20):\n", | |
" context_prompt, shuffled_context_ix = get_context_prompt(\n", | |
" train_set_ascii, context_ix,\n", | |
" start='', end='', sep='###\\n', shuffle=True,\n", | |
" seed=seed # May need to change this to shuffle\n", | |
" )\n", | |
"\n", | |
" # Inspect context order (want to avoid patterns between outputs, e.g., 0, 1, 0, 1, ...)\n", | |
" print(seed, train_set.targets[shuffled_context_ix])" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "V8ZfWFBMtS-E", | |
"outputId": "ae34e695-dae9-4055-bbc6-c4844ad79ec2" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"0 tensor([1, 2, 1, 2, 0, 2, 0, 1, 2, 0, 0, 1])\n", | |
"1 tensor([0, 0, 1, 2, 0, 1, 0, 1, 2, 2, 2, 1])\n", | |
"2 tensor([2, 1, 1, 1, 0, 2, 0, 0, 0, 1, 2, 2])\n", | |
"3 tensor([1, 1, 0, 0, 2, 1, 1, 0, 0, 2, 2, 2])\n", | |
"4 tensor([0, 1, 1, 2, 2, 0, 2, 0, 0, 1, 1, 2])\n", | |
"5 tensor([1, 1, 0, 2, 2, 1, 2, 0, 0, 2, 1, 0])\n", | |
"6 tensor([2, 2, 0, 1, 0, 1, 1, 0, 1, 0, 2, 2])\n", | |
"7 tensor([1, 2, 0, 1, 0, 0, 2, 2, 0, 1, 2, 1])\n", | |
"8 tensor([1, 2, 2, 2, 1, 2, 0, 0, 1, 0, 1, 0])\n", | |
"9 tensor([1, 2, 0, 0, 0, 1, 0, 2, 2, 2, 1, 1])\n", | |
"10 tensor([0, 1, 1, 1, 2, 2, 2, 0, 0, 0, 1, 2])\n", | |
"11 tensor([1, 0, 2, 2, 0, 2, 1, 1, 1, 0, 0, 2])\n", | |
"12 tensor([2, 1, 2, 2, 0, 1, 1, 0, 0, 0, 1, 2])\n", | |
"13 tensor([1, 2, 0, 1, 2, 0, 1, 2, 1, 2, 0, 0])\n", | |
"14 tensor([0, 2, 0, 1, 1, 0, 2, 0, 1, 1, 2, 2])\n", | |
"15 tensor([2, 1, 0, 0, 2, 0, 1, 0, 1, 2, 1, 2])\n", | |
"16 tensor([1, 0, 0, 2, 1, 2, 1, 0, 0, 1, 2, 2])\n", | |
"17 tensor([0, 0, 2, 2, 1, 1, 2, 1, 0, 2, 1, 0])\n", | |
"18 tensor([1, 1, 2, 0, 2, 0, 0, 1, 1, 2, 0, 2])\n", | |
"19 tensor([1, 0, 2, 2, 0, 1, 2, 1, 0, 2, 0, 1])\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# 3.1) Get prompt context (cont.)\n", | |
"context_prompt, shuffled_context_ix = get_context_prompt(\n", | |
" train_set_ascii, context_ix,\n", | |
" start='', end='', sep='###\\n', shuffle=True,\n", | |
" seed=7 # Pick a random-looking one from above\n", | |
")" | |
], | |
"metadata": { | |
"id": "_QXJgEJQtVCO" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# 3.2 Visualize a prompt\n", | |
"# Formatting the above (for sanity checking)\n", | |
"for i, ix in enumerate(shuffled_context_ix):\n", | |
" print(f'Input:')\n", | |
" pretty_print_ascii(\n", | |
" img_to_text(train_set_ascii.__getitem__(ix)[0]),\n", | |
" **img_dims\n", | |
" )\n", | |
" print(f'Output: {train_set_ascii.__getitem__(ix)[1]}\\n###\\n')\n", | |
"\n", | |
"# Eval sample\n", | |
"print(f'Input:')\n", | |
"pretty_print_ascii(\n", | |
" img_to_text(test_set_ascii.__getitem__(test_indices_by_class[1][0])[0]),\n", | |
" **img_dims\n", | |
")\n", | |
"print(f'Output:')" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "R4DOPULMyKKa", | |
"outputId": "235beff4-ccb7-482f-cfb7-5d66028d76d1" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Input:\n", | |
"+=+##%#%@@@%@@%@@@@%#=--\n", | |
"###%%%#%@@@@@%%%%%%##=``\n", | |
"%%%@@@#%@%@@%%%%%%#**-'`\n", | |
"%%%@@%#%%%%%%%%%%%#++-'`\n", | |
"%%%%%%#%%%%%%%%%%%#**-'`\n", | |
"%%%%%%##%#*#%%%%%%#**:'`\n", | |
"%@%%#*+***+**#%@%%%%%-'`\n", | |
"@@%*+=-*###*##+#@%%%%='`\n", | |
"@%#++=:*@%%%#%*=%@@##='`\n", | |
"%#+===+*%@@%%%#++%@%*='`\n", | |
"#*=-=++*%%%%%%#*++%##='`\n", | |
"#+==+**#%%###*****###='`\n", | |
"#=-=++*#*+--++++*+*+*-'`\n", | |
"*=++=+++++=+**++++++=-``\n", | |
"#***++++++**####****++-`\n", | |
"***+****+***+==+=*###+==\n", | |
"+=+==+***%#*=::-=#%##++=\n", | |
"=-+===++****+===+***++*=\n", | |
":-+==+**#=+##****=:``=#=\n", | |
"-+##***#*==*#*###+-=+*#=\n", | |
":=+#%%%#*-=*#**#%%%%%%#=\n", | |
":::-*%@@%--*####%%%%%@*-\n", | |
"-----=+*%*+%@@@@@@@@@%+-\n", | |
"---==--=+*##%%%%##**++==\n", | |
"Output: 1\n", | |
"###\n", | |
"\n", | |
"Input:\n", | |
"`=+=:````'`''```'''''`=*\n", | |
"-+**=:````````````````-+\n", | |
"+###*-'`:-:::`````````=#\n", | |
"*###+-::+*==-`:-`'''':*#\n", | |
"####*+++*#+**++*-```:=#%\n", | |
"##*##*###%##*+*+=---==#%\n", | |
"#######%%###*=:``:::::*%\n", | |
"******+**++++=:``:````+*\n", | |
"==-=======--=----+=-::--\n", | |
"=====---===----=====----\n", | |
"==============--====----\n", | |
"=======++***+=--=--=----\n", | |
"+====+*******+=-======--\n", | |
"++==+####*****+========-\n", | |
"+*+=*%%%%##*****========\n", | |
"**+=###%%%%####*+=======\n", | |
"#*=+*++%%%%%###*+=======\n", | |
"*+==++*##%#####*+=======\n", | |
"=====++++*++**+=========\n", | |
"+++==+++++++============\n", | |
"++++++++++++++++++++++++\n", | |
"+++++++++*+++++**++*++++\n", | |
"++++++++**+++++++++*++++\n", | |
"++++++++**+++++++*+++++=\n", | |
"Output: 2\n", | |
"###\n", | |
"\n", | |
"Input:\n", | |
"********+++++++++++*****\n", | |
"**+++++++++++++++*+++++*\n", | |
"++++++++++++++++*#++++++\n", | |
"+++++++++++*++++##++++++\n", | |
"++++++++++#%+=++##++++++\n", | |
"++++=====+%%+==+%*++=+++\n", | |
"=========#%#===*%#+=====\n", | |
"========+%%#**##%*======\n", | |
"======--*%%%%%%%%*-=====\n", | |
"-------=%%%%%%%%%*------\n", | |
"-------*%%%%%%%%%*------\n", | |
"-----:=%%%%%%%%%%+=++**+\n", | |
"::::::*%%%%####%%##%%##=\n", | |
":::::=%%%%%####%%%%#*=-:\n", | |
":::::*#%%%#####%##+-::::\n", | |
":::`=**%%###%%#*=:```:::\n", | |
"```:**#%####*+-````````:\n", | |
"```=#%%%##+=:```````````\n", | |
"``:#%%%*+:``````````````\n", | |
"``=%#+-:````````````````\n", | |
"`:+=:```````````````````\n", | |
"`::`````````````````````\n", | |
"````````````````````````\n", | |
"````````````````````````\n", | |
"Output: 0\n", | |
"###\n", | |
"\n", | |
"Input:\n", | |
"--===--------------==+++\n", | |
"===++==----------===+*#-\n", | |
"++:++==============++**:\n", | |
"+=:=+==============++**:\n", | |
"+=`=+===============+**:\n", | |
"+=`-=====-===-==--==+**:\n", | |
"+=`-=----:----------=+*:\n", | |
"*+`---:-----:--:::---++:\n", | |
"*+`:--=---------:::--=+:\n", | |
"*+:::*#+=++===+--:::-=+:\n", | |
"=::`=@@%%%#####+-=-:-=+:\n", | |
"=```:=====++***+==:`:=+:\n", | |
"*:::`::::::::::=:`````::\n", | |
"*---:------------::```::\n", | |
"=--+#+------:-----+#=:-:\n", | |
"--=%#%+----------=%##=-:\n", | |
"::+#+**--:-:::::-**=#=::\n", | |
"--+*=+#--:--:-:::#+-*+::\n", | |
"*#%*=*%+++++++==+%*=*#=:\n", | |
"%@@%+#@@@@@@@@@@@@#+%@#:\n", | |
"+**#**##*###**#**#***+=:\n", | |
":::::::::::::::::::::```\n", | |
"::::::::::::::::::``:::`\n", | |
":-:::---:::::-::::::::::\n", | |
"Output: 1\n", | |
"###\n", | |
"\n", | |
"Input:\n", | |
"::``````````````````````\n", | |
"::``````````````````````\n", | |
"::``````````````````````\n", | |
"````````````````````````\n", | |
"````````````````````````\n", | |
"````````````````````````\n", | |
"````````````````````````\n", | |
"`````````````````````::`\n", | |
"``````:-=::::`':-```'-=`\n", | |
"````:-=+*+++=::-=::--=*=\n", | |
"```-+*###*****####*****-\n", | |
"-:`=#%%%%###%%@@%#%#=*=:\n", | |
"#+-=*#@@%@%%@@@%%*#+-=+-\n", | |
"%#*+-=+=+@%#@@##=--:----\n", | |
"**++==-:+#+==*#+::---=--\n", | |
"+++=====#%*+=*%===**+++=\n", | |
"++++++*+*****##***#**++=\n", | |
"+++++++++++++++++++****+\n", | |
"+++++=+++++++++++++*++++\n", | |
"++++++++++++++++++++++++\n", | |
"++++++++++++++++++++****\n", | |
"++++++=+++++++++++++***+\n", | |
"++++++==++++++++++++++++\n", | |
"+++++===++++++++++++++++\n", | |
"Output: 0\n", | |
"###\n", | |
"\n", | |
"Input:\n", | |
"```````````````````````:\n", | |
"```````````````````````:\n", | |
"`````````````::::`::::::\n", | |
"::``:++*-```:::::`::::::\n", | |
"::-=+#%%*=-::::::```````\n", | |
"--+##%%%%##++---+=-==-::\n", | |
"###%%%#%#%%##*#+##*##++*\n", | |
"%%%%%%%%%%%#=+#*+****#%%\n", | |
"#%%%%%%%%%#+:=#*+++*#%#*\n", | |
"*#%%%%%%%%+--=++****%@%#\n", | |
"+=#@%%%#*==+++**###%%@@%\n", | |
"++#%%#*=--=++*##%##%@@@%\n", | |
"#**##=+==++#%%%####%@@@%\n", | |
"*##%%=+++*%@@@@@@%%%%@@@\n", | |
"#%@@@**#*#%@@@@@%##**###\n", | |
"%@@@@##@%%%%%%#***#*****\n", | |
"%%%%####**********#%****\n", | |
"##%#**********##%##%####\n", | |
"#############%%@@@@@%%%%\n", | |
"************##%%%%%#####\n", | |
"************************\n", | |
"************************\n", | |
"****++++++++*****+**++++\n", | |
"**************++***+++++\n", | |
"Output: 0\n", | |
"###\n", | |
"\n", | |
"Input:\n", | |
"###**********###########\n", | |
"#***********############\n", | |
"********#***############\n", | |
"**#######****###########\n", | |
"**######*++===+*###*****\n", | |
"###**+==--==-===*##*****\n", | |
"*++===----=--====*#***##\n", | |
"====+==-=====+++++**####\n", | |
"+*++++===+++*#*+****####\n", | |
"+**#*++==++=###****#####\n", | |
"==+**++==+++*****++*####\n", | |
"++++**+=++*+***#**+*####\n", | |
"++**++*+++*+*****######*\n", | |
"+++**+++++*##+****######\n", | |
"++++****++**++*#########\n", | |
"+++*+**##**++*#%########\n", | |
"**+++***#**+**#%########\n", | |
"***+**********##########\n", | |
"*******###****##*#******\n", | |
"*****######***##********\n", | |
"*****####%#**##*********\n", | |
"**+***####+*#%#*********\n", | |
"*****#####+*#%#*********\n", | |
"#***###%#**###**********\n", | |
"Output: 2\n", | |
"###\n", | |
"\n", | |
"Input:\n", | |
"########################\n", | |
"#####################**#\n", | |
"#####################**#\n", | |
"#############***######*#\n", | |
"#############*++########\n", | |
"######****####*+########\n", | |
"#####**++****##**###****\n", | |
"####***+****+##**###+*#*\n", | |
"####**+*****++***###+*#*\n", | |
"#####*+**********##*+*##\n", | |
"######***********##**###\n", | |
"######*+**++=++**#%####*\n", | |
"######*+**+==++++*#####*\n", | |
"#######*+===+******####*\n", | |
"########*+**+#%%#*+*##**\n", | |
"##########%#*%@##*++*#**\n", | |
"##########**%@%**#***##*\n", | |
"#########++*%%***##**##*\n", | |
"########++*###**####**##\n", | |
"#######*++*###**#*##**##\n", | |
"######*+*#####**#*#**##*\n", | |
"#####*+*#####***##***#**\n", | |
"####*+*#######*##*******\n", | |
"###*=*########**********\n", | |
"Output: 2\n", | |
"###\n", | |
"\n", | |
"Input:\n", | |
"%%%%%%%%%%%%%%%%%%%%#%%#\n", | |
"%%%%%%%%%%%%%%%%%%%%%%%%\n", | |
"%%%%%%%%%%%%%%%%%%%%%%%%\n", | |
"%%%%%%%%%%%%%%%%%%%%%%%%\n", | |
"%%%%%%%%%%%%%%%%%%%%%%%%\n", | |
"%%%%%%%%%%%%%%%%%%%%%%%%\n", | |
"%%%%%%%%%%%%%%%%%%%%%%%%\n", | |
"%%%%%##%%%%%%%%%%%%*#%%%\n", | |
"%%%%%#=-*#%%%%%%%%+-*%%%\n", | |
"%#+++++==+***##**=:=#%%%\n", | |
"%%*+=--==*#*+=-::``-#%%#\n", | |
"%###*+++==**++====-=*+=-\n", | |
"%+=--=++++*++==----=-:::\n", | |
"%%#+----::::=+-:::::--::\n", | |
"%##***++-::-=*+==-======\n", | |
"%#+=======++**+====++***\n", | |
"%#*+++++++*****+++=+*###\n", | |
"%%#********####***+*####\n", | |
"%%##*****#######****####\n", | |
"%%##########%%####*##%%%\n", | |
"%%%#######%%%%#######%%%\n", | |
"%%%%%####%%%%%%#####%%%%\n", | |
"%%%%%%%%%%%%%%%%####%%%%\n", | |
"%%%%%%%%%%%%%%%%%###%%%%\n", | |
"Output: 0\n", | |
"###\n", | |
"\n", | |
"Input:\n", | |
"`::::::::::`:`::::::::::\n", | |
":=+==++++++=---=====+===\n", | |
":+++=+++===--:-===++====\n", | |
":=======----::--=====+==\n", | |
":------=+****+===--=====\n", | |
"::::::--**##**+==------=\n", | |
"::::---:+####**=---:---:\n", | |
":-:-+**+*####***=--:::::\n", | |
":-+*++**#**+==--::```:::\n", | |
"::+###**#+*++==---::````\n", | |
"::+######++*++******+:``\n", | |
"::*#######*+-=+*%%##*-``\n", | |
"::=#%%##%%#****##***+=:`\n", | |
":`-#%%%%%#######*===++-`\n", | |
":`-%%#%%%**%%##%%***##+:\n", | |
":`-%*=+*#**%%##%@%%%%#+`\n", | |
":``-:::-=*#%%%%#*++*%#:`\n", | |
":`'''```:+%%+===---=#+`'\n", | |
":`'''''''-++----:::::``'\n", | |
":`''''''``````````````''\n", | |
"::``````````````````````\n", | |
"`::::::::::::::::::::::`\n", | |
"`::----:````````````````\n", | |
"`:--==-:`'''''''''''''''\n", | |
"Output: 1\n", | |
"###\n", | |
"\n", | |
"Input:\n", | |
"-==-------------::--::::\n", | |
"==----------::--::-:::::\n", | |
"=====-------::::::-:::::\n", | |
"======------::::--::::::\n", | |
"=====-------:::=*=::::::\n", | |
"======--------:#%=::::::\n", | |
"======--------=%%-::::::\n", | |
"-=========----#%+-::::::\n", | |
"-===========-=%%=---::::\n", | |
"-===========-+@%=---::::\n", | |
"-============*@%=---::::\n", | |
"---------====#@#----::::\n", | |
"::::---::---+#@*--------\n", | |
"::::::::::--#%%=--------\n", | |
":::::::::::=@%*+=-------\n", | |
"::::::::::-%@@#=-=------\n", | |
"::::::::::*@@#*+==------\n", | |
":::::::::+@@@==+-------:\n", | |
"::--::::-#@@*---------::\n", | |
"::--::::+%#*-:::::----::\n", | |
":--:::::*@*-::::::------\n", | |
"--:::::-*#=::::::::-:---\n", | |
"::::::::++:::::::::::---\n", | |
"::::::::::::::::::::----\n", | |
"Output: 2\n", | |
"###\n", | |
"\n", | |
"Input:\n", | |
"#%@%+*@@%%%%%%%%@%######\n", | |
"#@@**#@@@@@@@@%*%@%#%##*\n", | |
"#%#**%@@@@@@@@@#*%%###%#\n", | |
"#%*##%%%####%%%%*######*\n", | |
"+##*#%####%%##%%%####%*+\n", | |
"=+*+=***##%##%%%%%#==*+*\n", | |
"==+*+****######%###-``:-\n", | |
"-=+******%%%%%%%%%%#:'``\n", | |
":=*##*###%%%%%%##%##=```\n", | |
":+#**#*###***+++*###*:'`\n", | |
"-+***#*#*-+++++*####%+``\n", | |
":=*#####=:+*+==++*#%%*``\n", | |
"`:+##%###****=-=+#%##*:`\n", | |
"`:=*#%###*******#####*-`\n", | |
"`-*##%########%%%%%%##=`\n", | |
"`=#%%%%%#%%%%#%%%%%%#%=`\n", | |
"`-%%%%%%%%%%%%%@@%%#%%=`\n", | |
"`-#@%%%####%%%%%**%%%%-`\n", | |
"`:#@@@%%%*+###*#*#%@%#:`\n", | |
"`:*@@@@@@%%%####%%%@%#-`\n", | |
"`:*%%@@@@@%%#####@@%%#:`\n", | |
"::*%%%@@@@@#####%@%%#=``\n", | |
"::=%%%%%%%%##*##%%%#+:``\n", | |
"::-+%@@%%%%%%%%%%%*+-:``\n", | |
"Output: 1\n", | |
"###\n", | |
"\n", | |
"Input:\n", | |
"#***++++****#**###%%%%%%\n", | |
"++*******#*########%@%%%\n", | |
"*****##########%#%%@@%%%\n", | |
"***#############*#%@%#%%\n", | |
"######%%########*#%%##%%\n", | |
"%%%%%%%#########%@%##%%%\n", | |
"%%%%%#####**####%@%%%%%%\n", | |
"%%%%%#####***++*#%##%%%%\n", | |
"%%%#%####*=-----=%%##%%%\n", | |
"%#%#####+++**#%=`+@@#%%%\n", | |
"####%%#*###@@@%=:-#@%%%%\n", | |
"###########%%*=:::*%@%#%\n", | |
"#####**####=:```:+%%@*=#\n", | |
"#%##%##%%%*``::`-#%#*==#\n", | |
"%%+=#%%@@%+:-::::-+====#\n", | |
"%#:-#%%@@%+---+#+-:-==*%\n", | |
"%**+*%@@%#+--+%%@*-=+*@@\n", | |
"%#%#*%@@#*=:-*%%%%++*@@%\n", | |
"%%#**%@%#*---#%%%@*=#@@%\n", | |
"%%##*%#*==---#%*#@%*%@@@\n", | |
"%%##*##*+++++#%##@@%@@@@\n", | |
"%%%%@@@@@@@@@@%#%@@@@@@@\n", | |
"%%%%%%%@%%@@@@@%%@@@@@@%\n", | |
"%%#**#%%####%%%%@@@@@@%%\n", | |
"Output:\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# 4) Get full prompts for each few-shot prompt, and get GPT outputs\n", | |
"model_name = 'text-davinci-003'\n", | |
"all_responses = []\n", | |
"all_targets = []\n", | |
"for test_ix in test_eval_ix:\n", | |
" # Calling this line will give OpenAI money\n", | |
" eval_prompt, target = get_eval_prompt(test_set_ascii, ix=test_ix, start='')\n", | |
" _prompt = context_prompt + eval_prompt\n", | |
" response = get_gpt_output(_prompt, model=model_name)\n", | |
" all_responses.append(response)\n", | |
" all_targets.append(target)" | |
], | |
"metadata": { | |
"id": "-BxHv8bZtgAW" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# 5) Evaluate\n", | |
"\n", | |
"# Report accuracy metrics\n", | |
"avg_acc, class_accs = eval_predictions(\n", | |
" [decode_response(r) for r in all_responses],\n", | |
" [test_set.targets[test_ix] for test_ix in test_eval_ix],\n", | |
")\n", | |
"print('')\n", | |
"\n", | |
"# Report predicted vs true targets\n", | |
"print(f'GPT-3 {model_name} few-shot predictions')\n", | |
"print([decode_response(r) for r in all_responses])\n", | |
"print('')\n", | |
"print(f'Ground-truth labels')\n", | |
"print([test_set.targets[test_ix].item() for test_ix in test_eval_ix])" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "_J3VoNmIuZXH", | |
"outputId": "40a094f4-4b9a-4063-eb1d-2154688f4100" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Average acc: 50.0%\n", | |
"Class accs:\n", | |
"- 7 / 10 = 70.0%\n", | |
"- 0 / 10 = 0.0%\n", | |
"- 8 / 10 = 80.0%\n", | |
"\n", | |
"GPT-3 text-davinci-003 few-shot predictions\n", | |
"[0, 2, 2, 0, 0, 0, 0, 0, 2, 0, 2, 0, 2, 0, 2, 2, 2, 2, 2, 0, 2, 2, 2, 0, 0, 2, 2, 2, 2, 2]\n", | |
"\n", | |
"Ground-truth labels\n", | |
"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### Results: CIFAR-10\n", | |
"So still better than chance in aggregate... but questionable if it actually works. Curious how no labels for cars (`1`'s get outputted), even though as a prior I'd suspect planes and birds to be the confusing distinction.\n", | |
"\n", | |
"* There's probably also brittleness and a lot to do with the sample selection and prompting.\n", | |
"\n", | |
"Lots of battle-testing to do!" | |
], | |
"metadata": { | |
"id": "QaLxKRhrzuVR" | |
} | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment