Created
April 26, 2024 13:32
-
-
Save AWeirdDev/a818a957e7a0b4933d8791a41007ae21 to your computer and use it in GitHub Desktop.
TF-IDF in Python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [] | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# TF-IDF in Python\n", | |
"\n", | |
"Simple TF-IDF implementation in Python." | |
], | |
"metadata": { | |
"id": "yJb55YCmf7dA" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"First, we'll get the sample data from \"Random Articles.\"" | |
], | |
"metadata": { | |
"id": "G4nosQ1fgCA4" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"id": "lR02kJyYbHhx" | |
}, | |
"outputs": [], | |
"source": [ | |
"import requests\n", | |
"\n", | |
"r = requests.get(\n", | |
" 'https://gist.github.com/AWeirdScratcher/'\n", | |
" '100c5ea2f53f98f59857cb63755a9b18/raw/'\n", | |
" 'c32f490b13e5bad5bad4a8574cd42622bd90752a/data.json'\n", | |
")\n", | |
"data = r.json()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"We'll use the `nltk` library to tokenize our words." | |
], | |
"metadata": { | |
"id": "DKic3_24gHt0" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import nltk\n", | |
"\n", | |
"nltk.download('punkt')" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "Zn62y4-8bOrZ", | |
"outputId": "fb56636c-4a76-4d55-d212-cb4528b2b9be" | |
}, | |
"execution_count": 2, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"[nltk_data] Downloading package punkt to /root/nltk_data...\n", | |
"[nltk_data] Package punkt is already up-to-date!\n" | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"True" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 2 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"nltk.word_tokenize(\"I love chocolate! I'm happy!\")" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "aWtjoDHUgNi6", | |
"outputId": "c0432067-c85a-4612-e971-e68f07d727da" | |
}, | |
"execution_count": 3, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"['I', 'love', 'chocolate', '!', 'I', \"'m\", 'happy', '!']" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 3 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"Next, we'll re-organize the data (`dict`) and get its tokens. We'll remove the common stopwords in English for better results." | |
], | |
"metadata": { | |
"id": "GYkAO0mjgXiK" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"stop_words = '.,!?<>~&*$#@()_-+=/;\\'\\\"'" | |
], | |
"metadata": { | |
"id": "Udhg1YadgqGG" | |
}, | |
"execution_count": 4, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"documents = [\n", | |
" {\n", | |
" 'headline': row['headline'], # Title\n", | |
" 'tokens': [\n", | |
" t\n", | |
" for t in nltk.word_tokenize(row['article'].lower()) # Iterate through every article\n", | |
" if t not in stop_words # If the token is not a stop word\n", | |
" ]\n", | |
" } for row in data\n", | |
"]" | |
], | |
"metadata": { | |
"id": "vaQxdBxGbSwZ" | |
}, | |
"execution_count": 5, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"documents[0]" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "1J3jZsGrcY5_", | |
"outputId": "199a0c88-1a4a-4b62-f675-30ce7153e717" | |
}, | |
"execution_count": 6, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"{'headline': \"Study Finds Link Between Coffee Consumption and Reduced Risk of Alzheimer's Disease\",\n", | |
" 'tokens': ['a',\n", | |
" 'recent',\n", | |
" 'study',\n", | |
" 'suggests',\n", | |
" 'that',\n", | |
" 'regular',\n", | |
" 'coffee',\n", | |
" 'consumption',\n", | |
" 'could',\n", | |
" 'reduce',\n", | |
" 'the',\n", | |
" 'risk',\n", | |
" 'of',\n", | |
" 'developing',\n", | |
" 'alzheimer',\n", | |
" \"'s\",\n", | |
" 'disease',\n", | |
" 'researchers',\n", | |
" 'found',\n", | |
" 'that',\n", | |
" 'certain',\n", | |
" 'compounds',\n", | |
" 'in',\n", | |
" 'coffee',\n", | |
" 'may',\n", | |
" 'have',\n", | |
" 'neuroprotective',\n", | |
" 'effects',\n", | |
" 'potentially',\n", | |
" 'slowing',\n", | |
" 'down',\n", | |
" 'the',\n", | |
" 'progression',\n", | |
" 'of',\n", | |
" 'cognitive',\n", | |
" 'decline']}" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 6 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"Neat! Next up, we'll learn how to calculate TF-IDF." | |
], | |
"metadata": { | |
"id": "unxGHPWlg5Ik" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## TF-IDF Formula\n", | |
"\n", | |
"$\n", | |
"\\text{TF} = \\dfrac{\\text{word_count}}{\\text{words_in_document}}\n", | |
"$\n", | |
"\n", | |
"<br />\n", | |
"\n", | |
"$\n", | |
"\\text{IDF} = 1 + \\log({\\dfrac{\\text{doc_count}}{\\text{docs_with_word}}})\n", | |
"$\n", | |
"\n", | |
"<br />\n", | |
"\n", | |
"$\n", | |
"\\text{TF-IDF} = \\text{TF}\\ \\times\\ \\text{IDF}\n", | |
"$\n", | |
"\n", | |
"<br />\n", | |
"\n", | |
"Simple! (Because we have the formula)" | |
], | |
"metadata": { | |
"id": "U-BdjC80g8wb" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"Now, let's create a function that calculates TF. Call it `calculate_tf`." | |
], | |
"metadata": { | |
"id": "2V0lUDjwibz0" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def calculate_tf(query: str, tokens: list[str]):\n", | |
" return tokens.count(query.lower()) / len(tokens)" | |
], | |
"metadata": { | |
"id": "7qyrkKCEcaXb" | |
}, | |
"execution_count": 7, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"tf0 = calculate_tf(\n", | |
" \"chocolate\",\n", | |
" \"Welcome to Willy Wanka's Chocolate factory! We've got the best chocolate out there!\"\n", | |
")\n", | |
"tf0" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "zLVnUvmHcm3H", | |
"outputId": "3174c7a8-0c1b-4b70-d29f-c639afb17076" | |
}, | |
"execution_count": 8, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"0.012048192771084338" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 8 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"You can say, the word \"chocolate\" occupies 1.2% of the whole sentence/document." | |
], | |
"metadata": { | |
"id": "301FGvX7ik3j" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"Second of all, IDF. We'll first import the `math` module to use $log$, and define a function named `calculate_idf`.\n", | |
"\n", | |
"By the way, `docs_with_word` may be `0`, so to prevent `ZeroDivisonError`, we'll add an extra check." | |
], | |
"metadata": { | |
"id": "efRtVcWTisJI" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import math\n", | |
"\n", | |
"def calculate_idf(doc_count: int, docs_with_word: int):\n", | |
" # If docs_with_word == 0, we'll end up with-\n", | |
" # ZeroDivision error, so let's just return 0.\n", | |
" if not docs_with_word:\n", | |
" return 0.0\n", | |
"\n", | |
" return 1.0 + math.log(doc_count / docs_with_word)" | |
], | |
"metadata": { | |
"id": "atDy9LMZcr3U" | |
}, | |
"execution_count": 9, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"Assuming we have 100 documents, and there are only 5 documents containing a specific word." | |
], | |
"metadata": { | |
"id": "ewlX2_iJjcAu" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"idf0 = calculate_idf(100, 5)\n", | |
"idf0" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "N_BxzCk8c-kk", | |
"outputId": "aeb3c6b6-38af-4c5f-da41-25708c23dbeb" | |
}, | |
"execution_count": 10, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"3.995732273553991" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 10 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"...we get an IDF score of $3.99 ≈ 4.0$. Neat!" | |
], | |
"metadata": { | |
"id": "H-TGNyUBjlag" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# No error here!\n", | |
"idf1 = calculate_idf(100, 0)\n", | |
"idf1" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "gEN-khTljCdg", | |
"outputId": "1c7a788b-ba0f-4d64-d38d-fd3c656f3827" | |
}, | |
"execution_count": 11, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"0.0" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 11 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"Finally, let's get the TF-IDF score. Define a function named `get_tf_idf`." | |
], | |
"metadata": { | |
"id": "tAprwXnmjHUP" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def get_tf_idf(tf: float, idf: float):\n", | |
" return tf * idf" | |
], | |
"metadata": { | |
"id": "5Pi40FtAdBQu" | |
}, | |
"execution_count": 12, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"get_tf_idf(tf0, idf0)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "LBExLzUbdKbZ", | |
"outputId": "4ffa29ec-03fa-4310-c29f-1ca410951b46" | |
}, | |
"execution_count": 13, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"0.04814135269342158" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 13 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# ...along with our idf1 (value is 0)\n", | |
"get_tf_idf(tf0, idf1)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "mdLhDx1_jVoh", | |
"outputId": "e99db1f0-7bd3-49c0-88c7-e636099d0cc4" | |
}, | |
"execution_count": 14, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"0.0" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 14 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## TF-IDF\n", | |
"\n", | |
"Finally! We've been through the hard part, now we can search through our documents using TF-IDF.\n", | |
"\n", | |
"We'll create a new function named `search` that implements TF-IDF search.\n", | |
"\n", | |
"<br />\n", | |
"\n", | |
"> $\\text{docs_with_word} = \\text{docs_count} - \\text{docs_without_word}$\n", | |
"> <br />\n", | |
">\n", | |
"> ...where `docs_without_word` represents documents with a TF score of `0.0` (occupies 0% of the document)." | |
], | |
"metadata": { | |
"id": "7t--bjAbj7q_" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def search(query: str) -> list[float]:\n", | |
" # Get the TF's for every document\n", | |
" tf_s = [calculate_tf(query, doc['tokens']) for doc in documents]\n", | |
"\n", | |
" # Get the IDF's (it's like a summary of search)\n", | |
" idf = calculate_idf(\n", | |
" len(documents), # Document count\n", | |
" len(tf_s) - tf_s.count(0.0) # (Document count) - (docs w/out the word) = docs w/ the word\n", | |
" )\n", | |
"\n", | |
" # Get TF-IDF scores\n", | |
" tf_idfs = [get_tf_idf(tf, idf) for tf in tf_s]\n", | |
" return tf_idfs" | |
], | |
"metadata": { | |
"id": "pac5KMtPdN2Y" | |
}, | |
"execution_count": 15, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"scores = search(\"the\")\n", | |
"scores[:10]" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "_mjIM-SzfV44", | |
"outputId": "059d718c-a191-44ef-ea46-21a417bfa889" | |
}, | |
"execution_count": 16, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"[0.061408917536545905,\n", | |
" 0.061408917536545905,\n", | |
" 0.0,\n", | |
" 0.03684535052192754,\n", | |
" 0.058176869245148755,\n", | |
" 0.13398309280700926,\n", | |
" 0.06316345803759008,\n", | |
" 0.09211337630481886,\n", | |
" 0.12632691607518015,\n", | |
" 0.030704458768272953]" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 16 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Get the top 5 results\n", | |
"sorted(\n", | |
" tuple(\n", | |
" zip([d['headline'] for d in documents], scores)\n", | |
" ),\n", | |
" key=lambda t: t[1],\n", | |
" reverse=True\n", | |
")[:5]" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "WQtn982Id_R8", | |
"outputId": "d01daeae-96eb-4e7d-b2e0-921344382cd8" | |
}, | |
"execution_count": 17, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"[('World Health Organization Approves First Malaria Vaccine for Use in Africa',\n", | |
" 0.1441774585640643),\n", | |
" ('WHO Declares New Global Health Emergency Amidst Ebola Outbreak in Africa',\n", | |
" 0.13480006288510077),\n", | |
" ('Scientists Discover New Species of Deep-Sea Octopus in the Pacific Ocean',\n", | |
" 0.13398309280700926),\n", | |
" ('Tokyo Olympics Committee Announces Strict COVID-19 Protocols for Athletes',\n", | |
" 0.12632691607518015),\n", | |
" ('Amazon Rainforest Deforestation Reaches Record Levels, Sparking Global Concern',\n", | |
" 0.12632691607518015)]" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 17 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Exercise\n", | |
"\n", | |
"You can try to implement a search engine like Google using TF-IDF.\n", | |
"\n", | |
"Explore the datasets on [🤗 HuggingFace](https://huggingface.co/datasets)." | |
], | |
"metadata": { | |
"id": "o2jA15Rel5bo" | |
} | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
$ pip install nltk