Created
July 15, 2023 12:40
-
-
Save radekosmulski/00dc4b4915ea38e79c1c647c90d757c3 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "b284d8a4", | |
"metadata": {}, | |
"source": [ | |
"## F1 score on bigrams" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "4dff0ed2", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import re\n", | |
"import string\n", | |
"from nltk.util import bigrams\n", | |
"from collections import Counter\n", | |
"\n", | |
"regex_punctuation = re.compile('[%s]' % re.escape(string.punctuation))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "d3d7f995", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"gt = \"Vanilla is the best ice cream flavor in the world.\"\n", | |
"pred = \"Vanilla.\"" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "c225e095", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def string_to_tokens(s):\n", | |
" '''\n", | |
" This is a very basic way of tokenizing your text.\n", | |
" Probably not useful for real life datasets.\n", | |
" '''\n", | |
" return ['<bos>'] + re.sub(regex_punctuation, '' , s.lower()).split() + ['<eos>']" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "1f0bb882", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"['<bos>', 'vanilla', 'is', 'the', 'best', 'ice', 'cream', 'flavor', 'in', 'the', 'world', '<eos>'] ['<bos>', 'vanilla', '<eos>']\n" | |
] | |
} | |
], | |
"source": [ | |
"gt_words = string_to_tokens(gt)\n", | |
"pred_words = string_to_tokens(pred)\n", | |
"print(gt_words, pred_words)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "27f1721c", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[('<bos>', 'vanilla'), ('vanilla', 'is'), ('is', 'the'), ('the', 'best'), ('best', 'ice'), ('ice', 'cream'), ('cream', 'flavor'), ('flavor', 'in'), ('in', 'the'), ('the', 'world'), ('world', '<eos>')] [('<bos>', 'vanilla'), ('vanilla', '<eos>')]\n" | |
] | |
} | |
], | |
"source": [ | |
"gt_bigrams = list(bigrams(gt_words)) \n", | |
"pred_bigrams = list(bigrams(pred_words)) \n", | |
"print(gt_bigrams, pred_bigrams)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "ff6c9cff", | |
"metadata": {}, | |
"source": [ | |
"" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "eda31f54", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def f1(pred_bigrams, gt_bigrams):\n", | |
" shared_ngrams = Counter(pred_bigrams) & Counter(gt_bigrams)\n", | |
" num_same = sum(shared_ngrams.values())\n", | |
" \n", | |
" if num_same == 0: return 0\n", | |
" precision = 1.0 * num_same / len(pred_bigrams)\n", | |
" recall = 1.0 * num_same / len(gt_bigrams)\n", | |
" f1 = (2 * precision * recall) / (precision + recall)\n", | |
" return f1" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "8613bb0e", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.15384615384615385" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"f1(pred_bigrams, gt_bigrams)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "056863bd", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.8181818181818182" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"pred_bigrams = list(bigrams(string_to_tokens('Chocolate is the best ice cream flavor in the world.')))\n", | |
"f1(pred_bigrams, gt_bigrams)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "8462c989", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.8695652173913043" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"pred_bigrams = list(bigrams(string_to_tokens('Vanilla is not the best ice cream flavor in the world.')))\n", | |
"f1(pred_bigrams, gt_bigrams)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "d850c8a3", | |
"metadata": {}, | |
"source": [ | |
"## BERTscore" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "2dc1d5f8", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# !pip install evaluate\n", | |
"# !pip install bert_score\n", | |
"# !pip install pyarrow==11.0.0" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"id": "9da20328", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/home/radek/miniconda3/envs/pytorch/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", | |
" from .autonotebook import tqdm as notebook_tqdm\n" | |
] | |
} | |
], | |
"source": [ | |
" import evaluate\n", | |
"\n", | |
"bertscore = evaluate.load('bertscore')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"id": "ae052b41", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'precision': [0.6896928548812866],\n", | |
" 'recall': [0.7891448736190796],\n", | |
" 'f1': [0.7360748052597046],\n", | |
" 'hashcode': 'distilbert-base-uncased_L5_no-idf_version=0.3.12(hug_trans=4.29.2)'}" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"bertscore.compute(\n", | |
" predictions=['Vanilla is the best ice cream flavor in the world.'],\n", | |
" references=['Vanilla.'],\n", | |
" model_type=\"distilbert-base-uncased\"\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"id": "400c40ff", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'precision': [0.9863595962524414],\n", | |
" 'recall': [0.9863595962524414],\n", | |
" 'f1': [0.9863595962524414],\n", | |
" 'hashcode': 'distilbert-base-uncased_L5_no-idf_version=0.3.12(hug_trans=4.29.2)'}" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"bertscore.compute(\n", | |
" predictions=['Vanilla is the best ice cream flavor in the world.'],\n", | |
" references=['Chocolate is the best ice cream flavor in the world.'],\n", | |
" model_type=\"distilbert-base-uncased\"\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"id": "fe4e3421", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'precision': [0.9881718158721924],\n", | |
" 'recall': [0.9713627099990845],\n", | |
" 'f1': [0.979695200920105],\n", | |
" 'hashcode': 'distilbert-base-uncased_L5_no-idf_version=0.3.12(hug_trans=4.29.2)'}" | |
] | |
}, | |
"execution_count": 14, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"bertscore.compute(\n", | |
" predictions=['Vanilla is the best ice cream flavor in the world.'],\n", | |
" references=['Vanilla is not the best ice cream flavor in the world.'],\n", | |
" model_type=\"distilbert-base-uncased\"\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "cd316b0c", | |
"metadata": {}, | |
"source": [ | |
"## Perplexity" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"id": "a56cd1cf", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from transformers import AutoTokenizer, AutoModelForCausalLM\n", | |
"import transformers\n", | |
"import torch\n", | |
"from torch import nn\n", | |
"import numpy as np\n", | |
"\n", | |
"# Code from a wonderful Kaggle Notebook by Pilipp Singer\n", | |
"# https://www.kaggle.com/code/philippsinger/h2ogpt-perplexity-ranking\n", | |
"\n", | |
"class Perplexity(nn.Module):\n", | |
" def __init__(self, reduce: bool = True):\n", | |
" super().__init__()\n", | |
" self.loss_fn = nn.CrossEntropyLoss()\n", | |
" self.reduce = reduce\n", | |
"\n", | |
" def forward(self, logits, labels):\n", | |
" shift_logits = logits[..., :-1, :].contiguous()\n", | |
" shift_labels = labels[..., 1:].contiguous()\n", | |
"\n", | |
" perplexity = []\n", | |
" for i in range(labels.shape[0]):\n", | |
" perplexity.append(self.loss_fn(shift_logits[i], shift_labels[i]))\n", | |
" perplexity = torch.stack(perplexity, dim=0)\n", | |
" if self.reduce:\n", | |
" perplexity = torch.mean(perplexity)\n", | |
" return perplexity \n", | |
" \n", | |
"perp = Perplexity()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"id": "1d7fa1e6", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def perplexity(model, prompt):\n", | |
" tokenizer = AutoTokenizer.from_pretrained(model)\n", | |
"\n", | |
" model = AutoModelForCausalLM.from_pretrained(\n", | |
" model,\n", | |
" torch_dtype=torch.float16,\n", | |
" device_map=\"auto\",\n", | |
" trust_remote_code=True,\n", | |
" )\n", | |
" \n", | |
" with torch.no_grad():\n", | |
" inputs = tokenizer([prompt], return_tensors=\"pt\", add_special_tokens=False, truncation=True).to(\"cuda\")\n", | |
" logits = model(input_ids=inputs[\"input_ids\"], attention_mask=inputs[\"attention_mask\"]).logits\n", | |
" labels = inputs[\"input_ids\"]\n", | |
" return perp(logits[0].unsqueeze(0), labels[0].unsqueeze(0)).item()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"id": "af413bfa", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"prompt = \"\"\"\n", | |
" Q: Which ice cream flavor is the best?\n", | |
" A: Vanilla is the best ice cream flavor in the world.\n", | |
"\"\"\"" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"id": "28d6c4a8", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:06<00:00, 3.19s/it]\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"2.111328125" | |
] | |
}, | |
"execution_count": 18, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"perplexity(\"tiiuae/falcon-7b\", prompt)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"id": "da110e26", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Instantiating an MPTForCausalLM model from /home/radek/.cache/huggingface/modules/transformers_modules/mosaicml/mpt-7b/72e5f594ce36f9cabfa2a9fd8f58b491eb467ee7/modeling_mpt.py\n", | |
"You are using config.init_device='cpu', but you can also use config.init_device=\"meta\" with Composer + FSDP for fast initialization.\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:06<00:00, 3.07s/it]\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"1.77734375" | |
] | |
}, | |
"execution_count": 19, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"perplexity(\"mosaicml/mpt-7b\", prompt)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.11.4" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment