Skip to content

Instantly share code, notes, and snippets.

@synodriver
Last active April 11, 2026 01:49
Show Gist options
  • Select an option

  • Save synodriver/dcf1fac7e553b7998523c7073f83b5a7 to your computer and use it in GitHub Desktop.

Select an option

Save synodriver/dcf1fac7e553b7998523c7073f83b5a7 to your computer and use it in GitHub Desktop.
like singleflight in go
# -*- coding: utf-8 -*-
"""
Copyright (c) 2008-2026 synodriver <diguohuangjiajinweijun@gmail.com>
"""
from typing import Callable, ParamSpec, TypeVar, Awaitable, NoReturn
from threading import Lock
import asyncio
P = ParamSpec('P')
R = TypeVar('R')
class Caller:
def __init__(self):
self._val = None
self._err = None
self._done = asyncio.Event()
self._done.clear()
def set_result(self, val):
self._val = val
self._done.set()
def set_exception(self, err):
self._err = err
self._done.set()
async def result(self):
await self._done.wait()
if self._err is not None:
raise self._err
return self._val
class SingleFlight:
def __init__(self):
self._cached = {}
self._lock = Lock()
async def do(self, key: str, func: Callable[P, Awaitable[R | NoReturn]], *args: P.args, **kwargs: P.kwargs) -> R | NoReturn:
self._lock.acquire()
if key in self._cached:
caller = self._cached[key]
self._lock.release()
return await caller.result()
caller = Caller()
self._cached[key] = caller
self._lock.release()
try:
val = await func(*args, **kwargs)
caller.set_result(val)
return val
except BaseException as e:
caller.set_exception(e)
raise
finally:
del self._cached[key]
# -*- coding: utf-8 -*-
"""
Copyright (c) 2008-2026 synodriver <diguohuangjiajinweijun@gmail.com>
"""
import asyncio
from unittest import IsolatedAsyncioTestCase
import unittest
from singleflight import SingleFlight
class SingleFlightTestCase(IsolatedAsyncioTestCase):
def setUp(self) -> None:
self.sf = SingleFlight()
async def test_singleflight(self):
data = 0
async def func():
nonlocal data
await asyncio.sleep(0.1)
data += 1
return 32
t1 = asyncio.create_task(self.sf.do("key", func))
t2 = asyncio.create_task(self.sf.do("key", func))
t3 = asyncio.create_task(self.sf.do("key", func))
await asyncio.gather(t1, t2, t3)
self.assertEqual(data, 1)
self.assertEqual(t1.result(), 32)
self.assertEqual(t2.result(), 32)
self.assertEqual(t3.result(), 32)
data = 0
t1 = asyncio.create_task(self.sf.do("key", func))
t2 = asyncio.create_task(self.sf.do("key", func))
t3 = asyncio.create_task(self.sf.do("key2", func))
t4 = asyncio.create_task(self.sf.do("key2", func))
await asyncio.gather(t1, t2, t3, t4)
self.assertEqual(data, 2)
async def test_singleflight_err(self):
data = 0
async def func():
nonlocal data
data += 1
await asyncio.sleep(0.1)
raise ValueError
t1 = asyncio.create_task(self.sf.do("key", func))
t2 = asyncio.create_task(self.sf.do("key", func))
t3 = asyncio.create_task(self.sf.do("key", func))
with self.assertRaises(ValueError):
await asyncio.gather(t1, t2, t3)
self.assertEqual(data, 1)
async def test_cancel(self):
data = 0
async def func():
nonlocal data
await asyncio.sleep(5)
data += 1
t1 = asyncio.create_task(self.sf.do("key", func))
t2 = asyncio.create_task(self.sf.do("key", func))
t3 = asyncio.create_task(self.sf.do("key", func))
await asyncio.sleep(0) # 必须有,没有这个,三个task还没调度,t1直接结束,没机会给t2和t3传递cancel
t1.cancel()
try:
await t1
except asyncio.CancelledError:
pass
# with self.assertRaises(asyncio.CancelledError):
# await t1
with self.assertRaises(asyncio.CancelledError):
await t2
with self.assertRaises(asyncio.CancelledError):
await t3
self.assertEqual(data, 0)
if __name__ == '__main__':
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment