Last active
August 27, 2020 19:37
-
-
Save peheje/98070f0b065c1ed10917b40dab30bd29 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.lang.Exception | |
import java.nio.ByteBuffer | |
import java.security.MessageDigest | |
import java.util.* | |
import kotlin.math.abs | |
fun main() { | |
val filterSize = 1_000_000 | |
val numberOfEntries = 100_000 | |
val filter = BloomFilter(filterSize, numberOfHashes = 4) | |
val entriesInFilter = Array(numberOfEntries) { randomString() } | |
val entriesNotInFilter = Array(numberOfEntries) { randomString() } | |
for (entry in entriesInFilter) | |
filter.add(entry) | |
val confusionMatrix = ConfusionMatrix(entriesInFilter, entriesNotInFilter) { sample -> | |
filter.maybeExists(sample) | |
} | |
confusionMatrix.printReport() | |
if (confusionMatrix.falseNegativeRate > 0.0) { | |
throw Exception("This should not happen, if it does the implementation of the bloom filter is wrong.") | |
} | |
} | |
class BloomFilter(private val size: Int, numberOfHashes: Int) { | |
private val flags = BitSet(size) | |
private val salts = IntArray(numberOfHashes) { it }.map { it.toString() } | |
private val sha = MessageDigest.getInstance("SHA-1") | |
fun add(entry: String) { | |
for (salt in salts) { | |
val index = hashedIndex(entry, salt) | |
flags.set(index) | |
} | |
} | |
fun maybeExists(entry: String): Boolean { | |
for (salt in salts) { | |
val index = hashedIndex(entry, salt) | |
if (!flags[index]) { | |
return false | |
} | |
} | |
return true | |
} | |
private fun hashedIndex(entry: String, salt: String): Int { | |
val salted = entry + salt | |
val hash = sha.digest(salted.toByteArray()) | |
val wrapped = ByteBuffer.wrap(hash) | |
return abs(wrapped.int) % size | |
} | |
} | |
class ConfusionMatrix<T>(positives: Array<T>, negatives: Array<T>, val predict: (sample: T) -> Boolean) { | |
private val positivesCount = positives.size | |
private val negativesCount = negatives.size | |
private var truePositiveCount = 0 | |
private var trueNegativeCount = 0 | |
private var falsePositiveCount = 0 | |
private var falseNegativeCount = 0 | |
val accuracyRate: Double | |
val misclassificationRate: Double | |
val truePositiveRate: Double | |
val trueNegativeRate: Double | |
val falsePositiveRate: Double | |
val falseNegativeRate: Double | |
init { | |
if (positives.isEmpty()) throw Exception("positives must not be empty") | |
if (negatives.isEmpty()) throw Exception("negatives must not be empty") | |
countTruePositiveAndFalseNegative(positives) | |
countFalsePositiveAndTrueNegative(negatives) | |
accuracyRate = (truePositiveCount + trueNegativeCount).toDouble() / (negativesCount + positivesCount) | |
misclassificationRate = 1.0 - accuracyRate | |
truePositiveRate = truePositiveCount.toDouble() / positivesCount | |
trueNegativeRate = trueNegativeCount.toDouble() / negativesCount | |
falsePositiveRate = falsePositiveCount.toDouble() / negativesCount | |
falseNegativeRate = falseNegativeCount.toDouble() / positivesCount | |
} | |
private fun countTruePositiveAndFalseNegative(positives: Array<T>) { | |
for (positive in positives) { | |
if (predict(positive)) | |
truePositiveCount++ | |
else | |
falseNegativeCount++ | |
} | |
} | |
private fun countFalsePositiveAndTrueNegative(negatives: Array<T>) { | |
for (negative in negatives) { | |
if (predict(negative)) | |
falsePositiveCount++ | |
else | |
trueNegativeCount++ | |
} | |
} | |
fun printReport() { | |
val dataRows = mapOf( | |
"Accuracy" to accuracyRate, | |
"Misclassification rate" to misclassificationRate, | |
"True positive rate" to truePositiveRate, | |
"True negative rate" to trueNegativeRate, | |
"False positive rate" to falsePositiveRate, | |
"False negative rate" to falseNegativeRate | |
) | |
val printer = Printer(dataRows) | |
printer.print() | |
} | |
} | |
class Printer(private val dataRows: Map<String, Double>) { | |
private val spacing = 2 | |
private val longestLabelLength = getLongestString(dataRows.keys) + spacing | |
private val stringBuilder = StringBuilder() | |
private fun getLongestString(labels: Set<String>): Int { | |
return labels.map { it.length }.maxOrNull() ?: 50 | |
} | |
fun print() { | |
for ((label, value) in dataRows) { | |
printLabel(label) | |
printPadding(label) | |
printFormattedValue(value) | |
println() | |
} | |
} | |
private fun printLabel(label: String) { | |
print("$label:") | |
} | |
private fun printPadding(label: String) { | |
val paddingNeeded = longestLabelLength - label.length | |
stringBuilder.clear() | |
for (x in 0 until paddingNeeded) stringBuilder.append(" ") | |
print(stringBuilder.toString()) | |
} | |
private fun printFormattedValue(value: Double) { | |
val width6digits2 = "%6.2f" | |
val percentage = String.format(width6digits2, value * 100) + "%" | |
print(percentage) | |
} | |
} | |
private fun randomString(): String { | |
return UUID.randomUUID().toString() | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
You can run it on try.kotlinlang.org but set
Or lower if you get timeouts.