Created
November 28, 2024 13:05
-
-
Save alexcpn/36e90cb3c78695e6b09eb97cdb277414 to your computer and use it in GitHub Desktop.
llm_probability2.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"authorship_tag": "ABX9TyNr/UpNAHfOy+HC0ztAfuHW", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/alexcpn/36e90cb3c78695e6b09eb97cdb277414/llm_probability2.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "kDxTsn5QhwKK" | |
}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"import numpy as np\n", | |
"import matplotlib.pyplot as plt\n", | |
"from transformers import (\n", | |
" AutoModelForCausalLM,\n", | |
" AutoTokenizer,\n", | |
")\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"\n", | |
"def print_probability_distribution(current, probabilities, tokenizer, top_n=20, bar_width=50):\n", | |
" \"\"\"\n", | |
" Print the top N tokens and their probabilities in ASCII format.\n", | |
"\n", | |
" Parameters:\n", | |
" - current: Current context as a string.\n", | |
" - probabilities: Probability distribution over the vocabulary.\n", | |
" - tokenizer: Tokenizer to decode token IDs to tokens.\n", | |
" - top_n: Number of top tokens to display.\n", | |
" - bar_width: Width of the ASCII bar representing probabilities.\n", | |
" \"\"\"\n", | |
" # Get top N tokens and their probabilities\n", | |
" top_indices = np.argsort(probabilities)[-top_n:][::-1]\n", | |
" top_probs = probabilities[top_indices]\n", | |
" top_tokens = [tokenizer.decode([i]).strip() for i in top_indices]\n", | |
"\n", | |
" # Find the next token (highest probability token)\n", | |
" max_token = top_tokens[0]\n", | |
"\n", | |
" # Display the current context\n", | |
" print(f\"Context: {current}\")\n", | |
" print(f\"Next Token Prediction: '{max_token}'\\n\")\n", | |
"\n", | |
" # Print the top N tokens and their probabilities as an ASCII bar chart\n", | |
" for token, prob in zip(top_tokens, top_probs):\n", | |
" bar = \"#\" * int(prob * bar_width)\n", | |
" print(f\"{token:>15} | {bar} {prob:.4f}\")\n", | |
"\n", | |
"def plot_probability_distribution(current, probabilities, tokenizer, top_n=20):\n", | |
" # Get top N tokens and their probabilities\n", | |
" top_indices = np.argsort(probabilities)[-top_n:][::-1]\n", | |
" top_probs = probabilities[top_indices]\n", | |
" top_tokens = [tokenizer.decode([i]) for i in top_indices]\n", | |
"\n", | |
" # Find the next token (highest probability token)\n", | |
" max_token = tokenizer.decode([top_indices[0]])\n", | |
"\n", | |
" # Plot\n", | |
" plt.figure(figsize=(12, 7))\n", | |
" bars = plt.bar(top_tokens, top_probs, color=\"blue\")\n", | |
" bars[0].set_color(\"red\") # Highlight the next token\n", | |
"\n", | |
" # Add the current context inside the graph\n", | |
" plt.text(\n", | |
" 0.5,\n", | |
" 0.9,\n", | |
" f\"Context: {current}\\nNext Token: {max_token}\",\n", | |
" ha=\"center\",\n", | |
" va=\"center\",\n", | |
" transform=plt.gca().transAxes,\n", | |
" fontsize=12,\n", | |
" bbox=dict(facecolor=\"white\", alpha=0.8, edgecolor=\"black\"),\n", | |
" )\n", | |
"\n", | |
" plt.xlabel(\"Tokens\")\n", | |
" plt.ylabel(\"Probabilities\")\n", | |
" plt.xticks(rotation=45)\n", | |
" plt.tight_layout()\n", | |
" plt.show()\n", | |
"\n" | |
], | |
"metadata": { | |
"id": "mxq_nsjshx9N" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"model_name = 'gpt2'\n", | |
"#model_name = \"meta-llama/Llama-3.2-1B-Instruct\" # try with this also\n", | |
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n", | |
"model = AutoModelForCausalLM.from_pretrained(\n", | |
" model_name,\n", | |
" torch_dtype=torch.float16,\n", | |
")\n", | |
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", | |
"model = model.to(device)\n", | |
"model.eval()\n" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "gHJuOJHoiIo-", | |
"outputId": "dac7e8ca-a337-44d6-84ed-18ec297b8b49" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"GPT2LMHeadModel(\n", | |
" (transformer): GPT2Model(\n", | |
" (wte): Embedding(50257, 768)\n", | |
" (wpe): Embedding(1024, 768)\n", | |
" (drop): Dropout(p=0.1, inplace=False)\n", | |
" (h): ModuleList(\n", | |
" (0-11): 12 x GPT2Block(\n", | |
" (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", | |
" (attn): GPT2SdpaAttention(\n", | |
" (c_attn): Conv1D(nf=2304, nx=768)\n", | |
" (c_proj): Conv1D(nf=768, nx=768)\n", | |
" (attn_dropout): Dropout(p=0.1, inplace=False)\n", | |
" (resid_dropout): Dropout(p=0.1, inplace=False)\n", | |
" )\n", | |
" (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", | |
" (mlp): GPT2MLP(\n", | |
" (c_fc): Conv1D(nf=3072, nx=768)\n", | |
" (c_proj): Conv1D(nf=768, nx=3072)\n", | |
" (act): NewGELUActivation()\n", | |
" (dropout): Dropout(p=0.1, inplace=False)\n", | |
" )\n", | |
" )\n", | |
" )\n", | |
" (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n", | |
" )\n", | |
" (lm_head): Linear(in_features=768, out_features=50257, bias=False)\n", | |
")" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 8 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Get the vocabulary as a dictionary {token: token_id}\n", | |
"vocab = tokenizer.get_vocab()\n", | |
"# Print the vocabulary size\n", | |
"print(f\"Vocabulary Size: {len(vocab)}\")\n", | |
"prompt_template = \"I love New\"\n", | |
"\n", | |
"if model_name == \"meta-llama/Llama-3.2-1B-Instruct\" :\n", | |
" # use its format as we are using the Instuct model, the prompt template is as below\n", | |
" system_message =\"You complete sentences with funny words\"\n", | |
" question = \"Complete the sentence I love New\"\n", | |
" prompt_template=f'''\n", | |
" <|begin_of_text|><|start_header_id|>system<|end_header_id|>\n", | |
" {system_message}<|eot_id|><|start_header_id|>user<|end_header_id|>\n", | |
" {question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n", | |
" '''\n", | |
"\n", | |
"print(f\"Original Text: {prompt_template}\")\n", | |
"input_id_list = list(tokenizer.encode(prompt_template))\n", | |
"text =input_id_list\n", | |
"generated_tokens = []\n", | |
"\n", | |
"# Set the number of tokens to generate\n", | |
"N = 10\n", | |
"\n", | |
"# Iterative generation\n", | |
"for i in range(N):\n", | |
" current_input = torch.tensor([text], dtype=torch.long)\n", | |
"\n", | |
" # Forward pass to get logits\n", | |
" with torch.no_grad():\n", | |
" outputs = model(current_input.to(device))\n", | |
" logits = outputs.logits\n", | |
"\n", | |
" # Get probabilities for the last token\n", | |
" probabilities = torch.softmax(logits[0, -1], dim=0).cpu().numpy()\n", | |
" probabilities /= probabilities.sum() # Normalize\n", | |
"\n", | |
" # Find the token with the maximum probability\n", | |
" max_token_id = np.argmax(probabilities)\n", | |
" max_token = tokenizer.decode([max_token_id])\n", | |
" generated_tokens.append(max_token)\n", | |
"\n", | |
" # Append the generated token to the input for the next iteration\n", | |
" text.append(max_token_id)\n", | |
"\n", | |
" # Decode current context for display\n", | |
" current = tokenizer.decode(text)\n", | |
" print(f\"Decoded Context: {current}\")\n", | |
" print(f\"Max Probability Token: '{max_token}' (ID: {max_token_id} word {i})\")\n", | |
"\n", | |
" # Plot the probability distribution\n", | |
" #plot_probability_distribution(current, probabilities, tokenizer, top_n=10)\n", | |
" print_probability_distribution(current, probabilities, tokenizer, top_n=10)\n", | |
"\n", | |
"# Final Output\n", | |
"final_generated_text = tokenizer.decode(text)\n", | |
"print(f\"\\nFinal Generated Text: {final_generated_text}\")\n" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "lHcKnZdvhz9E", | |
"outputId": "cdacf974-a11a-487d-d7a4-86d0812450b8" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Vocabulary Size: 50257\n", | |
"Original Text: I love New\n", | |
"Decoded Context: I love New York\n", | |
"Max Probability Token: ' York' (ID: 1971 word 0)\n", | |
"Context: I love New York\n", | |
"Next Token Prediction: 'York'\n", | |
"\n", | |
" York | ##################### 0.4355\n", | |
" Orleans | #### 0.0972\n", | |
" Zealand | #### 0.0885\n", | |
" England | ## 0.0504\n", | |
" Jersey | # 0.0393\n", | |
" Year | # 0.0278\n", | |
" Yorkers | 0.0191\n", | |
" Mexico | 0.0144\n", | |
" Hampshire | 0.0096\n", | |
" Years | 0.0085\n", | |
"Decoded Context: I love New York.\n", | |
"Max Probability Token: '.' (ID: 13 word 1)\n", | |
"Context: I love New York.\n", | |
"Next Token Prediction: '.'\n", | |
"\n", | |
" . | ########## 0.2020\n", | |
" , | ######## 0.1783\n", | |
" and | #### 0.0955\n", | |
" City | ### 0.0792\n", | |
" ! | # 0.0398\n", | |
" ,\" | # 0.0351\n", | |
" .\" | # 0.0291\n", | |
" !\" | # 0.0273\n", | |
" so | # 0.0227\n", | |
" 's | # 0.0227\n", | |
"Decoded Context: I love New York. I\n", | |
"Max Probability Token: ' I' (ID: 314 word 2)\n", | |
"Context: I love New York. I\n", | |
"Next Token Prediction: 'I'\n", | |
"\n", | |
" I | ################ 0.3269\n", | |
" It | ###### 0.1202\n", | |
" | ## 0.0533\n", | |
" We | ## 0.0471\n", | |
" But | ## 0.0416\n", | |
" And | # 0.0390\n", | |
" The | # 0.0268\n", | |
" You | 0.0184\n", | |
" So | 0.0163\n", | |
" My | 0.0144\n", | |
"Decoded Context: I love New York. I love\n", | |
"Max Probability Token: ' love' (ID: 1842 word 3)\n", | |
"Context: I love New York. I love\n", | |
"Next Token Prediction: 'love'\n", | |
"\n", | |
" love | ###################### 0.4473\n", | |
" 'm | ### 0.0686\n", | |
" 've | ## 0.0416\n", | |
" like | # 0.0324\n", | |
" think | # 0.0252\n", | |
" have | # 0.0223\n", | |
" live | # 0.0223\n", | |
" know | 0.0173\n", | |
" don | 0.0173\n", | |
" want | 0.0153\n", | |
"Decoded Context: I love New York. I love the\n", | |
"Max Probability Token: ' the' (ID: 262 word 4)\n", | |
"Context: I love New York. I love the\n", | |
"Next Token Prediction: 'the'\n", | |
"\n", | |
" the | ########## 0.2021\n", | |
" New | ###### 0.1305\n", | |
" it | ## 0.0544\n", | |
" my | # 0.0330\n", | |
" being | # 0.0257\n", | |
" this | # 0.0257\n", | |
" to | # 0.0227\n", | |
" all | # 0.0227\n", | |
" that | 0.0188\n", | |
" living | 0.0156\n", | |
"Decoded Context: I love New York. I love the city\n", | |
"Max Probability Token: ' city' (ID: 1748 word 5)\n", | |
"Context: I love New York. I love the city\n", | |
"Next Token Prediction: 'city'\n", | |
"\n", | |
" city | ######### 0.1903\n", | |
" people | ## 0.0580\n", | |
" way | # 0.0213\n", | |
" fact | # 0.0213\n", | |
" place | 0.0156\n", | |
" New | 0.0147\n", | |
" music | 0.0138\n", | |
" country | 0.0101\n", | |
" great | 0.0079\n", | |
" culture | 0.0079\n", | |
"Decoded Context: I love New York. I love the city.\n", | |
"Max Probability Token: '.' (ID: 13 word 6)\n", | |
"Context: I love New York. I love the city.\n", | |
"Next Token Prediction: '.'\n", | |
"\n", | |
" . | ######################## 0.4858\n", | |
" , | ##### 0.1154\n", | |
" and | ### 0.0793\n", | |
" .\" | ### 0.0657\n", | |
" of | ## 0.0424\n", | |
" ,\" | # 0.0363\n", | |
" that | # 0.0201\n", | |
" I | 0.0156\n", | |
" ! | 0.0114\n", | |
" !\" | 0.0079\n", | |
"Decoded Context: I love New York. I love the city. I\n", | |
"Max Probability Token: ' I' (ID: 314 word 7)\n", | |
"Context: I love New York. I love the city. I\n", | |
"Next Token Prediction: 'I'\n", | |
"\n", | |
" I | ############################ 0.5669\n", | |
" And | ### 0.0767\n", | |
" It | ## 0.0527\n", | |
" But | ## 0.0437\n", | |
" | # 0.0340\n", | |
" We | # 0.0206\n", | |
" The | 0.0182\n", | |
" So | 0.0110\n", | |
" New | 0.0110\n", | |
" My | 0.0098\n", | |
"Decoded Context: I love New York. I love the city. I love\n", | |
"Max Probability Token: ' love' (ID: 1842 word 8)\n", | |
"Context: I love New York. I love the city. I love\n", | |
"Next Token Prediction: 'love'\n", | |
"\n", | |
" love | ####################################### 0.7949\n", | |
" 'm | # 0.0255\n", | |
" like | # 0.0240\n", | |
" want | 0.0128\n", | |
" think | 0.0094\n", | |
" really | 0.0073\n", | |
" 've | 0.0073\n", | |
" know | 0.0069\n", | |
" don | 0.0069\n", | |
" have | 0.0065\n", | |
"Decoded Context: I love New York. I love the city. I love the\n", | |
"Max Probability Token: ' the' (ID: 262 word 9)\n", | |
"Context: I love New York. I love the city. I love the\n", | |
"Next Token Prediction: 'the'\n", | |
"\n", | |
" the | ################## 0.3635\n", | |
" it | ### 0.0632\n", | |
" New | ## 0.0594\n", | |
" my | # 0.0338\n", | |
" all | # 0.0280\n", | |
" that | # 0.0280\n", | |
" to | # 0.0218\n", | |
" how | # 0.0205\n", | |
" being | # 0.0205\n", | |
" this | 0.0170\n", | |
"\n", | |
"Final Generated Text: I love New York. I love the city. I love the\n" | |
] | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment