Last active
November 8, 2023 00:03
-
-
Save siliconjungle/38863f070f67e5ace68c37e8782393b4 to your computer and use it in GitHub Desktop.
Distance functions
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// These methods were made by chat gpt and have not been tested. | |
const cosineSimilarity = (vecA, vecB) => { | |
if (vecA.length !== vecB.length) { | |
throw new Error("Vectors must be of the same length."); | |
} | |
let dotProduct = vecA.reduce((sum, val, i) => sum + val * vecB[i], 0); | |
let normA = vecA.reduce((sum, val) => sum + val * val, 0); | |
let normB = vecB.reduce((sum, val) => sum + val * val, 0); | |
if (normA === 0 || normB === 0) { | |
throw new Error("Cannot compute cosine similarity for a zero vector."); | |
} | |
return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB)); | |
} | |
const euclideanDistance = (vecA, vecB) => { | |
if (vecA.length !== vecB.length) { | |
throw new Error("Vectors must be of the same length."); | |
} | |
let distance = 0; | |
for (let i = 0; i < vecA.length; i++) { | |
distance += (vecA[i] - vecB[i]) ** 2; | |
} | |
return Math.sqrt(distance); | |
} | |
const jaccardSimilarity = (vecA, vecB) => { | |
if (vecA.length !== vecB.length) { | |
throw new Error("Vectors must be of the same length."); | |
} | |
let intersection = 0; | |
let union = 0; | |
for (let i = 0; i < vecA.length; i++) { | |
if (vecA[i] === 1 && vecB[i] === 1) intersection++; | |
if (vecA[i] === 1 || vecB[i] === 1) union++; | |
} | |
return intersection / union; | |
} | |
const pearsonCorrelation = (vecA, vecB) => { | |
if (vecA.length !== vecB.length) { | |
throw new Error("Vectors must be of the same length."); | |
} | |
let sumA = 0, sumB = 0, sumA2 = 0, sumB2 = 0, sumAB = 0; | |
const length = vecA.length; | |
for (let i = 0; i < length; i++) { | |
sumA += vecA[i]; | |
sumB += vecB[i]; | |
sumA2 += vecA[i] * vecA[i]; | |
sumB2 += vecB[i] * vecB[i]; | |
sumAB += vecA[i] * vecB[i]; | |
} | |
const numerator = sumAB - (sumA * sumB / length); | |
const denominator = Math.sqrt((sumA2 - sumA * sumA / length) * (sumB2 - sumB * sumB / length)); | |
if (denominator === 0) { | |
throw new Error("Denominator is zero, correlation is undefined."); | |
} | |
return numerator / denominator; | |
} | |
const kNearestNeighbors = (embeddings, labels, queryEmbedding, k, similarityFunc) => { | |
// Calculate the similarity between the query and all embeddings | |
let similarities = embeddings.map((embedding, index) => ({ | |
similarity: similarityFunc(embedding, queryEmbedding), | |
label: labels[index] | |
})); | |
// Sort by similarity | |
similarities.sort((a, b) => b.similarity - a.similarity); | |
// Take the top k similarities | |
let kNearest = similarities.slice(0, k); | |
// Count the frequency of labels among the k nearest points | |
let counts = kNearest.reduce((acc, value) => { | |
acc[value.label] = (acc[value.label] || 0) + 1; | |
return acc; | |
}, {}); | |
// Sort the labels by frequency | |
let sortedLabels = Object.keys(counts).sort((a, b) => counts[b] - counts[a]); | |
// Return the label with the highest frequency | |
return sortedLabels[0]; | |
} | |
// Example usage: | |
// Assuming you have text embeddings and their respective labels | |
// const embeddings = [...]; // array of vector embeddings | |
// const labels = [...]; // array of labels corresponding to each embedding | |
// const queryEmbedding = [...]; // a vector embedding of the text you want to classify | |
// const k = 3; // for instance, looking at the 3 nearest neighbors | |
// const predictedLabel = kNearestNeighbors(embeddings, labels, queryEmbedding, k, cosineSimilarity); | |
// console.log(predictedLabel); | |
const kNearestNeighborsMultiLabel = (embeddings, labels, queryEmbedding, k, similarityFunc) => { | |
// Calculate the similarity for each embedding | |
let similarities = embeddings.map((embedding, index) => ({ | |
similarity: similarityFunc(embedding, queryEmbedding), | |
labels: labels[index] // Each index now has an array of labels instead of just one | |
})); | |
// Sort by similarity | |
similarities.sort((a, b) => b.similarity - a.similarity); | |
// Take the top k similarities | |
let kNearest = similarities.slice(0, k); | |
// Count the frequency of each label among the k nearest neighbors | |
let labelCounts = {}; | |
kNearest.forEach(neighbor => { | |
neighbor.labels.forEach(label => { | |
labelCounts[label] = (labelCounts[label] || 0) + 1; | |
}); | |
}); | |
// Sort the labels by frequency and return them | |
return Object.keys(labelCounts).sort((a, b) => labelCounts[b] - labelCounts[a]); | |
} | |
// Example usage: | |
// const embeddings = [...]; // array of vector embeddings | |
// const labels = [['label1', 'label3'], ['label2', 'label3'], ...]; // array of arrays of labels | |
// const queryEmbedding = [...]; // a vector embedding of the text you want to classify | |
// const k = 3; // the number of nearest neighbors to consider | |
// const predictedLabels = kNearestNeighborsMultiLabel(embeddings, labels, queryEmbedding, k, cosineSimilarity); | |
// console.log(predictedLabels); // Might output: ['label3', 'label1', 'label2'] | |
function calculateMean(embeddings) { | |
const mean = new Array(embeddings[0].length).fill(0); | |
for (const embedding of embeddings) { | |
embedding.forEach((value, i) => { | |
mean[i] += value; | |
}); | |
} | |
return mean.map(value => value / embeddings.length); | |
} | |
function interpolate(vecA, vecB, t) { | |
return vecA.map((a, i) => a + t * (vecB[i] - a)); | |
} | |
function normalize(vec) { | |
const norm = Math.sqrt(vec.reduce((acc, val) => acc + val * val, 0)); | |
return vec.map(val => val / norm); | |
} | |
function weightedMean(embeddings, weights) { | |
const mean = new Array(embeddings[0].length).fill(0); | |
let weightSum = 0; | |
embeddings.forEach((embedding, index) => { | |
const weight = weights[index]; | |
embedding.forEach((value, i) => { | |
mean[i] += value * weight; | |
}); | |
weightSum += weight; | |
}); | |
return mean.map(value => value / weightSum); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment