Skip to content

Instantly share code, notes, and snippets.

@alexcpn
Created November 28, 2024 13:05
Show Gist options
  • Save alexcpn/36e90cb3c78695e6b09eb97cdb277414 to your computer and use it in GitHub Desktop.
Save alexcpn/36e90cb3c78695e6b09eb97cdb277414 to your computer and use it in GitHub Desktop.
llm_probability2.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyNr/UpNAHfOy+HC0ztAfuHW",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/alexcpn/36e90cb3c78695e6b09eb97cdb277414/llm_probability2.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kDxTsn5QhwKK"
},
"outputs": [],
"source": [
"import torch\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from transformers import (\n",
" AutoModelForCausalLM,\n",
" AutoTokenizer,\n",
")\n"
]
},
{
"cell_type": "code",
"source": [
"\n",
"def print_probability_distribution(current, probabilities, tokenizer, top_n=20, bar_width=50):\n",
" \"\"\"\n",
" Print the top N tokens and their probabilities in ASCII format.\n",
"\n",
" Parameters:\n",
" - current: Current context as a string.\n",
" - probabilities: Probability distribution over the vocabulary.\n",
" - tokenizer: Tokenizer to decode token IDs to tokens.\n",
" - top_n: Number of top tokens to display.\n",
" - bar_width: Width of the ASCII bar representing probabilities.\n",
" \"\"\"\n",
" # Get top N tokens and their probabilities\n",
" top_indices = np.argsort(probabilities)[-top_n:][::-1]\n",
" top_probs = probabilities[top_indices]\n",
" top_tokens = [tokenizer.decode([i]).strip() for i in top_indices]\n",
"\n",
" # Find the next token (highest probability token)\n",
" max_token = top_tokens[0]\n",
"\n",
" # Display the current context\n",
" print(f\"Context: {current}\")\n",
" print(f\"Next Token Prediction: '{max_token}'\\n\")\n",
"\n",
" # Print the top N tokens and their probabilities as an ASCII bar chart\n",
" for token, prob in zip(top_tokens, top_probs):\n",
" bar = \"#\" * int(prob * bar_width)\n",
" print(f\"{token:>15} | {bar} {prob:.4f}\")\n",
"\n",
"def plot_probability_distribution(current, probabilities, tokenizer, top_n=20):\n",
" # Get top N tokens and their probabilities\n",
" top_indices = np.argsort(probabilities)[-top_n:][::-1]\n",
" top_probs = probabilities[top_indices]\n",
" top_tokens = [tokenizer.decode([i]) for i in top_indices]\n",
"\n",
" # Find the next token (highest probability token)\n",
" max_token = tokenizer.decode([top_indices[0]])\n",
"\n",
" # Plot\n",
" plt.figure(figsize=(12, 7))\n",
" bars = plt.bar(top_tokens, top_probs, color=\"blue\")\n",
" bars[0].set_color(\"red\") # Highlight the next token\n",
"\n",
" # Add the current context inside the graph\n",
" plt.text(\n",
" 0.5,\n",
" 0.9,\n",
" f\"Context: {current}\\nNext Token: {max_token}\",\n",
" ha=\"center\",\n",
" va=\"center\",\n",
" transform=plt.gca().transAxes,\n",
" fontsize=12,\n",
" bbox=dict(facecolor=\"white\", alpha=0.8, edgecolor=\"black\"),\n",
" )\n",
"\n",
" plt.xlabel(\"Tokens\")\n",
" plt.ylabel(\"Probabilities\")\n",
" plt.xticks(rotation=45)\n",
" plt.tight_layout()\n",
" plt.show()\n",
"\n"
],
"metadata": {
"id": "mxq_nsjshx9N"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"model_name = 'gpt2'\n",
"#model_name = \"meta-llama/Llama-3.2-1B-Instruct\" # try with this also\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
"model = AutoModelForCausalLM.from_pretrained(\n",
" model_name,\n",
" torch_dtype=torch.float16,\n",
")\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"model = model.to(device)\n",
"model.eval()\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "gHJuOJHoiIo-",
"outputId": "dac7e8ca-a337-44d6-84ed-18ec297b8b49"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"GPT2LMHeadModel(\n",
" (transformer): GPT2Model(\n",
" (wte): Embedding(50257, 768)\n",
" (wpe): Embedding(1024, 768)\n",
" (drop): Dropout(p=0.1, inplace=False)\n",
" (h): ModuleList(\n",
" (0-11): 12 x GPT2Block(\n",
" (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" (attn): GPT2SdpaAttention(\n",
" (c_attn): Conv1D(nf=2304, nx=768)\n",
" (c_proj): Conv1D(nf=768, nx=768)\n",
" (attn_dropout): Dropout(p=0.1, inplace=False)\n",
" (resid_dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" (mlp): GPT2MLP(\n",
" (c_fc): Conv1D(nf=3072, nx=768)\n",
" (c_proj): Conv1D(nf=768, nx=3072)\n",
" (act): NewGELUActivation()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (lm_head): Linear(in_features=768, out_features=50257, bias=False)\n",
")"
]
},
"metadata": {},
"execution_count": 8
}
]
},
{
"cell_type": "code",
"source": [
"# Get the vocabulary as a dictionary {token: token_id}\n",
"vocab = tokenizer.get_vocab()\n",
"# Print the vocabulary size\n",
"print(f\"Vocabulary Size: {len(vocab)}\")\n",
"prompt_template = \"I love New\"\n",
"\n",
"if model_name == \"meta-llama/Llama-3.2-1B-Instruct\" :\n",
" # use its format as we are using the Instuct model, the prompt template is as below\n",
" system_message =\"You complete sentences with funny words\"\n",
" question = \"Complete the sentence I love New\"\n",
" prompt_template=f'''\n",
" <|begin_of_text|><|start_header_id|>system<|end_header_id|>\n",
" {system_message}<|eot_id|><|start_header_id|>user<|end_header_id|>\n",
" {question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n",
" '''\n",
"\n",
"print(f\"Original Text: {prompt_template}\")\n",
"input_id_list = list(tokenizer.encode(prompt_template))\n",
"text =input_id_list\n",
"generated_tokens = []\n",
"\n",
"# Set the number of tokens to generate\n",
"N = 10\n",
"\n",
"# Iterative generation\n",
"for i in range(N):\n",
" current_input = torch.tensor([text], dtype=torch.long)\n",
"\n",
" # Forward pass to get logits\n",
" with torch.no_grad():\n",
" outputs = model(current_input.to(device))\n",
" logits = outputs.logits\n",
"\n",
" # Get probabilities for the last token\n",
" probabilities = torch.softmax(logits[0, -1], dim=0).cpu().numpy()\n",
" probabilities /= probabilities.sum() # Normalize\n",
"\n",
" # Find the token with the maximum probability\n",
" max_token_id = np.argmax(probabilities)\n",
" max_token = tokenizer.decode([max_token_id])\n",
" generated_tokens.append(max_token)\n",
"\n",
" # Append the generated token to the input for the next iteration\n",
" text.append(max_token_id)\n",
"\n",
" # Decode current context for display\n",
" current = tokenizer.decode(text)\n",
" print(f\"Decoded Context: {current}\")\n",
" print(f\"Max Probability Token: '{max_token}' (ID: {max_token_id} word {i})\")\n",
"\n",
" # Plot the probability distribution\n",
" #plot_probability_distribution(current, probabilities, tokenizer, top_n=10)\n",
" print_probability_distribution(current, probabilities, tokenizer, top_n=10)\n",
"\n",
"# Final Output\n",
"final_generated_text = tokenizer.decode(text)\n",
"print(f\"\\nFinal Generated Text: {final_generated_text}\")\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "lHcKnZdvhz9E",
"outputId": "cdacf974-a11a-487d-d7a4-86d0812450b8"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Vocabulary Size: 50257\n",
"Original Text: I love New\n",
"Decoded Context: I love New York\n",
"Max Probability Token: ' York' (ID: 1971 word 0)\n",
"Context: I love New York\n",
"Next Token Prediction: 'York'\n",
"\n",
" York | ##################### 0.4355\n",
" Orleans | #### 0.0972\n",
" Zealand | #### 0.0885\n",
" England | ## 0.0504\n",
" Jersey | # 0.0393\n",
" Year | # 0.0278\n",
" Yorkers | 0.0191\n",
" Mexico | 0.0144\n",
" Hampshire | 0.0096\n",
" Years | 0.0085\n",
"Decoded Context: I love New York.\n",
"Max Probability Token: '.' (ID: 13 word 1)\n",
"Context: I love New York.\n",
"Next Token Prediction: '.'\n",
"\n",
" . | ########## 0.2020\n",
" , | ######## 0.1783\n",
" and | #### 0.0955\n",
" City | ### 0.0792\n",
" ! | # 0.0398\n",
" ,\" | # 0.0351\n",
" .\" | # 0.0291\n",
" !\" | # 0.0273\n",
" so | # 0.0227\n",
" 's | # 0.0227\n",
"Decoded Context: I love New York. I\n",
"Max Probability Token: ' I' (ID: 314 word 2)\n",
"Context: I love New York. I\n",
"Next Token Prediction: 'I'\n",
"\n",
" I | ################ 0.3269\n",
" It | ###### 0.1202\n",
" | ## 0.0533\n",
" We | ## 0.0471\n",
" But | ## 0.0416\n",
" And | # 0.0390\n",
" The | # 0.0268\n",
" You | 0.0184\n",
" So | 0.0163\n",
" My | 0.0144\n",
"Decoded Context: I love New York. I love\n",
"Max Probability Token: ' love' (ID: 1842 word 3)\n",
"Context: I love New York. I love\n",
"Next Token Prediction: 'love'\n",
"\n",
" love | ###################### 0.4473\n",
" 'm | ### 0.0686\n",
" 've | ## 0.0416\n",
" like | # 0.0324\n",
" think | # 0.0252\n",
" have | # 0.0223\n",
" live | # 0.0223\n",
" know | 0.0173\n",
" don | 0.0173\n",
" want | 0.0153\n",
"Decoded Context: I love New York. I love the\n",
"Max Probability Token: ' the' (ID: 262 word 4)\n",
"Context: I love New York. I love the\n",
"Next Token Prediction: 'the'\n",
"\n",
" the | ########## 0.2021\n",
" New | ###### 0.1305\n",
" it | ## 0.0544\n",
" my | # 0.0330\n",
" being | # 0.0257\n",
" this | # 0.0257\n",
" to | # 0.0227\n",
" all | # 0.0227\n",
" that | 0.0188\n",
" living | 0.0156\n",
"Decoded Context: I love New York. I love the city\n",
"Max Probability Token: ' city' (ID: 1748 word 5)\n",
"Context: I love New York. I love the city\n",
"Next Token Prediction: 'city'\n",
"\n",
" city | ######### 0.1903\n",
" people | ## 0.0580\n",
" way | # 0.0213\n",
" fact | # 0.0213\n",
" place | 0.0156\n",
" New | 0.0147\n",
" music | 0.0138\n",
" country | 0.0101\n",
" great | 0.0079\n",
" culture | 0.0079\n",
"Decoded Context: I love New York. I love the city.\n",
"Max Probability Token: '.' (ID: 13 word 6)\n",
"Context: I love New York. I love the city.\n",
"Next Token Prediction: '.'\n",
"\n",
" . | ######################## 0.4858\n",
" , | ##### 0.1154\n",
" and | ### 0.0793\n",
" .\" | ### 0.0657\n",
" of | ## 0.0424\n",
" ,\" | # 0.0363\n",
" that | # 0.0201\n",
" I | 0.0156\n",
" ! | 0.0114\n",
" !\" | 0.0079\n",
"Decoded Context: I love New York. I love the city. I\n",
"Max Probability Token: ' I' (ID: 314 word 7)\n",
"Context: I love New York. I love the city. I\n",
"Next Token Prediction: 'I'\n",
"\n",
" I | ############################ 0.5669\n",
" And | ### 0.0767\n",
" It | ## 0.0527\n",
" But | ## 0.0437\n",
" | # 0.0340\n",
" We | # 0.0206\n",
" The | 0.0182\n",
" So | 0.0110\n",
" New | 0.0110\n",
" My | 0.0098\n",
"Decoded Context: I love New York. I love the city. I love\n",
"Max Probability Token: ' love' (ID: 1842 word 8)\n",
"Context: I love New York. I love the city. I love\n",
"Next Token Prediction: 'love'\n",
"\n",
" love | ####################################### 0.7949\n",
" 'm | # 0.0255\n",
" like | # 0.0240\n",
" want | 0.0128\n",
" think | 0.0094\n",
" really | 0.0073\n",
" 've | 0.0073\n",
" know | 0.0069\n",
" don | 0.0069\n",
" have | 0.0065\n",
"Decoded Context: I love New York. I love the city. I love the\n",
"Max Probability Token: ' the' (ID: 262 word 9)\n",
"Context: I love New York. I love the city. I love the\n",
"Next Token Prediction: 'the'\n",
"\n",
" the | ################## 0.3635\n",
" it | ### 0.0632\n",
" New | ## 0.0594\n",
" my | # 0.0338\n",
" all | # 0.0280\n",
" that | # 0.0280\n",
" to | # 0.0218\n",
" how | # 0.0205\n",
" being | # 0.0205\n",
" this | 0.0170\n",
"\n",
"Final Generated Text: I love New York. I love the city. I love the\n"
]
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment