Last active
December 24, 2015 10:09
-
-
Save mynameisfiber/6782583 to your computer and use it in GitHub Desktop.
Counting Bloom and a Timing Bloom using python arrays and tornado PeriodicCallback's
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python | |
| import tornado.ioloop | |
| import tornado.testing | |
| import array | |
| import struct | |
| import math | |
| import mmh3 | |
| import time | |
| class CountingBloomFilter(object): | |
| def __init__(self, capacity, error=0.005, dtype="B"): | |
| self.capacity = capacity | |
| self.error = error | |
| self.num_bytes = int(-capacity * math.log(error) / math.log(2)**2) + 1 | |
| self.num_hashes = int(self.num_bytes / capacity * math.log(2)) + 1 | |
| self.dtype = dtype | |
| self.data = array.array(dtype, (0,) * self.num_bytes) | |
| def _indexes(self, key): | |
| """ | |
| Generates the indicies corresponding to the given key | |
| """ | |
| for i in xrange(self.num_hashes): | |
| h1, h2 = mmh3.hash64(key) | |
| yield (h1 + i * h2) % self.num_bytes | |
| def add(self, key, N=1): | |
| """ | |
| Adds `N` counts to the indicies given by the key | |
| """ | |
| assert isinstance(key, str) | |
| for index in self._indexes(key): | |
| self.data[index] += N | |
| return self | |
| def remove(self, key, N=1): | |
| """ | |
| Removes `N` counts to the indicies given by the key | |
| """ | |
| assert isinstance(key, str) | |
| indexes = list(self._indexes(key)) | |
| if not any(self.data[index] < N for index in indexes): | |
| for index in indexes: | |
| self.data[index] -= N | |
| return self | |
| def remove_all(self, N=1): | |
| """ | |
| Removes `N` counts to all indicies. Useful for expirations | |
| """ | |
| for i in xrange(self.num_bytes): | |
| if self.data[i] >= N: | |
| self.data[i] -= N | |
| def contains(self, key): | |
| """ | |
| Check if the current bloom contains the key `key` | |
| """ | |
| assert isinstance(key, str) | |
| return all(self.data[index] != 0 for index in self._indexes(key)) | |
| def tofile(self, f): | |
| """ | |
| Writes the bloom into the given fileobject. | |
| """ | |
| header = struct.pack("QdQQc", self.capacity, self.error, self.num_bytes, self.num_hashes, self.dtype) | |
| f.write(header + "\n") | |
| self.data.tofile(f) | |
| @classmethod | |
| def fromfile(cls, f): | |
| """ | |
| Reads the bloom from the given fileobject and returns the python object | |
| """ | |
| self = cls.__new__(cls) | |
| header = f.readline()[:-1] | |
| self.capacity, self.error, self.num_bytes, self.num_hashes, self.dtype = struct.unpack("QdQQc", header) | |
| self.data = array.array(self.dtype) | |
| self.data.fromfile(f, self.num_bytes) | |
| return self | |
| def __contains__(self, key): | |
| return self.contains(key) | |
| def __add__(self, other): | |
| return self.add(other) | |
| def __sub__(self, other): | |
| return self.remove(other) | |
| class TimingBloomFilter(CountingBloomFilter): | |
| def __init__(self, *args, **kwargs): | |
| self.decay_time = kwargs.pop("decay_time", None) | |
| self._ioloop = kwargs.pop("ioloop", None) or tornado.ioloop.IOLoop.instance() | |
| assert self.decay_time is not None, "Must provide decay_time parameter" | |
| super(TimingBloomFilter, self).__init__(*args, **kwargs) | |
| self.ring_size = (1 << (struct.calcsize(self.dtype) * 8)) - 1 | |
| self.dN = self.ring_size / 2 | |
| self.dt = self.decay_time / float(self.dN) | |
| self._setup_decay() | |
| def _setup_decay(self): | |
| t = self.dt * 1000.0 / 2.0 | |
| print "Going to decay every %f ms" % t | |
| self._callbacktimer = tornado.ioloop.PeriodicCallback(self.decay, t, self._ioloop) | |
| def _tick(self): | |
| return int((time.time() // self.dt) % self.ring_size) + 1 | |
| def _tick_range(self): | |
| tick_max = self._tick() | |
| tick_min = (tick_max - self.dN - 1) % self.ring_size + 1 | |
| return tick_min, tick_max | |
| def _test_interval(self): | |
| tick_min, tick_max = self._tick_range() | |
| if tick_min < tick_max: | |
| return lambda x : x and tick_min < x <= tick_max | |
| else: | |
| return lambda x : 0 < x <= tick_max or tick_min < x < self.ring_size | |
| def add(self, key): | |
| assert isinstance(key, str) | |
| tick = self._tick() | |
| for index in self._indexes(key): | |
| self.data[index] = tick | |
| return self | |
| def contains(self, key): | |
| """ | |
| Check if the current bloom contains the key `key` | |
| """ | |
| assert isinstance(key, str) | |
| test_interval = self._test_interval() | |
| return all(test_interval(self.data[index]) for index in self._indexes(key)) | |
| def decay(self): | |
| test_interval = self._test_interval() | |
| for i in xrange(self.num_bytes): | |
| if not test_interval(self.data[i]): | |
| self.data[i] = 0 | |
| def start(self): | |
| assert not self._callbacktimer._running, "Decay timer already running" | |
| self._callbacktimer.start() | |
| return self | |
| def stop(self): | |
| assert self._callbacktimer._running, "Decay timer not running" | |
| self._callbacktimer.stop() | |
| return self | |
| import contextlib | |
| @contextlib.contextmanager | |
| def TimingBlock(name, N=None): | |
| start = time.time() | |
| yield | |
| dt = time.time() - start | |
| if N: | |
| print "[%s][timing] %s: %fs (%f / s)" % (start, name, dt, N/dt) | |
| else: | |
| print "[%s][timing] %s: %fs" % (start, name, dt) | |
| class TestTimingBloomFilter(tornado.testing.AsyncTestCase): | |
| def test_decay(self): | |
| tbf = TimingBloomFilter(500, decay_time=4, ioloop=self.io_loop).start() | |
| tbf += "hello" | |
| assert tbf.contains("hello") == True | |
| try: | |
| self.wait(timeout = 4) | |
| except: | |
| pass | |
| assert tbf.contains("hello") == False | |
| def test_holistic(self): | |
| n = int(2e5) | |
| N = int(1e5) | |
| T = 10 | |
| print "TimingBloom with capacity %e and expiration time %ds" % (n, T) | |
| with TimingBlock("Initialization"): | |
| tbf = TimingBloomFilter(n, decay_time=T, dtype="B", ioloop=self.io_loop) | |
| orig_decay = tbf.decay | |
| def new_decay(*args, **kwargs): | |
| with TimingBlock("Decaying"): | |
| val = orig_decay(*args, **kwargs) | |
| return val | |
| setattr(tbf, "decay", new_decay) | |
| tbf._setup_decay() | |
| tbf.start() | |
| print "num_hashes = %d, num_bytes = %d" % (tbf.num_hashes, tbf.num_bytes) | |
| print "sizeof(TimingBloom) = %d bytes" % (struct.calcsize(tbf.dtype) * tbf.num_bytes) | |
| with TimingBlock("Adding %d values" % N, N): | |
| for i in xrange(N): | |
| tbf.add(str(i)) | |
| last_insert = time.time() | |
| with TimingBlock("Testing %d positive values" % N, N): | |
| for i in xrange(N): | |
| assert str(i) in tbf | |
| with TimingBlock("Testing %d negative values" % N, N): | |
| err = 0 | |
| for i in xrange(N, 2*N): | |
| if str(i) in tbf: | |
| err += 1 | |
| tot_err = err / float(N) | |
| assert tot_err <= tbf.error, "Error is too high: %f > %f" % (tot_err, tbf.error) | |
| try: | |
| t = T - (time.time() - last_insert) | |
| if t > 0: | |
| self.wait(timeout = t) | |
| except: | |
| pass | |
| with TimingBlock("Testing %d expired values" % N, N): | |
| err = 0 | |
| for i in xrange(N): | |
| if str(i) in tbf: | |
| err += 1 | |
| tot_err = err / float(N) | |
| assert tot_err <= tbf.error, "Error is too high: %f > %f" % (tot_err, tbf.error) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment