Last active
April 11, 2026 01:49
-
-
Save synodriver/dcf1fac7e553b7998523c7073f83b5a7 to your computer and use it in GitHub Desktop.
like singleflight in go
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # -*- coding: utf-8 -*- | |
| """ | |
| Copyright (c) 2008-2026 synodriver <diguohuangjiajinweijun@gmail.com> | |
| """ | |
| from typing import Callable, ParamSpec, TypeVar, Awaitable, NoReturn | |
| from threading import Lock | |
| import asyncio | |
| P = ParamSpec('P') | |
| R = TypeVar('R') | |
| class Caller: | |
| def __init__(self): | |
| self._val = None | |
| self._err = None | |
| self._done = asyncio.Event() | |
| self._done.clear() | |
| def set_result(self, val): | |
| self._val = val | |
| self._done.set() | |
| def set_exception(self, err): | |
| self._err = err | |
| self._done.set() | |
| async def result(self): | |
| await self._done.wait() | |
| if self._err is not None: | |
| raise self._err | |
| return self._val | |
| class SingleFlight: | |
| def __init__(self): | |
| self._cached = {} | |
| self._lock = Lock() | |
| async def do(self, key: str, func: Callable[P, Awaitable[R | NoReturn]], *args: P.args, **kwargs: P.kwargs) -> R | NoReturn: | |
| self._lock.acquire() | |
| if key in self._cached: | |
| caller = self._cached[key] | |
| self._lock.release() | |
| return await caller.result() | |
| caller = Caller() | |
| self._cached[key] = caller | |
| self._lock.release() | |
| try: | |
| val = await func(*args, **kwargs) | |
| caller.set_result(val) | |
| return val | |
| except BaseException as e: | |
| caller.set_exception(e) | |
| raise | |
| finally: | |
| del self._cached[key] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # -*- coding: utf-8 -*- | |
| """ | |
| Copyright (c) 2008-2026 synodriver <diguohuangjiajinweijun@gmail.com> | |
| """ | |
| import asyncio | |
| from unittest import IsolatedAsyncioTestCase | |
| import unittest | |
| from singleflight import SingleFlight | |
| class SingleFlightTestCase(IsolatedAsyncioTestCase): | |
| def setUp(self) -> None: | |
| self.sf = SingleFlight() | |
| async def test_singleflight(self): | |
| data = 0 | |
| async def func(): | |
| nonlocal data | |
| await asyncio.sleep(0.1) | |
| data += 1 | |
| return 32 | |
| t1 = asyncio.create_task(self.sf.do("key", func)) | |
| t2 = asyncio.create_task(self.sf.do("key", func)) | |
| t3 = asyncio.create_task(self.sf.do("key", func)) | |
| await asyncio.gather(t1, t2, t3) | |
| self.assertEqual(data, 1) | |
| self.assertEqual(t1.result(), 32) | |
| self.assertEqual(t2.result(), 32) | |
| self.assertEqual(t3.result(), 32) | |
| data = 0 | |
| t1 = asyncio.create_task(self.sf.do("key", func)) | |
| t2 = asyncio.create_task(self.sf.do("key", func)) | |
| t3 = asyncio.create_task(self.sf.do("key2", func)) | |
| t4 = asyncio.create_task(self.sf.do("key2", func)) | |
| await asyncio.gather(t1, t2, t3, t4) | |
| self.assertEqual(data, 2) | |
| async def test_singleflight_err(self): | |
| data = 0 | |
| async def func(): | |
| nonlocal data | |
| data += 1 | |
| await asyncio.sleep(0.1) | |
| raise ValueError | |
| t1 = asyncio.create_task(self.sf.do("key", func)) | |
| t2 = asyncio.create_task(self.sf.do("key", func)) | |
| t3 = asyncio.create_task(self.sf.do("key", func)) | |
| with self.assertRaises(ValueError): | |
| await asyncio.gather(t1, t2, t3) | |
| self.assertEqual(data, 1) | |
| async def test_cancel(self): | |
| data = 0 | |
| async def func(): | |
| nonlocal data | |
| await asyncio.sleep(5) | |
| data += 1 | |
| t1 = asyncio.create_task(self.sf.do("key", func)) | |
| t2 = asyncio.create_task(self.sf.do("key", func)) | |
| t3 = asyncio.create_task(self.sf.do("key", func)) | |
| await asyncio.sleep(0) # 必须有,没有这个,三个task还没调度,t1直接结束,没机会给t2和t3传递cancel | |
| t1.cancel() | |
| try: | |
| await t1 | |
| except asyncio.CancelledError: | |
| pass | |
| # with self.assertRaises(asyncio.CancelledError): | |
| # await t1 | |
| with self.assertRaises(asyncio.CancelledError): | |
| await t2 | |
| with self.assertRaises(asyncio.CancelledError): | |
| await t3 | |
| self.assertEqual(data, 0) | |
| if __name__ == '__main__': | |
| unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment