Created
May 24, 2018 06:30
-
-
Save cadl/2a9774e117cff4b554345936738e84d5 to your computer and use it in GitHub Desktop.
atfield
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf-8 | |
import functools | |
import time | |
from threading import local | |
from collections import deque | |
class ATFieldException(Exception): | |
pass | |
class ATFieldPolicyException(ATFieldException): | |
pass | |
class ATFieldExecuteException(ATFieldException): | |
pass | |
class CircuitStatus(object): | |
close = 1 | |
open = 2 | |
class ATFieldContext(object): | |
def __init__(self, retry_policy, circuit_breaker_policy, fallback_policy): | |
self.circuit_status = CircuitStatus.close | |
self.last_open_ts = None | |
self.circuit_breaker_policy = circuit_breaker_policy | |
self.circular_queue = deque(maxlen=self.circuit_breaker_policy.period) | |
self.stats = self._default_stats() | |
def _default_stats(self): | |
return {'circuit_breaker_count': 0, | |
'success_count': 0, | |
'failure_count': 0, | |
'retry_count': 0} | |
def incr_circuit_breaker_count(self): | |
self.stats['circuit_breaker_count'] += 1 | |
def incr_success_count(self): | |
self.stats['success_count'] += 1 | |
def incr_failure_count(self): | |
self.stats['failure_count'] += 1 | |
def incr_retry_count(self): | |
self.stats['retry_count'] += 1 | |
def init_stats(self): | |
self.stats = self._default_stats() | |
def is_circuit_breaker_open(self): | |
if self.circuit_status == CircuitStatus.close: | |
return False | |
ts = int(time.time()) | |
if ts > (self.last_open_ts + self.circuit_breaker_policy.recover_seconds): | |
self.close_circuit_breaker() | |
return False | |
else: | |
return True | |
def close_circuit_breaker(self): | |
self.circular_queue.clear() | |
self.circuit_status = CircuitStatus.close | |
def open_circuit_breaker(self): | |
self.circuit_status = CircuitStatus.open | |
self.last_open_ts = int(time.time()) | |
def record_success(self): | |
self.circular_queue.append(0) | |
def record_failure(self): | |
self.circular_queue.append(1) | |
if self.circuit_status == CircuitStatus.close: | |
if sum(self.circular_queue) >= self.circuit_breaker_policy.failure_threshold: | |
self.open_circuit_breaker() | |
class ATFieldContextManager(local): | |
initialized = False | |
context_map = {} | |
def __init__(self): | |
self.context_map = {} | |
self.initialized = True | |
def init_context(self, key, retry_policy, circuit_breaker_policy, fallback_policy): | |
context = ATFieldContext(retry_policy, circuit_breaker_policy, fallback_policy) | |
self.context_map[key] = context | |
return context | |
def put_context(self, key, context): | |
self.context_map[key] = context | |
def get_context(self, key): | |
return self.context_map.get(key) | |
def iterate_context(self): | |
return self.context_map.iteritems() | |
class ATField(object): | |
context_manager = ATFieldContextManager() | |
def __init__(self, retry_policy, circuit_breaker_policy, fallback_policy): | |
self.retry_policy = retry_policy | |
self.circuit_breaker_policy = circuit_breaker_policy | |
self.fallback_policy = fallback_policy | |
def __call__(self, func): | |
@functools.wraps(func) | |
def wrapper(*args, **kw): | |
return self.call(func, *args, **kw) | |
return wrapper | |
@staticmethod | |
def activate(retry_policy=None, circuit_breaker_policy=None, fallback_policy=None): | |
if not any([retry_policy, circuit_breaker_policy, fallback_policy]): | |
raise ATFieldPolicyException | |
atfield = ATField(retry_policy or RetryPolicy.default(), | |
circuit_breaker_policy or CircuitBreakerPolicy.default(), | |
fallback_policy or FallbackPolicy.default()) | |
return atfield | |
@classmethod | |
def _install_context_manager(cls): | |
if not getattr(cls.context_manager, 'initialized', None): | |
cls.context_manager = ATFieldContextManager() | |
return cls.context_manager | |
def call(self, func, *args, **kw): | |
self._install_context_manager() | |
key = '%s.%s' % (func.__module__, func.__name__) | |
context = self.context_manager.get_context(key) | |
if not context: | |
context = self.context_manager.init_context(key, self.retry_policy, self.circuit_breaker_policy, self.fallback_policy) | |
if context.is_circuit_breaker_open(): | |
context.incr_circuit_breaker_count() | |
return self.fallback_policy.handler_func() | |
try: | |
ret_val = self._call_with_retry(context, | |
self.retry_policy.max_retries, | |
self.retry_policy.exceptions, | |
func, *args, **kw) | |
except self.fallback_policy.exceptions: | |
context.record_failure() | |
context.incr_failure_count() | |
return self.fallback_policy.handler_func() | |
except Exception as e: | |
context.record_failure() | |
context.incr_failure_count() | |
raise e | |
else: | |
context.record_success() | |
context.incr_success_count() | |
return ret_val | |
def _call_with_retry(self, context, retry_cnt, exceptions, func, *args, **kw): | |
try: | |
return func(*args, **kw) | |
except exceptions as e: | |
if retry_cnt < 1: | |
raise e | |
else: | |
context.incr_retry_count() | |
return self._call_with_retry(context, retry_cnt-1, exceptions, func, | |
*args, **kw) | |
@classmethod | |
def stats(cls): | |
ret = {} | |
cls._install_context_manager() | |
for key, context in cls.context_manager.iterate_context(): | |
ret[key] = context.stats.copy() | |
return ret | |
@classmethod | |
def reset_stats(cls): | |
for _, context in cls.context_manager.iterate_context(): | |
context.init_stats() | |
class RetryPolicy(object): | |
def __init__(self, max_retries, exceptions): | |
self.max_retries = max_retries | |
self.exceptions = tuple(exceptions) | |
@classmethod | |
def default(cls): | |
return cls(0, []) | |
class CircuitBreakerPolicy(object): | |
def __init__(self, failure_threshold, period, recover_seconds): | |
self.failure_threshold = failure_threshold | |
self.period = period | |
self.recover_seconds = recover_seconds | |
@classmethod | |
def default(cls): | |
return cls(1, 0, 0) | |
class FallbackPolicy(object): | |
def __init__(self, handler_func, exceptions): | |
self.handler_func = handler_func | |
self.exceptions = tuple(exceptions) | |
@staticmethod | |
def default_fallback_func(): | |
raise ATFieldExecuteException | |
@classmethod | |
def default(cls): | |
return cls(FallbackPolicy.default_fallback_func, []) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
0 1527142944.33 | |
run | |
run | |
run | |
run | |
fallback | |
1 1527142945.33 | |
run | |
run | |
run | |
run | |
fallback | |
2 1527142946.33 | |
fallback | |
3 1527142947.34 | |
fallback | |
4 1527142948.34 | |
fallback | |
5 1527142949.34 | |
fallback | |
6 1527142950.35 | |
fallback | |
7 1527142951.35 | |
fallback | |
8 1527142952.35 | |
fallback | |
9 1527142953.35 | |
fallback | |
10 1527142954.36 | |
fallback | |
11 1527142955.36 | |
fallback | |
12 1527142956.37 | |
run | |
run | |
run | |
run | |
fallback | |
13 1527142957.37 | |
run | |
run | |
run | |
run | |
fallback | |
14 1527142958.37 | |
fallback | |
15 1527142959.38 | |
fallback | |
16 1527142960.38 | |
fallback | |
17 1527142961.39 | |
fallback | |
18 1527142962.39 | |
fallback | |
19 1527142963.39 | |
fallback | |
{'__main__.foo': {'retry_count': 12, 'failure_count': 4, 'success_count': 0, 'circuit_breaker_count': 16}} | |
None | |
{'__main__.foo': {'retry_count': 0, 'failure_count': 0, 'success_count': 0, 'circuit_breaker_count': 0}} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf-8 | |
# 需要指定重试/fallback/熔断规则,若未指定,默认不进行重试,没有 fallback 处理,不开启熔断机制 | |
import time | |
from atfield import ATField, RetryPolicy, CircuitBreakerPolicy, FallbackPolicy | |
class FooException(Exception): | |
pass | |
def fallback(): | |
print 'fallback' | |
# 捕获 FooException 进行重试,最多重试 3 次 | |
# 指定 fallback 函数 | |
# 指定熔断规则,5 次执行中失败 2 次就开启熔断机制,10 秒内的调用统一按 fallback 模式处理 | |
@ATField.activate(retry_policy=RetryPolicy(3, [FooException]), | |
fallback_policy=FallbackPolicy(fallback, [FooException]), | |
circuit_breaker_policy=CircuitBreakerPolicy(2, 5, 10)) | |
def foo(): | |
print 'run' | |
raise FooException | |
print 'success' | |
# 捕获 FooException 进行重试,最多重试 3 次 | |
@ATField.activate(retry_policy=RetryPolicy(3, [FooException])) | |
def just_retry(): | |
print 'run' | |
raise FooException | |
print 'success' | |
if __name__ == '__main__': | |
for i in range(20): | |
print i, time.time() | |
foo() | |
time.sleep(1) | |
print ATField.stats() | |
ATField.reset_stats() | |
print ATField.stats() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment