Skip to content

Instantly share code, notes, and snippets.

@AWeirdDev
Created April 26, 2024 13:32
Show Gist options
  • Save AWeirdDev/a818a957e7a0b4933d8791a41007ae21 to your computer and use it in GitHub Desktop.
Save AWeirdDev/a818a957e7a0b4933d8791a41007ae21 to your computer and use it in GitHub Desktop.
TF-IDF in Python
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"source": [
"# TF-IDF in Python\n",
"\n",
"Simple TF-IDF implementation in Python."
],
"metadata": {
"id": "yJb55YCmf7dA"
}
},
{
"cell_type": "markdown",
"source": [
"First, we'll get the sample data from \"Random Articles.\""
],
"metadata": {
"id": "G4nosQ1fgCA4"
}
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "lR02kJyYbHhx"
},
"outputs": [],
"source": [
"import requests\n",
"\n",
"r = requests.get(\n",
" 'https://gist.github.com/AWeirdScratcher/'\n",
" '100c5ea2f53f98f59857cb63755a9b18/raw/'\n",
" 'c32f490b13e5bad5bad4a8574cd42622bd90752a/data.json'\n",
")\n",
"data = r.json()"
]
},
{
"cell_type": "markdown",
"source": [
"We'll use the `nltk` library to tokenize our words."
],
"metadata": {
"id": "DKic3_24gHt0"
}
},
{
"cell_type": "code",
"source": [
"import nltk\n",
"\n",
"nltk.download('punkt')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Zn62y4-8bOrZ",
"outputId": "fb56636c-4a76-4d55-d212-cb4528b2b9be"
},
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"[nltk_data] Downloading package punkt to /root/nltk_data...\n",
"[nltk_data] Package punkt is already up-to-date!\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"True"
]
},
"metadata": {},
"execution_count": 2
}
]
},
{
"cell_type": "code",
"source": [
"nltk.word_tokenize(\"I love chocolate! I'm happy!\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "aWtjoDHUgNi6",
"outputId": "c0432067-c85a-4612-e971-e68f07d727da"
},
"execution_count": 3,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"['I', 'love', 'chocolate', '!', 'I', \"'m\", 'happy', '!']"
]
},
"metadata": {},
"execution_count": 3
}
]
},
{
"cell_type": "markdown",
"source": [
"Next, we'll re-organize the data (`dict`) and get its tokens. We'll remove the common stopwords in English for better results."
],
"metadata": {
"id": "GYkAO0mjgXiK"
}
},
{
"cell_type": "code",
"source": [
"stop_words = '.,!?<>~&*$#@()_-+=/;\\'\\\"'"
],
"metadata": {
"id": "Udhg1YadgqGG"
},
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"source": [
"documents = [\n",
" {\n",
" 'headline': row['headline'], # Title\n",
" 'tokens': [\n",
" t\n",
" for t in nltk.word_tokenize(row['article'].lower()) # Iterate through every article\n",
" if t not in stop_words # If the token is not a stop word\n",
" ]\n",
" } for row in data\n",
"]"
],
"metadata": {
"id": "vaQxdBxGbSwZ"
},
"execution_count": 5,
"outputs": []
},
{
"cell_type": "code",
"source": [
"documents[0]"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "1J3jZsGrcY5_",
"outputId": "199a0c88-1a4a-4b62-f675-30ce7153e717"
},
"execution_count": 6,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"{'headline': \"Study Finds Link Between Coffee Consumption and Reduced Risk of Alzheimer's Disease\",\n",
" 'tokens': ['a',\n",
" 'recent',\n",
" 'study',\n",
" 'suggests',\n",
" 'that',\n",
" 'regular',\n",
" 'coffee',\n",
" 'consumption',\n",
" 'could',\n",
" 'reduce',\n",
" 'the',\n",
" 'risk',\n",
" 'of',\n",
" 'developing',\n",
" 'alzheimer',\n",
" \"'s\",\n",
" 'disease',\n",
" 'researchers',\n",
" 'found',\n",
" 'that',\n",
" 'certain',\n",
" 'compounds',\n",
" 'in',\n",
" 'coffee',\n",
" 'may',\n",
" 'have',\n",
" 'neuroprotective',\n",
" 'effects',\n",
" 'potentially',\n",
" 'slowing',\n",
" 'down',\n",
" 'the',\n",
" 'progression',\n",
" 'of',\n",
" 'cognitive',\n",
" 'decline']}"
]
},
"metadata": {},
"execution_count": 6
}
]
},
{
"cell_type": "markdown",
"source": [
"Neat! Next up, we'll learn how to calculate TF-IDF."
],
"metadata": {
"id": "unxGHPWlg5Ik"
}
},
{
"cell_type": "markdown",
"source": [
"## TF-IDF Formula\n",
"\n",
"$\n",
"\\text{TF} = \\dfrac{\\text{word_count}}{\\text{words_in_document}}\n",
"$\n",
"\n",
"<br />\n",
"\n",
"$\n",
"\\text{IDF} = 1 + \\log({\\dfrac{\\text{doc_count}}{\\text{docs_with_word}}})\n",
"$\n",
"\n",
"<br />\n",
"\n",
"$\n",
"\\text{TF-IDF} = \\text{TF}\\ \\times\\ \\text{IDF}\n",
"$\n",
"\n",
"<br />\n",
"\n",
"Simple! (Because we have the formula)"
],
"metadata": {
"id": "U-BdjC80g8wb"
}
},
{
"cell_type": "markdown",
"source": [
"Now, let's create a function that calculates TF. Call it `calculate_tf`."
],
"metadata": {
"id": "2V0lUDjwibz0"
}
},
{
"cell_type": "code",
"source": [
"def calculate_tf(query: str, tokens: list[str]):\n",
" return tokens.count(query.lower()) / len(tokens)"
],
"metadata": {
"id": "7qyrkKCEcaXb"
},
"execution_count": 7,
"outputs": []
},
{
"cell_type": "code",
"source": [
"tf0 = calculate_tf(\n",
" \"chocolate\",\n",
" \"Welcome to Willy Wanka's Chocolate factory! We've got the best chocolate out there!\"\n",
")\n",
"tf0"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "zLVnUvmHcm3H",
"outputId": "3174c7a8-0c1b-4b70-d29f-c639afb17076"
},
"execution_count": 8,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0.012048192771084338"
]
},
"metadata": {},
"execution_count": 8
}
]
},
{
"cell_type": "markdown",
"source": [
"You can say, the word \"chocolate\" occupies 1.2% of the whole sentence/document."
],
"metadata": {
"id": "301FGvX7ik3j"
}
},
{
"cell_type": "markdown",
"source": [
"Second of all, IDF. We'll first import the `math` module to use $log$, and define a function named `calculate_idf`.\n",
"\n",
"By the way, `docs_with_word` may be `0`, so to prevent `ZeroDivisonError`, we'll add an extra check."
],
"metadata": {
"id": "efRtVcWTisJI"
}
},
{
"cell_type": "code",
"source": [
"import math\n",
"\n",
"def calculate_idf(doc_count: int, docs_with_word: int):\n",
" # If docs_with_word == 0, we'll end up with-\n",
" # ZeroDivision error, so let's just return 0.\n",
" if not docs_with_word:\n",
" return 0.0\n",
"\n",
" return 1.0 + math.log(doc_count / docs_with_word)"
],
"metadata": {
"id": "atDy9LMZcr3U"
},
"execution_count": 9,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Assuming we have 100 documents, and there are only 5 documents containing a specific word."
],
"metadata": {
"id": "ewlX2_iJjcAu"
}
},
{
"cell_type": "code",
"source": [
"idf0 = calculate_idf(100, 5)\n",
"idf0"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "N_BxzCk8c-kk",
"outputId": "aeb3c6b6-38af-4c5f-da41-25708c23dbeb"
},
"execution_count": 10,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"3.995732273553991"
]
},
"metadata": {},
"execution_count": 10
}
]
},
{
"cell_type": "markdown",
"source": [
"...we get an IDF score of $3.99 ≈ 4.0$. Neat!"
],
"metadata": {
"id": "H-TGNyUBjlag"
}
},
{
"cell_type": "code",
"source": [
"# No error here!\n",
"idf1 = calculate_idf(100, 0)\n",
"idf1"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "gEN-khTljCdg",
"outputId": "1c7a788b-ba0f-4d64-d38d-fd3c656f3827"
},
"execution_count": 11,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0.0"
]
},
"metadata": {},
"execution_count": 11
}
]
},
{
"cell_type": "markdown",
"source": [
"Finally, let's get the TF-IDF score. Define a function named `get_tf_idf`."
],
"metadata": {
"id": "tAprwXnmjHUP"
}
},
{
"cell_type": "code",
"source": [
"def get_tf_idf(tf: float, idf: float):\n",
" return tf * idf"
],
"metadata": {
"id": "5Pi40FtAdBQu"
},
"execution_count": 12,
"outputs": []
},
{
"cell_type": "code",
"source": [
"get_tf_idf(tf0, idf0)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "LBExLzUbdKbZ",
"outputId": "4ffa29ec-03fa-4310-c29f-1ca410951b46"
},
"execution_count": 13,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0.04814135269342158"
]
},
"metadata": {},
"execution_count": 13
}
]
},
{
"cell_type": "code",
"source": [
"# ...along with our idf1 (value is 0)\n",
"get_tf_idf(tf0, idf1)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "mdLhDx1_jVoh",
"outputId": "e99db1f0-7bd3-49c0-88c7-e636099d0cc4"
},
"execution_count": 14,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"0.0"
]
},
"metadata": {},
"execution_count": 14
}
]
},
{
"cell_type": "markdown",
"source": [
"## TF-IDF\n",
"\n",
"Finally! We've been through the hard part, now we can search through our documents using TF-IDF.\n",
"\n",
"We'll create a new function named `search` that implements TF-IDF search.\n",
"\n",
"<br />\n",
"\n",
"> $\\text{docs_with_word} = \\text{docs_count} - \\text{docs_without_word}$\n",
"> <br />\n",
">\n",
"> ...where `docs_without_word` represents documents with a TF score of `0.0` (occupies 0% of the document)."
],
"metadata": {
"id": "7t--bjAbj7q_"
}
},
{
"cell_type": "code",
"source": [
"def search(query: str) -> list[float]:\n",
" # Get the TF's for every document\n",
" tf_s = [calculate_tf(query, doc['tokens']) for doc in documents]\n",
"\n",
" # Get the IDF's (it's like a summary of search)\n",
" idf = calculate_idf(\n",
" len(documents), # Document count\n",
" len(tf_s) - tf_s.count(0.0) # (Document count) - (docs w/out the word) = docs w/ the word\n",
" )\n",
"\n",
" # Get TF-IDF scores\n",
" tf_idfs = [get_tf_idf(tf, idf) for tf in tf_s]\n",
" return tf_idfs"
],
"metadata": {
"id": "pac5KMtPdN2Y"
},
"execution_count": 15,
"outputs": []
},
{
"cell_type": "code",
"source": [
"scores = search(\"the\")\n",
"scores[:10]"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "_mjIM-SzfV44",
"outputId": "059d718c-a191-44ef-ea46-21a417bfa889"
},
"execution_count": 16,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[0.061408917536545905,\n",
" 0.061408917536545905,\n",
" 0.0,\n",
" 0.03684535052192754,\n",
" 0.058176869245148755,\n",
" 0.13398309280700926,\n",
" 0.06316345803759008,\n",
" 0.09211337630481886,\n",
" 0.12632691607518015,\n",
" 0.030704458768272953]"
]
},
"metadata": {},
"execution_count": 16
}
]
},
{
"cell_type": "code",
"source": [
"# Get the top 5 results\n",
"sorted(\n",
" tuple(\n",
" zip([d['headline'] for d in documents], scores)\n",
" ),\n",
" key=lambda t: t[1],\n",
" reverse=True\n",
")[:5]"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "WQtn982Id_R8",
"outputId": "d01daeae-96eb-4e7d-b2e0-921344382cd8"
},
"execution_count": 17,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[('World Health Organization Approves First Malaria Vaccine for Use in Africa',\n",
" 0.1441774585640643),\n",
" ('WHO Declares New Global Health Emergency Amidst Ebola Outbreak in Africa',\n",
" 0.13480006288510077),\n",
" ('Scientists Discover New Species of Deep-Sea Octopus in the Pacific Ocean',\n",
" 0.13398309280700926),\n",
" ('Tokyo Olympics Committee Announces Strict COVID-19 Protocols for Athletes',\n",
" 0.12632691607518015),\n",
" ('Amazon Rainforest Deforestation Reaches Record Levels, Sparking Global Concern',\n",
" 0.12632691607518015)]"
]
},
"metadata": {},
"execution_count": 17
}
]
},
{
"cell_type": "markdown",
"source": [
"## Exercise\n",
"\n",
"You can try to implement a search engine like Google using TF-IDF.\n",
"\n",
"Explore the datasets on [🤗 HuggingFace](https://huggingface.co/datasets)."
],
"metadata": {
"id": "o2jA15Rel5bo"
}
}
]
}
@AWeirdDev
Copy link
Author

$ pip install nltk

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment