Last active
February 28, 2020 18:07
-
-
Save davidblewett/59b3a7309f4d65feca20291c2d73be04 to your computer and use it in GitHub Desktop.
Expiring Counter Implementation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import warnings | |
from collections import ( | |
Counter, MutableMapping, OrderedDict, deque, | |
) | |
from datetime import timedelta | |
from tornado import gen | |
from tornado.ioloop import IOLoop, PeriodicCallback | |
class ExpiringCounter(MutableMapping): | |
"""The general idea for this class is that it is a limited-size | |
deque of Counter instances. It must be driven externally by calling | |
`tick()` periodically, so that elements are rotated through the deque. | |
We delegate the required abstract methods to our list of Counter instances. | |
""" | |
def __init__(self, iterable=None, maxlen=1): | |
if iterable is None: | |
self._epochs = deque([Counter()], maxlen) | |
else: | |
self._epochs = deque(iterable, maxlen) | |
def __contains__(self, key): | |
return any(( | |
key in epoch | |
for epoch in self._epochs | |
)) | |
def __delitem__(self, key): | |
for epoch in self._epochs: | |
del epoch[key] | |
def __getitem__(self, key): | |
return sum(( | |
epoch[key] | |
for epoch in self._epochs | |
)) | |
def __iter__(self): | |
# Preserve order of appearance in epochs | |
result = OrderedDict() | |
for epoch in self._epochs: | |
result.update(epoch) | |
return iter(result) | |
def __len__(self): | |
return len(set.union(*( | |
set(epoch) | |
for epoch in self._epochs | |
))) | |
def __setitem__(self, key, value): | |
msg = u"Using += on the base class is not supported; use `increment`" | |
if value > 1: | |
warnings.warn(msg, SyntaxWarning) | |
self._epochs[-1][key] = value | |
def clear(self): | |
# This is a shortcut, instead of having to iterate over all keys | |
self._epochs.clear() | |
self.tick() | |
# You would be tempted to write this, | |
# and you would be wrong... | |
# def decrement(self, key, value=1): | |
# self._epochs[-1][key] -= value | |
def increment(self, key, value=1): | |
self._epochs[-1][key] += value | |
def tick(self): | |
"""This should be called periodically (how frequently you want to | |
expire keys; max key duration would be tick frequency * maxlen). | |
""" | |
self._epochs.append(Counter()) | |
class TornadoExpiringCounter(ExpiringCounter): | |
"""Implementation of ExpiringCounter that uses the Tornado IOLoop | |
to drive the deque rotation. | |
""" | |
def __init__(self, | |
loop=None, | |
max_duration=timedelta(minutes=5).total_seconds(), | |
granularity=timedelta(seconds=10).total_seconds(), | |
# Escape hatch for tests | |
maxlen=None): | |
if loop is None: | |
self._loop = IOLoop.current() | |
else: | |
self._loop = loop | |
self._max_duration = max_duration | |
self._granularity = granularity | |
if not self._granularity or self._granularity is gen.moment: | |
self._tick_pc = None | |
self._loop.add_callback(self.tick) | |
maxlen = maxlen | |
else: | |
self._tick_pc = PeriodicCallback(self.tick, | |
granularity * 1000, | |
self._loop) | |
self._tick_pc.start() | |
# Convert max_duration to maxlen | |
maxlen = int(max_duration / granularity) | |
super(TornadoExpiringCounter, self).__init__(iterable=None, | |
maxlen=maxlen) | |
def tick(self): | |
try: | |
super(TornadoExpiringCounter, self).tick() | |
finally: | |
# If no periodic callback is registered, | |
# schedule next tick immediately | |
if self._tick_pc is None: | |
self._loop.add_callback(self.tick) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import unittest | |
import warnings | |
from collections import Counter | |
from cs.eyrie.vassal import ExpiringCounter, TornadoExpiringCounter | |
from tornado import gen | |
from tornado.testing import AsyncTestCase, gen_test, main | |
class TestExpiringCounter(unittest.TestCase): | |
def test_init(self): | |
empty_expiring_counter = ExpiringCounter() | |
self.assertEqual(len(empty_expiring_counter._epochs), 1) | |
primed_expiring_counter = ExpiringCounter([Counter(), Counter()]) | |
self.assertEqual(len(primed_expiring_counter._epochs), 2) | |
def test_maxlen(self): | |
expiring_counter = ExpiringCounter([Counter(), Counter(), Counter()], | |
maxlen=2) | |
self.assertEqual(expiring_counter._epochs.maxlen, 2) | |
self.assertEqual(len(expiring_counter._epochs), 2) | |
def test_contains_last(self): | |
iterable = [Counter(), Counter(), Counter(bar=3)] | |
expiring_counter = ExpiringCounter(iterable, maxlen=3) | |
self.assertIn('bar', expiring_counter) | |
self.assertNotIn('baz', expiring_counter) | |
def test_contains_middle(self): | |
iterable = [Counter(), Counter(bar=3), Counter()] | |
expiring_counter = ExpiringCounter(iterable, maxlen=3) | |
self.assertIn('bar', expiring_counter) | |
self.assertNotIn('baz', expiring_counter) | |
def test_contains_first(self): | |
iterable = [Counter(bar=3), Counter(), Counter()] | |
expiring_counter = ExpiringCounter(iterable, maxlen=3) | |
self.assertIn('bar', expiring_counter) | |
self.assertNotIn('baz', expiring_counter) | |
def test_delitem_last(self): | |
iterable = [Counter(), Counter(), Counter(bar=3)] | |
expiring_counter = ExpiringCounter(iterable, maxlen=3) | |
self.assertIn('bar', expiring_counter) | |
del expiring_counter['bar'] | |
self.assertNotIn('bar', expiring_counter) | |
def test_delitem_middle(self): | |
iterable = [Counter(), Counter(bar=3), Counter()] | |
expiring_counter = ExpiringCounter(iterable, maxlen=3) | |
self.assertIn('bar', expiring_counter) | |
del expiring_counter['bar'] | |
self.assertNotIn('bar', expiring_counter) | |
def test_delitem_first(self): | |
iterable = [Counter(bar=3), Counter(), Counter()] | |
expiring_counter = ExpiringCounter(iterable, maxlen=3) | |
self.assertIn('bar', expiring_counter) | |
del expiring_counter['bar'] | |
self.assertNotIn('bar', expiring_counter) | |
def test_delitem_multi(self): | |
iterable = [Counter(bar=3), Counter(), Counter(bar=2)] | |
expiring_counter = ExpiringCounter(iterable, maxlen=3) | |
self.assertIn('bar', expiring_counter) | |
del expiring_counter['bar'] | |
self.assertNotIn('bar', expiring_counter) | |
def test_getitem_last(self): | |
iterable = [Counter(), Counter(), Counter(bar=3)] | |
expiring_counter = ExpiringCounter(iterable, maxlen=3) | |
self.assertEqual(expiring_counter['bar'], 3) | |
def test_getitem_middle(self): | |
iterable = [Counter(), Counter(bar=3), Counter()] | |
expiring_counter = ExpiringCounter(iterable, maxlen=3) | |
self.assertEqual(expiring_counter['bar'], 3) | |
def test_getitem_first(self): | |
iterable = [Counter(bar=3), Counter(), Counter()] | |
expiring_counter = ExpiringCounter(iterable, maxlen=3) | |
self.assertEqual(expiring_counter['bar'], 3) | |
def test_getitem_multi(self): | |
iterable = [Counter(bar=3), Counter(), Counter(bar=2)] | |
expiring_counter = ExpiringCounter(iterable, maxlen=3) | |
self.assertEqual(expiring_counter['bar'], 5) | |
def test_increment(self): | |
iterable = [Counter(), Counter(), Counter()] | |
expiring_counter = ExpiringCounter(iterable, maxlen=3) | |
expiring_counter['foo'] = 1 | |
self.assertEqual(expiring_counter['foo'], 1) | |
expiring_counter.tick() | |
expiring_counter.increment('foo', 1) | |
self.assertEqual(expiring_counter['foo'], 2) | |
expiring_counter.tick() | |
expiring_counter.increment('foo', 1) | |
self.assertEqual(expiring_counter['foo'], 3) | |
expiring_counter.tick() | |
self.assertEqual(expiring_counter['foo'], 2) | |
def test_iter_last(self): | |
iterable = [Counter(), Counter(), Counter(bar=3)] | |
expiring_counter = ExpiringCounter(iterable, maxlen=3) | |
self.assertEqual([key for key in expiring_counter], ['bar']) | |
def test_iter_middle(self): | |
iterable = [Counter(), Counter(bar=3), Counter()] | |
expiring_counter = ExpiringCounter(iterable, maxlen=3) | |
self.assertEqual([key for key in expiring_counter], ['bar']) | |
def test_iter_first(self): | |
iterable = [Counter(bar=3), Counter(), Counter()] | |
expiring_counter = ExpiringCounter(iterable, maxlen=3) | |
self.assertEqual([key for key in expiring_counter], ['bar']) | |
def test_iter_multi(self): | |
iterable = [Counter(bar=3), Counter(), Counter(bar=4)] | |
expiring_counter = ExpiringCounter(iterable, maxlen=3) | |
self.assertEqual([key for key in expiring_counter], | |
['bar']) | |
def test_iter_multi_distinct(self): | |
iterable = [Counter(bar=3), Counter(baz=2), Counter(bing=4)] | |
expiring_counter = ExpiringCounter(iterable, maxlen=3) | |
self.assertEqual([key for key in expiring_counter], | |
['bar', 'baz', 'bing']) | |
def test_len_last(self): | |
iterable = [Counter(), Counter(), Counter(bar=3)] | |
expiring_counter = ExpiringCounter(iterable, maxlen=3) | |
self.assertEqual(len(expiring_counter), 1) | |
def test_len_middle(self): | |
iterable = [Counter(), Counter(bar=3), Counter()] | |
expiring_counter = ExpiringCounter(iterable, maxlen=3) | |
self.assertEqual(len(expiring_counter), 1) | |
def test_len_first(self): | |
iterable = [Counter(bar=3), Counter(), Counter()] | |
expiring_counter = ExpiringCounter(iterable, maxlen=3) | |
self.assertEqual(len(expiring_counter), 1) | |
def test_len_multi(self): | |
iterable = [Counter(bar=3), Counter(), Counter(bar=2)] | |
expiring_counter = ExpiringCounter(iterable, maxlen=3) | |
self.assertEqual(len(expiring_counter), 2) | |
def test_len_multi_distinct(self): | |
iterable = [Counter(bar=3), Counter(), Counter(baz=2)] | |
expiring_counter = ExpiringCounter(iterable, maxlen=3) | |
self.assertEqual(len(expiring_counter), 2) | |
def test_setitem(self): | |
expiring_counter = ExpiringCounter(maxlen=3) | |
expiring_counter['foo'] += 1 | |
self.assertEqual(expiring_counter['foo'], 1) | |
def test_setitem_warning(self): | |
expiring_counter = ExpiringCounter(maxlen=3) | |
expiring_counter['bar'] = 1 | |
try: | |
# Cause all warnings to always be triggered. | |
warnings.simplefilter("always") | |
with warnings.catch_warnings(record=True) as result: | |
# This should trigger a warning because Python | |
# does a __getitem__ call, increment, then __setitem__ | |
# This sequence breaks if one of the elements is in an | |
# earlier epoch | |
expiring_counter['bar'] += 1 | |
# Verify some things | |
self.assertEqual(len(result), 1) | |
self.assertIs(result[-1].category, SyntaxWarning) | |
self.assertIn("increment", str(result[-1].message)) | |
finally: | |
warnings.resetwarnings() | |
def test_clear(self): | |
expiring_counter = ExpiringCounter(maxlen=3) | |
expiring_counter['foo'] += 1 | |
self.assertEqual(expiring_counter['foo'], 1) | |
expiring_counter.clear() | |
self.assertEqual(expiring_counter['foo'], 0) | |
def test_tick(self): | |
iterable = [Counter(), Counter(), Counter(bar=3)] | |
expiring_counter = ExpiringCounter(iterable, maxlen=3) | |
self.assertEqual(expiring_counter['bar'], 3) | |
self.assertEqual(len(expiring_counter), 1) | |
expiring_counter.tick() | |
self.assertEqual(expiring_counter['bar'], 3) | |
self.assertEqual(len(expiring_counter), 1) | |
expiring_counter.tick() | |
self.assertEqual(expiring_counter['bar'], 3) | |
self.assertEqual(len(expiring_counter), 1) | |
expiring_counter.tick() | |
self.assertEqual(expiring_counter['bar'], 0) | |
self.assertEqual(len(expiring_counter), 0) | |
class TestTornadoExpiringCounter(AsyncTestCase): | |
@gen_test | |
def test_init(self): | |
expiring_counter = TornadoExpiringCounter(self.io_loop, | |
granularity=None, | |
maxlen=3) | |
self.assertEqual(expiring_counter._epochs.maxlen, 3) | |
self.assertIs(expiring_counter._tick_pc, None) | |
expiring_counter = TornadoExpiringCounter(self.io_loop, | |
granularity=gen.moment, | |
maxlen=3) | |
self.assertIs(expiring_counter._tick_pc, None) | |
expiring_counter = TornadoExpiringCounter(self.io_loop, | |
granularity=0.0, | |
maxlen=3) | |
self.assertIs(expiring_counter._tick_pc, None) | |
@gen_test | |
def test_tick(self): | |
expiring_counter = TornadoExpiringCounter(self.io_loop, | |
granularity=None, | |
maxlen=1) | |
expiring_counter['foo'] += 1 | |
self.assertEqual(expiring_counter['foo'], 1) | |
self.assertEqual(len(expiring_counter), 1) | |
# Let IOLoop advance one iteration | |
yield None | |
self.assertEqual(expiring_counter['foo'], 0) | |
self.assertEqual(len(expiring_counter), 0) | |
expiring_counter = TornadoExpiringCounter(self.io_loop, | |
granularity=None, | |
maxlen=2) | |
expiring_counter['foo'] += 1 | |
self.assertEqual(expiring_counter['foo'], 1) | |
self.assertEqual(len(expiring_counter), 1) | |
# Let IOLoop advance one iteration | |
yield None | |
self.assertEqual(expiring_counter['foo'], 1) | |
self.assertEqual(len(expiring_counter), 1) | |
yield None | |
self.assertEqual(expiring_counter['foo'], 0) | |
self.assertEqual(expiring_counter['foo'], 0) | |
def all(): | |
suite = unittest.TestLoader().loadTestsFromName(__name__) | |
return suite | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment