Created
December 25, 2019 02:27
-
-
Save rjzak/543a7228a3aa505d786cde9c64014c4b to your computer and use it in GitHub Desktop.
A simple bloom filter with standard dependencies. Expects the input to be strings.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python3 | |
import os | |
import zlib | |
import pickle | |
import numpy as np | |
import unittest | |
P_B = 227 | |
P_M = 1000005 | |
def rabin_hash(ngram): | |
r = 0 | |
for ng in ngram: | |
r = r * P_B + ord(ng) | |
r %= P_M | |
return abs(r) | |
def crc32_hash(ngram): | |
if type(ngram) == str: | |
ngram = ngram.encode() | |
return abs(zlib.crc32(ngram)) | |
def adler32_hash(ngram): | |
if type(ngram) == str: | |
ngram = ngram.encode() | |
return abs(zlib.adler32(ngram)) | |
class BloomFilter: | |
def __init__(self, length=0, fname=None): | |
if fname is not None: | |
with open(fname, "rb") as f: | |
tempDict = pickle.load(f) | |
self.length = tempDict["length"] | |
self.bits = tempDict["bits"] | |
else: | |
if length < 10: | |
print("Bloom Filter size of %d does't make sense, chaned to 1000" % length) | |
length = 1000 | |
self.length = length | |
self.bits = np.zeros(length, dtype=np.bool) | |
self.hash_functions = (rabin_hash, crc32_hash, adler32_hash) | |
def insert(self, data): | |
for func in self.hash_functions: | |
index = func(data) % self.length | |
self.bits[index] = True | |
assert np.sum(self.bits) < self.length/4, "Bloom Filter exhausted" | |
def contains(self, data): | |
for func in self.hash_functions: | |
index = func(data) % self.length | |
if self.bits[index] == False: | |
return False | |
return True | |
def isempty(self): | |
return np.sum(self.bits) == 0 | |
def save(self, fname): | |
tempDict = {"length": self.length, "bits": self.bits} | |
with open(fname, "wb") as f: | |
pickle.dump(tempDict, f) | |
class BloomFilterTest(unittest.TestCase): | |
def setUp(self): | |
self.filterSize = int(1e6) | |
def test_empty_is_empty(self): | |
bf = BloomFilter(self.filterSize) | |
self.assertTrue(bf.isempty()) | |
def test_not_empty(self): | |
bf = BloomFilter(self.filterSize) | |
bf.insert("1234") | |
self.assertFalse(bf.isempty()) | |
def test_did_insert(self): | |
bf = BloomFilter(self.filterSize) | |
tempVal = "1234999" | |
bf.insert(tempVal) | |
self.assertTrue(bf.contains(tempVal)) | |
def test_not_contains(self): | |
bf = BloomFilter(self.filterSize) | |
bf.insert("1234") | |
bf.insert("abc123") | |
bf.insert("xyz000") | |
bf.insert("qwertyuiiop") | |
self.assertFalse(bf.contains("9999999999")) | |
def test_symbols(self): | |
bf = BloomFilter(self.filterSize) | |
bf.insert("{}|[]\;':&*(&)") | |
bf.insert("#$$^%") | |
self.assertFalse(bf.contains("?>?<>:")) | |
def test_serialization(self): | |
fname = "bloom_test.pkl" | |
testValue = "qazwsxedcrfvtgbyhnujmik,il." | |
bf = BloomFilter(self.filterSize) | |
bf.insert(testValue) | |
bf.save(fname) | |
newBF = BloomFilter(fname=fname) | |
os.remove(fname) | |
self.assertTrue(newBF.contains(testValue)) | |
self.assertEqual(bf.length, newBF.length) | |
if __name__ == '__main__': | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment