Skip to content

Instantly share code, notes, and snippets.

@cadl
Created May 24, 2018 06:30
Show Gist options
  • Save cadl/2a9774e117cff4b554345936738e84d5 to your computer and use it in GitHub Desktop.
Save cadl/2a9774e117cff4b554345936738e84d5 to your computer and use it in GitHub Desktop.
atfield
# coding: utf-8
import functools
import time
from threading import local
from collections import deque
class ATFieldException(Exception):
pass
class ATFieldPolicyException(ATFieldException):
pass
class ATFieldExecuteException(ATFieldException):
pass
class CircuitStatus(object):
close = 1
open = 2
class ATFieldContext(object):
def __init__(self, retry_policy, circuit_breaker_policy, fallback_policy):
self.circuit_status = CircuitStatus.close
self.last_open_ts = None
self.circuit_breaker_policy = circuit_breaker_policy
self.circular_queue = deque(maxlen=self.circuit_breaker_policy.period)
self.stats = self._default_stats()
def _default_stats(self):
return {'circuit_breaker_count': 0,
'success_count': 0,
'failure_count': 0,
'retry_count': 0}
def incr_circuit_breaker_count(self):
self.stats['circuit_breaker_count'] += 1
def incr_success_count(self):
self.stats['success_count'] += 1
def incr_failure_count(self):
self.stats['failure_count'] += 1
def incr_retry_count(self):
self.stats['retry_count'] += 1
def init_stats(self):
self.stats = self._default_stats()
def is_circuit_breaker_open(self):
if self.circuit_status == CircuitStatus.close:
return False
ts = int(time.time())
if ts > (self.last_open_ts + self.circuit_breaker_policy.recover_seconds):
self.close_circuit_breaker()
return False
else:
return True
def close_circuit_breaker(self):
self.circular_queue.clear()
self.circuit_status = CircuitStatus.close
def open_circuit_breaker(self):
self.circuit_status = CircuitStatus.open
self.last_open_ts = int(time.time())
def record_success(self):
self.circular_queue.append(0)
def record_failure(self):
self.circular_queue.append(1)
if self.circuit_status == CircuitStatus.close:
if sum(self.circular_queue) >= self.circuit_breaker_policy.failure_threshold:
self.open_circuit_breaker()
class ATFieldContextManager(local):
initialized = False
context_map = {}
def __init__(self):
self.context_map = {}
self.initialized = True
def init_context(self, key, retry_policy, circuit_breaker_policy, fallback_policy):
context = ATFieldContext(retry_policy, circuit_breaker_policy, fallback_policy)
self.context_map[key] = context
return context
def put_context(self, key, context):
self.context_map[key] = context
def get_context(self, key):
return self.context_map.get(key)
def iterate_context(self):
return self.context_map.iteritems()
class ATField(object):
context_manager = ATFieldContextManager()
def __init__(self, retry_policy, circuit_breaker_policy, fallback_policy):
self.retry_policy = retry_policy
self.circuit_breaker_policy = circuit_breaker_policy
self.fallback_policy = fallback_policy
def __call__(self, func):
@functools.wraps(func)
def wrapper(*args, **kw):
return self.call(func, *args, **kw)
return wrapper
@staticmethod
def activate(retry_policy=None, circuit_breaker_policy=None, fallback_policy=None):
if not any([retry_policy, circuit_breaker_policy, fallback_policy]):
raise ATFieldPolicyException
atfield = ATField(retry_policy or RetryPolicy.default(),
circuit_breaker_policy or CircuitBreakerPolicy.default(),
fallback_policy or FallbackPolicy.default())
return atfield
@classmethod
def _install_context_manager(cls):
if not getattr(cls.context_manager, 'initialized', None):
cls.context_manager = ATFieldContextManager()
return cls.context_manager
def call(self, func, *args, **kw):
self._install_context_manager()
key = '%s.%s' % (func.__module__, func.__name__)
context = self.context_manager.get_context(key)
if not context:
context = self.context_manager.init_context(key, self.retry_policy, self.circuit_breaker_policy, self.fallback_policy)
if context.is_circuit_breaker_open():
context.incr_circuit_breaker_count()
return self.fallback_policy.handler_func()
try:
ret_val = self._call_with_retry(context,
self.retry_policy.max_retries,
self.retry_policy.exceptions,
func, *args, **kw)
except self.fallback_policy.exceptions:
context.record_failure()
context.incr_failure_count()
return self.fallback_policy.handler_func()
except Exception as e:
context.record_failure()
context.incr_failure_count()
raise e
else:
context.record_success()
context.incr_success_count()
return ret_val
def _call_with_retry(self, context, retry_cnt, exceptions, func, *args, **kw):
try:
return func(*args, **kw)
except exceptions as e:
if retry_cnt < 1:
raise e
else:
context.incr_retry_count()
return self._call_with_retry(context, retry_cnt-1, exceptions, func,
*args, **kw)
@classmethod
def stats(cls):
ret = {}
cls._install_context_manager()
for key, context in cls.context_manager.iterate_context():
ret[key] = context.stats.copy()
return ret
@classmethod
def reset_stats(cls):
for _, context in cls.context_manager.iterate_context():
context.init_stats()
class RetryPolicy(object):
def __init__(self, max_retries, exceptions):
self.max_retries = max_retries
self.exceptions = tuple(exceptions)
@classmethod
def default(cls):
return cls(0, [])
class CircuitBreakerPolicy(object):
def __init__(self, failure_threshold, period, recover_seconds):
self.failure_threshold = failure_threshold
self.period = period
self.recover_seconds = recover_seconds
@classmethod
def default(cls):
return cls(1, 0, 0)
class FallbackPolicy(object):
def __init__(self, handler_func, exceptions):
self.handler_func = handler_func
self.exceptions = tuple(exceptions)
@staticmethod
def default_fallback_func():
raise ATFieldExecuteException
@classmethod
def default(cls):
return cls(FallbackPolicy.default_fallback_func, [])
0 1527142944.33
run
run
run
run
fallback
1 1527142945.33
run
run
run
run
fallback
2 1527142946.33
fallback
3 1527142947.34
fallback
4 1527142948.34
fallback
5 1527142949.34
fallback
6 1527142950.35
fallback
7 1527142951.35
fallback
8 1527142952.35
fallback
9 1527142953.35
fallback
10 1527142954.36
fallback
11 1527142955.36
fallback
12 1527142956.37
run
run
run
run
fallback
13 1527142957.37
run
run
run
run
fallback
14 1527142958.37
fallback
15 1527142959.38
fallback
16 1527142960.38
fallback
17 1527142961.39
fallback
18 1527142962.39
fallback
19 1527142963.39
fallback
{'__main__.foo': {'retry_count': 12, 'failure_count': 4, 'success_count': 0, 'circuit_breaker_count': 16}}
None
{'__main__.foo': {'retry_count': 0, 'failure_count': 0, 'success_count': 0, 'circuit_breaker_count': 0}}
# coding: utf-8
# 需要指定重试/fallback/熔断规则,若未指定,默认不进行重试,没有 fallback 处理,不开启熔断机制
import time
from atfield import ATField, RetryPolicy, CircuitBreakerPolicy, FallbackPolicy
class FooException(Exception):
pass
def fallback():
print 'fallback'
# 捕获 FooException 进行重试,最多重试 3 次
# 指定 fallback 函数
# 指定熔断规则,5 次执行中失败 2 次就开启熔断机制,10 秒内的调用统一按 fallback 模式处理
@ATField.activate(retry_policy=RetryPolicy(3, [FooException]),
fallback_policy=FallbackPolicy(fallback, [FooException]),
circuit_breaker_policy=CircuitBreakerPolicy(2, 5, 10))
def foo():
print 'run'
raise FooException
print 'success'
# 捕获 FooException 进行重试,最多重试 3 次
@ATField.activate(retry_policy=RetryPolicy(3, [FooException]))
def just_retry():
print 'run'
raise FooException
print 'success'
if __name__ == '__main__':
for i in range(20):
print i, time.time()
foo()
time.sleep(1)
print ATField.stats()
ATField.reset_stats()
print ATField.stats()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment