Created
March 16, 2014 03:18
-
-
Save cooldaemon/9578081 to your computer and use it in GitHub Desktop.
Redis Mutex を Python で実装する ref: http://qiita.com/cooldaemon/items/a192c608a8ead1577881
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class MutexError(Exception): | |
pass | |
class DuplicateLockError(MutexError): | |
""" | |
既に lock() 実行済みの Mutex オブジェクトで lock() を再実行すると発生. | |
一度, unlock() を実行するか, 別の Mutex オブジェクトを作成する必要がある. | |
""" | |
pass | |
class HasNotLockError(MutexError): | |
""" | |
まだ, lock() が実行されていない Mutex オブジェクトで unlock() を実行すると発生. | |
lock() 後に実行する必要がある. | |
""" | |
pass | |
class ExpiredLockError(MutexError): | |
""" | |
lock() 実行後, expire によりロックが解放されている状態で unlock() を実行すると発生. | |
""" | |
pass | |
class SetnxError(MutexError): | |
pass | |
class LockError(MutexError): | |
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from datetime import datetime | |
import time | |
from functools import wraps | |
from .exception import (DuplicateLockError, | |
HasNotLockError, | |
ExpiredLockError, | |
SetnxError, | |
LockError) | |
class Mutex(object): | |
def __init__(self, client, key, | |
expire=10, | |
retry_count=6, # retry_count * retry_sleep_sec = 最大待ち時間 | |
retry_setnx_count=100, | |
retry_sleep_sec=0.25): | |
self._lock = None | |
self._r = client | |
self._key = key | |
self._expire = expire | |
self._retry_count = retry_count | |
self._retry_setnx_count = retry_setnx_count | |
self._retry_sleep_sec = retry_sleep_sec | |
def _get_now(self): | |
return float(datetime.now().strftime('%s.%f')) | |
def lock(self): | |
if self._lock: | |
raise DuplicateLockError(self._key) | |
self._do_lock() | |
def _do_lock(self): | |
for n in xrange(0, self._retry_count): | |
is_set, old_expire = self._setnx() | |
if is_set: | |
self._lock = self._get_now() | |
return | |
if self._need_retry(old_expire): | |
continue | |
if not self._need_retry(self._getset()): | |
self._lock = self._get_now() | |
return | |
raise LockError(self._key) | |
def _setnx(self): | |
for n in xrange(0, self._retry_setnx_count): | |
is_set = self._r.setnx(self._key, self._get_now() + self._expire) | |
if is_set: | |
return True, 0 | |
old_expire = self._r.get(self._key) | |
if old_expire is not None: | |
return False, float(old_expire) | |
raise SetnxError(self._key) | |
def _need_retry(self, expire): | |
if expire < self._get_now(): | |
return False | |
time.sleep(self._retry_sleep_sec) | |
return True | |
def _getset(self): | |
old_expire = self._r.getset(self._key, self._get_now() + self._expire) | |
if old_expire is None: | |
return 0 | |
return float(old_expire) | |
def unlock(self): | |
if not self._lock: | |
raise HasNotLockError(self._key) | |
elapsed_time = self._get_now() - self._lock | |
if self._expire <= elapsed_time: | |
raise ExpiredLockError(self._key, elapsed_time) | |
self._r.delete(self._key) | |
self._lock = None | |
def __enter__(self): | |
self.lock() | |
return self | |
def __exit__(self, exc_type, exc_value, traceback): | |
if self._lock: | |
self.unlock() | |
return True if exc_type is None else False | |
def __call__(self, func): | |
@wraps(func) | |
def inner(*args, **kwargs): | |
with self: | |
return func(*args, **kwargs) | |
return inner |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import unittest | |
import redis | |
import time | |
from multiprocessing import Process | |
from .mutex import Mutex | |
from .exception import (DuplicateLockError, | |
HasNotLockError, | |
ExpiredLockError, | |
LockError) | |
class TestMutex(unittest.TestCase): | |
def setUp(self): | |
self.key = 'spam' | |
self.r = redis.StrictRedis() | |
self.mutex = Mutex(self.r, self.key) | |
def tearDown(self): | |
mutex = self.mutex | |
if mutex._lock: | |
mutex.unlock() | |
mutex._r.delete('ham') | |
def test_lock(self): | |
mutex = self.mutex | |
mutex.lock() | |
self.assertIsNotNone(mutex._r.get(mutex._key)) | |
with self.assertRaises(DuplicateLockError): | |
mutex.lock() | |
def test_unlock(self): | |
self.test_lock() | |
mutex = self.mutex | |
self.mutex.unlock() | |
self.assertIsNone(mutex._r.get(mutex._key)) | |
with self.assertRaises(HasNotLockError): | |
mutex.unlock() | |
self.test_lock() | |
time.sleep(10.5) | |
with self.assertRaises(ExpiredLockError): | |
mutex.unlock() | |
mutex._lock = None # 強制的に初期化 | |
def test_expire(self): | |
mutex1 = self.mutex | |
mutex2 = Mutex(self.r, self.key, expire=2) | |
mutex2.lock() # 2 秒 Lock し続ける | |
with self.assertRaises(LockError): | |
mutex1.lock() # retry 6 回 * sleep 0.25 秒 = 1.5 秒 | |
time.sleep(0.6) # おまけ | |
mutex1.lock() | |
self.assertIsNotNone(mutex1._r.get(mutex1._key)) | |
def test_with(self): | |
mutex1 = self.mutex | |
with mutex1: | |
self.assertIsNotNone(mutex1._r.get(mutex1._key)) | |
self.assertIsNone(mutex1._r.get(mutex1._key)) | |
mutex2 = Mutex(self.r, self.key, expire=2) | |
mutex2.lock() # 2 秒 Lock し続ける | |
with self.assertRaises(LockError): | |
with mutex1: # retry 6 回 * sleep 0.25 秒 = 1.5 秒 | |
pass | |
mutex2.unlock() | |
with mutex1: | |
with self.assertRaises(DuplicateLockError): | |
with mutex1: | |
pass | |
def test_decorator(self): | |
mutex = self.mutex | |
@mutex | |
def egg(): | |
self.assertIsNotNone(mutex._r.get(mutex._key)) | |
egg() | |
self.assertIsNone(mutex._r.get(mutex._key)) | |
def test_multi_process(self): | |
procs = 20 | |
counter = 100 | |
def incr(): | |
mutex = Mutex(redis.StrictRedis(), self.key, retry_count=100) | |
for n in xrange(0, counter): | |
mutex.lock() | |
ham = mutex._r.get('ham') or 0 | |
mutex._r.set('ham', int(ham) + 1) | |
mutex.unlock() | |
ps = [Process(target=incr) for n in xrange(0, procs)] | |
for p in ps: | |
p.start() | |
for p in ps: | |
p.join() | |
self.assertEqual(int(self.mutex._r.get('ham')), counter * procs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
>>> from mutex import Mutex | |
>>> with Mutex(':'.join(['EmitAccessToken', user_id]): | |
>>> # do something ... | |
>>> pass | |
>>> @Mutex(':'.join(['EmitAccessToken', user_id]): | |
>>> def emit_access_token(): | |
>>> # do something ... | |
>>> pass | |
>>> mutex = Mutex(':'.join(['EmitAccessToken', user_id]) | |
>>> mutex.lock() | |
>>> # do something ... | |
>>> mutex.unlock() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment