Created
September 14, 2021 09:02
-
-
Save LeopoldTal/bb3286ea5079671a47f5afaf638f66a4 to your computer and use it in GitHub Desktop.
tf-idf tutorial
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
<!DOCTYPE html> | |
<html lang="en"> | |
<head> | |
<meta charset="utf-8" /> | |
<title>tf-idf tutorial</title> | |
<script src="https://d3js.org/d3.v5.min.js"></script> | |
<script src="./set-documents.js"></script> | |
</head> | |
<body> | |
<h1>tf-idf</h1> | |
<figure style="float: right;"> | |
<img src="machine-learning.png" alt="A robot reads a book" /> | |
<figcaption>Language processing</figcaption> | |
</figure> | |
<h2>Goal</h2> | |
<p>You have a whole bunch of documents: articles, product descriptions, projects, etc. You want to find | |
<strong>similarities</strong> between them.</p> | |
<p>Examples: A search engine. A recommendation engine.</p> | |
<p>For this tutorial, I'll use the 5917 featured articles on Wikipedia.</p> | |
<h2>Principle</h2> | |
<p>What does it mean for two documents to be similar, or related? Can we give a formal definition?</p> | |
<p>We're looking for a concept of <strong>distance</strong>: two copies of the same document should be at | |
distance 0. Similar documents should be closer together than totally unrelated documents.</p> | |
<p>A distance… in what? In what vector space?</p> | |
<p>Machines are stupid. They can't understand the documents, just see what words appear in them. Hey, what | |
if we used word count?</p> | |
<p>That defines dimensions: The more times the word "cheese" appears in the document, the further the | |
document is on the "cheese" axis.</p> | |
<p>We're using the <strong>frequency of the term</strong> "cheese": the "tf" in "tf-idf".</p> | |
<figure> | |
<figcaption>Frequency of the word "cheese" in all documents</figcaption> | |
<div id="singleTerm1d"></div> | |
</figure> | |
<p>That defines an absolutely enormous space: one dimension per unique word.</p> | |
<p>There are thousands of words in the English language but if I draw diagrams with thousands of dimensions, | |
they won't be very clear. So I'll just show 2: "horse" and "island".</p> | |
<h2>Term frequency</h2> | |
<p>I lied to you! Counting occurrences of the word isn't the only way to do it. Other ways:</p> | |
<dl> | |
<dt>simple</dt> | |
<dd>Count occurrences. | |
<pre><code>const freqSimple = (document, term) => document.terms[term] || 0;</code></pre> | |
</dd> | |
<dt>normalised</dt> | |
<dd>Count occurrences then divide by total word count in the document, so it's actually a frequency. | |
<pre><code>const totalTermCount = document => Object.values(document.terms).reduce((a, b) => a + b); | |
const freqNorm = (document, term) => freqSimple(document, term) / totalTermCount(document);</code></pre> | |
</dd> | |
<dt>booleean</dt> | |
<dd>Count whether the word appears or not. | |
<pre><code>const freqBool = (document, term) => document.terms[term] ? 1 : 0;</code></pre> | |
</dd> | |
<dt>logarithmic</dt> | |
<dd>Logarithm of the number of occurrences (+ 1 so that it's 0 if missing). | |
<pre><code>const freqLog = (document, term) => Math.log(1 + freqSimple(document, term));</code></pre> | |
</dd> | |
<dt>augmented</dt> | |
<dd>Instead of dividing by total word count, divide by the number of occurrences of the most frequent | |
word. For very long documents, that helps identify which words matter most. | |
<pre><code>const maxWordCount = document => Math.max(...Object.values(document.terms)); | |
const freqAug = (document, term) => 1 + freqSimple(document, term) / maxWordCount(document);</code></pre> | |
</dd> | |
</dl> | |
<p><strong>Mouse over</strong> any point to see the other terms.</p> | |
<figure> | |
<figcaption> | |
<select id="tf2d-select"> | |
<option value="simple">Simple</option> | |
<option value="norm">Normalised</option> | |
<option value="bool">Booleean</option> | |
<option value="log">Logarithmic</option> | |
<option value="aug">Augmented</option> | |
</select> | |
frequency of the terms "horse" and "island" in all documents | |
</figcaption> | |
<div id="tf2d"></div> | |
</figure> | |
<h2>Inverse frequency</h2> | |
<p>Okay, that makes a pretty picture for those 2 words, but it doesn't really work: very common words | |
like "the" always come out on top.</p> | |
<p>I could build a stopword list, but I'm lazy. Also, some words will still be more common.</p> | |
<p>The secret ingredient (Spärck Jones, 1972): divide everything by the log of the number of | |
documents that contain the term.</p> | |
<p><strong>The rarer the term is, the more it counts.</strong></p> | |
<figure style="float: right;"> | |
<img src="spaerck_jones.jpg" alt="Karen Spärck Jones in 2002" /> | |
<figcaption>Professor Karen Spärck Jones</figcaption> | |
</figure> | |
<p>That's the "idf" ("inverse document frequency") in "tf-idf".</p> | |
<pre><code>const docCount = (allDocuments, term) => allDocuments.filter(document => document.terms[term]).length; | |
const idf = (allDocuments, term) => Math.log(allDocuments.length / (1 + docCount(allDocuments, term))); | |
const tfIdf = (allDocuments, document, term, freq) => freq(document, term) * idf(allDocuments, term);</code></pre> | |
<figure> | |
<figcaption> | |
<select id="tfIdf2d-select"> | |
<option value="simple">Simple</option> | |
<option value="norm">Normalised</option> | |
<option value="bool">Boolean</option> | |
<option value="log">Logarithmic</option> | |
<option value="aug">Augmented</option> | |
</select> | |
tf-idf of the terms "horse" and "island" in all documents | |
(mouse over to see other terms) | |
</figcaption> | |
<div id="tfIdf2d"></div> | |
</figure> | |
<p>Success! We get terms strongly related to the topic of each article.</p> | |
<h2>Similarities</h2> | |
<p>So we've transformed each document into (term, tf-idf) pairs. How do you measure distance between | |
two of those?</p> | |
<p>If a term is very frequent in document 1, you want its frequency in document 2 to matter a lot, and | |
vice versa. To model this, <strong>multiply</strong> the tf-idf frequencies of the term in the | |
two documents.</p> | |
<p>The contribution of each term is independent, so we can just <strong>add</strong> them together.</p> | |
<pre><code>const cosineSimilarity = (allDocuments, document1, document2, freq) => { | |
const byTerms = Object.keys(document1.terms).map(term => | |
tfIdf(allDocuments, document1, term, freq) * tfIdf(allDocuments, document2, term, freq) | |
); | |
return byTerms.reduce((a, b) => a + b); | |
};</code></pre> | |
<p>Say, that's a scalar product! It's the cosine of the angle between the vectors of the two | |
documents. Hence the name "cosine similarity".</p> | |
<pre><code>const getMostSimilar = (allDocuments, toDocument, freq) => { | |
const otherDocuments = allDocuments.filter(document => document !== toDocument); | |
const similarities = otherDocuments.map(document => ({ | |
document, | |
similarity: cosineSimilarity(allDocuments, document, toDocument, freq) | |
})); | |
similarities.sort((document1, document2) => document2.similarity - document1.similarity); | |
return similarities.slice(0, 10); | |
};</code></pre> | |
<p>A cool feature of tf-if: you can tell <strong>why</strong> two documents are similar: | |
look at the terms with the biggest product.</p> | |
<div id="similarities"> | |
<caption> | |
Articles similar to | |
<cite id="document-name"></cite> | |
<button type="button" id="random-article">change</button> | |
by | |
<select id="similarity-select"> | |
<option value="simple">simple</option> | |
<option value="norm">normalised</option> | |
<option value="bool">booleean</option> | |
<option value="log">logarithmic</option> | |
<option value="aug">augmented</option> | |
</select> | |
tf-idf | |
</caption> | |
<table id="similarity-table"></table> | |
</div> | |
<p>Question time!</p> | |
<script src="./tough-dough.js"></script> | |
</body> | |
</html> |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Terms for demo | |
const TERM_SINGLE = 'cheese'; | |
const TERM_X = 'horse'; | |
const TERM_Y = 'island'; | |
const NB_TOP_TERMS = 3; | |
// Term frequency functions | |
const freqSimple = (document, term) => document.terms[term] || 0; | |
const totalTermCount = document => { // memoised | |
document.totalTermCount = document.totalTermCount || Object.values(document.terms).reduce((a, b) => a + b); | |
return document.totalTermCount; | |
}; | |
const freqNorm = (document, term) => freqSimple(document, term) / totalTermCount(document); | |
const freqBool = (document, term) => document.terms[term] ? 1 : 0; | |
const freqLog = (document, term) => Math.log(1 + freqSimple(document, term)); | |
const maxWordCount = document => { // memoised | |
document.maxWordCount = document.maxWordCount === undefined | |
? Math.max(...Object.values(document.terms)) | |
: document.maxWordCount; | |
return document.maxWordCount; | |
}; | |
const freqAug = (document, term) => 1 + freqSimple(document, term) / maxWordCount(document); | |
// Inverse document frequency | |
const idfMap = {}; // memoised | |
const docCount = (allDocuments, term) => allDocuments.filter(document => document.terms[term]).length; | |
const idf = (allDocuments, term) => { | |
idfMap[term] = idfMap[term] || Math.log(allDocuments.length / (1 + docCount(allDocuments, term))); | |
return idfMap[term]; | |
}; | |
// tf-idf | |
const tfIdf = (allDocuments, document, term, freq) => freq(document, term) * idf(allDocuments, term); | |
// Similarity | |
// pre-compile tfIdfs for target document | |
const cosineSimilarity = (allDocuments, targetTfIdf, otherDocument, freq) => { | |
const byTerms = targetTfIdf.map(({ term, tdIdfValue }) => ({ | |
term, | |
weight: tdIdfValue * tfIdf(allDocuments, otherDocument, term, freq) | |
})); | |
byTerms.sort((a, b) => b.weight - a.weight); | |
const similarity = byTerms.reduce((acc, term) => acc + term.weight, 0); | |
const why = byTerms.slice(0, NB_TOP_TERMS); | |
return { similarity, why }; | |
}; | |
const getMostSimilar = (allDocuments, toDocument, freq) => { | |
const otherDocuments = allDocuments.filter(document => document !== toDocument); | |
const precompiledTfIdf = Object.keys(toDocument.terms).map(term => ({ | |
term, | |
tdIdfValue: tfIdf(allDocuments, toDocument, term, freq) | |
})); | |
precompiledTfIdf.sort( ({ tdIdfValue: val1 }, { tdIdfValue: val2 }) => val2 - val1 ); | |
const targetTfIdf = precompiledTfIdf.slice(0, 500); // only use most-significant terms | |
const similarities = otherDocuments.map(document => ({ | |
document, | |
...cosineSimilarity(allDocuments, targetTfIdf, document, freq) | |
})); | |
similarities.sort((document1, document2) => document2.similarity - document1.similarity); | |
return similarities.slice(0, 10); | |
}; | |
// Interactive examples | |
const getTicks = values => { | |
const maxValue = Math.max(...values); | |
const maxTick = parseFloat(maxValue.toPrecision(1)); | |
const nbTicks = 5; | |
const ticks = []; | |
for (let ii = 0; ii < nbTicks; ii++) { | |
const tickValue = ii / nbTicks * maxTick; | |
const tickLabel = (Math.round(tickValue * 100) / 100).toString(); | |
ticks.push(tickLabel); | |
} | |
return ticks; | |
}; | |
const drawXAxis = ( | |
axisGroup, | |
label, | |
data, | |
scale, | |
{ width, height } | |
) => { | |
axisGroup.append('svg:line') | |
.attr('x1', 0) | |
.attr('y1', height) | |
.attr('x2', width) | |
.attr('y2', height) | |
.attr('stroke', 'black') | |
.attr('class', 'xTicks'); | |
const labelHeight = height + 15; | |
const ticks = getTicks(data.map(point => point.x)); | |
axisGroup.selectAll('text.xAxisBottom') | |
.data(ticks) | |
.enter() | |
.append('svg:text') | |
.text(count => count) | |
.attr('x', scale) | |
.attr('y', labelHeight) | |
.attr('text-anchor', 'middle') | |
.attr('class', 'xAxisBottom'); | |
axisGroup.append('svg:text') | |
.text(label) | |
.attr('x', width - 20) | |
.attr('y', labelHeight) | |
.attr('text-anchor', 'middle'); | |
}; | |
const drawYAxis = ( | |
axisGroup, | |
label, | |
data, | |
scale, | |
{ width, height } | |
) => { | |
axisGroup.append('svg:line') | |
.attr('x1', width) | |
.attr('y1', 0) | |
.attr('x2', width) | |
.attr('y2', height) | |
.attr('stroke', 'black') | |
.attr('class', 'yTicks'); | |
const labelLeft = width - 30; | |
const ticks = getTicks(data.map(point => point.y)); | |
axisGroup.selectAll('text.yAxisLeft') | |
.data(ticks) | |
.enter() | |
.append('svg:text') | |
.text(count => count) | |
.attr('x', labelLeft) | |
.attr('y', scale) | |
.attr('text-anchor', 'right') | |
.attr('class', 'yAxisLeft'); | |
axisGroup.append('svg:text') | |
.text(label) | |
.attr('x', labelLeft) | |
.attr('y', 20) | |
.attr('text-anchor', 'right'); | |
}; | |
const makeTooltip = visRoot => { | |
const tooltip = visRoot.append('div') | |
.style('display', 'none') | |
.style('position', 'absolute') | |
.style('background-color', 'white') | |
.style('border', 'solid') | |
.style('border-width', '1px') | |
.style('padding', '3px') | |
.attr('class', 'tooltip'); | |
const showTooltip = () => tooltip.style('display', 'block'); | |
const setTooltipText = point => { | |
tooltip | |
.html(point.getTitle()) | |
.style('left', (d3.event.pageX + 5) + 'px') | |
.style('top', d3.event.pageY + 'px') | |
}; | |
return { showTooltip, setTooltipText }; | |
}; | |
// move points slightly to avoid overlap | |
const jiggle = coord => coord - 2 + 4 * Math.random(); | |
const drawCircles = (nodes, { scaleX, scaleY }, { showTooltip, setTooltipText }) => nodes | |
.append('svg:circle') | |
.attr('class', 'nodes') | |
.attr('cx', point => jiggle(scaleX(point.x))) | |
.attr('cy', point => jiggle(scaleY(point.y))) | |
.attr('r', '6px') | |
.attr('stroke', 'black') | |
.attr('fill', 'white') | |
.on('mouseover', showTooltip) | |
.on('mousemove', setTooltipText); | |
// 1D: simple frequency of a single term | |
const set1DExample = () => { | |
const points = window.allDocuments.map(document => ({ | |
getTitle: () => `${document.title} (${freqSimple(document, TERM_SINGLE)})`, | |
x: freqSimple(document, TERM_SINGLE), | |
y: 40 | |
})); | |
const visRoot = d3.select('#singleTerm1d'); | |
const vis = visRoot | |
.append('svg:svg') | |
.attr('width', 620) | |
.attr('height', 80); | |
const scale = coord => 10 + 33 * coord; | |
const axisGroup = vis.append('svg:g'); | |
drawXAxis(axisGroup, TERM_SINGLE, points, scale, { | |
width: 600, | |
height: 45 | |
}); | |
drawCircles( | |
vis.selectAll('circle .nodes').data(points).enter(), | |
{ scaleX: scale, scaleY: y => y }, | |
makeTooltip(visRoot) | |
); | |
}; | |
set1DExample(); | |
// 2D example: raw term frequencies | |
const getFreq = shortName => { | |
const freqMap = { | |
simple: freqSimple, | |
norm: freqNorm, | |
bool: freqBool, | |
log: freqLog, | |
aug: freqAug | |
}; | |
return freqMap[shortName]; | |
}; | |
const getTopTerms = (document, freq) => { | |
const withFreqs = Object.keys(document.terms).map(term => ({ | |
term, | |
termFreq: freq(document, term) | |
})); | |
withFreqs.sort((term1, term2) => term2.termFreq - term1.termFreq); | |
return withFreqs.slice(0, 5); | |
}; | |
const setTf2DExample = freq => { | |
const points = window.allDocuments.map(document => ({ | |
getTitle: () => `<p>${document.title}</p><table>${ | |
getTopTerms(document, freq) | |
.map(({ term, termFreq }) => `<tr><td>${term}</td><td>${termFreq}</td></tr>`) | |
.join('') | |
}</table>`, | |
x: freq(document, TERM_X), | |
y: freq(document, TERM_Y) | |
})); | |
const visRoot = d3.select('#tf2d'); | |
visRoot.selectAll('*').remove(); | |
const vis = visRoot | |
.append('svg:svg') | |
.attr('width', 650) | |
.attr('height', 590); | |
const maxCoord = Math.max( | |
...points.map(point => point.x), | |
...points.map(point => point.y) | |
); | |
const scaleX = coord => 40 + 500 * coord / maxCoord; | |
const scaleY = coord => 545 - 500 * coord / maxCoord; | |
const axisGroup = vis.append('svg:g'); | |
drawXAxis(axisGroup, TERM_X, points, scaleX, { | |
width: 600, | |
height: 550 | |
}); | |
drawYAxis(axisGroup, TERM_Y, points, scaleY, { | |
width: 45, | |
height: 550 | |
}); | |
drawCircles( | |
vis.selectAll('circle .nodes').data(points).enter(), | |
{ scaleX, scaleY }, | |
makeTooltip(visRoot) | |
); | |
}; | |
const updateTf2DExample = e => setTf2DExample(getFreq(e.target.value)); | |
document.getElementById('tf2d-select').addEventListener('change', updateTf2DExample); | |
setTf2DExample(freqSimple); | |
// 2D example: tf-idf | |
const setTfIdf2DExample = freq => { | |
const toTfIdf = (document, term) => tfIdf(window.allDocuments, document, term, freq); | |
const points = window.allDocuments.map(document => ({ | |
getTitle: () => `<p>${document.title}</p><table>${ | |
getTopTerms(document, toTfIdf) | |
.map(({ term, termFreq }) => `<tr><td>${term}</td><td>${termFreq}</td></tr>`) | |
.join('') | |
}</table>`, | |
x: toTfIdf(document, TERM_X), | |
y: toTfIdf(document, TERM_Y) | |
})); | |
const visRoot = d3.select('#tfIdf2d'); | |
visRoot.selectAll('*').remove(); | |
const vis = visRoot | |
.append('svg:svg') | |
.attr('width', 650) | |
.attr('height', 590); | |
const maxCoord = Math.max( | |
...points.map(point => point.x), | |
...points.map(point => point.y) | |
); | |
const scaleX = coord => 40 + 500 * coord / maxCoord; | |
const scaleY = coord => 545 - 500 * coord / maxCoord; | |
const axisGroup = vis.append('svg:g'); | |
drawXAxis(axisGroup, TERM_X, points, scaleX, { | |
width: 600, | |
height: 550 | |
}); | |
drawYAxis(axisGroup, TERM_Y, points, scaleY, { | |
width: 45, | |
height: 550 | |
}); | |
drawCircles( | |
vis.selectAll('circle .nodes').data(points).enter(), | |
{ scaleX, scaleY }, | |
makeTooltip(visRoot) | |
); | |
}; | |
const updateTfIdf2DExample = e => setTfIdf2DExample(getFreq(e.target.value)); | |
document.getElementById('tfIdf2d-select').addEventListener('change', updateTfIdf2DExample); | |
setTfIdf2DExample(freqSimple); | |
// Similarities | |
const setSimilarities = (toDocument, freq) => { | |
const nameDisplay = window.document.getElementById('document-name'); | |
nameDisplay.innerText = toDocument.title; | |
const table = window.document.getElementById('similarity-table'); | |
table.innerHTML = ''; | |
const headerRow = window.document.createElement('tr'); | |
headerRow.innerHTML = `<th>Article</th><th>Similarity</th><th>Top terms</th>`; | |
table.appendChild(headerRow); | |
const similarities = getMostSimilar(window.allDocuments, toDocument, freq); | |
similarities.forEach(({ document: { title }, similarity, why }) => { | |
const row = window.document.createElement('tr'); | |
const topTerms = why.map( | |
({ term, weight }) => `${term}<small> (${weight.toPrecision(5)})</small>` | |
).join(', '); | |
row.innerHTML = `<td>${title}</td><td>${similarity.toPrecision(6)}</td><td>${topTerms}</td>`; | |
table.appendChild(row); | |
}); | |
}; | |
const getRandomDocument = () => window.allDocuments[ | |
Math.floor(window.allDocuments.length * Math.random()) | |
]; | |
window.selectedDocument = getRandomDocument(); | |
const changeSelectedDocument = () => { | |
window.selectedDocument = getRandomDocument(); | |
setSimilarities( | |
window.selectedDocument, | |
getFreq(document.getElementById('similarity-select').value) | |
); | |
}; | |
document.getElementById('random-article').addEventListener('click', changeSelectedDocument); | |
const updateSimilarities = e => setSimilarities(window.selectedDocument, getFreq(e.target.value)); | |
document.getElementById('similarity-select').addEventListener('change', updateSimilarities); | |
setSimilarities(window.selectedDocument, freqSimple); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment