Skip to content

Instantly share code, notes, and snippets.

@mynameisfiber
Last active December 24, 2015 10:09
Show Gist options
  • Select an option

  • Save mynameisfiber/6782583 to your computer and use it in GitHub Desktop.

Select an option

Save mynameisfiber/6782583 to your computer and use it in GitHub Desktop.
Counting Bloom and a Timing Bloom using python arrays and tornado PeriodicCallback's
#!/usr/bin/env python
import tornado.ioloop
import tornado.testing
import array
import struct
import math
import mmh3
import time
class CountingBloomFilter(object):
def __init__(self, capacity, error=0.005, dtype="B"):
self.capacity = capacity
self.error = error
self.num_bytes = int(-capacity * math.log(error) / math.log(2)**2) + 1
self.num_hashes = int(self.num_bytes / capacity * math.log(2)) + 1
self.dtype = dtype
self.data = array.array(dtype, (0,) * self.num_bytes)
def _indexes(self, key):
"""
Generates the indicies corresponding to the given key
"""
for i in xrange(self.num_hashes):
h1, h2 = mmh3.hash64(key)
yield (h1 + i * h2) % self.num_bytes
def add(self, key, N=1):
"""
Adds `N` counts to the indicies given by the key
"""
assert isinstance(key, str)
for index in self._indexes(key):
self.data[index] += N
return self
def remove(self, key, N=1):
"""
Removes `N` counts to the indicies given by the key
"""
assert isinstance(key, str)
indexes = list(self._indexes(key))
if not any(self.data[index] < N for index in indexes):
for index in indexes:
self.data[index] -= N
return self
def remove_all(self, N=1):
"""
Removes `N` counts to all indicies. Useful for expirations
"""
for i in xrange(self.num_bytes):
if self.data[i] >= N:
self.data[i] -= N
def contains(self, key):
"""
Check if the current bloom contains the key `key`
"""
assert isinstance(key, str)
return all(self.data[index] != 0 for index in self._indexes(key))
def tofile(self, f):
"""
Writes the bloom into the given fileobject.
"""
header = struct.pack("QdQQc", self.capacity, self.error, self.num_bytes, self.num_hashes, self.dtype)
f.write(header + "\n")
self.data.tofile(f)
@classmethod
def fromfile(cls, f):
"""
Reads the bloom from the given fileobject and returns the python object
"""
self = cls.__new__(cls)
header = f.readline()[:-1]
self.capacity, self.error, self.num_bytes, self.num_hashes, self.dtype = struct.unpack("QdQQc", header)
self.data = array.array(self.dtype)
self.data.fromfile(f, self.num_bytes)
return self
def __contains__(self, key):
return self.contains(key)
def __add__(self, other):
return self.add(other)
def __sub__(self, other):
return self.remove(other)
class TimingBloomFilter(CountingBloomFilter):
def __init__(self, *args, **kwargs):
self.decay_time = kwargs.pop("decay_time", None)
self._ioloop = kwargs.pop("ioloop", None) or tornado.ioloop.IOLoop.instance()
assert self.decay_time is not None, "Must provide decay_time parameter"
super(TimingBloomFilter, self).__init__(*args, **kwargs)
self.ring_size = (1 << (struct.calcsize(self.dtype) * 8)) - 1
self.dN = self.ring_size / 2
self.dt = self.decay_time / float(self.dN)
self._setup_decay()
def _setup_decay(self):
t = self.dt * 1000.0 / 2.0
print "Going to decay every %f ms" % t
self._callbacktimer = tornado.ioloop.PeriodicCallback(self.decay, t, self._ioloop)
def _tick(self):
return int((time.time() // self.dt) % self.ring_size) + 1
def _tick_range(self):
tick_max = self._tick()
tick_min = (tick_max - self.dN - 1) % self.ring_size + 1
return tick_min, tick_max
def _test_interval(self):
tick_min, tick_max = self._tick_range()
if tick_min < tick_max:
return lambda x : x and tick_min < x <= tick_max
else:
return lambda x : 0 < x <= tick_max or tick_min < x < self.ring_size
def add(self, key):
assert isinstance(key, str)
tick = self._tick()
for index in self._indexes(key):
self.data[index] = tick
return self
def contains(self, key):
"""
Check if the current bloom contains the key `key`
"""
assert isinstance(key, str)
test_interval = self._test_interval()
return all(test_interval(self.data[index]) for index in self._indexes(key))
def decay(self):
test_interval = self._test_interval()
for i in xrange(self.num_bytes):
if not test_interval(self.data[i]):
self.data[i] = 0
def start(self):
assert not self._callbacktimer._running, "Decay timer already running"
self._callbacktimer.start()
return self
def stop(self):
assert self._callbacktimer._running, "Decay timer not running"
self._callbacktimer.stop()
return self
import contextlib
@contextlib.contextmanager
def TimingBlock(name, N=None):
start = time.time()
yield
dt = time.time() - start
if N:
print "[%s][timing] %s: %fs (%f / s)" % (start, name, dt, N/dt)
else:
print "[%s][timing] %s: %fs" % (start, name, dt)
class TestTimingBloomFilter(tornado.testing.AsyncTestCase):
def test_decay(self):
tbf = TimingBloomFilter(500, decay_time=4, ioloop=self.io_loop).start()
tbf += "hello"
assert tbf.contains("hello") == True
try:
self.wait(timeout = 4)
except:
pass
assert tbf.contains("hello") == False
def test_holistic(self):
n = int(2e5)
N = int(1e5)
T = 10
print "TimingBloom with capacity %e and expiration time %ds" % (n, T)
with TimingBlock("Initialization"):
tbf = TimingBloomFilter(n, decay_time=T, dtype="B", ioloop=self.io_loop)
orig_decay = tbf.decay
def new_decay(*args, **kwargs):
with TimingBlock("Decaying"):
val = orig_decay(*args, **kwargs)
return val
setattr(tbf, "decay", new_decay)
tbf._setup_decay()
tbf.start()
print "num_hashes = %d, num_bytes = %d" % (tbf.num_hashes, tbf.num_bytes)
print "sizeof(TimingBloom) = %d bytes" % (struct.calcsize(tbf.dtype) * tbf.num_bytes)
with TimingBlock("Adding %d values" % N, N):
for i in xrange(N):
tbf.add(str(i))
last_insert = time.time()
with TimingBlock("Testing %d positive values" % N, N):
for i in xrange(N):
assert str(i) in tbf
with TimingBlock("Testing %d negative values" % N, N):
err = 0
for i in xrange(N, 2*N):
if str(i) in tbf:
err += 1
tot_err = err / float(N)
assert tot_err <= tbf.error, "Error is too high: %f > %f" % (tot_err, tbf.error)
try:
t = T - (time.time() - last_insert)
if t > 0:
self.wait(timeout = t)
except:
pass
with TimingBlock("Testing %d expired values" % N, N):
err = 0
for i in xrange(N):
if str(i) in tbf:
err += 1
tot_err = err / float(N)
assert tot_err <= tbf.error, "Error is too high: %f > %f" % (tot_err, tbf.error)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment