Last active
March 1, 2024 04:48
-
-
Save SwadicalRag/6f966f172b87d7323cdee5bc9a719d2b to your computer and use it in GitHub Desktop.
AUROC / AUPRC / binary classifier statistics in typescript
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import * as fs from "fs"; | |
/** | |
* Class representing evaluation metrics for a binary classifier system. | |
* This class is designed to calculate and analyze the performance of a binary classifier | |
* as its discrimination threshold is varied, including ROC curve analysis and Precision-Recall curve analysis. | |
*/ | |
export class BinaryClassifierStatistics { | |
/** Sorted array of unique thresholds from the scores in descending order */ | |
thresholds: number[] = []; | |
/** | |
* Creates an instance of the ROC class. | |
* @param trueLabels - The array of true binary labels of the instances (1 for positive and 0 for negative). | |
* @param scores - The array of scores or probabilities as estimated by the model, corresponding to the true labels. | |
*/ | |
constructor(public id: string, public trueLabels: (1 | 0)[] = [], public scores: number[] = []) { | |
this.validate(); | |
this.recalculateThresholds(); | |
} | |
/** | |
* Loads a `BinaryClassifierStatistics` instance from a specified file. | |
* This static method creates a new instance of the class, reads data from the given file path, | |
* deserializes the JSON content into the class properties, and returns the populated instance. | |
* | |
* @param path - The file path from which to load the serialized class instance data. | |
* @return A new instance of `BinaryClassifierStatistics` populated with the data from the file. | |
*/ | |
static load(path: string) { | |
const res = new BinaryClassifierStatistics(""); | |
res.load(path); | |
return res; | |
} | |
/** | |
* Loads data into the current instance from a specified file. | |
* This method reads data from the given file path, deserializes the JSON content, | |
* and updates the current instance's properties with the deserialized data. | |
* | |
* @param path - The file path from which to load the serialized data. | |
*/ | |
load(path: string) { | |
this.deserialize(fs.readFileSync(path).toString()); | |
} | |
/** | |
* Saves the current instance's data to a specified file. | |
* This method serializes the instance's properties (id, trueLabels, scores, thresholds) | |
* into JSON format and writes this data to the given file path, overwriting any existing content. | |
* | |
* @param path - The file path to which the serialized data will be saved. | |
*/ | |
save(path: string) { | |
fs.writeFileSync(path, this.serialize()); | |
} | |
/** | |
* Serializes the current instance's data into a JSON string. | |
* This method converts the instance's properties (id, trueLabels, scores, thresholds) | |
* into a JSON string format for easy storage or transmission. | |
* | |
* @return A JSON string representation of the instance's data. | |
*/ | |
serialize() { | |
return JSON.stringify({ | |
id: this.id, | |
trueLabels: this.trueLabels, | |
scores: this.scores, | |
thresholds: this.thresholds, | |
}); | |
} | |
/** | |
* Deserializes data from a JSON string into the instance's properties. | |
* This method parses a JSON string to update the instance's properties (id, trueLabels, scores, thresholds) | |
* with the data from the JSON string, effectively loading the state from a serialized format. | |
* | |
* @param data - The JSON string from which to deserialize the data. | |
*/ | |
deserialize(data: string) { | |
const deserialized = JSON.parse(data); | |
this.id = deserialized.id; | |
this.trueLabels = deserialized.trueLabels; | |
this.scores = deserialized.scores; | |
this.thresholds = deserialized.thresholds; | |
} | |
/** | |
* Adds a data point to the internal buffer | |
* @param trueLabel the known, true binary label of the data instance | |
* @param labelProbability the inferred probability of the data instance | |
*/ | |
addData(trueLabel: boolean | 1 | 0, labelProbability: number) { | |
this.trueLabels.push(trueLabel ? 1 : 0); | |
this.scores.push(labelProbability); | |
this.recalculateThresholds(); | |
} | |
/** | |
* Updates the data bufffers of an instance of the ROC class. | |
* @param trueLabels - The array of true binary labels of the instances (1 for positive and 0 for negative). | |
* @param scores - The array of scores or probabilities as estimated by the model, corresponding to the true labels. | |
*/ | |
setData(trueLabels: (1 | 0)[], scores: number[]) { | |
this.trueLabels = trueLabels; | |
this.scores = scores; | |
this.validate(); | |
this.recalculateThresholds(); | |
} | |
validate() { | |
if(this.trueLabels.length !== this.scores.length) { | |
throw new Error("true label / prediction array length mismatch"); | |
} | |
} | |
recalculateThresholds() { | |
this.thresholds = Array.from(new Set(this.scores)).sort((a, b) => b - a); | |
} | |
/** | |
* Calculates detailed statistics for each threshold | |
* | |
* This method iterates over each unique threshold to determine these statistics, which are | |
* crucial for evaluating the performance of a binary classification model. | |
* | |
* @returns An array of objects, each representing statistics at a specific threshold | |
*/ | |
calculateStatistics() { | |
const results = this.thresholds.map(threshold => { | |
return { | |
/** threshold used to generate statistics */ | |
threshold: threshold, | |
...this.calculateStatisticsAtThreshold(threshold), | |
}; | |
}); | |
return results; | |
} | |
/** | |
* Calculates detailed statistics for a specified threshold | |
*/ | |
calculateStatisticsAtThreshold(threshold: number) { | |
let TP = 0, FP = 0, TN = 0, FN = 0; | |
for (let i = 0; i < this.scores.length; i++) { | |
if (this.scores[i] >= threshold) { | |
if (this.trueLabels[i] === 1) { | |
TP++; | |
} else { | |
FP++; | |
} | |
} else { | |
if (this.trueLabels[i] === 1) { | |
FN++; | |
} else { | |
TN++; | |
} | |
} | |
} | |
const TPR = TP + FN === 0 ? 0 : TP / (TP + FN); | |
const FPR = FP + TN === 0 ? 0 : FP / (FP + TN); | |
const FNR = TP + FN === 0 ? 0 : FN / (TP + FN); | |
const TNR = TN + FP === 0 ? 0 : TN / (TN + FP); | |
const PPV = TP + FP === 0 ? 0 : TP / (TP + FP); | |
const NPV = TN + FN === 0 ? 0 : TN / (FN + TN); | |
const FOR = FN + TN === 0 ? 0 : FN / (FN + TN); | |
const FDR = TP + FP === 0 ? 0 : FP / (TP + FP); | |
const Accuracy = TP + FN + FP + TN === 0 ? 0 : (TP + TN) / (TP + FN + FP + TN); | |
const BalancedAccuracy = (TPR + TNR) / 2; | |
const Informedness = TPR + TNR - 1; | |
const Markedness = PPV + NPV - 1; | |
const FM = PPV + TPR === 0 ? 0 : Math.sqrt(PPV * TPR); | |
const MCC = (TP + FN) * (TP + FP) * (TN + FP) * (TN + FN) === 0 ? 0 : (TP * TN - FP * FN) / Math.sqrt((TP + FN) * (TP + FP) * (TN + FP) * (TN + FN)); | |
const PT = TPR === FPR ? 0 : (Math.sqrt(TPR * FPR) - FPR) / (TPR - FPR); | |
const PLR = FPR === 0 ? Infinity : TPR / FPR; | |
const NLR = TNR === 0 ? 0 : FNR / TNR; | |
const DOR = NLR === 0 ? Infinity : PLR / NLR; | |
const CSI = TP + FN + FP === 0 ? 0 : TP / (TP + FN + FP); | |
return { | |
/** True positives: The number of instances correctly identified as positive */ | |
TP, | |
/** False positives: The number of instances incorrectly identified as positive */ | |
FP, | |
/** True negatives: The number of instances correctly identified as negative */ | |
TN, | |
/** False negatives: The number of instances incorrectly identified as negative */ | |
FN, | |
/** The proportion of true results (both true positives and true negatives) among the | |
* total number of cases examined. It measures the overall correctness of the model. */ | |
Accuracy, | |
/** The average of the proportion of true results in each class (sensitivity and | |
* specificity). It is particularly useful in situations where the classes are imbalanced. */ | |
BalancedAccuracy, | |
/** True Positive Rate: Also known as sensitivity or recall, it measures the proportion | |
* of actual positives that are correctly identified. A higher TPR indicates a model's | |
* better performance in identifying positive cases. */ | |
TPR, | |
/** False Positive Rate: It measures the proportion of actual negatives incorrectly identified | |
* as positives. */ | |
FPR, | |
/** False Negative Rate: It measures the proportion of actual positives incorrectly identified | |
* as negatives. */ | |
FNR, | |
/** True Negative Rate: Also known as specificity or selectivity, it measures the proportion | |
* of actual negatives that are correctly identified. A higher TNR indicates a model's | |
* better performance in identifying negative cases. */ | |
TNR, | |
/** Positive Predictive Value: Also known as precision, it measures the proportion | |
* of positive identifications that were actually correct. A higher PPV indicates a model's | |
* better performance in predicting positive cases accurately. */ | |
PPV, | |
/** Negative Predictive Value: It measures the proportion of negative identifications that were | |
* actually correct. */ | |
NPV, | |
/** False Omission Rate: Measures the proportion of negative predictions that were incorrect */ | |
FOR, | |
/** False Discovery Rate: The proportion of positive predictions that were incorrect. */ | |
FDR, | |
/** Fowlkes-Mallows Index: A measure that combines precision and recall into | |
* a single metric. It is the geometric mean of precision (PPV) and recall (TPR). | |
* | |
* An FM score ranges from 0 to 1, where 1 indicates perfect precision and recall. A higher FM score | |
* suggests that the model effectively identifies positive instances and that the positive predictions | |
* it makes are reliable. */ | |
FM, | |
/** Matthews Correlation Coefficient: Also known as the phi coefficient. A correlation coefficient between | |
* the observed and predicted classifications. It takes into account true and false positives and negatives | |
* and is considered a balanced measure that can be used even if the classes are of very different sizes. | |
* | |
* The MCC value ranges from -1 to 1. A coefficient of +1 represents a perfect prediction, 0 an average | |
* random prediction, and -1 an inverse prediction. This metric is particularly useful because it | |
* remains informative even when the dataset is imbalanced. */ | |
MCC, | |
/** Prevalence Threshold: Refers to the point at which the prevalence of the condition being tested for | |
* makes the model's positive predictive value (PPV) equal to its sensitivity (TPR). | |
* | |
* PT provides insight into the effectiveness of a test or model across different prevalence rates, | |
* highlighting the importance of considering disease prevalence when evaluating test performance. */ | |
PT, | |
/** Positive Likelihood Ratio: Indicates how much the odds of the disease increase when a test is positive. */ | |
PLR, | |
/** Negative Likelihood Ratio: Indicates how much the odds of the disease decrease when a test is negative. */ | |
NLR, | |
/** Diagnostic Odds Ratio): The ratio of the odds of the test being positive if the subject has a condition | |
* versus the odds of the test being positive if the subject does not have the condition. */ | |
DOR, | |
/** Critical Success Index / Threat Score: Measures the proportion of correct positive predictions out of | |
* all instances that were predicted positive or were actually positive. It is similar to the F1 score | |
* but does not consider true negatives in its calculation. | |
* | |
* CSI values range from 0 to 1, where 1 indicates perfect performance in predicting positive instances. | |
* It is particularly useful where the focus is on correctly predicting rare events. */ | |
CSI, | |
/** Also known as Youden's index. Measures the probability that a prediction is informed in relation to the | |
* actual class. | |
* | |
* It ranges from -1 to 1, where 1 indicates perfect knowledge (all predictions are correct), 0 indicates | |
* no better than random guessing, and -1 indicates total disagreement between prediction and actual class. */ | |
Informedness, | |
/** Measures the probability that the actual class is correctly informed by the prediction. | |
* | |
* Like Informedness, it ranges from -1 to 1. A value of 1 indicates perfect marking (all | |
* actual classes are predicted correctly), 0 indicates no better than random marking, and | |
* -1 indicates complete misclassification. */ | |
Markedness, | |
/** F1 Score: The harmonic mean of precision and recall, providing a single metric to assess | |
* the balance between them. A higher F1 score indicates a model's better overall performance | |
* in terms of both precision and recall. */ | |
F1: (PPV + TPR) === 0 ? 0 : (2 * PPV * TPR) / (PPV + TPR), | |
}; | |
} | |
/** | |
* Calculates detailed classifier statistics at threshold 0.5, alongside epidemiological statistics | |
*/ | |
calculateSnapshotStatistics() { | |
return { | |
TotalPopulation: this.trueLabels.length, | |
Prevalance: this.trueLabels.filter(value => value === 1).length / this.trueLabels.length, | |
...this.calculateStatisticsAtThreshold(0.5), | |
}; | |
} | |
/** | |
* Calculates the Area Under the Receiver Operating Characteristic Curve (AUROC). | |
* This method uses the trapezoidal rule to approximate the area under the ROC curve, | |
* which is a measure of the model's ability to discriminate between the positive and negative classes. | |
* @returns The calculated AUROC value. | |
*/ | |
calculateAUROC() { | |
const points = this.calculateStatistics(); // Get the TPR and FPR points for the ROC curve | |
let auc = 0; // Initialize AUROC | |
for (let i = 0; i < points.length - 1; i++) { | |
const xDiff = points[i + 1].FPR - points[i].FPR; // Difference in FPR between consecutive points | |
const yAvg = (points[i].TPR + points[i + 1].TPR) / 2; // Average TPR of the two points | |
auc += xDiff * yAvg; // Increment AUROC using the trapezoidal rule | |
} | |
return auc; // Return the computed AUROC value | |
} | |
/** | |
* Calculates the Area Under the Precision-Recall Curve (AUPRC). | |
* | |
* The AUPRC is a valuable metric for evaluating the performance of a binary classifier, especially in datasets | |
* where the positive class is rare. This method approximates the AUPRC using the trapezoidal rule, based on | |
* the precision (positive predictive value) and recall (true positive rate) at various thresholds. It provides | |
* an aggregate measure of the model's ability to identify positive instances accurately across different | |
* threshold settings, emphasizing the balance between precision and recall in the presence of class imbalance. | |
* | |
* @returns The calculated AUPRC value, representing the model's average precision across all levels | |
* of recall. Higher AUPRC values indicate better model performance, particularly in its ability to prioritize | |
* the correct identification of positive instances while minimizing false positives. | |
*/ | |
calculateAUPRC() { | |
const stats = this.calculateStatistics().sort((a, b) => a.TPR - b.TPR); // Ensure stats are sorted by recall | |
let auprc = 0; // Initialize AUPRC | |
for (let i = 0; i < stats.length - 1; i++) { | |
// Calculate the difference in recall between consecutive points | |
const recallDiff = stats[i + 1].TPR - stats[i].TPR; | |
// Calculate the average precision of the two points | |
const precisionAvg = (stats[i].PPV + stats[i + 1].PPV) / 2; | |
// Increment AUPRC using the trapezoidal rule | |
auprc += recallDiff * precisionAvg; | |
} | |
return auprc; // Return the computed AUPRC value | |
} | |
/** | |
* Calculates the optimal threshold based on maximizing the F1 score. | |
* The F1 score is the harmonic mean of precision and recall, providing a balance between the two. | |
* This method iterates over all possible thresholds to find the one with the highest F1 score. | |
* | |
* @returns An object containing the optimal threshold and the corresponding F1 score. | |
*/ | |
calculateOptimalThresholdUsingF1Score() { | |
let optimalThreshold = 0; | |
let maxF1 = 0; | |
let thresholdStats: ReturnType<BinaryClassifierStatistics["calculateStatisticsAtThreshold"]> | undefined; | |
this.thresholds.forEach(threshold => { | |
const stats = this.calculateStatisticsAtThreshold(threshold); | |
const { F1 } = stats; | |
if (F1 > maxF1) { | |
maxF1 = F1; | |
optimalThreshold = threshold; | |
thresholdStats = stats; | |
} | |
}); | |
return { | |
optimalThreshold: optimalThreshold, | |
maxF1: maxF1, | |
thresholdStats, | |
}; | |
} | |
/** | |
* Calculates the optimal threshold based on maximizing Youden's index. | |
* Youden's index is defined as J = sensitivity + specificity - 1, which maximizes | |
* the classifier's performance by considering both true positive and true negative rates. | |
* | |
* @returns An object containing the optimal threshold and the corresponding Youden's index. | |
*/ | |
calculateOptimalThresholdUsingYoudensIndex() { | |
let optimalThreshold = 0; | |
let maxYoudenIndex = -1; // Initialize with -1, the minimum possible value for J | |
let thresholdStats: ReturnType<BinaryClassifierStatistics["calculateStatisticsAtThreshold"]> | undefined; | |
this.thresholds.forEach(threshold => { | |
const stats = this.calculateStatisticsAtThreshold(threshold); | |
const { Informedness } = stats; | |
if (Informedness > maxYoudenIndex) { | |
maxYoudenIndex = Informedness; | |
optimalThreshold = threshold; | |
thresholdStats = stats; | |
} | |
}); | |
return { | |
optimalThreshold: optimalThreshold, | |
maxYoudenIndex: maxYoudenIndex, | |
thresholdStats, | |
}; | |
} | |
/** | |
* Calculates the optimal threshold based on maximizing the Matthews Correlation Coefficient (MCC). | |
* MCC is considered a balanced measure which can be used even if the classes are of very different sizes, | |
* ranging from -1 (total disagreement) to +1 (perfect prediction), with 0 indicating no better than random prediction. | |
* | |
* @returns An object containing the optimal threshold and the corresponding MCC. | |
*/ | |
calculateOptimalThresholdUsingMCC() { | |
let optimalThreshold = 0; | |
let maxMCC = -1; // Initialize with -1, the minimum possible value for MCC | |
let thresholdStats: ReturnType<BinaryClassifierStatistics["calculateStatisticsAtThreshold"]> | undefined; | |
this.thresholds.forEach(threshold => { | |
const stats = this.calculateStatisticsAtThreshold(threshold); | |
const { MCC } = stats; | |
if (MCC > maxMCC) { | |
maxMCC = MCC; | |
optimalThreshold = threshold; | |
thresholdStats = stats; | |
} | |
}); | |
return { | |
optimalThreshold: optimalThreshold, | |
maxMCC: maxMCC, | |
thresholdStats, | |
}; | |
} | |
/** | |
* Calculates the optimal threshold for a specified metric and optimization goal. | |
* This method allows for flexible optimization based on a variety of metrics | |
* such as accuracy, precision, recall, F1 score, Matthews Correlation Coefficient (MCC), etc. | |
* It supports finding either the maximum or minimum value of the chosen metric across all possible thresholds, | |
* which can be useful for tailoring the performance of the binary classifier to specific operational requirements. | |
* | |
* The method iterates over all possible thresholds, evaluates the classifier's performance at each threshold using | |
* the specified metric, and identifies the threshold that optimizes (maximizes or minimizes) the metric's value. | |
* This approach enables the fine-tuning of the classifier's decision boundary for optimal performance on | |
* the given metric, which is particularly valuable in scenarios where trade-offs between different types | |
* of classification errors must be carefully managed. | |
* | |
* @param metric The name of the metric to optimize for. This should be a key from the object returned by | |
* calculateStatisticsAtThreshold, representing a specific performance metric of the classifier. | |
* @param optimisation Specifies the optimization goal: "maximum" to find the threshold that maximizes the metric, | |
* or "minimum" to find the threshold that minimizes the metric. | |
* @param initialisation (Optional) An initial value to start the optimization process. For maximum optimization, | |
* this could be the lowest possible value (e.g., -Infinity) to ensure any real value is higher. | |
* For minimum optimization, it could be the highest possible value (e.g., Infinity) to ensure | |
* any real value is lower. If not provided, defaults to -Infinity for maximum optimization | |
* and Infinity for minimum optimization. | |
* | |
* @returns An object containing the optimal threshold and the corresponding value of the optimized metric. | |
* The object has the structure: { optimalThreshold: number, value: number }, where optimalThreshold | |
* is the threshold that optimizes the specified metric, and value is the metric's optimized value. | |
*/ | |
calculateOptimalThreshold(metric: keyof ReturnType<BinaryClassifierStatistics["calculateStatisticsAtThreshold"]>, optimisation: "maximum" | "minimum", initialisation?: number) { | |
let optimalThreshold = 0; | |
let thresholdStats: ReturnType<BinaryClassifierStatistics["calculateStatisticsAtThreshold"]> | undefined; | |
if(optimisation === "maximum") { | |
let maxValue = initialisation ?? (-1 / 0); | |
this.thresholds.forEach(threshold => { | |
const stats = this.calculateStatisticsAtThreshold(threshold); | |
const entry = stats[metric]; | |
if (entry > maxValue) { | |
maxValue = entry; | |
optimalThreshold = threshold; | |
thresholdStats = stats; | |
} | |
}); | |
return { | |
optimalThreshold: optimalThreshold, | |
value: maxValue, | |
thresholdStats, | |
}; | |
} | |
else if(optimisation === "minimum") { | |
let minValue = initialisation ?? (1 / 0); | |
this.thresholds.forEach(threshold => { | |
const stats = this.calculateStatisticsAtThreshold(threshold); | |
const entry = stats[metric]; | |
if (entry < minValue) { | |
minValue = entry; | |
optimalThreshold = threshold; | |
thresholdStats = stats; | |
} | |
}); | |
return { | |
optimalThreshold: optimalThreshold, | |
value: minValue, | |
thresholdStats, | |
}; | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment