Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save shrijayan/c3595fdffeb37c2f964b5ceb4848cf67 to your computer and use it in GitHub Desktop.
Save shrijayan/c3595fdffeb37c2f964b5ceb4848cf67 to your computer and use it in GitHub Desktop.
Can We Teach AI to See Through Text? A Fun Exploration with ASCII Art and GPT-3
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"source": [
"# Art Project - Large Language Models can be Few-shot MNIST Classifiers - [Shrijayan](https://shrijayan.cpluz.com)\n",
"\n",
"Can we get large language models (LLMs) like GPT-3 to do few-shot classification of images?\n",
"\n",
"A key challenge is how to get a large language model to understand images (channel x width x height pixel matrices). We try to bridge this gap by \"translating\" images to ASCII art.\n",
"\n",
"The key idea is to prompt an LLM to output a class label given an ASCII image + a few examples of (image, class label) pairs as context.\n",
"\n",
"### *Hypothesis* (why interesting)\n",
"We've seen that LLMs are very good at few-shot learning. They can learn to complete a new task based on only a few unseen examples, all expressed as natural text!\n",
"\n",
"However, while these models were trained on lots of human-produced text, it may not matter as much what the actual content of the examples are (or if the content is similar to something the LLM encountered during pretraining).\n",
"* Instead, to get an LLM to solve a task, the prompt we give it may only need some consistent structure or recognizable patterns across examples. \n",
"\n",
"One way to test this is to see if an LLM can reason over sequences containing ASCII art of real-world images.\n",
"\n",
"\n",
"\n",
"### *Motivations* (why important) and *difficulties* (why hard)\n",
"While this is just an art project for fun, there might also be some related usefulness. Personally, I've been interested in \n",
"1. How we can adapt benefits of foundation models to work well on individualized use-cases and local points of deployment \n",
"2. How to get models to handle new settings without much data (perhaps a new failure mode or edge case surfaces).\n",
"\n",
"Some difficulties with realizing the above are:\n",
"\n",
"- Training a sufficient model from scratch may not be feasible, as there may not be enough data in the deployment setting. \n",
"- Finetuning a pretrained model may be expensive (large FMs that not everyone has access to, hard to maintain such a model for each setting). \n",
"- A pretrained model's out-of-the-box zero-shot behavior may not be desired for a specific deployment scenario.\n",
"\n",
"### *Solution?*\n",
"However, recent LLMs have shown the ability to \"steer\" / change their classification based on a few examples, without having to finetune their weights. This is great because for the above settings, by providing a few examples illustrating the desired classification behavior, we could then get personalized inference. \n",
"- But the world is not just all natural language! It'd also be great to extend this capability to more modalities and domains. \n",
"- Given that we have these text-trained LLMs lying around, can we use text as a universal modality to do so?"
],
"metadata": {
"id": "qdAMTEAwR3en"
}
},
{
"cell_type": "markdown",
"source": [
"### Setup"
],
"metadata": {
"id": "oH6-uz6qSg7r"
}
},
{
"cell_type": "code",
"source": [
"!pip install numpy\n",
"!pip install pytorch\n",
"!pip install torchvision\n",
"!pip install pillow\n",
"!pip install openai # Only for using GPT-3 models"
],
"metadata": {
"id": "wj-CgWxsSFyO"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import copy\n",
"import torch\n",
"import numpy as np\n",
"\n",
"from PIL import Image"
],
"metadata": {
"id": "hUbNPT5sStwL"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Fake args\n",
"class Args():\n",
" def __init__(self):\n",
" pass\n",
"args = Args()\n",
"args.seed = 42\n",
"args.device = torch.device('cuda:0')\n",
"args.display_image = True\n",
"args.num_workers = 2\n",
"\n",
"np.random.seed(args.seed)\n",
"torch.manual_seed(args.seed)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Q6bm0782S26g",
"outputId": "22d30d11-4a43-4028-c966-35c284f51c81"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<torch._C.Generator at 0x7fb4d21ddc70>"
]
},
"metadata": {},
"execution_count": 159
}
]
},
{
"cell_type": "markdown",
"source": [
"### Load data\n",
"\n",
"We need to load images and convert them to ASCII. To ensure this fits into a propt that an LLM can process (frequently 2048 or 4096 tokens), we may also need to downsize the images.\n",
"\n",
"For this art project, we'll use the MNIST dataset."
],
"metadata": {
"id": "u_8Z7l8YS-v_"
}
},
{
"cell_type": "markdown",
"source": [
"### Image transforms\n",
"\n",
"We convert each image to a downsized, center-cropped, gray-scaled square image."
],
"metadata": {
"id": "r2AcDPUsTFMo"
}
},
{
"cell_type": "code",
"source": [
"import torchvision\n",
"from torchvision import transforms\n",
"from torchvision.transforms import Compose, Resize, CenterCrop, Grayscale, ToTensor\n",
"try:\n",
" from torchvision.transforms import InterpolationMode\n",
" BICUBIC = InterpolationMode.BICUBIC\n",
"except ImportError:\n",
" BICUBIC = Image.BICUBIC"
],
"metadata": {
"id": "NXv9l_GYTLFd"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def get_transform(w, grayscale=True, to_tensor=False):\n",
" transforms = [\n",
" Resize(w, interpolation=BICUBIC),\n",
" CenterCrop(w),\n",
" ]\n",
" if grayscale:\n",
" transforms.append(Grayscale(num_output_channels=1))\n",
" if to_tensor:\n",
" transforms.append(ToTensor())\n",
" return Compose(transforms)"
],
"metadata": {
"id": "SBdlcKGpTO4h"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Specify the width of downsized images\n",
"args.width = 24"
],
"metadata": {
"id": "SOHVeNVbS4mo"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Actual transform function\n",
"transform_grayscale = get_transform(args.width, grayscale=True)\n",
"\n",
"# Original image (for visualization)\n",
"transform_base = get_transform(28, grayscale=False)"
],
"metadata": {
"id": "A2W5hOkvTQtB"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### ASCII conversion \n",
"3 options to pick from: \n",
"* `ASCII_68`: 68 levels of brightness \n",
"* `ASCII_10`: 10 levels of brightness \n",
"* `ASCII_10n`: 10 levels of brightness using integers to denote level"
],
"metadata": {
"id": "5CzZaOl2TV4R"
}
},
{
"cell_type": "code",
"source": [
"# No periods or spaces bc I thought that might interfere with LLM\n",
"ASCII_68 = \"$@B%8&WM#*oahkbdpqwmZO0QLCJUYXzcvunxrjft/\\|()1{}[]?-_+~<>i!lI;:,\\\"^`'\"\n",
"ASCII_10 = \"@%#*+=-:`'\"\n",
"ASCII_10n = ''.join([str(i) for i in range(len(ASCII_10))[::-1]])"
],
"metadata": {
"id": "RyObH9RMTgfs"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def img_to_text(img, scale=10, numerical=False):\n",
" global ASCII_68, ASCII_10, ASCII_10n\n",
" if scale > 10:\n",
" ascii_scale = ASCII_68\n",
" elif numerical:\n",
" ascii_scale = ASCII_10n\n",
" else:\n",
" ascii_scale = ASCII_10\n",
"\n",
" return ''.join([ascii_scale[int(p * (scale - 1) / 255)] for p in\n",
" np.array(img).flatten()])"
],
"metadata": {
"id": "QpCs4OgOTirt"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def pretty_print_ascii(text, width, height):\n",
" num_lines = int(len(text) / width)\n",
" assert num_lines == height\n",
" print_str = []\n",
" for i in range(num_lines - 1):\n",
" print_str.append(f'{text[i * width: (i + 1) * width]}\\n')\n",
" i += 1\n",
" print_str.append(f'{text[i * width: (i + 1) * width]}')\n",
" print(''.join(print_str))"
],
"metadata": {
"id": "1GsASUR3TkHl"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### See MNIST examples"
],
"metadata": {
"id": "NUWPFFBfTo5F"
}
},
{
"cell_type": "code",
"source": [
"# Load original data\n",
"train_set = torchvision.datasets.MNIST(root='.', train=True, transform=transform_base, download=True)\n",
"test_set = torchvision.datasets.MNIST(root='.', train=False, transform=transform_base, download=True)\n"
],
"metadata": {
"id": "HJy-SZmzTnF3"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"train_set.__getitem__(0)[0]"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 45
},
"id": "hPdwx9hVUHao",
"outputId": "39d34dde-c9d6-4b27-e515-f69ac302c1b0"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<PIL.Image.Image image mode=L size=28x28 at 0x7FB4C431FB50>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABAElEQVR4nGNgGMyAWUhIqK5jvdSy/9/rGRgYGFhgEnJsVjYCwQwMDAxPJgV+vniQgYGBgREqZ7iXH8r6l/SV4dn7m8gmCt3++/fv37/Htn3/iMW+gDnZf/+e5WbQnoXNNXyMs/5GoQoxwVmf/n9kSGFiwAW49/11wynJoPzx4YIcRlyygR/+/i2XxCWru+vv32nSuGQFYv/83Y3b4p9/fzpAmSyoMnohpiwM1w5h06Q+5enfv39/bcMiJVF09+/fv39P+mFKiTtd/fv3799jgZiBJLT69t+/f/8eDuDEkDJf8+jv379/v7Ryo4qzMDAwMAQGMjBc3/y35wM2V1IfAABFF16Aa0wAOwAAAABJRU5ErkJggg==\n"
},
"metadata": {},
"execution_count": 168
}
]
},
{
"cell_type": "code",
"source": [
"# Load ASCII data\n",
"train_set_ascii = torchvision.datasets.MNIST(root='.', train=True,\n",
" transform=transform_grayscale, download=False)\n",
"test_set_ascii = torchvision.datasets.MNIST(root='.', train=False,\n",
" transform=transform_grayscale, download=False)"
],
"metadata": {
"id": "EW7TgezzTzYG"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"pretty_print_ascii(\n",
" img_to_text(train_set_ascii.__getitem__(0)[0]),\n",
" width=args.width, height=args.width\n",
")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "qzUce8kCUJey",
"outputId": "f728ad4d-018b-4eb8-c083-73c2493810ee"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@@@@*+%*-=%@@@\n",
"@@@@@@@%%+=::```=``-%@@@\n",
"@@@@@@#`'``''''=##%@@@@@\n",
"@@@@@@%:````==:*@@@@@@@@\n",
"@@@@@@@%#=`=@@%%@@@@@@@@\n",
"@@@@@@@@@#`=@@@@@@@@@@@@\n",
"@@@@@@@@@@=`%@@@@@@@@@@@\n",
"@@@@@@@@@@%::+%@@@@@@@@@\n",
"@@@@@@@@@@@#:':#@@@@@@@@\n",
"@@@@@@@@@@@@%=`:*@@@@@@@\n",
"@@@@@@@@@@@@@@+'`%@@@@@@\n",
"@@@@@@@@@@@@@#='`%@@@@@@\n",
"@@@@@@@@@@@*-``':@@@@@@@\n",
"@@@@@@@@@#-`''`=%@@@@@@@\n",
"@@@@@@@%+`''`=%@@@@@@@@@\n",
"@@@@%#-`''`=%@@@@@@@@@@@\n",
"@@@%:''`:-*@@@@@@@@@@@@@\n",
"@@@%***#%@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"# LLM Time\n",
"\n",
"Now we'll create prompts, initialize GPT-3, and get outputs."
],
"metadata": {
"id": "7ivitinuUuSe"
}
},
{
"cell_type": "markdown",
"source": [
"## Prompt Creation\n",
"\n",
"How we format the prompt, and what samples we include in the prompt could affect the LLM output. \n",
"\n",
"For formatting, we'll just follow a simple few-shot template: \n",
"```\n",
"Input: [flattened_ascii_image] \n",
"Output: [class_label] \n",
"###\n",
"Input: [flattened_ascii_image] \n",
"Output: [class_label] \n",
"###\n",
"...\n",
"Input: [flattened_ascii_image] \n",
"Output: \n",
"``` \n",
"\n",
"For selecting samples to include, there are many potentially fancy ways to do this. For classification with many potential groups, it'd be good to include at least one example of each group. \n",
"\n",
"For now, we'll just randomly sample from each group (e.g., MNIST digit class)."
],
"metadata": {
"id": "aesYe1OmU5EW"
}
},
{
"cell_type": "markdown",
"source": [
"## Helpers"
],
"metadata": {
"id": "eWkwQlBuVDYf"
}
},
{
"cell_type": "code",
"source": [
"def get_indices_by_class(dataset):\n",
" \"\"\"\n",
" Organize data indices into classes\n",
" \"\"\"\n",
" indices_by_class = []\n",
" num_classes = len(np.unique(dataset.targets))\n",
" for i in np.arange(num_classes):\n",
" indices_class = np.where(dataset.targets == i)[0]\n",
" indices_by_class.append(indices_class)\n",
" return indices_by_class"
],
"metadata": {
"id": "HiDLQbWwUQX8"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def get_random_indices_by_class(indices_by_class, classes, num_samples, sample_seed):\n",
" \"\"\"\n",
" Randomly sample samples from each group\n",
" \"\"\"\n",
" np.random.seed(sample_seed)\n",
" reference_sample_ix = []\n",
" reference_sample_y = []\n",
" for i in classes:\n",
" sample_ix = np.random.choice(indices_by_class[i], size=num_samples, replace=False)\n",
" reference_sample_ix.append(sample_ix)\n",
" reference_sample_y.append(np.repeat([i], num_samples))\n",
" return np.concatenate(reference_sample_ix), np.concatenate(reference_sample_y)"
],
"metadata": {
"id": "Nrp7hWBgVHR1"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def get_context_prompt(dataset, indices, pretty_print=False,\n",
" start='[START]\\n', end='[END]\\n', sep='###\\n',\n",
" shuffle=False, seed=42):\n",
" \"\"\"\n",
" Given dataset, dataset indices, retrieve ASCII images, ground-truth labels,\n",
" and format into the context of a prompt\n",
" \"\"\"\n",
" prompt = []\n",
" # Template: start, x, y, end, sep\n",
" sample = '{}Input: {}\\nOutput: {}\\n{}{}'\n",
" if pretty_print: # Support this later\n",
" raise NotImplementedError\n",
" else:\n",
" for i in indices:\n",
" data_x, data_y = dataset.__getitem__(i)\n",
" prompt.append(copy.copy(sample).format(\n",
" start, img_to_text(data_x), data_y, end, sep\n",
" ))\n",
" shuffle_ix = np.arange(len(indices))\n",
" if shuffle:\n",
" np.random.seed(seed)\n",
" np.random.shuffle(shuffle_ix)\n",
" prompt = [prompt[ix] for ix in shuffle_ix]\n",
" prompt = ''.join(prompt)\n",
"\n",
" return prompt, indices[shuffle_ix]"
],
"metadata": {
"id": "xtebS4E1VJK_"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def get_eval_prompt(dataset, ix, pretty_print=False,\n",
" start='[START]\\n', end='[END]\\n', sep='###\\n'):\n",
" \"\"\"\n",
" Format image we want to classify into a prompt\n",
" \"\"\"\n",
" # Template: start, x\n",
" prompt = '{}Input: {}\\nOutput: '\n",
" data_x, data_y = dataset.__getitem__(ix)\n",
" return prompt.format(start, img_to_text(data_x)), data_y\n",
""
],
"metadata": {
"id": "WeyOlXlMVK-b"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Easy-mode Image Classification\n",
"\n",
"For now, we'll stick to classifying only 3 MNIST digit classes (0s, 1s, and 2s).\n",
"\n",
"(Later we can try doing more, but I don't have the budget to call OpenAI's API all day for inference and evaluation)"
],
"metadata": {
"id": "Hnd4v61rVZZ5"
}
},
{
"cell_type": "code",
"source": [
"# Organize train and test into classes\n",
"num_classes = len(torch.unique(train_set.targets))\n",
"\n",
"# TRAIN\n",
"indices_by_class = get_indices_by_class(train_set)\n",
"\n",
"# TEST\n",
"test_indices_by_class = get_indices_by_class(test_set)"
],
"metadata": {
"id": "mv9OGYaEWLgL"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"class_indices = [0, 1, 2] # These are the classes we'll evaluate"
],
"metadata": {
"id": "OTgwRrO2WNi6"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### Visualizing a formatted prompt \"context\" \n",
"We can now visualize a few-shot prompt, without the sample we want to classify:"
],
"metadata": {
"id": "DiCRGe9UWRPB"
}
},
{
"cell_type": "code",
"source": [
"args.seed_cls = 42\n",
"args.num_samples = 1\n",
"img_dims = {'width': args.width, 'height': args.width}\n",
"classes = range(num_classes)\n",
"\n",
"context_ix, context_y = get_random_indices_by_class(\n",
" indices_by_class,\n",
" class_indices,\n",
" args.num_samples,\n",
" args.seed_cls\n",
")\n",
"\n",
"for i, ix in enumerate(context_ix):\n",
" print(f'Input:')\n",
" pretty_print_ascii(\n",
" img_to_text(train_set_ascii.__getitem__(ix)[0]),\n",
" **img_dims\n",
" )\n",
" print(f'Output: {train_set_ascii.__getitem__(ix)[1]}\\n###\\n')\n",
" # Unrolled\n",
" # print(img_to_text(train_set_ascii.__getitem__(ix)[0]))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "y9gDbEs2WP13",
"outputId": "53321ce9-f443-4b55-912d-015f60ca0b2e"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Input:\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@%+===*@@@@@@@\n",
"@@@@@@@@@%=:`'''`-@@@@@@\n",
"@@@@@@@@+``''```''+@@@@@\n",
"@@@@@@@=''''''''`'-@@@@@\n",
"@@@@@@:```:==-:````+@@@@\n",
"@@@@@@`''`#@@@%-'``=@@@@\n",
"@@@@@#``'`%@@@@%:'':%@@@\n",
"@@@@@:`':#@@@@@@='':%@@@\n",
"@@@@+`'-@@@@@@@@='':%@@@\n",
"@@@@=``=@@@@@@@@-'':%@@@\n",
"@@@@-'`#@@@@@@@+`'`=@@@@\n",
"@@@%:':@@@@@@@%:''`#@@@@\n",
"@@@*``-@@@@@%#-```=@@@@@\n",
"@@@='``-=##=:`'```*@@@@@\n",
"@@@=``''```'''````@@@@@@\n",
"@@@@-'```''```'`-@@@@@@@\n",
"@@@@+:'''''''':=%@@@@@@@\n",
"@@@@%%=======+%@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"Output: 0\n",
"###\n",
"\n",
"Input:\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@#@@@@@@@@@@@@\n",
"@@@@@@@@@@@-%@@@@@@@@@@@\n",
"@@@@@@@@@@@=%@@@@@@@@@@@\n",
"@@@@@@@@@@@=*@@@@@@@@@@@\n",
"@@@@@@@@@@@=+@@@@@@@@@@@\n",
"@@@@@@@@@@@--@@@@@@@@@@@\n",
"@@@@@@@@@@@--@@@@@@@@@@@\n",
"@@@@@@@@@@@=:@@@@@@@@@@@\n",
"@@@@@@@@@@@*`@@@@@@@@@@@\n",
"@@@@@@@@@@@#`@@@@@@@@@@@\n",
"@@@@@@@@@@@%:@@@@@@@@@@@\n",
"@@@@@@@@@@@@-*@@@@@@@@@@\n",
"@@@@@@@@@@@@==@@@@@@@@@@\n",
"@@@@@@@@@@@@=-@@@@@@@@@@\n",
"@@@@@@@@@@@@*:%@@@@@@@@@\n",
"@@@@@@@@@@@@%:#@@@@@@@@@\n",
"@@@@@@@@@@@@@:+@@@@@@@@@\n",
"@@@@@@@@@@@@@**@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"Output: 1\n",
"###\n",
"\n",
"Input:\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@@@@@%*#@@@@@@\n",
"@@@@@@@@@@@@%+::``-*@@@@\n",
"@@@@@@@@@@@=`'''`'`+@@@@\n",
"@@@@@@@@@@+''`:````*@@@@\n",
"@@@@@@@@@@%::**:'':@@@@@\n",
"@@@@@@@@@@@%%@*```#@@@@@\n",
"@@@@@@@@@@@@@*:'':@@@@@@\n",
"@@@@@%%%%%%+=`````-=%@@@\n",
"@@@@*`'''''''`''''''=@@@\n",
"@@@*`'``````''::::::+@@@\n",
"@@@='``````'`+#%%%%%%@@@\n",
"@@@='````''=%@@@@@@@@@@@\n",
"@@@+`````+%@@@@@@@@@@@@@\n",
"@@@@%***%@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"Output: 2\n",
"###\n",
"\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"### Visualizing a full prompt\n",
"\n",
"Below we'll see a prompt that we can actually input to an LLM. For now we linearize the ASCII inputs, but it might be good to try the formatted ASCII too."
],
"metadata": {
"id": "9b9CV5YiWY61"
}
},
{
"cell_type": "code",
"source": [
"context_prompt, shuffled_context_ix = get_context_prompt(\n",
" train_set_ascii, context_ix,\n",
" start='', end='', sep='###\\n', shuffle=True, seed=42\n",
")\n",
"\n",
"eval_prompt, eval_target = get_eval_prompt(\n",
" test_set_ascii, ix=test_indices_by_class[0][0], start=''\n",
")\n",
"# Linearized (what we'll actually send to an LLM)\n",
"print(context_prompt + eval_prompt)\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Qj6YnnAPWYFT",
"outputId": "26da1599-c2c1-4e14-a11a-6b43d4ffb0dc"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Input: @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@%+===*@@@@@@@@@@@@@@@@%=:`'''`-@@@@@@@@@@@@@@+``''```''+@@@@@@@@@@@@=''''''''`'-@@@@@@@@@@@:```:==-:````+@@@@@@@@@@`''`#@@@%-'``=@@@@@@@@@#``'`%@@@@%:'':%@@@@@@@@:`':#@@@@@@='':%@@@@@@@+`'-@@@@@@@@='':%@@@@@@@=``=@@@@@@@@-'':%@@@@@@@-'`#@@@@@@@+`'`=@@@@@@@%:':@@@@@@@%:''`#@@@@@@@*``-@@@@@%#-```=@@@@@@@@='``-=##=:`'```*@@@@@@@@=``''```'''````@@@@@@@@@@-'```''```'`-@@@@@@@@@@@+:'''''''':=%@@@@@@@@@@@%%=======+%@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@\n",
"Output: 0\n",
"###\n",
"Input: @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@#@@@@@@@@@@@@@@@@@@@@@@@-%@@@@@@@@@@@@@@@@@@@@@@=%@@@@@@@@@@@@@@@@@@@@@@=*@@@@@@@@@@@@@@@@@@@@@@=+@@@@@@@@@@@@@@@@@@@@@@--@@@@@@@@@@@@@@@@@@@@@@--@@@@@@@@@@@@@@@@@@@@@@=:@@@@@@@@@@@@@@@@@@@@@@*`@@@@@@@@@@@@@@@@@@@@@@#`@@@@@@@@@@@@@@@@@@@@@@%:@@@@@@@@@@@@@@@@@@@@@@@-*@@@@@@@@@@@@@@@@@@@@@@==@@@@@@@@@@@@@@@@@@@@@@=-@@@@@@@@@@@@@@@@@@@@@@*:%@@@@@@@@@@@@@@@@@@@@@%:#@@@@@@@@@@@@@@@@@@@@@@:+@@@@@@@@@@@@@@@@@@@@@@**@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@\n",
"Output: 1\n",
"###\n",
"Input: @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@%*#@@@@@@@@@@@@@@@@@@%+::``-*@@@@@@@@@@@@@@@=`'''`'`+@@@@@@@@@@@@@@+''`:````*@@@@@@@@@@@@@@%::**:'':@@@@@@@@@@@@@@@@%%@*```#@@@@@@@@@@@@@@@@@@*:'':@@@@@@@@@@@%%%%%%+=`````-=%@@@@@@@*`'''''''`''''''=@@@@@@*`'``````''::::::+@@@@@@='``````'`+#%%%%%%@@@@@@='````''=%@@@@@@@@@@@@@@+`````+%@@@@@@@@@@@@@@@@@%***%@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@\n",
"Output: 2\n",
"###\n",
"Input: @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@#=*@@@@@@@@@@@@@@@@@@@@@:':%@@@@@@@@@@@@@@@@@@@+`':%@@@@@@@@@@@@@@@@@%=````=*@@@@@@@@@@@@@@@@+'``'''`+@@@@@@@@@@@@@@#`'`':=-`'#@@@@@@@@@@@@%-`'`-#@@-'-%@@@@@@@@@@@#''`*%@@@#`':@@@@@@@@@@@#'`*@@@@@@='`@@@@@@@@@@@#'=@@@@@@@='`*@@@@@@@@@@-'=@@@@@@#`'`@@@@@@@@@@@:'=@@@@@*:'`=@@@@@@@@@@@:'=@@@@+`'`:@@@@@@@@@@@@:'=%+--```'#@@@@@@@@@@@@+'```'''``-@@@@@@@@@@@@@%:`'``'`:*@@@@@@@@@@@@@@@%=`'`:-%@@@@@@@@@@@@@@@@@@#+*%@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@\n",
"Output: \n"
]
}
]
},
{
"cell_type": "code",
"source": [
"eval_target # Double check this should be 0"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "D8M-aUoqWnON",
"outputId": "f3a1e852-7def-45d3-8024-605a92f72e31"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0"
]
},
"metadata": {},
"execution_count": 179
}
]
},
{
"cell_type": "code",
"source": [
"# Formatting the above (for sanity checking)\n",
"for i, ix in enumerate(shuffled_context_ix):\n",
" print(f'Input:')\n",
" pretty_print_ascii(\n",
" img_to_text(train_set_ascii.__getitem__(ix)[0]),\n",
" **img_dims\n",
" )\n",
" print(f'Output: {train_set_ascii.__getitem__(ix)[1]}\\n###\\n')\n",
"\n",
"# Eval sample\n",
"print(f'Input:')\n",
"pretty_print_ascii(\n",
" img_to_text(test_set_ascii.__getitem__(test_indices_by_class[0][0])[0]),\n",
" **img_dims\n",
")\n",
"print(f'Output:')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "X22Ad8hjw7Fq",
"outputId": "52ed577c-657c-49ce-bfe2-2252768998c3"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Input:\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@%+===*@@@@@@@\n",
"@@@@@@@@@%=:`'''`-@@@@@@\n",
"@@@@@@@@+``''```''+@@@@@\n",
"@@@@@@@=''''''''`'-@@@@@\n",
"@@@@@@:```:==-:````+@@@@\n",
"@@@@@@`''`#@@@%-'``=@@@@\n",
"@@@@@#``'`%@@@@%:'':%@@@\n",
"@@@@@:`':#@@@@@@='':%@@@\n",
"@@@@+`'-@@@@@@@@='':%@@@\n",
"@@@@=``=@@@@@@@@-'':%@@@\n",
"@@@@-'`#@@@@@@@+`'`=@@@@\n",
"@@@%:':@@@@@@@%:''`#@@@@\n",
"@@@*``-@@@@@%#-```=@@@@@\n",
"@@@='``-=##=:`'```*@@@@@\n",
"@@@=``''```'''````@@@@@@\n",
"@@@@-'```''```'`-@@@@@@@\n",
"@@@@+:'''''''':=%@@@@@@@\n",
"@@@@%%=======+%@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"Output: 0\n",
"###\n",
"\n",
"Input:\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@#@@@@@@@@@@@@\n",
"@@@@@@@@@@@-%@@@@@@@@@@@\n",
"@@@@@@@@@@@=%@@@@@@@@@@@\n",
"@@@@@@@@@@@=*@@@@@@@@@@@\n",
"@@@@@@@@@@@=+@@@@@@@@@@@\n",
"@@@@@@@@@@@--@@@@@@@@@@@\n",
"@@@@@@@@@@@--@@@@@@@@@@@\n",
"@@@@@@@@@@@=:@@@@@@@@@@@\n",
"@@@@@@@@@@@*`@@@@@@@@@@@\n",
"@@@@@@@@@@@#`@@@@@@@@@@@\n",
"@@@@@@@@@@@%:@@@@@@@@@@@\n",
"@@@@@@@@@@@@-*@@@@@@@@@@\n",
"@@@@@@@@@@@@==@@@@@@@@@@\n",
"@@@@@@@@@@@@=-@@@@@@@@@@\n",
"@@@@@@@@@@@@*:%@@@@@@@@@\n",
"@@@@@@@@@@@@%:#@@@@@@@@@\n",
"@@@@@@@@@@@@@:+@@@@@@@@@\n",
"@@@@@@@@@@@@@**@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"Output: 1\n",
"###\n",
"\n",
"Input:\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@@@@@%*#@@@@@@\n",
"@@@@@@@@@@@@%+::``-*@@@@\n",
"@@@@@@@@@@@=`'''`'`+@@@@\n",
"@@@@@@@@@@+''`:````*@@@@\n",
"@@@@@@@@@@%::**:'':@@@@@\n",
"@@@@@@@@@@@%%@*```#@@@@@\n",
"@@@@@@@@@@@@@*:'':@@@@@@\n",
"@@@@@%%%%%%+=`````-=%@@@\n",
"@@@@*`'''''''`''''''=@@@\n",
"@@@*`'``````''::::::+@@@\n",
"@@@='``````'`+#%%%%%%@@@\n",
"@@@='````''=%@@@@@@@@@@@\n",
"@@@+`````+%@@@@@@@@@@@@@\n",
"@@@@%***%@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"Output: 2\n",
"###\n",
"\n",
"Input:\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@#=*@@@@@@@@@@\n",
"@@@@@@@@@@@:':%@@@@@@@@@\n",
"@@@@@@@@@@+`':%@@@@@@@@@\n",
"@@@@@@@@%=````=*@@@@@@@@\n",
"@@@@@@@@+'``'''`+@@@@@@@\n",
"@@@@@@@#`'`':=-`'#@@@@@@\n",
"@@@@@@%-`'`-#@@-'-%@@@@@\n",
"@@@@@@#''`*%@@@#`':@@@@@\n",
"@@@@@@#'`*@@@@@@='`@@@@@\n",
"@@@@@@#'=@@@@@@@='`*@@@@\n",
"@@@@@@-'=@@@@@@#`'`@@@@@\n",
"@@@@@@:'=@@@@@*:'`=@@@@@\n",
"@@@@@@:'=@@@@+`'`:@@@@@@\n",
"@@@@@@:'=%+--```'#@@@@@@\n",
"@@@@@@+'```'''``-@@@@@@@\n",
"@@@@@@%:`'``'`:*@@@@@@@@\n",
"@@@@@@@%=`'`:-%@@@@@@@@@\n",
"@@@@@@@@@#+*%@@@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"@@@@@@@@@@@@@@@@@@@@@@@@\n",
"Output:\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"# GPT-3 Few-shot Classification\n",
"\n",
"Time to actually classify MNIST with a language model."
],
"metadata": {
"id": "AGoCzDdAWp1l"
}
},
{
"cell_type": "code",
"source": [
"import openai"
],
"metadata": {
"id": "LD9VrIw0XI9v"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Some helpers\n",
"def get_gpt_output(prompt, model):\n",
" response = openai.Completion.create(\n",
" prompt=prompt,\n",
" model=model,\n",
" # Rest are defaults\n",
" temperature=0,\n",
" max_tokens=7,\n",
" top_p=1.0,\n",
" n=1,\n",
" )\n",
" return response\n",
"\n",
"def decode_response(response):\n",
" result = response['choices'][0]['text']\n",
" try:\n",
" return int(result)\n",
" except Exception as e:\n",
" print(e)\n",
" return -1"
],
"metadata": {
"id": "UkvSNAVJXRQM"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### Evaluation\n",
"\n",
"For evaluation, we'll pick a couple examples from the MNIST test set. In particular, we'll sample an equal number of datapoints from each class. We'll test the performance of the LLM via overall and class-specific accuracy."
],
"metadata": {
"id": "T64YRAhCXVz1"
}
},
{
"cell_type": "code",
"source": [
"def eval_predictions(predictions, classes):\n",
" \"\"\"\n",
" Report average and group-wise accuracies\n",
" \"\"\"\n",
" correct = np.array(predictions) == np.array(classes)\n",
" class_correct = []\n",
" class_total = []\n",
" for ix, c in enumerate(np.unique(classes)):\n",
" c_ix = np.where(np.array(classes) == c)[0]\n",
" class_correct.append(correct[c_ix].sum())\n",
" class_total.append(len(c_ix))\n",
" avg_acc = np.mean(correct) * 100\n",
" class_accs = np.array(class_correct) / np.array(class_total) * 100\n",
" print(f'Average acc: {avg_acc:.1f}%')\n",
" print(f'Class accs:')\n",
" for ix, c in enumerate(class_accs):\n",
" print(f'- {class_correct[ix]} / {class_total[ix]} = {c:.1f}%')\n",
" return avg_acc, class_accs"
],
"metadata": {
"id": "ALRaEE-FX_9v"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### Few-shot Prompt Template\n",
"To make prompts, we'll use the same context for each sample: an equal number of samples from each group in the MNIST train set."
],
"metadata": {
"id": "x5861FmKX8Qw"
}
},
{
"cell_type": "markdown",
"source": [
"## Actual GPT-3 Demo"
],
"metadata": {
"id": "-5CBGw16YEJQ"
}
},
{
"cell_type": "code",
"source": [
"openai.api_key = \"sk-\" # Your OpenAI API key goes here"
],
"metadata": {
"id": "pHiH1Tj8YDUS"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"models = [\n",
" 'text-ada-001',\n",
" 'text-babbage-001',\n",
" 'text-curie-001',\n",
" 'text-davinci-002'\n",
"]"
],
"metadata": {
"id": "reMB79l6YInh"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# 1) Get samples to serve as \"context\"\n",
"args.seed_train = 42\n",
"args.train_samples_per_class = 4\n",
"img_dims = {'width': args.width, 'height': args.width}\n",
"\n",
"context_ix, context_y = get_random_indices_by_class(\n",
" indices_by_class,\n",
" class_indices,\n",
" args.train_samples_per_class,\n",
" args.seed_train\n",
")"
],
"metadata": {
"id": "z32une07YJyg"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# 2) Get examples to few-shot classify\n",
"args.seed_eval = 42\n",
"args.eval_samples_per_group = 10\n",
"test_eval_ix, test_eval_y = get_random_indices_by_class(\n",
" test_indices_by_class,\n",
" class_indices,\n",
" args.eval_samples_per_group,\n",
" args.seed_eval\n",
")"
],
"metadata": {
"id": "dJ7lQKYFYLLS"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# 3) Get prompt context\n",
"for seed in range(20):\n",
" context_prompt, shuffled_context_ix = get_context_prompt(\n",
" train_set_ascii, context_ix,\n",
" start='', end='', sep='###\\n', shuffle=True,\n",
" seed=seed # May need to change this to shuffle\n",
" )\n",
"\n",
" # Inspect context order (want to avoid patterns between outputs, e.g., 0, 1, 0, 1, ...)\n",
" print(seed, train_set.targets[shuffled_context_ix])"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "goTJZGxaYMsi",
"outputId": "fbfc3326-c66a-40d8-bf4b-8fbf3ee19260"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"0 tensor([1, 2, 1, 2, 0, 2, 0, 1, 2, 0, 0, 1])\n",
"1 tensor([0, 0, 1, 2, 0, 1, 0, 1, 2, 2, 2, 1])\n",
"2 tensor([2, 1, 1, 1, 0, 2, 0, 0, 0, 1, 2, 2])\n",
"3 tensor([1, 1, 0, 0, 2, 1, 1, 0, 0, 2, 2, 2])\n",
"4 tensor([0, 1, 1, 2, 2, 0, 2, 0, 0, 1, 1, 2])\n",
"5 tensor([1, 1, 0, 2, 2, 1, 2, 0, 0, 2, 1, 0])\n",
"6 tensor([2, 2, 0, 1, 0, 1, 1, 0, 1, 0, 2, 2])\n",
"7 tensor([1, 2, 0, 1, 0, 0, 2, 2, 0, 1, 2, 1])\n",
"8 tensor([1, 2, 2, 2, 1, 2, 0, 0, 1, 0, 1, 0])\n",
"9 tensor([1, 2, 0, 0, 0, 1, 0, 2, 2, 2, 1, 1])\n",
"10 tensor([0, 1, 1, 1, 2, 2, 2, 0, 0, 0, 1, 2])\n",
"11 tensor([1, 0, 2, 2, 0, 2, 1, 1, 1, 0, 0, 2])\n",
"12 tensor([2, 1, 2, 2, 0, 1, 1, 0, 0, 0, 1, 2])\n",
"13 tensor([1, 2, 0, 1, 2, 0, 1, 2, 1, 2, 0, 0])\n",
"14 tensor([0, 2, 0, 1, 1, 0, 2, 0, 1, 1, 2, 2])\n",
"15 tensor([2, 1, 0, 0, 2, 0, 1, 0, 1, 2, 1, 2])\n",
"16 tensor([1, 0, 0, 2, 1, 2, 1, 0, 0, 1, 2, 2])\n",
"17 tensor([0, 0, 2, 2, 1, 1, 2, 1, 0, 2, 1, 0])\n",
"18 tensor([1, 1, 2, 0, 2, 0, 0, 1, 1, 2, 0, 2])\n",
"19 tensor([1, 0, 2, 2, 0, 1, 2, 1, 0, 2, 0, 1])\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# 3.1) Get prompt context (cont.)\n",
"context_prompt, shuffled_context_ix = get_context_prompt(\n",
" train_set_ascii, context_ix,\n",
" start='', end='', sep='###\\n', shuffle=True,\n",
" seed=16 # Pick a random-looking one from above\n",
")"
],
"metadata": {
"id": "i3EzD6I3YQbp"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# 4) Get full prompts for each few-shot prompt, and get GPT outputs\n",
"model_name = 'text-davinci-003'\n",
"all_responses = []\n",
"all_targets = []\n",
"for test_ix in test_eval_ix:\n",
" # Calling this line will give OpenAI money\n",
" eval_prompt, target = get_eval_prompt(test_set_ascii, ix=test_ix, start='')\n",
" _prompt = context_prompt + eval_prompt\n",
" response = get_gpt_output(_prompt, model=model_name)\n",
" all_responses.append(response)\n",
" all_targets.append(target)"
],
"metadata": {
"id": "prLD5YMLYZUr"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# 5) Evaluate\n",
"\n",
"# Report accuracy metrics\n",
"avg_acc, class_accs = eval_predictions(\n",
" [decode_response(r) for r in all_responses],\n",
" [test_set.targets[test_ix] for test_ix in test_eval_ix],\n",
")\n",
"\n",
"# Report predicted vs true targets\n",
"print(f'GPT-3 {model_name} few-shot predictions')\n",
"print([decode_response(r) for r in all_responses])\n",
"\n",
"print(f'Ground-truth labels')\n",
"print([test_set.targets[test_ix].item() for test_ix in test_eval_ix])"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "siiYoSPGYm5w",
"outputId": "91fbfbda-5e52-4090-825f-3a0a5ad457b5"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Average acc: 66.7%\n",
"Class accs:\n",
"- 4 / 10 = 40.0%\n",
"- 10 / 10 = 100.0%\n",
"- 6 / 10 = 60.0%\n",
"GPT-3 text-davinci-003 few-shot predictions\n",
"[0, 2, 0, 0, 2, 0, 2, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 3, 2, 2, 0, 0, 2, 2, 2, 2]\n",
"Ground-truth labels\n",
"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"### Results\n",
"So... better than chance! Our toy idea *seems* to work on our toy example (of course not battle-tested across seeds, different prompts, etc...) \n",
"* If you wanted to twist your arms into interpretating the above, seems plausible that the `1`s are more distinctive than the `0`s and `2`s, and hence the higher accuracy (also good precision and recall there).\n",
"\n",
"No further analysis here; just an art project. But some bonuses below!"
],
"metadata": {
"id": "CX8Nmw4x0Zxz"
}
},
{
"cell_type": "markdown",
"source": [
"# Bonus 1: ChatGPT\n",
"\n",
"We can also try the above with ChatGPT. I'm not going to because I don't want to paste all these examples into their text interface."
],
"metadata": {
"id": "Mw0cFIhVY1TW"
}
},
{
"cell_type": "code",
"source": [
"# # Run this block to get prompts to paste for ChatGPT\n",
"# str_classes = ', '.join([f\"'{c}'\" for c in class_indices])\n",
"# str_classes = '[' + str_classes + ']'\n",
"# prompt_prefix = f\"In the text below, what is the most likely next character? Choose from a character in {str_classes}:\"\n",
"# all_prompts = []\n",
"# all_targets = []\n",
"# for test_ix in test_eval_ix:\n",
"# eval_prompt, eval_target = get_eval_prompt(\n",
"# test_set_ascii, ix=test_ix, start=''\n",
"# )\n",
"# _prompt = context_prompt + eval_prompt\n",
"# print('\\n', '-' * 20, f'Sample {test_ix} (target = {eval_target})', '-' * 20)\n",
"# print(prompt_prefix)\n",
"# print(_prompt)\n",
"# all_prompts.append(_prompt)\n",
"# all_targets.append(eval_target)"
],
"metadata": {
"id": "OuG1VxwEZAJm"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Bonus 2: CIFAR-10 classification (CIFAR-3)\n",
"\n",
"So we get better than random accuracy on MNIST. But what about \"real\" images?\n",
"\n",
"Like CIFAR-10? (Yes `channel x height x width` dims of `3 x 32 x 32` is hardly real, but it's progress)\n",
"\n",
"Let's try it out with the exact pipeline above!"
],
"metadata": {
"id": "6URKLSmzmc3N"
}
},
{
"cell_type": "code",
"source": [
"# Load original data\n",
"train_set = torchvision.datasets.CIFAR10(root='.', train=True, transform=transform_base, download=True)\n",
"test_set = torchvision.datasets.CIFAR10(root='.', train=False, transform=transform_base, download=True)\n",
"\n",
"train_set.targets = torch.tensor(train_set.targets) # CIFAR-10 targets are a list by default\n",
"test_set.targets = torch.tensor(test_set.targets)\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "RJQ0mOt2oo7Z",
"outputId": "9bcc8bf8-24d1-4698-ae0d-f4c050b5d1e9"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Files already downloaded and verified\n",
"Files already downloaded and verified\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# Load ASCII data\n",
"train_set_ascii = torchvision.datasets.CIFAR10(root='.', train=True,\n",
" transform=transform_grayscale, download=False)\n",
"test_set_ascii = torchvision.datasets.CIFAR10(root='.', train=False,\n",
" transform=transform_grayscale, download=False)\n",
"\n",
"train_set_ascii.targets = torch.tensor(train_set_ascii.targets)\n",
"test_set_ascii.targets = torch.tensor(test_set_ascii.targets)"
],
"metadata": {
"id": "TW6Xu-2Trxt5"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Also do subsampled evaluation by class\n",
"class_names = ('plane', 'car', 'bird', 'cat',\n",
" 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')"
],
"metadata": {
"id": "8zyO_YyFpf8H"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Organize train and test into classes\n",
"num_classes = len(np.unique(train_set.targets)) # 10\n",
"\n",
"# TRAIN\n",
"indices_by_class = get_indices_by_class(train_set)\n",
"\n",
"# TEST\n",
"test_indices_by_class = get_indices_by_class(test_set)"
],
"metadata": {
"id": "1HE-RWKNp2jq"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"class_indices = [0, 1, 2] # These are the classes we'll evaluate -> planes vs cars vs birds\n",
"print([class_names[i] for i in class_indices])"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "wjndnBc6pliP",
"outputId": "42b9e68a-6cec-44f8-a815-cdaee12ed78f"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"['plane', 'car', 'bird']\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"!pip install matplotlib\n",
"\n",
"import matplotlib.pyplot as plt # so we can visualize the images below"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "wWA2EzbisliH",
"outputId": "88fe462f-35d8-4d1e-879e-e221a04916e6"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Requirement already satisfied: matplotlib in /usr/local/lib/python3.8/dist-packages (3.2.2)\n",
"Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib) (1.4.4)\n",
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.8/dist-packages (from matplotlib) (0.11.0)\n",
"Requirement already satisfied: numpy>=1.11 in /usr/local/lib/python3.8/dist-packages (from matplotlib) (1.21.6)\n",
"Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib) (2.8.2)\n",
"Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.8/dist-packages (from matplotlib) (3.0.9)\n",
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.8/dist-packages (from python-dateutil>=2.1->matplotlib) (1.15.0)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# Visualize original image vs ASCII\n",
"args.seed_cls = 42\n",
"args.num_samples = 1\n",
"img_dims = {'width': args.width, 'height': args.width}\n",
"classes = range(num_classes)\n",
"\n",
"context_ix, context_y = get_random_indices_by_class(\n",
" indices_by_class,\n",
" class_indices,\n",
" args.num_samples,\n",
" args.seed_cls\n",
")\n",
"\n",
"# Show samples\n",
"for i, ix in enumerate(context_ix):\n",
" target = train_set_ascii.__getitem__(ix)[1]\n",
" print('=' * 4, f'Sample {i}, class = {class_names[target]}', '=' * 4)\n",
" print(f'Original:')\n",
" plt.imshow(train_set.__getitem__(ix)[0])\n",
" plt.show()\n",
"\n",
" print(f'\\nASCII: ')\n",
" pretty_print_ascii(\n",
" img_to_text(train_set_ascii.__getitem__(ix)[0]),\n",
" **img_dims\n",
" )\n",
" print('\\n')\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "OOrij5e8uB1L",
"outputId": "37ea0d0c-7752-4f3f-9e48-8f21e25430d2"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"==== Sample 0, class = plane ====\n",
"Original:\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAS4ElEQVR4nO3dXWxd1ZUH8P+fYOcLCLFDTBIjCF9CphIBLL6pGFVUhBfoCyoPVSrQpA8gtVIfipiH8ohG01YIjSqlA2o6KlSVWkQEaKZMVBRVQgETAiQEJnwE7ODEwRbBIQRwsubBh8oEn7Uud99zzp3u/0+KbN/lfc6+23flfqyz96aZQUT+8Z3SdAdEpB5KdpFMKNlFMqFkF8mEkl0kE6fWebK+vj4bHBxsuz3JtmKdkHL8bu5bk9WYf+RKUHTfqrrvo6OjmJqamvcBkZTsJG8B8BCABQD+w8we9H5/cHAQW7ZsKY2fcor/QqO3t7c01tPT47ZNTTivb1G/o3jUt9T2ntQHZUr8xIkTbtvjx4+78WhcPNG5Ux8vMzMzbty7b1HfvDFdv359aazt0SK5AMC/A1gPYAjAnSSH2j2eiFQr5T37VQDeMrN3zOxzAH8AcFtnuiUinZaS7GsAjM75eay47StIbiQ5QnJkcnIy4XQikqLyT+PNbJOZDZvZcH9/f9WnE5ESKcm+H8A5c34eLG4TkS6UkuwvAriI5FqSvQC+D6D8o3YRaVTbpTczmyF5L4D/xmzp7VEz253SmSrLOKnlrRRVl3m89tGYRueO2qeUiaLSWuq5UzRZkky5317bpDq7mT0D4JmUY4hIPXS5rEgmlOwimVCyi2RCyS6SCSW7SCaU7CKZqHU+O+DXs6uc993kNNKobZXxqufSR+Pm1YQXLFjgtq2ylh31O3W+eXT86BoDjzdu3t9bz+wimVCyi2RCyS6SCSW7SCaU7CKZULKLZKKrSm8p0yUjVZagUss0qaW3lDGNpE6BTSkLNjnFNfXxktK3qGyn0puIuJTsIplQsotkQskukgklu0gmlOwimVCyi2Si1jo7SbdGmLobanTulGNXOcU19dxubTVxKmdUL075m1S9rXFK+yqX0AbSpnq3O+Z6ZhfJhJJdJBNKdpFMKNlFMqFkF8mEkl0kE0p2kUx0VZ09pabbzcs1p14/kFLjT62zN7kOQJN19tS58il/05S2Xiwp2UnuAzAN4DiAGTMbTjmeiFSnE8/s/2RmH3bgOCJSIb1nF8lEarIbgL+QfInkxvl+geRGkiMkRyYnJxNPJyLtSk32G8zsCgDrAdxD8tsn/4KZbTKzYTMb7u/vTzydiLQrKdnNbH/xdQLAEwCu6kSnRKTz2k52kktJnv7l9wC+C2BXpzomIp2V8mn8AIAnirreqQAeM7P/ihp5dfYq1yBPmXedeu7UWnWT20VHUurRVa5JD/h9S5lv3sq5o7iXB6lbWZdpO9nN7B0Al7XbXkTqpdKbSCaU7CKZULKLZELJLpIJJbtIJmqf4uqVNKKSQ8q0wEiTU1yrnp5bpSrPHR37+PHjbbeP/iZRSTF6rEbH7+3tLY1FpTXvfmvLZhFRsovkQskukgklu0gmlOwimVCyi2RCyS6SiVrr7IBff4zqpp6qa9Vev1OXiq5yOefUenGkyjp76rUT7dajW4lH4/bFF1+48Q8/LF+j9bTTTnPb9vT0uPEyemYXyYSSXSQTSnaRTCjZRTKhZBfJhJJdJBNKdpFM1F5n91Q5b7vKbZWPHj3qtj1w4IAb//jjj934ihUr3LhX843qvX19fW7cqwcDwAUXXODGvXH77LPP3LZvv/22G4/GZfny5W7cc+qpfmpEj5fovj311FOlsaGhIbft9ddfXxrTfHYRUbKL5ELJLpIJJbtIJpTsIplQsotkQskukona6+wp66+3e1wA+Pzzz914VCs/fPhwaWxqaspt++6777rxqJ68bNkyNz4wMFAaW7x4sdt2enrajUfz4fv7+924V8ePjv3KK6+48ajG7/VtZmbGbRv1LaqzR4+nQ4cOlcb27Nnjtl29enVpzHuch8/sJB8lOUFy15zb+kg+S3Jv8bX9qxdEpBatvIz/LYBbTrrtPgBbzewiAFuLn0Wki4XJbmbbAJz8OvU2AJuL7zcDuL3D/RKRDmv3A7oBMxsvvj8AoPRNI8mNJEdIjkxOTrZ5OhFJlfxpvM2ulli6YqKZbTKzYTMbjj7MEZHqtJvsB0muAoDi60TnuiQiVWg32bcA2FB8vwHAk53pjohUJayzk3wcwE0AVpAcA/BzAA8C+CPJuwG8B+COTnQmqm16tfSojj4x4b/4GBsbc+O7d+8ujUXz1Y8cOeLGo/nsS5cudeNezda7PgAAPv30Uzd+ySWXuPGnn37ajadcO7F///6k+Msvv1wai/YoSHksAnGdfXR0tDQWXXfhrTHgXfMRJruZ3VkS+k7UVkS6hy6XFcmEkl0kE0p2kUwo2UUyoWQXyUTtU1xTtmX2phVG5avnnnvOje/atcuNe2WeY8eOuW2jMk1U5lm4cKEb95YtTi0h7d27142PjIy4cW+6ZlT2i6aRRldkrly5sjQWbZO9ZMkSNz44OOjGI97fZXh42G3rLSX92GOPlcb0zC6SCSW7SCaU7CKZULKLZELJLpIJJbtIJpTsIpmotc4+PT2Nbdu2lcYvvvhit723LHJUZ3/jjTfc+Pbt2924N2Wxt7fXbdvT0+PGo1p3tH2wN4X2jDPOcNtG1wiMj4+78ei6CW/cvK2mgXi76UWLFrnx008/vTT25ptvum2j6xNuueXkNVi/KtqyeefOnaWxNWvWuG3PPPPM0pg3pnpmF8mEkl0kE0p2kUwo2UUyoWQXyYSSXSQTSnaRTNRaZ5+YmMDDDz9cGr/mmmvc9pdffnnb5/Zqk0Bc23z//fdLY1EtOqqjR/XmaF63V+s+77zz3LYpWwsD8ZbP3jUCZ511lts2Gpeolu1dexHNV7/wwgvd+PLl/sbF0dLmZ599dmnMe6xFvDUC9Mwukgklu0gmlOwimVCyi2RCyS6SCSW7SCaU7CKZqLXOfuzYMXcecbQ+urdVbbSt8cDAgBu/66673PiOHTtKYx988IHbtq+vz41HNdlorr23NbFXzwX8Od9AvH1wNOd89erVpbGoVh3Fb775Zje+du3a0lhUw4+uAYjmu0fXAFx55ZWlsWiNAW9ckuazk3yU5ATJXXNue4DkfpI7i3+3RscRkWa18jL+twDmW5bjV2a2rvj3TGe7JSKdFia7mW0DMFVDX0SkQikf0N1L8tXiZX7pmwiSG0mOkByJ3ueISHXaTfZfA7gAwDoA4wB+UfaLZrbJzIbNbDia0CEi1Wkr+8zsoJkdN7MTAH4D4KrOdktEOq2tZCe5as6P3wPg73csIo0L6+wkHwdwE4AVJMcA/BzATSTXATAA+wD8qJWTLVmyBOvWrSuNR+uve3N1ozXEo/24I95+3Ndee63b9txzz3XjExMTbnxsbMyNe3OzL730UrftzMyMG4/m+U9Ntf/ZbTRn/LLLLnPj1113nRv31hGIrg+I4tHjKVonwGsfjblbS3feKofJbmZ3znPzI1E7Eeku+sRMJBNKdpFMKNlFMqFkF8mEkl0kE7VOcV20aBGGhoZK49GSy950zGhr4qisFy2J7JW3Jicn3bbRdMdoO+kofsUVV5TGotJbVLLs7+9349EU2Iceeqg0FpUkzz//fDcejevo6GhpbN++fW7bqCQZbcMdXRruLf8dTe1dtmxZacybLq1ndpFMKNlFMqFkF8mEkl0kE0p2kUwo2UUyoWQXyUStdfYFCxa4Sz57tccoHtXJo1VyopqtV1c9ePBg222BuOYbLUXt1flfeOEFt623pTIQL7kcbensTUuOtrp+/vnn3Xh0Xcbhw4dLY9GYR3XyaHpudH2C93iMHstend09Z1utROT/HSW7SCaU7CKZULKLZELJLpIJJbtIJpTsIpmotc5OEosXL3bjHq/OHs0/jpYG9mqyAHDkyJG2zx2J5uJfffXVbtyrhUdjGo1LtCSy9/cEgBtvvLE0Fq0xEG2FHbX/6KOPSmPR/Y6uP4iurYi2Xfbq8N7aCYC/9Lh3v/TMLpIJJbtIJpTsIplQsotkQskukgklu0gmlOwimai1zg74NelozrlXM45qrtG87Ii3vnrKPHwgbY3xqH107KgOH83F99byB/x6depa/9F9W7lyZWnMm2cPxHXy6PqDTz75xI179fBoLX9vTJPWjSd5Dsm/knyd5G6SPy5u7yP5LMm9xVd/ZXsRaVQrL+NnAPzUzIYAXAPgHpJDAO4DsNXMLgKwtfhZRLpUmOxmNm5mO4rvpwHsAbAGwG0ANhe/thnA7VV1UkTSfaMP6EieB+ByANsBDJjZl4uIHQAwUNJmI8kRkiPR+xgRqU7LyU7yNAB/AvATM/t4bszMDIDN187MNpnZsJkNe4tNiki1Wkp2kj2YTfTfm9mfi5sPklxVxFcBKJ+KIyKNC0tvnK3NPAJgj5n9ck5oC4ANAB4svj4ZHevEiRNuaSDilaCiKYlRGScqMXntoymuUXkrEh3fK0Gl9m32RVv7ca9v3hTUVkTn9kRluygeTUONlv/2pgZHJUnvfnuP41bq7NcD+AGA10juLG67H7NJ/keSdwN4D8AdLRxLRBoSJruZ/Q1A2X//3+lsd0SkKrpcViQTSnaRTCjZRTKhZBfJhJJdJBO1T3GNpmt6vCmwUc01qu+n1pM90fTaaGpv1N6rCUfXD6ROz43iXp0/qmVHUv5mUdto3KL4woUL3bi3RXi0hLY35lpKWkSU7CK5ULKLZELJLpIJJbtIJpTsIplQsotkotY6u5m5ddeU5Z6jumnKMtWR1DnhqVLqyVE8um/ROgKeqEafOm5eHT96PERbOkfXbUxPT7vxlMdbtDZDGT2zi2RCyS6SCSW7SCaU7CKZULKLZELJLpIJJbtIJmqvs3u11ZSab1TvjeYfRzV+Lx7VTFPrySntU2vVqVtde1KvAUi5tiK6X9HjKdrSOeWxXNV1GXpmF8mEkl0kE0p2kUwo2UUyoWQXyYSSXSQTSnaRTLSyP/s5AH4HYACAAdhkZg+RfADAPwM4VPzq/Wb2THAst/aZMuc8qnWnzjlPmRuduj97xDt+6j7kUY0/Ze331Ln00bh7x4/6HZ07ZS3/6PhRjd9bk949rnvUWTMAfmpmO0ieDuAlks8WsV+Z2b+1cAwRaVgr+7OPAxgvvp8muQfAmqo7JiKd9Y3es5M8D8DlALYXN91L8lWSj5JcXtJmI8kRkiNHjx5N6qyItK/lZCd5GoA/AfiJmX0M4NcALgCwDrPP/L+Yr52ZbTKzYTMbXrJkSQe6LCLtaCnZSfZgNtF/b2Z/BgAzO2hmx83sBIDfALiqum6KSKow2Tn78d4jAPaY2S/n3L5qzq99D8CuzndPRDqllU/jrwfwAwCvkdxZ3HY/gDtJrsNsOW4fgB+1ckKvNBCVUlKkLjVdZektZRvrVo7vSSk5ttK+qrZAWrk1dcyj8lgU9+57VXnQyqfxfwMw36i5NXUR6S66gk4kE0p2kUwo2UUyoWQXyYSSXSQTSnaRTNS6lDTg1xCjrWhTtgeOapdR3Dt3So2+FSn16GgJ7ajvVV4DEB079dwp06lTjg3EU2C9rcuj+93uVG89s4tkQskukgklu0gmlOwimVCyi2RCyS6SCSW7SCZY1faw856MPATgvTk3rQDwYW0d+Ga6tW/d2i9AfWtXJ/t2rpmdNV+g1mT/2snJETMbbqwDjm7tW7f2C1Df2lVX3/QyXiQTSnaRTDSd7JsaPr+nW/vWrf0C1Ld21dK3Rt+zi0h9mn5mF5GaKNlFMtFIspO8heSbJN8ieV8TfShDch/J10juJDnScF8eJTlBctec2/pIPktyb/F13j32GurbAyT3F2O3k+StDfXtHJJ/Jfk6yd0kf1zc3ujYOf2qZdxqf89OcgGA/wVwM4AxAC8CuNPMXq+1IyVI7gMwbGaNX4BB8tsAjgD4nZl9q7jtXwFMmdmDxX+Uy83sZ13StwcAHGl6G+9it6JVc7cZB3A7gB+iwbFz+nUHahi3Jp7ZrwLwlpm9Y2afA/gDgNsa6EfXM7NtAKZOuvk2AJuL7zdj9sFSu5K+dQUzGzezHcX30wC+3Ga80bFz+lWLJpJ9DYDROT+Pobv2ezcAfyH5EsmNTXdmHgNmNl58fwDAQJOdmUe4jXedTtpmvGvGrp3tz1PpA7qvu8HMrgCwHsA9xcvVrmSz78G6qXba0jbedZlnm/G/a3Ls2t3+PFUTyb4fwDlzfh4sbusKZra/+DoB4Al031bUB7/cQbf4OtFwf/6um7bxnm+bcXTB2DW5/XkTyf4igItIriXZC+D7ALY00I+vIbm0+OAEJJcC+C66byvqLQA2FN9vAPBkg335im7Zxrtsm3E0PHaNb39uZrX/A3ArZj+RfxvAvzTRh5J+nQ/gleLf7qb7BuBxzL6s+wKzn23cDaAfwFYAewH8D4C+LurbfwJ4DcCrmE2sVQ317QbMvkR/FcDO4t+tTY+d069axk2Xy4pkQh/QiWRCyS6SCSW7SCaU7CKZULKLZELJLpIJJbtIJv4Pw1VAg23eYPsAAAAASUVORK5CYII=\n"
},
"metadata": {
"needs_background": "light"
}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"ASCII: \n",
"::``````````````````````\n",
"::``````````````````````\n",
"::``````````````````````\n",
"````````````````````````\n",
"````````````````````````\n",
"````````````````````````\n",
"````````````````````````\n",
"`````````````````````::`\n",
"``````:-=::::`':-```'-=`\n",
"````:-=+*+++=::-=::--=*=\n",
"```-+*###*****####*****-\n",
"-:`=#%%%%###%%@@%#%#=*=:\n",
"#+-=*#@@%@%%@@@%%*#+-=+-\n",
"%#*+-=+=+@%#@@##=--:----\n",
"**++==-:+#+==*#+::---=--\n",
"+++=====#%*+=*%===**+++=\n",
"++++++*+*****##***#**++=\n",
"+++++++++++++++++++****+\n",
"+++++=+++++++++++++*++++\n",
"++++++++++++++++++++++++\n",
"++++++++++++++++++++****\n",
"++++++=+++++++++++++***+\n",
"++++++==++++++++++++++++\n",
"+++++===++++++++++++++++\n",
"\n",
"\n",
"==== Sample 1, class = car ====\n",
"Original:\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {
"needs_background": "light"
}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"ASCII: \n",
"#%@%+*@@%%%%%%%%@%######\n",
"#@@**#@@@@@@@@%*%@%#%##*\n",
"#%#**%@@@@@@@@@#*%%###%#\n",
"#%*##%%%####%%%%*######*\n",
"+##*#%####%%##%%%####%*+\n",
"=+*+=***##%##%%%%%#==*+*\n",
"==+*+****######%###-``:-\n",
"-=+******%%%%%%%%%%#:'``\n",
":=*##*###%%%%%%##%##=```\n",
":+#**#*###***+++*###*:'`\n",
"-+***#*#*-+++++*####%+``\n",
":=*#####=:+*+==++*#%%*``\n",
"`:+##%###****=-=+#%##*:`\n",
"`:=*#%###*******#####*-`\n",
"`-*##%########%%%%%%##=`\n",
"`=#%%%%%#%%%%#%%%%%%#%=`\n",
"`-%%%%%%%%%%%%%@@%%#%%=`\n",
"`-#@%%%####%%%%%**%%%%-`\n",
"`:#@@@%%%*+###*#*#%@%#:`\n",
"`:*@@@@@@%%%####%%%@%#-`\n",
"`:*%%@@@@@%%#####@@%%#:`\n",
"::*%%%@@@@@#####%@%%#=``\n",
"::=%%%%%%%%##*##%%%#+:``\n",
"::-+%@@%%%%%%%%%%%*+-:``\n",
"\n",
"\n",
"==== Sample 2, class = bird ====\n",
"Original:\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {
"needs_background": "light"
}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"ASCII: \n",
"########################\n",
"#####################**#\n",
"#####################**#\n",
"#############***######*#\n",
"#############*++########\n",
"######****####*+########\n",
"#####**++****##**###****\n",
"####***+****+##**###+*#*\n",
"####**+*****++***###+*#*\n",
"#####*+**********##*+*##\n",
"######***********##**###\n",
"######*+**++=++**#%####*\n",
"######*+**+==++++*#####*\n",
"#######*+===+******####*\n",
"########*+**+#%%#*+*##**\n",
"##########%#*%@##*++*#**\n",
"##########**%@%**#***##*\n",
"#########++*%%***##**##*\n",
"########++*###**####**##\n",
"#######*++*###**#*##**##\n",
"######*+*#####**#*#**##*\n",
"#####*+*#####***##***#**\n",
"####*+*#######*##*******\n",
"###*=*########**********\n",
"\n",
"\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# 1) Get samples to serve as \"context\"\n",
"args.seed_train = 42\n",
"args.train_samples_per_class = 4\n",
"img_dims = {'width': args.width, 'height': args.width}\n",
"\n",
"context_ix, context_y = get_random_indices_by_class(\n",
" indices_by_class,\n",
" class_indices,\n",
" args.train_samples_per_class,\n",
" args.seed_train\n",
")"
],
"metadata": {
"id": "9Qrhqhy0qRbV"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# 2) Get examples to few-shot classify\n",
"args.seed_eval = 42\n",
"args.eval_samples_per_group = 10\n",
"test_eval_ix, test_eval_y = get_random_indices_by_class(\n",
" test_indices_by_class,\n",
" class_indices,\n",
" args.eval_samples_per_group,\n",
" args.seed_eval\n",
")"
],
"metadata": {
"id": "QWW1ab8atP5Z"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# 3) Get prompt context\n",
"for seed in range(20):\n",
" context_prompt, shuffled_context_ix = get_context_prompt(\n",
" train_set_ascii, context_ix,\n",
" start='', end='', sep='###\\n', shuffle=True,\n",
" seed=seed # May need to change this to shuffle\n",
" )\n",
"\n",
" # Inspect context order (want to avoid patterns between outputs, e.g., 0, 1, 0, 1, ...)\n",
" print(seed, train_set.targets[shuffled_context_ix])"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "V8ZfWFBMtS-E",
"outputId": "ae34e695-dae9-4055-bbc6-c4844ad79ec2"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"0 tensor([1, 2, 1, 2, 0, 2, 0, 1, 2, 0, 0, 1])\n",
"1 tensor([0, 0, 1, 2, 0, 1, 0, 1, 2, 2, 2, 1])\n",
"2 tensor([2, 1, 1, 1, 0, 2, 0, 0, 0, 1, 2, 2])\n",
"3 tensor([1, 1, 0, 0, 2, 1, 1, 0, 0, 2, 2, 2])\n",
"4 tensor([0, 1, 1, 2, 2, 0, 2, 0, 0, 1, 1, 2])\n",
"5 tensor([1, 1, 0, 2, 2, 1, 2, 0, 0, 2, 1, 0])\n",
"6 tensor([2, 2, 0, 1, 0, 1, 1, 0, 1, 0, 2, 2])\n",
"7 tensor([1, 2, 0, 1, 0, 0, 2, 2, 0, 1, 2, 1])\n",
"8 tensor([1, 2, 2, 2, 1, 2, 0, 0, 1, 0, 1, 0])\n",
"9 tensor([1, 2, 0, 0, 0, 1, 0, 2, 2, 2, 1, 1])\n",
"10 tensor([0, 1, 1, 1, 2, 2, 2, 0, 0, 0, 1, 2])\n",
"11 tensor([1, 0, 2, 2, 0, 2, 1, 1, 1, 0, 0, 2])\n",
"12 tensor([2, 1, 2, 2, 0, 1, 1, 0, 0, 0, 1, 2])\n",
"13 tensor([1, 2, 0, 1, 2, 0, 1, 2, 1, 2, 0, 0])\n",
"14 tensor([0, 2, 0, 1, 1, 0, 2, 0, 1, 1, 2, 2])\n",
"15 tensor([2, 1, 0, 0, 2, 0, 1, 0, 1, 2, 1, 2])\n",
"16 tensor([1, 0, 0, 2, 1, 2, 1, 0, 0, 1, 2, 2])\n",
"17 tensor([0, 0, 2, 2, 1, 1, 2, 1, 0, 2, 1, 0])\n",
"18 tensor([1, 1, 2, 0, 2, 0, 0, 1, 1, 2, 0, 2])\n",
"19 tensor([1, 0, 2, 2, 0, 1, 2, 1, 0, 2, 0, 1])\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# 3.1) Get prompt context (cont.)\n",
"context_prompt, shuffled_context_ix = get_context_prompt(\n",
" train_set_ascii, context_ix,\n",
" start='', end='', sep='###\\n', shuffle=True,\n",
" seed=7 # Pick a random-looking one from above\n",
")"
],
"metadata": {
"id": "_QXJgEJQtVCO"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# 3.2 Visualize a prompt\n",
"# Formatting the above (for sanity checking)\n",
"for i, ix in enumerate(shuffled_context_ix):\n",
" print(f'Input:')\n",
" pretty_print_ascii(\n",
" img_to_text(train_set_ascii.__getitem__(ix)[0]),\n",
" **img_dims\n",
" )\n",
" print(f'Output: {train_set_ascii.__getitem__(ix)[1]}\\n###\\n')\n",
"\n",
"# Eval sample\n",
"print(f'Input:')\n",
"pretty_print_ascii(\n",
" img_to_text(test_set_ascii.__getitem__(test_indices_by_class[1][0])[0]),\n",
" **img_dims\n",
")\n",
"print(f'Output:')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "R4DOPULMyKKa",
"outputId": "235beff4-ccb7-482f-cfb7-5d66028d76d1"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Input:\n",
"+=+##%#%@@@%@@%@@@@%#=--\n",
"###%%%#%@@@@@%%%%%%##=``\n",
"%%%@@@#%@%@@%%%%%%#**-'`\n",
"%%%@@%#%%%%%%%%%%%#++-'`\n",
"%%%%%%#%%%%%%%%%%%#**-'`\n",
"%%%%%%##%#*#%%%%%%#**:'`\n",
"%@%%#*+***+**#%@%%%%%-'`\n",
"@@%*+=-*###*##+#@%%%%='`\n",
"@%#++=:*@%%%#%*=%@@##='`\n",
"%#+===+*%@@%%%#++%@%*='`\n",
"#*=-=++*%%%%%%#*++%##='`\n",
"#+==+**#%%###*****###='`\n",
"#=-=++*#*+--++++*+*+*-'`\n",
"*=++=+++++=+**++++++=-``\n",
"#***++++++**####****++-`\n",
"***+****+***+==+=*###+==\n",
"+=+==+***%#*=::-=#%##++=\n",
"=-+===++****+===+***++*=\n",
":-+==+**#=+##****=:``=#=\n",
"-+##***#*==*#*###+-=+*#=\n",
":=+#%%%#*-=*#**#%%%%%%#=\n",
":::-*%@@%--*####%%%%%@*-\n",
"-----=+*%*+%@@@@@@@@@%+-\n",
"---==--=+*##%%%%##**++==\n",
"Output: 1\n",
"###\n",
"\n",
"Input:\n",
"`=+=:````'`''```'''''`=*\n",
"-+**=:````````````````-+\n",
"+###*-'`:-:::`````````=#\n",
"*###+-::+*==-`:-`'''':*#\n",
"####*+++*#+**++*-```:=#%\n",
"##*##*###%##*+*+=---==#%\n",
"#######%%###*=:``:::::*%\n",
"******+**++++=:``:````+*\n",
"==-=======--=----+=-::--\n",
"=====---===----=====----\n",
"==============--====----\n",
"=======++***+=--=--=----\n",
"+====+*******+=-======--\n",
"++==+####*****+========-\n",
"+*+=*%%%%##*****========\n",
"**+=###%%%%####*+=======\n",
"#*=+*++%%%%%###*+=======\n",
"*+==++*##%#####*+=======\n",
"=====++++*++**+=========\n",
"+++==+++++++============\n",
"++++++++++++++++++++++++\n",
"+++++++++*+++++**++*++++\n",
"++++++++**+++++++++*++++\n",
"++++++++**+++++++*+++++=\n",
"Output: 2\n",
"###\n",
"\n",
"Input:\n",
"********+++++++++++*****\n",
"**+++++++++++++++*+++++*\n",
"++++++++++++++++*#++++++\n",
"+++++++++++*++++##++++++\n",
"++++++++++#%+=++##++++++\n",
"++++=====+%%+==+%*++=+++\n",
"=========#%#===*%#+=====\n",
"========+%%#**##%*======\n",
"======--*%%%%%%%%*-=====\n",
"-------=%%%%%%%%%*------\n",
"-------*%%%%%%%%%*------\n",
"-----:=%%%%%%%%%%+=++**+\n",
"::::::*%%%%####%%##%%##=\n",
":::::=%%%%%####%%%%#*=-:\n",
":::::*#%%%#####%##+-::::\n",
":::`=**%%###%%#*=:```:::\n",
"```:**#%####*+-````````:\n",
"```=#%%%##+=:```````````\n",
"``:#%%%*+:``````````````\n",
"``=%#+-:````````````````\n",
"`:+=:```````````````````\n",
"`::`````````````````````\n",
"````````````````````````\n",
"````````````````````````\n",
"Output: 0\n",
"###\n",
"\n",
"Input:\n",
"--===--------------==+++\n",
"===++==----------===+*#-\n",
"++:++==============++**:\n",
"+=:=+==============++**:\n",
"+=`=+===============+**:\n",
"+=`-=====-===-==--==+**:\n",
"+=`-=----:----------=+*:\n",
"*+`---:-----:--:::---++:\n",
"*+`:--=---------:::--=+:\n",
"*+:::*#+=++===+--:::-=+:\n",
"=::`=@@%%%#####+-=-:-=+:\n",
"=```:=====++***+==:`:=+:\n",
"*:::`::::::::::=:`````::\n",
"*---:------------::```::\n",
"=--+#+------:-----+#=:-:\n",
"--=%#%+----------=%##=-:\n",
"::+#+**--:-:::::-**=#=::\n",
"--+*=+#--:--:-:::#+-*+::\n",
"*#%*=*%+++++++==+%*=*#=:\n",
"%@@%+#@@@@@@@@@@@@#+%@#:\n",
"+**#**##*###**#**#***+=:\n",
":::::::::::::::::::::```\n",
"::::::::::::::::::``:::`\n",
":-:::---:::::-::::::::::\n",
"Output: 1\n",
"###\n",
"\n",
"Input:\n",
"::``````````````````````\n",
"::``````````````````````\n",
"::``````````````````````\n",
"````````````````````````\n",
"````````````````````````\n",
"````````````````````````\n",
"````````````````````````\n",
"`````````````````````::`\n",
"``````:-=::::`':-```'-=`\n",
"````:-=+*+++=::-=::--=*=\n",
"```-+*###*****####*****-\n",
"-:`=#%%%%###%%@@%#%#=*=:\n",
"#+-=*#@@%@%%@@@%%*#+-=+-\n",
"%#*+-=+=+@%#@@##=--:----\n",
"**++==-:+#+==*#+::---=--\n",
"+++=====#%*+=*%===**+++=\n",
"++++++*+*****##***#**++=\n",
"+++++++++++++++++++****+\n",
"+++++=+++++++++++++*++++\n",
"++++++++++++++++++++++++\n",
"++++++++++++++++++++****\n",
"++++++=+++++++++++++***+\n",
"++++++==++++++++++++++++\n",
"+++++===++++++++++++++++\n",
"Output: 0\n",
"###\n",
"\n",
"Input:\n",
"```````````````````````:\n",
"```````````````````````:\n",
"`````````````::::`::::::\n",
"::``:++*-```:::::`::::::\n",
"::-=+#%%*=-::::::```````\n",
"--+##%%%%##++---+=-==-::\n",
"###%%%#%#%%##*#+##*##++*\n",
"%%%%%%%%%%%#=+#*+****#%%\n",
"#%%%%%%%%%#+:=#*+++*#%#*\n",
"*#%%%%%%%%+--=++****%@%#\n",
"+=#@%%%#*==+++**###%%@@%\n",
"++#%%#*=--=++*##%##%@@@%\n",
"#**##=+==++#%%%####%@@@%\n",
"*##%%=+++*%@@@@@@%%%%@@@\n",
"#%@@@**#*#%@@@@@%##**###\n",
"%@@@@##@%%%%%%#***#*****\n",
"%%%%####**********#%****\n",
"##%#**********##%##%####\n",
"#############%%@@@@@%%%%\n",
"************##%%%%%#####\n",
"************************\n",
"************************\n",
"****++++++++*****+**++++\n",
"**************++***+++++\n",
"Output: 0\n",
"###\n",
"\n",
"Input:\n",
"###**********###########\n",
"#***********############\n",
"********#***############\n",
"**#######****###########\n",
"**######*++===+*###*****\n",
"###**+==--==-===*##*****\n",
"*++===----=--====*#***##\n",
"====+==-=====+++++**####\n",
"+*++++===+++*#*+****####\n",
"+**#*++==++=###****#####\n",
"==+**++==+++*****++*####\n",
"++++**+=++*+***#**+*####\n",
"++**++*+++*+*****######*\n",
"+++**+++++*##+****######\n",
"++++****++**++*#########\n",
"+++*+**##**++*#%########\n",
"**+++***#**+**#%########\n",
"***+**********##########\n",
"*******###****##*#******\n",
"*****######***##********\n",
"*****####%#**##*********\n",
"**+***####+*#%#*********\n",
"*****#####+*#%#*********\n",
"#***###%#**###**********\n",
"Output: 2\n",
"###\n",
"\n",
"Input:\n",
"########################\n",
"#####################**#\n",
"#####################**#\n",
"#############***######*#\n",
"#############*++########\n",
"######****####*+########\n",
"#####**++****##**###****\n",
"####***+****+##**###+*#*\n",
"####**+*****++***###+*#*\n",
"#####*+**********##*+*##\n",
"######***********##**###\n",
"######*+**++=++**#%####*\n",
"######*+**+==++++*#####*\n",
"#######*+===+******####*\n",
"########*+**+#%%#*+*##**\n",
"##########%#*%@##*++*#**\n",
"##########**%@%**#***##*\n",
"#########++*%%***##**##*\n",
"########++*###**####**##\n",
"#######*++*###**#*##**##\n",
"######*+*#####**#*#**##*\n",
"#####*+*#####***##***#**\n",
"####*+*#######*##*******\n",
"###*=*########**********\n",
"Output: 2\n",
"###\n",
"\n",
"Input:\n",
"%%%%%%%%%%%%%%%%%%%%#%%#\n",
"%%%%%%%%%%%%%%%%%%%%%%%%\n",
"%%%%%%%%%%%%%%%%%%%%%%%%\n",
"%%%%%%%%%%%%%%%%%%%%%%%%\n",
"%%%%%%%%%%%%%%%%%%%%%%%%\n",
"%%%%%%%%%%%%%%%%%%%%%%%%\n",
"%%%%%%%%%%%%%%%%%%%%%%%%\n",
"%%%%%##%%%%%%%%%%%%*#%%%\n",
"%%%%%#=-*#%%%%%%%%+-*%%%\n",
"%#+++++==+***##**=:=#%%%\n",
"%%*+=--==*#*+=-::``-#%%#\n",
"%###*+++==**++====-=*+=-\n",
"%+=--=++++*++==----=-:::\n",
"%%#+----::::=+-:::::--::\n",
"%##***++-::-=*+==-======\n",
"%#+=======++**+====++***\n",
"%#*+++++++*****+++=+*###\n",
"%%#********####***+*####\n",
"%%##*****#######****####\n",
"%%##########%%####*##%%%\n",
"%%%#######%%%%#######%%%\n",
"%%%%%####%%%%%%#####%%%%\n",
"%%%%%%%%%%%%%%%%####%%%%\n",
"%%%%%%%%%%%%%%%%%###%%%%\n",
"Output: 0\n",
"###\n",
"\n",
"Input:\n",
"`::::::::::`:`::::::::::\n",
":=+==++++++=---=====+===\n",
":+++=+++===--:-===++====\n",
":=======----::--=====+==\n",
":------=+****+===--=====\n",
"::::::--**##**+==------=\n",
"::::---:+####**=---:---:\n",
":-:-+**+*####***=--:::::\n",
":-+*++**#**+==--::```:::\n",
"::+###**#+*++==---::````\n",
"::+######++*++******+:``\n",
"::*#######*+-=+*%%##*-``\n",
"::=#%%##%%#****##***+=:`\n",
":`-#%%%%%#######*===++-`\n",
":`-%%#%%%**%%##%%***##+:\n",
":`-%*=+*#**%%##%@%%%%#+`\n",
":``-:::-=*#%%%%#*++*%#:`\n",
":`'''```:+%%+===---=#+`'\n",
":`'''''''-++----:::::``'\n",
":`''''''``````````````''\n",
"::``````````````````````\n",
"`::::::::::::::::::::::`\n",
"`::----:````````````````\n",
"`:--==-:`'''''''''''''''\n",
"Output: 1\n",
"###\n",
"\n",
"Input:\n",
"-==-------------::--::::\n",
"==----------::--::-:::::\n",
"=====-------::::::-:::::\n",
"======------::::--::::::\n",
"=====-------:::=*=::::::\n",
"======--------:#%=::::::\n",
"======--------=%%-::::::\n",
"-=========----#%+-::::::\n",
"-===========-=%%=---::::\n",
"-===========-+@%=---::::\n",
"-============*@%=---::::\n",
"---------====#@#----::::\n",
"::::---::---+#@*--------\n",
"::::::::::--#%%=--------\n",
":::::::::::=@%*+=-------\n",
"::::::::::-%@@#=-=------\n",
"::::::::::*@@#*+==------\n",
":::::::::+@@@==+-------:\n",
"::--::::-#@@*---------::\n",
"::--::::+%#*-:::::----::\n",
":--:::::*@*-::::::------\n",
"--:::::-*#=::::::::-:---\n",
"::::::::++:::::::::::---\n",
"::::::::::::::::::::----\n",
"Output: 2\n",
"###\n",
"\n",
"Input:\n",
"#%@%+*@@%%%%%%%%@%######\n",
"#@@**#@@@@@@@@%*%@%#%##*\n",
"#%#**%@@@@@@@@@#*%%###%#\n",
"#%*##%%%####%%%%*######*\n",
"+##*#%####%%##%%%####%*+\n",
"=+*+=***##%##%%%%%#==*+*\n",
"==+*+****######%###-``:-\n",
"-=+******%%%%%%%%%%#:'``\n",
":=*##*###%%%%%%##%##=```\n",
":+#**#*###***+++*###*:'`\n",
"-+***#*#*-+++++*####%+``\n",
":=*#####=:+*+==++*#%%*``\n",
"`:+##%###****=-=+#%##*:`\n",
"`:=*#%###*******#####*-`\n",
"`-*##%########%%%%%%##=`\n",
"`=#%%%%%#%%%%#%%%%%%#%=`\n",
"`-%%%%%%%%%%%%%@@%%#%%=`\n",
"`-#@%%%####%%%%%**%%%%-`\n",
"`:#@@@%%%*+###*#*#%@%#:`\n",
"`:*@@@@@@%%%####%%%@%#-`\n",
"`:*%%@@@@@%%#####@@%%#:`\n",
"::*%%%@@@@@#####%@%%#=``\n",
"::=%%%%%%%%##*##%%%#+:``\n",
"::-+%@@%%%%%%%%%%%*+-:``\n",
"Output: 1\n",
"###\n",
"\n",
"Input:\n",
"#***++++****#**###%%%%%%\n",
"++*******#*########%@%%%\n",
"*****##########%#%%@@%%%\n",
"***#############*#%@%#%%\n",
"######%%########*#%%##%%\n",
"%%%%%%%#########%@%##%%%\n",
"%%%%%#####**####%@%%%%%%\n",
"%%%%%#####***++*#%##%%%%\n",
"%%%#%####*=-----=%%##%%%\n",
"%#%#####+++**#%=`+@@#%%%\n",
"####%%#*###@@@%=:-#@%%%%\n",
"###########%%*=:::*%@%#%\n",
"#####**####=:```:+%%@*=#\n",
"#%##%##%%%*``::`-#%#*==#\n",
"%%+=#%%@@%+:-::::-+====#\n",
"%#:-#%%@@%+---+#+-:-==*%\n",
"%**+*%@@%#+--+%%@*-=+*@@\n",
"%#%#*%@@#*=:-*%%%%++*@@%\n",
"%%#**%@%#*---#%%%@*=#@@%\n",
"%%##*%#*==---#%*#@%*%@@@\n",
"%%##*##*+++++#%##@@%@@@@\n",
"%%%%@@@@@@@@@@%#%@@@@@@@\n",
"%%%%%%%@%%@@@@@%%@@@@@@%\n",
"%%#**#%%####%%%%@@@@@@%%\n",
"Output:\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# 4) Get full prompts for each few-shot prompt, and get GPT outputs\n",
"model_name = 'text-davinci-003'\n",
"all_responses = []\n",
"all_targets = []\n",
"for test_ix in test_eval_ix:\n",
" # Calling this line will give OpenAI money\n",
" eval_prompt, target = get_eval_prompt(test_set_ascii, ix=test_ix, start='')\n",
" _prompt = context_prompt + eval_prompt\n",
" response = get_gpt_output(_prompt, model=model_name)\n",
" all_responses.append(response)\n",
" all_targets.append(target)"
],
"metadata": {
"id": "-BxHv8bZtgAW"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# 5) Evaluate\n",
"\n",
"# Report accuracy metrics\n",
"avg_acc, class_accs = eval_predictions(\n",
" [decode_response(r) for r in all_responses],\n",
" [test_set.targets[test_ix] for test_ix in test_eval_ix],\n",
")\n",
"print('')\n",
"\n",
"# Report predicted vs true targets\n",
"print(f'GPT-3 {model_name} few-shot predictions')\n",
"print([decode_response(r) for r in all_responses])\n",
"print('')\n",
"print(f'Ground-truth labels')\n",
"print([test_set.targets[test_ix].item() for test_ix in test_eval_ix])"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "_J3VoNmIuZXH",
"outputId": "40a094f4-4b9a-4063-eb1d-2154688f4100"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Average acc: 50.0%\n",
"Class accs:\n",
"- 7 / 10 = 70.0%\n",
"- 0 / 10 = 0.0%\n",
"- 8 / 10 = 80.0%\n",
"\n",
"GPT-3 text-davinci-003 few-shot predictions\n",
"[0, 2, 2, 0, 0, 0, 0, 0, 2, 0, 2, 0, 2, 0, 2, 2, 2, 2, 2, 0, 2, 2, 2, 0, 0, 2, 2, 2, 2, 2]\n",
"\n",
"Ground-truth labels\n",
"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"### Results: CIFAR-10\n",
"So still better than chance in aggregate... but questionable if it actually works. Curious how no labels for cars (`1`'s get outputted), even though as a prior I'd suspect planes and birds to be the confusing distinction.\n",
"\n",
"* There's probably also brittleness and a lot to do with the sample selection and prompting.\n",
"\n",
"Lots of battle-testing to do!"
],
"metadata": {
"id": "QaLxKRhrzuVR"
}
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment