Skip to content

Instantly share code, notes, and snippets.

@davidblewett
Last active February 28, 2020 18:07
Show Gist options
  • Save davidblewett/59b3a7309f4d65feca20291c2d73be04 to your computer and use it in GitHub Desktop.
Save davidblewett/59b3a7309f4d65feca20291c2d73be04 to your computer and use it in GitHub Desktop.
Expiring Counter Implementation
import warnings
from collections import (
Counter, MutableMapping, OrderedDict, deque,
)
from datetime import timedelta
from tornado import gen
from tornado.ioloop import IOLoop, PeriodicCallback
class ExpiringCounter(MutableMapping):
"""The general idea for this class is that it is a limited-size
deque of Counter instances. It must be driven externally by calling
`tick()` periodically, so that elements are rotated through the deque.
We delegate the required abstract methods to our list of Counter instances.
"""
def __init__(self, iterable=None, maxlen=1):
if iterable is None:
self._epochs = deque([Counter()], maxlen)
else:
self._epochs = deque(iterable, maxlen)
def __contains__(self, key):
return any((
key in epoch
for epoch in self._epochs
))
def __delitem__(self, key):
for epoch in self._epochs:
del epoch[key]
def __getitem__(self, key):
return sum((
epoch[key]
for epoch in self._epochs
))
def __iter__(self):
# Preserve order of appearance in epochs
result = OrderedDict()
for epoch in self._epochs:
result.update(epoch)
return iter(result)
def __len__(self):
return len(set.union(*(
set(epoch)
for epoch in self._epochs
)))
def __setitem__(self, key, value):
msg = u"Using += on the base class is not supported; use `increment`"
if value > 1:
warnings.warn(msg, SyntaxWarning)
self._epochs[-1][key] = value
def clear(self):
# This is a shortcut, instead of having to iterate over all keys
self._epochs.clear()
self.tick()
# You would be tempted to write this,
# and you would be wrong...
# def decrement(self, key, value=1):
# self._epochs[-1][key] -= value
def increment(self, key, value=1):
self._epochs[-1][key] += value
def tick(self):
"""This should be called periodically (how frequently you want to
expire keys; max key duration would be tick frequency * maxlen).
"""
self._epochs.append(Counter())
class TornadoExpiringCounter(ExpiringCounter):
"""Implementation of ExpiringCounter that uses the Tornado IOLoop
to drive the deque rotation.
"""
def __init__(self,
loop=None,
max_duration=timedelta(minutes=5).total_seconds(),
granularity=timedelta(seconds=10).total_seconds(),
# Escape hatch for tests
maxlen=None):
if loop is None:
self._loop = IOLoop.current()
else:
self._loop = loop
self._max_duration = max_duration
self._granularity = granularity
if not self._granularity or self._granularity is gen.moment:
self._tick_pc = None
self._loop.add_callback(self.tick)
maxlen = maxlen
else:
self._tick_pc = PeriodicCallback(self.tick,
granularity * 1000,
self._loop)
self._tick_pc.start()
# Convert max_duration to maxlen
maxlen = int(max_duration / granularity)
super(TornadoExpiringCounter, self).__init__(iterable=None,
maxlen=maxlen)
def tick(self):
try:
super(TornadoExpiringCounter, self).tick()
finally:
# If no periodic callback is registered,
# schedule next tick immediately
if self._tick_pc is None:
self._loop.add_callback(self.tick)
import unittest
import warnings
from collections import Counter
from cs.eyrie.vassal import ExpiringCounter, TornadoExpiringCounter
from tornado import gen
from tornado.testing import AsyncTestCase, gen_test, main
class TestExpiringCounter(unittest.TestCase):
def test_init(self):
empty_expiring_counter = ExpiringCounter()
self.assertEqual(len(empty_expiring_counter._epochs), 1)
primed_expiring_counter = ExpiringCounter([Counter(), Counter()])
self.assertEqual(len(primed_expiring_counter._epochs), 2)
def test_maxlen(self):
expiring_counter = ExpiringCounter([Counter(), Counter(), Counter()],
maxlen=2)
self.assertEqual(expiring_counter._epochs.maxlen, 2)
self.assertEqual(len(expiring_counter._epochs), 2)
def test_contains_last(self):
iterable = [Counter(), Counter(), Counter(bar=3)]
expiring_counter = ExpiringCounter(iterable, maxlen=3)
self.assertIn('bar', expiring_counter)
self.assertNotIn('baz', expiring_counter)
def test_contains_middle(self):
iterable = [Counter(), Counter(bar=3), Counter()]
expiring_counter = ExpiringCounter(iterable, maxlen=3)
self.assertIn('bar', expiring_counter)
self.assertNotIn('baz', expiring_counter)
def test_contains_first(self):
iterable = [Counter(bar=3), Counter(), Counter()]
expiring_counter = ExpiringCounter(iterable, maxlen=3)
self.assertIn('bar', expiring_counter)
self.assertNotIn('baz', expiring_counter)
def test_delitem_last(self):
iterable = [Counter(), Counter(), Counter(bar=3)]
expiring_counter = ExpiringCounter(iterable, maxlen=3)
self.assertIn('bar', expiring_counter)
del expiring_counter['bar']
self.assertNotIn('bar', expiring_counter)
def test_delitem_middle(self):
iterable = [Counter(), Counter(bar=3), Counter()]
expiring_counter = ExpiringCounter(iterable, maxlen=3)
self.assertIn('bar', expiring_counter)
del expiring_counter['bar']
self.assertNotIn('bar', expiring_counter)
def test_delitem_first(self):
iterable = [Counter(bar=3), Counter(), Counter()]
expiring_counter = ExpiringCounter(iterable, maxlen=3)
self.assertIn('bar', expiring_counter)
del expiring_counter['bar']
self.assertNotIn('bar', expiring_counter)
def test_delitem_multi(self):
iterable = [Counter(bar=3), Counter(), Counter(bar=2)]
expiring_counter = ExpiringCounter(iterable, maxlen=3)
self.assertIn('bar', expiring_counter)
del expiring_counter['bar']
self.assertNotIn('bar', expiring_counter)
def test_getitem_last(self):
iterable = [Counter(), Counter(), Counter(bar=3)]
expiring_counter = ExpiringCounter(iterable, maxlen=3)
self.assertEqual(expiring_counter['bar'], 3)
def test_getitem_middle(self):
iterable = [Counter(), Counter(bar=3), Counter()]
expiring_counter = ExpiringCounter(iterable, maxlen=3)
self.assertEqual(expiring_counter['bar'], 3)
def test_getitem_first(self):
iterable = [Counter(bar=3), Counter(), Counter()]
expiring_counter = ExpiringCounter(iterable, maxlen=3)
self.assertEqual(expiring_counter['bar'], 3)
def test_getitem_multi(self):
iterable = [Counter(bar=3), Counter(), Counter(bar=2)]
expiring_counter = ExpiringCounter(iterable, maxlen=3)
self.assertEqual(expiring_counter['bar'], 5)
def test_increment(self):
iterable = [Counter(), Counter(), Counter()]
expiring_counter = ExpiringCounter(iterable, maxlen=3)
expiring_counter['foo'] = 1
self.assertEqual(expiring_counter['foo'], 1)
expiring_counter.tick()
expiring_counter.increment('foo', 1)
self.assertEqual(expiring_counter['foo'], 2)
expiring_counter.tick()
expiring_counter.increment('foo', 1)
self.assertEqual(expiring_counter['foo'], 3)
expiring_counter.tick()
self.assertEqual(expiring_counter['foo'], 2)
def test_iter_last(self):
iterable = [Counter(), Counter(), Counter(bar=3)]
expiring_counter = ExpiringCounter(iterable, maxlen=3)
self.assertEqual([key for key in expiring_counter], ['bar'])
def test_iter_middle(self):
iterable = [Counter(), Counter(bar=3), Counter()]
expiring_counter = ExpiringCounter(iterable, maxlen=3)
self.assertEqual([key for key in expiring_counter], ['bar'])
def test_iter_first(self):
iterable = [Counter(bar=3), Counter(), Counter()]
expiring_counter = ExpiringCounter(iterable, maxlen=3)
self.assertEqual([key for key in expiring_counter], ['bar'])
def test_iter_multi(self):
iterable = [Counter(bar=3), Counter(), Counter(bar=4)]
expiring_counter = ExpiringCounter(iterable, maxlen=3)
self.assertEqual([key for key in expiring_counter],
['bar'])
def test_iter_multi_distinct(self):
iterable = [Counter(bar=3), Counter(baz=2), Counter(bing=4)]
expiring_counter = ExpiringCounter(iterable, maxlen=3)
self.assertEqual([key for key in expiring_counter],
['bar', 'baz', 'bing'])
def test_len_last(self):
iterable = [Counter(), Counter(), Counter(bar=3)]
expiring_counter = ExpiringCounter(iterable, maxlen=3)
self.assertEqual(len(expiring_counter), 1)
def test_len_middle(self):
iterable = [Counter(), Counter(bar=3), Counter()]
expiring_counter = ExpiringCounter(iterable, maxlen=3)
self.assertEqual(len(expiring_counter), 1)
def test_len_first(self):
iterable = [Counter(bar=3), Counter(), Counter()]
expiring_counter = ExpiringCounter(iterable, maxlen=3)
self.assertEqual(len(expiring_counter), 1)
def test_len_multi(self):
iterable = [Counter(bar=3), Counter(), Counter(bar=2)]
expiring_counter = ExpiringCounter(iterable, maxlen=3)
self.assertEqual(len(expiring_counter), 2)
def test_len_multi_distinct(self):
iterable = [Counter(bar=3), Counter(), Counter(baz=2)]
expiring_counter = ExpiringCounter(iterable, maxlen=3)
self.assertEqual(len(expiring_counter), 2)
def test_setitem(self):
expiring_counter = ExpiringCounter(maxlen=3)
expiring_counter['foo'] += 1
self.assertEqual(expiring_counter['foo'], 1)
def test_setitem_warning(self):
expiring_counter = ExpiringCounter(maxlen=3)
expiring_counter['bar'] = 1
try:
# Cause all warnings to always be triggered.
warnings.simplefilter("always")
with warnings.catch_warnings(record=True) as result:
# This should trigger a warning because Python
# does a __getitem__ call, increment, then __setitem__
# This sequence breaks if one of the elements is in an
# earlier epoch
expiring_counter['bar'] += 1
# Verify some things
self.assertEqual(len(result), 1)
self.assertIs(result[-1].category, SyntaxWarning)
self.assertIn("increment", str(result[-1].message))
finally:
warnings.resetwarnings()
def test_clear(self):
expiring_counter = ExpiringCounter(maxlen=3)
expiring_counter['foo'] += 1
self.assertEqual(expiring_counter['foo'], 1)
expiring_counter.clear()
self.assertEqual(expiring_counter['foo'], 0)
def test_tick(self):
iterable = [Counter(), Counter(), Counter(bar=3)]
expiring_counter = ExpiringCounter(iterable, maxlen=3)
self.assertEqual(expiring_counter['bar'], 3)
self.assertEqual(len(expiring_counter), 1)
expiring_counter.tick()
self.assertEqual(expiring_counter['bar'], 3)
self.assertEqual(len(expiring_counter), 1)
expiring_counter.tick()
self.assertEqual(expiring_counter['bar'], 3)
self.assertEqual(len(expiring_counter), 1)
expiring_counter.tick()
self.assertEqual(expiring_counter['bar'], 0)
self.assertEqual(len(expiring_counter), 0)
class TestTornadoExpiringCounter(AsyncTestCase):
@gen_test
def test_init(self):
expiring_counter = TornadoExpiringCounter(self.io_loop,
granularity=None,
maxlen=3)
self.assertEqual(expiring_counter._epochs.maxlen, 3)
self.assertIs(expiring_counter._tick_pc, None)
expiring_counter = TornadoExpiringCounter(self.io_loop,
granularity=gen.moment,
maxlen=3)
self.assertIs(expiring_counter._tick_pc, None)
expiring_counter = TornadoExpiringCounter(self.io_loop,
granularity=0.0,
maxlen=3)
self.assertIs(expiring_counter._tick_pc, None)
@gen_test
def test_tick(self):
expiring_counter = TornadoExpiringCounter(self.io_loop,
granularity=None,
maxlen=1)
expiring_counter['foo'] += 1
self.assertEqual(expiring_counter['foo'], 1)
self.assertEqual(len(expiring_counter), 1)
# Let IOLoop advance one iteration
yield None
self.assertEqual(expiring_counter['foo'], 0)
self.assertEqual(len(expiring_counter), 0)
expiring_counter = TornadoExpiringCounter(self.io_loop,
granularity=None,
maxlen=2)
expiring_counter['foo'] += 1
self.assertEqual(expiring_counter['foo'], 1)
self.assertEqual(len(expiring_counter), 1)
# Let IOLoop advance one iteration
yield None
self.assertEqual(expiring_counter['foo'], 1)
self.assertEqual(len(expiring_counter), 1)
yield None
self.assertEqual(expiring_counter['foo'], 0)
self.assertEqual(expiring_counter['foo'], 0)
def all():
suite = unittest.TestLoader().loadTestsFromName(__name__)
return suite
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment