Created
September 19, 2023 17:50
-
-
Save skyfallsin/689ebceca5de47dfff2aecaa502c1bc9 to your computer and use it in GitHub Desktop.
AI Guide - pick a model, test a model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "view-in-github", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "<a href=\"https://colab.research.google.com/gist/skyfallsin/689ebceca5de47dfff2aecaa502c1bc9/ai-guide-pick-a-model-test-a-model.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "OPw-GjAas0m0" | |
| }, | |
| "source": [ | |
| "#**First Steps with Language Models**\n", | |
| "\n", | |
| "Unlike other guides, this one is designed to help pick the right model for whatever it is you're trying to do, by:\n", | |
| "- teaching you how to always remain on the bleeding edge of published AI research\n", | |
| "- broadening your perspective on current open options for any given task\n", | |
| "- not be tied to a closed-source / closed-data large language model (ex OpenAI, Anthropic)\n", | |
| "- creating a data-led system for always identifying and using the state-of-the-art (SOTA) model for any particular task.\n", | |
| "\n", | |
| "We're going to hone in on \"text summarization\" as our first task." | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "gAihu4Qw-Gzu" | |
| }, | |
| "source": [ | |
| "## So... why are we not using one of the popular large language models?\n", | |
| "\n", | |
| "Great question. Most available LLMs worth their salt can do many tasks, including summarization, but not all of them may be good at what specifically you want them to do. We should figure out how to evaluate whether they actually can or not, which are the next few steps we will learn.\n", | |
| "\n", | |
| "Also, many of the current popular ones are not open, are trained on undisclosed data and exhibit biases. Responsible AI use require careful choices, and we're here to help you make them.\n", | |
| "\n", | |
| "Finally, most large LLMs require powerful GPU compute to use. While there are many models that you can use as a service, most of them cost money per API call. Unnecessary when some of the more common tasks can be done at good quality with already available open models and off-the-shelf hardware." | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "7e-c-RKd_pyj" | |
| }, | |
| "source": [ | |
| "## Why do using open models matter?\n", | |
| "\n", | |
| "Over the last few decades, engineers have been blessed with being able to onboard by starting with open source projects, and eventually shipping open source to production. This default state is now at risk.\n", | |
| "\n", | |
| "Yes, there are many open models available that do a great job. However, most guides don't discuss how to get started with them using simple steps and instead bias towards existing closed APIs.\n", | |
| "\n", | |
| "Funding is flowing to commercial AI projects, who have larger budgets than open source contributors to market their work, which inevitably leads to engineers starting with closed source projects and shipping expensive closed projects to production." | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "N6LEcWmCt6v1" | |
| }, | |
| "source": [ | |
| "#Our First Project - Summarization\n", | |
| "\n", | |
| "We're going to:\n", | |
| "- Get some long documents to summarize.\n", | |
| "- Figure out how to summarize them using the current state-of-the-art open source models.\n", | |
| "- Write some code to do so." | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [], | |
| "metadata": { | |
| "id": "fVCSy7zlZNdm" | |
| } | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "kcb9K3dLubjT" | |
| }, | |
| "source": [ | |
| "### Where can I grab some documents?\n", | |
| "For simplicity's sake, let's grab **Mozilla's Trustworthy AI Guidelines** in string form\n", | |
| "\n", | |
| "Note that in the real world, you will likely have use other libraries to extract content for any particular file type." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": { | |
| "id": "uHgg18k1t4sy", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "outputId": "15d85e2d-94f5-477a-b28f-b874e5c9eb90" | |
| }, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "Mozilla's \"Trustworthy AI\" Thinking Points: PRIVACY: How is data collected, stored, and shared? Our\n", | |
| "personal data powers everything from traffic maps to targeted advertising. Trustworthy AI should\n", | |
| "enable people to decide how their data is used and what decisions are made with it. FAIRNESS: We’ve\n", | |
| "seen time and again how bias shows up in computational models, data, and frameworks behind automated\n", | |
| "decision making. The values and goals of a system should be power aware and seek to minimize harm.\n", | |
| "Further, AI systems that depend on human workers should protect people from exploitation and\n", | |
| "overwork. TRUST: People should have agency and control over their data and algorithmic outputs,\n", | |
| "especially considering the high stakes for individuals and societies. For instance, when online\n", | |
| "recommendation systems push people towards extreme, misleading content, potentially misinforming or\n", | |
| "radicalizing them. SAFETY: AI systems can carry high risk for exploitation by bad actors.\n", | |
| "Developers need to implement strong measures to protect our data and personal security. Further,\n", | |
| "excessive energy consumption and extraction of natural resources for computing and machine learning\n", | |
| "accelerates the climate crisis. TRANSPARENCY: Automated decisions can have huge personal impacts,\n", | |
| "yet the reasons for decisions are often opaque. We need to mandate transparency so that we can fully\n", | |
| "understand these systems and their potential for harm.\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "import textwrap\n", | |
| "\n", | |
| "content = \"\"\"Mozilla's \"Trustworthy AI\" Thinking Points:\n", | |
| "\n", | |
| "PRIVACY: How is data collected, stored, and shared? Our personal data powers everything from traffic maps to targeted advertising. Trustworthy AI should enable people to decide how their data is used and what decisions are made with it.\n", | |
| "\n", | |
| "FAIRNESS: We’ve seen time and again how bias shows up in computational models, data, and frameworks behind automated decision making. The values and goals of a system should be power aware and seek to minimize harm. Further, AI systems that depend on human workers should protect people from exploitation and overwork.\n", | |
| "\n", | |
| "TRUST: People should have agency and control over their data and algorithmic outputs, especially considering the high stakes for individuals and societies. For instance, when online recommendation systems push people towards extreme, misleading content, potentially misinforming or radicalizing them.\n", | |
| "\n", | |
| "SAFETY: AI systems can carry high risk for exploitation by bad actors. Developers need to implement strong measures to protect our data and personal security. Further, excessive energy consumption and extraction of natural resources for computing and machine learning accelerates the climate crisis.\n", | |
| "\n", | |
| "TRANSPARENCY: Automated decisions can have huge personal impacts, yet the reasons for decisions are often opaque. We need to mandate transparency so that we can fully understand these systems and their potential for harm.\"\"\"\n", | |
| "\n", | |
| "# let's take a look at the actual content\n", | |
| "print(\n", | |
| " textwrap.fill(\n", | |
| " content, 100\n", | |
| " )\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "W5Y9LBzfujnm" | |
| }, | |
| "source": [ | |
| "Great. Now we're ready to start summarizing." | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "DM2n6byeDHss" | |
| }, | |
| "source": [ | |
| "### A brief pause for context.\n", | |
| "\n", | |
| "The AI space is moving so fast that it requires a tremendous amount of catching up on scientific papers every week to understand the lay of the land and the state of the art.\n", | |
| "\n", | |
| "It's some effort for an engineer who is brand new to AI to:\n", | |
| "* discover which open models are even out there\n", | |
| "* which models are appropriate for a particular task\n", | |
| "* which benchmarks are used to evaluate those models\n", | |
| "* which models are performing well based on evaluations\n", | |
| "* which models can actually run on available hardware\n", | |
| "\n", | |
| "For the working engineer on a deadline, this is problematic. There's not much centralized discourse on working with open source AI models. Instead there are fragmented X (formerly Twitter) threads, random private groups and lots of word-of-mouth transfer.\n", | |
| "\n", | |
| "However, once you master a framework on how to address all of the above, you will have the means to forever be on the bleeding age of published AI research.\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "oG-p4zhyJ35-" | |
| }, | |
| "source": [ | |
| "### How do I get a list of available open summarization models?\n", | |
| "\n", | |
| "For now, we recommend [Huggingface](https://huggingface.co/models?pipeline_tag=summarization) and their large directory of open models broken down by task. This is a great starting point. Note that larger LLMs are also included in these lists, so we will have to filter.\n", | |
| "\n", | |
| "In this huge list of summarization models, which ones do we choose?\n", | |
| "\n", | |
| "We don't know what any of these models are trained on. For example, a summarizer trained on news articles vs Reddit posts will perform better on news articles.\n", | |
| "\n", | |
| "What we need is a set of metrics and benchmarks that we can use to do apples-to-apples comparisons of these models." | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "LJ0kYfaHYypn" | |
| }, | |
| "source": [ | |
| "### How do I evaluate summarization models?\n", | |
| "\n", | |
| "These steps below can be used to evaluate any available model for any task. It requires hopping between a few sources of data for now, but we will be making this a lot easier moving forward.\n", | |
| "\n", | |
| "Steps:\n", | |
| "1. Find the most common datasets used to train models for summarization.\n", | |
| "2. Find the most common metrics used to evaluate models for summarization across those datasets." | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "Uv6_QnqCmHu1" | |
| }, | |
| "source": [ | |
| "#### Finding datasets\n", | |
| "\n", | |
| "The easiest way to do this is using _[Papers With Code](https://paperswithcode.com/methods)_, an excellent resource for finding the latest scientific papers by task that also have code repositories attached.\n", | |
| "\n", | |
| "First, filter _Papers With Code_'s \"Text Summarization\" datasets [by most cited text-based English datasets](https://paperswithcode.com/datasets?q=&v=lst&o=cited&lang=english&mod=texts&task=text-summarization&page=1).\n", | |
| "\n", | |
| "Let's pick (as of this writing) the most cited dataset -- the \"[CNN/DailyMail](https://paperswithcode.com/dataset/cnn-daily-mail-1)\" dataset. Usually most cited is one marker of popularity.\n", | |
| "\n", | |
| "Now, you don't need to download this dataset. But we're going to review the info _Papers With Code_ have provided to learn more about it for the next step. This dataset is also available on [Huggingface](https://huggingface.co/datasets/cnn_dailymail).\n", | |
| "\n", | |
| "You want to check N things:\n", | |
| "- license\n", | |
| "- recent papers\n", | |
| "- whether the data is traceable and the methods are transparent, etc etc (responsible AI stuff)\n", | |
| "\n", | |
| "First, check the license. In this case, it's MIT licensed, which means it can be used for both commercial and personal projects.\n", | |
| "\n", | |
| "Next, see if the papers using this dataset are recent. You can do this by sorting Papers in descending order. This particular dataset has many papers from 2023 - great!\n", | |
| "\n", | |
| "Now, let's dig into how we can evaluate models that use this dataset.\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "fCXBmqwZmMAh" | |
| }, | |
| "source": [ | |
| "#### Evaluating models\n", | |
| "\n", | |
| "Next, we look for measured metrics that are common across datasets for the summarization task. BUT, if you're not familiar with the literature on summarization, you have no idea what those are.\n", | |
| "\n", | |
| "To find out, pick a \"Subtask\" that's close to what you'd like to see. We'd like to summarize the CNN article we pulled down above, so let's choose \"[Abstractive Text Summarization](https://paperswithcode.com/sota/abstractive-text-summarization-on-cnn-daily)\".\n", | |
| "\n", | |
| "Now we're in business! This page contains a significant amount of new information.\n", | |
| "\n", | |
| "There are mentions of three new terms: ROUGE-1, ROUGE-2 and ROUGE-L. These are the metrics that are used to [measure summarization performance](https://en.wikipedia.org/wiki/ROUGE_(metric)).\n", | |
| "\n", | |
| "There are also a list of models and their scores on these three metrics - this is exactly what we're looking for.\n", | |
| "\n", | |
| "Assuming we're looking at ROUGE-1 as our metric, we now have the top 3 models that we can evaluate in more detail. All 3 are close to 50, which is a promising ROUGE score (read up on ROUGE)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "1D3Kq1wlqPWH" | |
| }, | |
| "source": [ | |
| "### Testing out a model\n", | |
| "\n", | |
| "OK, we have a few candidates, so let's pick a model that will run on our local machines. Many models get their best performance when running on GPUs, but there are many that also generate summaries fast on CPUs. Let's pick one of those to start - Google's Pegasus." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": { | |
| "id": "21Ay3KNetit6", | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "outputId": "f71bafd1-fd05-4929-db74-87084adb849a" | |
| }, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.33.2)\n", | |
| "Requirement already satisfied: sentencepiece in /usr/local/lib/python3.10/dist-packages (0.1.99)\n", | |
| "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.12.2)\n", | |
| "Requirement already satisfied: huggingface-hub<1.0,>=0.15.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.17.2)\n", | |
| "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.23.5)\n", | |
| "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (23.1)\n", | |
| "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.1)\n", | |
| "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2023.6.3)\n", | |
| "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.31.0)\n", | |
| "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.13.3)\n", | |
| "Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.3.3)\n", | |
| "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.1)\n", | |
| "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.15.1->transformers) (2023.6.0)\n", | |
| "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.15.1->transformers) (4.5.0)\n", | |
| "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.2.0)\n", | |
| "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.4)\n", | |
| "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.4)\n", | |
| "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2023.7.22)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# first we install huggingface's transformers library\n", | |
| "%pip install transformers sentencepiece" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "0Bv1cXlft3Tn" | |
| }, | |
| "source": [ | |
| "Then we [find Pegasus](https://huggingface.co/google/pegasus-cnn_dailymail) on Huggingface. Note that part of the datasets Pegasus was trained on includes CNN/DailyMail which bodes well for our article summarization." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "from transformers import PegasusForConditionalGeneration, PegasusTokenizer\n", | |
| "import torch\n", | |
| "\n", | |
| "# Set the seed, this will help reproduce results. Changing the seed will\n", | |
| "# generate new results\n", | |
| "from transformers import set_seed\n", | |
| "set_seed(248602)\n", | |
| "\n", | |
| "# We're using the version of Pegasus specifically trained for summarization\n", | |
| "# using the CNN/DailyMail dataset\n", | |
| "model_name = \"google/pegasus-cnn_dailymail\"\n", | |
| "\n", | |
| "# If you're following along in Colab, switch your runtime to a\n", | |
| "# T4 GPU or other CUDA-compliant device for a speedup\n", | |
| "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", | |
| "\n", | |
| "# Load the tokenizer\n", | |
| "tokenizer = PegasusTokenizer.from_pretrained(model_name)\n", | |
| "\n", | |
| "# Load the model\n", | |
| "model = PegasusForConditionalGeneration.from_pretrained(model_name).to(device)\n", | |
| "\n", | |
| "\n" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "L3AN0XBD6LTG", | |
| "outputId": "09b9517d-2934-4db4-aae3-843a2d051ac4" | |
| }, | |
| "execution_count": 3, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stderr", | |
| "text": [ | |
| "Some weights of PegasusForConditionalGeneration were not initialized from the model checkpoint at google/pegasus-cnn_dailymail and are newly initialized: ['model.decoder.embed_positions.weight', 'model.encoder.embed_positions.weight']\n", | |
| "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Tokenize the entire content\n", | |
| "batch = tokenizer(content, padding=\"longest\", return_tensors=\"pt\").to(device)\n", | |
| "\n", | |
| "# Generate the summary as tokens\n", | |
| "summarized = model.generate(**batch)\n", | |
| "\n", | |
| "# Decode the tokens back into text\n", | |
| "summarized_decoded = tokenizer.batch_decode(summarized, skip_special_tokens=True)\n", | |
| "summarized_text = summarized_decoded[0]\n", | |
| "\n", | |
| "# Compare\n", | |
| "def compare(original, summarized_text):\n", | |
| " print(f\"Article text length: {len(original)}\\n\")\n", | |
| " print(textwrap.fill(summarized_text, 100))\n", | |
| " print()\n", | |
| " print(f\"Summarized length: {len(summarized_text)}\")\n", | |
| "\n", | |
| "compare(content, summarized_text)" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "JcAfvCvjE7De", | |
| "outputId": "a1edf823-ab40-45a9-a922-42044aaed8d4" | |
| }, | |
| "execution_count": 4, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "Article text length: 1427\n", | |
| "\n", | |
| "Trustworthy AI should enable people to decide how their data is used.<n>values and goals of a system\n", | |
| "should be power aware and seek to minimize harm.<n>People should have agency and control over their\n", | |
| "data and algorithmic outputs.<n>Developers need to implement strong measures to protect our data and\n", | |
| "personal security.\n", | |
| "\n", | |
| "Summarized length: 320\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "Alright, we got something! Kind of short though. Let's see if we can make the summary longer..." | |
| ], | |
| "metadata": { | |
| "id": "-yjHN7F8BK8A" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# model = PegasusForConditionalGeneration.from_pretrained(model_name).to(device)\n", | |
| "set_seed(860912)\n", | |
| "\n", | |
| "# Generate the summary as tokens, with a max_new_tokens\n", | |
| "summarized = model.generate(**batch, max_new_tokens=800)\n", | |
| "summarized_decoded = tokenizer.batch_decode(summarized, skip_special_tokens=True)\n", | |
| "summarized_text = summarized_decoded[0]\n", | |
| "\n", | |
| "compare(content, summarized_text)" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "36xlxJHwBapQ", | |
| "outputId": "57115fc6-679b-44d8-c036-57c154ecb5c8" | |
| }, | |
| "execution_count": 5, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "Article text length: 1427\n", | |
| "\n", | |
| "Trustworthy AI should enable people to decide how their data is used.<n>values and goals of a system\n", | |
| "should be power aware and seek to minimize harm.<n>People should have agency and control over their\n", | |
| "data and algorithmic outputs.<n>Developers need to implement strong measures to protect our data and\n", | |
| "personal security.\n", | |
| "\n", | |
| "Summarized length: 320\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "Well, that didn't really work. Let's try a different approach called '**sampling**'. This allows the model to pick the next word according to its conditional probability distribution (specifically, the probability that said word follows the word before).\n", | |
| "\n", | |
| "We'll also be setting the '**temperature**'. This variable works to control the levels of randomness and creativity in the generated output." | |
| ], | |
| "metadata": { | |
| "id": "0DxUaSfoHZ9q" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# model = PegasusForConditionalGeneration.from_pretrained(model_name).to(device)\n", | |
| "set_seed(118511)\n", | |
| "summarized = model.generate(**batch, do_sample=True, temperature=0.8, top_k=0)\n", | |
| "summarized_decoded = tokenizer.batch_decode(summarized, skip_special_tokens=True)\n", | |
| "summarized_text = summarized_decoded[0]\n", | |
| "compare(content, summarized_text)" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "2HKM8i7HF2L1", | |
| "outputId": "4ef7d32a-0499-4647-a33a-09675cd4a17d" | |
| }, | |
| "execution_count": 6, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "Article text length: 1427\n", | |
| "\n", | |
| "Mozilla's \"Trustworthy AI\" Thinking Points:.<n>People should have agency and control over their data\n", | |
| "and algorithmic outputs.<n>Developers need to implement strong measures to protect our data.\n", | |
| "\n", | |
| "Summarized length: 193\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "Shorter, but the quality is higher. Adjusting the **temperature** up will likely help." | |
| ], | |
| "metadata": { | |
| "id": "WenTW2LuLrf0" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# model = PegasusForConditionalGeneration.from_pretrained(model_name).to(device)\n", | |
| "set_seed(108814)\n", | |
| "summarized = model.generate(**batch, do_sample=True, temperature=1.0, top_k=0)\n", | |
| "summarized_decoded = tokenizer.batch_decode(summarized, skip_special_tokens=True)\n", | |
| "summarized_text = summarized_decoded[0]\n", | |
| "compare(content, summarized_text)" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "R2N3zH4DHiNa", | |
| "outputId": "073e5993-d058-4a61-bc1b-5ba2a3f9fc5b" | |
| }, | |
| "execution_count": 7, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "Article text length: 1427\n", | |
| "\n", | |
| "Mozilla's \"Trustworthy AI\" Thinking Points:.<n>People should have agency and control over their data\n", | |
| "and algorithmic outputs.<n>Developers need to implement strong measures to protect our data and\n", | |
| "personal security.<n>We need to mandate transparency so that we can fully understand these systems\n", | |
| "and their potential for harm.\n", | |
| "\n", | |
| "Summarized length: 325\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "Now let's play with one other generation approach called **top_k** sampling -- instead of considering all possible next words in the vocabulary, the model only considers the top 'k' most probable next words.\n", | |
| "\n", | |
| "This technique helps to focus the model on likely continuations and reduces the chances of generating irrelevant or nonsensical text.\n", | |
| "\n", | |
| "It strikes a balance between creativity and coherence by limiting the pool of next word choices, but not so much that the output becomes deterministic." | |
| ], | |
| "metadata": { | |
| "id": "-Hi_spZxL5Tm" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# model = PegasusForConditionalGeneration.from_pretrained(model_name).to(device)\n", | |
| "set_seed(226012)\n", | |
| "summarized = model.generate(**batch, do_sample=True, top_k=50)\n", | |
| "summarized_decoded = tokenizer.batch_decode(summarized, skip_special_tokens=True)\n", | |
| "summarized_text = summarized_decoded[0]\n", | |
| "compare(content, summarized_text)" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "_rVb-4BMLxXp", | |
| "outputId": "0cea6575-cc35-403f-dcd6-7cff4e905dae" | |
| }, | |
| "execution_count": 8, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "Article text length: 1427\n", | |
| "\n", | |
| "Mozilla's \"Trustworthy AI\" Thinking Points look at ethical issues surrounding automated decision\n", | |
| "making.<n>values and goals of a system should be power aware and seek to minimize harm.<n>People\n", | |
| "should have agency and control over their data and algorithmic outputs.<n>Developers need to\n", | |
| "implement strong measures to protect our data and personal security.\n", | |
| "\n", | |
| "Summarized length: 355\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "Finally, let's try **top_p** sampling -- also known as nucleus sampling, is a strategy where the model considers only the smallest set of top words whose cumulative probability exceeds a threshold 'p'.\n", | |
| "\n", | |
| "Unlike top-k which considers a fixed number of words, **top-p** adapts based on the distribution of probabilities for the next word. This makes it more dynamic and flexible. It helps create diverse and sensible text by allowing less probable words to be selected when the most probable ones don't add up to 'p'." | |
| ], | |
| "metadata": { | |
| "id": "WuLjk7tjSKKx" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "set_seed(21420041)\n", | |
| "summarized = model.generate(**batch, do_sample=True, top_p=0.9, top_k=50)\n", | |
| "summarized_decoded = tokenizer.batch_decode(summarized, skip_special_tokens=True)\n", | |
| "summarized_text = summarized_decoded[0]\n", | |
| "compare(content, summarized_text)\n", | |
| "\n", | |
| "# saving this for later.\n", | |
| "pegasus_summarized_text = summarized_text" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "z6KyxLQjNpAr", | |
| "outputId": "3a0cb93a-0655-45f1-9e51-569ed38db89c" | |
| }, | |
| "execution_count": 9, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "Article text length: 1427\n", | |
| "\n", | |
| "Mozilla's \"Trustworthy AI\" Thinking Points:.<n>People should have agency and control over their data\n", | |
| "and algorithmic outputs.<n>Developers need to implement strong measures to protect our data and\n", | |
| "personal security.<n>We need to mandate transparency so that we can fully understand these systems\n", | |
| "and their potential for harm.\n", | |
| "\n", | |
| "Summarized length: 325\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "Now, let's try out another model -- [Meta's \"BART\"](https://huggingface.co/docs/transformers/model_doc/bart).\n", | |
| "\n", | |
| "Looking at the [PapersWithCode graph](https://paperswithcode.com/sota/abstractive-text-summarization-on-cnn-daily?tag_filter=4), BART has solid results with ROUGE-1.\n", | |
| "\n", | |
| "Similar to Pegasus, BART has a [custom version](http://huggingface.co) finetuned on CNN data." | |
| ], | |
| "metadata": { | |
| "id": "988bxC3pTSeT" | |
| } | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "from transformers import BartTokenizer, BartForConditionalGeneration\n", | |
| "\n", | |
| "set_seed(120986)\n", | |
| "bart_model_name = \"facebook/bart-large-cnn\"\n", | |
| "\n", | |
| "# Load the tokenizer\n", | |
| "bart_tokenizer = BartTokenizer.from_pretrained(bart_model_name)\n", | |
| "\n", | |
| "# Load the model\n", | |
| "bart_model = BartForConditionalGeneration.from_pretrained(bart_model_name).to(device)\n" | |
| ], | |
| "metadata": { | |
| "id": "Wk0iM8xgZy5v" | |
| }, | |
| "execution_count": 10, | |
| "outputs": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "source": [ | |
| "# Using the same parameters as Pegasus, let's try running BART\n", | |
| "\n", | |
| "batch = bart_tokenizer(content, padding=\"longest\", return_tensors=\"pt\").to(device)\n", | |
| "summarized = bart_model.generate(**batch, do_sample=True, top_p=0.5, top_k=50, max_new_tokens=500)\n", | |
| "summarized_decoded = bart_tokenizer.batch_decode(summarized, skip_special_tokens=True)\n", | |
| "summarized_text = summarized_decoded[0]\n", | |
| "compare(content, summarized_text)\n", | |
| "\n", | |
| "bart_summarized_text = summarized_text" | |
| ], | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/" | |
| }, | |
| "id": "toA777JgbboS", | |
| "outputId": "27c559bb-bbc2-4889-9957-ebad51c3a4c8" | |
| }, | |
| "execution_count": 26, | |
| "outputs": [ | |
| { | |
| "output_type": "stream", | |
| "name": "stdout", | |
| "text": [ | |
| "Article text length: 1427\n", | |
| "\n", | |
| "Mozilla's \"Trustworthy AI\" Thinking Points: How is data collected, stored, and shared? Our personal\n", | |
| "data powers everything from traffic maps to targeted advertising. Trustworthy AI should enable\n", | |
| "people to decide how their data is used and what decisions are made with it.\n", | |
| "\n", | |
| "Summarized length: 271\n" | |
| ] | |
| } | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "source": [ | |
| "Is this the best that BART can do? Unlikely. You can take this as a starting point to experiment.\n", | |
| "\n", | |
| "At this point, you should have enough of a workflow to find, select and try these models out for not just summarization but any text-based use-case. Let's start learning and experimenting!\n", | |
| "\n", | |
| "There are [many other variables](https://huggingface.co/docs/transformers/main_classes/text_generation) that control the output, but this is a great stopping point to switch over to how to evaluate the results of these and any other model quantitatively for your own use-cases." | |
| ], | |
| "metadata": { | |
| "id": "ixijBTs-f5WB" | |
| } | |
| } | |
| ], | |
| "metadata": { | |
| "accelerator": "GPU", | |
| "colab": { | |
| "provenance": [], | |
| "gpuType": "T4", | |
| "include_colab_link": true | |
| }, | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "name": "python" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 0 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment