Last active
June 16, 2024 18:17
-
-
Save frostming/c9ef1c74ca28cf2f8fca3d4d15e9e245 to your computer and use it in GitHub Desktop.
Python implementation of JavaScript's Promise interface
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Python implementation of JavaScript's Promise interface | |
with the power of asyncio. | |
See https://developer.mozilla.org/zh-CN/docs/Web/JavaScript/Reference/Global_Objects/Promise | |
for the API specification. | |
Authored by: Frost Ming <[email protected]> | |
License: WTFPL | |
""" | |
import asyncio | |
import functools | |
import time | |
from typing import Any, Callable, Iterable, Optional | |
ResolveFunc = Callable[[Any], Any] | |
RejectFunc = Callable[[Exception], Any] | |
FinallyFunc = Callable[[], Any] | |
PromiseCallBack = Callable[[ResolveFunc, RejectFunc], Any] | |
class AggregateError(Exception): | |
def __init__(self, errors: Iterable[Exception]) -> None: | |
self.errors = list(errors) | |
class Promise: | |
"""A class that implements the JavaScript's Promise interface, with asyncio!""" | |
def __init__(self, callback: PromiseCallBack) -> None: | |
self.callback = callback | |
try: | |
self._loop = asyncio.get_event_loop() | |
except RuntimeError: | |
self._loop = asyncio.new_event_loop() | |
# Commit the task to the event loop immediately after Promise is created | |
self._future = self.__run() | |
def __repr__(self) -> str: | |
try: | |
result = f" {self._future.result()!r}" | |
except Exception: | |
result = "" | |
return f"<{self.__class__.__name__}[{self._future._state}]{result}>" | |
def __await__(self): | |
yield from self._future | |
result = self._future.result() | |
# Await all chaining promises | |
if isinstance(result, Promise): | |
return (yield from result.__await__()) | |
else: | |
return result | |
def __run(self): | |
def _resolve(r): | |
def _handle(): | |
if not future.done(): | |
future.set_result(r) | |
self._loop.call_soon_threadsafe(_handle) | |
def _reject(e): | |
def _handle(): | |
if not future.done(): | |
future.set_exception(e) | |
self._loop.call_soon_threadsafe(_handle) | |
future = self._loop.create_future() | |
handle = self._loop.run_in_executor(None, self.callback, _resolve, _reject) | |
def _check_cancel(target): | |
if target.cancelled(): | |
handle.get_loop().call_soon_threadsafe(handle.cancel) | |
future.add_done_callback(_check_cancel) | |
return future | |
@classmethod | |
def resolve(cls, result: Any) -> "Promise": | |
def callback(resolve_func, _): | |
resolve_func(result) | |
return cls(callback) | |
@classmethod | |
def reject(cls, exc: Exception) -> "Promise": | |
def callback(_, reject_func): | |
reject_func(exc) | |
return cls(callback) | |
def __chain_promise(self) -> "Promise": | |
"""Create a new promise, chaining the status of current promise.""" | |
this_future = self._future | |
def _noop_callback(resolve_func, reject_func): | |
pass | |
def _check_cancel(future): | |
source_loop, other_loop = future.get_loop(), self._loop | |
if future.cancelled(): | |
if source_loop is other_loop: | |
this_future.cancel() | |
else: | |
other_loop.call_soon_threadsafe(this_future.cancel) | |
new_promise = Promise(_noop_callback) | |
new_promise._future.add_done_callback(functools.partial(_check_cancel)) | |
return new_promise | |
def then( | |
self, resolve_func: ResolveFunc, reject_func: Optional[RejectFunc] = None | |
) -> "Promise": | |
new_promise = self.__chain_promise() | |
new_future = new_promise._future | |
dest_loop = new_future.get_loop() | |
def receive_result(future: asyncio.Future): | |
if future.cancelled(): | |
return | |
elif future.exception(): | |
if reject_func is None: | |
dest_loop.call_soon_threadsafe( | |
new_future.set_exception, future.exception() | |
) | |
else: | |
try: | |
result = reject_func(future.exception()) | |
except Exception as e: | |
dest_loop.call_soon_threadsafe(new_future.set_exception, e) | |
else: | |
dest_loop.call_soon_threadsafe(new_future.set_result, result) | |
else: | |
result = resolve_func(future.result()) | |
dest_loop.call_soon_threadsafe(new_future.set_result, result) | |
if not self._future.done(): | |
self._future.add_done_callback(receive_result) | |
else: | |
receive_result(self._future) | |
return new_promise | |
def catch(self, callback: RejectFunc) -> "Promise": | |
new_promise = self.__chain_promise() | |
new_future = new_promise._future | |
dest_loop = new_future.get_loop() | |
def receive_result(future: asyncio.Future): | |
if future.cancelled(): | |
return | |
elif future.exception(): | |
try: | |
result = callback(future.exception()) | |
except Exception as e: | |
dest_loop.call_soon_threadsafe(new_future.set_exception, e) | |
else: | |
dest_loop.call_soon_threadsafe(new_future.set_result, result) | |
else: | |
dest_loop.call_soon_threadsafe(new_future.set_result, future.result()) | |
if not self._future.done(): | |
self._future.add_done_callback(receive_result) | |
else: | |
receive_result(self._future) | |
return new_promise | |
def final(self, callback: FinallyFunc) -> "Promise": | |
new_promise = self.__chain_promise() | |
new_future = new_promise._future | |
dest_loop = new_future.get_loop() | |
def receive_result(future: asyncio.Future): | |
if future.cancelled(): | |
return | |
elif future.exception(): | |
dest_loop.call_soon_threadsafe( | |
new_future.set_result, future.exception() | |
) | |
else: | |
dest_loop.call_soon_threadsafe(new_future.set_result, future.result()) | |
callback() | |
if not self._future.done(): | |
self._future.add_done_callback(receive_result) | |
else: | |
receive_result(self._future) | |
return new_promise | |
@classmethod | |
def all(cls, promises: Iterable["Promise"]) -> "Promise": | |
promises = list(promises) | |
def callback(resolve, reject): | |
def on_resolve(_): | |
if any(not p._future.done() or p._future.exception() for p in promises): | |
return | |
resolve([p._future.result() for p in promises]) | |
def on_reject(err): | |
for p in promises: | |
if not p._future.done(): | |
p._future.get_loop().call_soon_threadsafe(p._future.cancel) | |
reject(err) | |
for p in promises: | |
p.then(on_resolve, on_reject) | |
return cls(callback) | |
@classmethod | |
def any(cls, promises: Iterable["Promise"]) -> "Promise": | |
promises = list(promises) | |
def callback(resolve, reject): | |
def on_resolve(r): | |
for p in promises: | |
if not p._future.done(): | |
p._future.get_loop().call_soon_threadsafe(p._future.cancel) | |
resolve(r) | |
def on_reject(err): | |
if any( | |
not p._future.done() or not p._future.exception() for p in promises | |
): | |
return | |
errors = filter(None, (p._future.exception() for p in promises)) | |
reject(AggregateError(errors)) | |
for p in promises: | |
p.then(on_resolve, on_reject) | |
return cls(callback) | |
@classmethod | |
def race(cls, promises: Iterable["Promise"]) -> "Promise": | |
promises = list(promises) | |
def callback(resolve, reject): | |
def on_resolve(r): | |
for p in promises: | |
if not p._future.done(): | |
p._future.get_loop().call_soon_threadsafe(p._future.cancel) | |
resolve(r) | |
def on_reject(err): | |
for p in promises: | |
if not p._future.done(): | |
p._future.get_loop().call_soon_threadsafe(p._future.cancel) | |
reject(err) | |
for p in promises: | |
p.then(on_resolve, on_reject) | |
return cls(callback) | |
def is_even(number): | |
def callback(resolve, reject): | |
if not isinstance(number, int): | |
reject(ValueError(f"Invalid number: {number!r}")) | |
else: | |
resolve(number % 2 == 0) | |
return Promise(callback) | |
async def test(): | |
assert await is_even(4) | |
assert not (await is_even(5)) | |
err = None | |
try: | |
await is_even("4") | |
except Exception as e: | |
err = e | |
assert str(err) == "Invalid number: '4'" | |
assert await is_even(4).then(str).then(lambda s: s.upper()) == "TRUE" | |
async def main(): | |
promises = [Promise.resolve("foo")] | |
def callback(resolve, reject): | |
time.sleep(1) | |
resolve("bar") | |
promises.append(Promise(callback)) | |
print(await Promise.race(promises)) | |
if __name__ == "__main__": | |
# asyncio.run(test()) | |
asyncio.run(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment