Last active
May 2, 2023 09:12
-
-
Save qexat/d80bfe66cfa17322ae5b86b847d47be7 to your computer and use it in GitHub Desktop.
cursed exception divergence using impure-state tracker object
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# requires: >= 3.10 | |
from __future__ import annotations | |
from collections.abc import Callable | |
import functools | |
import inspect | |
from types import TracebackType | |
from typing import Concatenate, ParamSpec, TypeVar | |
P = ParamSpec("P") | |
R = TypeVar("R") | |
class CallerException(Exception): | |
@classmethod | |
def wrap(cls, exc: Exception): | |
return cls(exc) | |
class Caller: | |
def __init__(self) -> None: | |
self.__exceptions: list[CallerException] = [] | |
self.__warnings: list[Warning] = [] | |
@property | |
def exceptions(self) -> tuple[Exception, ...]: | |
return tuple(self.__exceptions) | |
@property | |
def warnings(self) -> tuple[Warning, ...]: | |
return tuple(self.__warnings) | |
@property | |
def last_exception(self) -> CallerException: | |
return self.__exceptions[-1] | |
@property | |
def last_warning(self) -> Warning: | |
return self.__warnings[-1] | |
def report_error(self, exception: Exception) -> None: | |
assert (cur_frame := inspect.currentframe()) is not None | |
assert (prev_frame := cur_frame.f_back) is not None | |
frame, lasti, lineno = (prev_frame, prev_frame.f_lasti, prev_frame.f_lineno) | |
final = exception.with_traceback(TracebackType(None, frame, lasti, lineno)) | |
self.__exceptions.append(CallerException.wrap(final)) | |
def report_warning(self, warning: Warning) -> None: | |
self.__warnings.append(warning) | |
@classmethod | |
def entry_point( | |
cls, | |
function: Callable[Concatenate[Caller, P], R], | |
) -> Callable[P, R]: | |
caller = cls() | |
@functools.wraps(function) | |
def inner(*args: P.args, **kwargs: P.kwargs) -> R: | |
try: | |
return function(caller, *args, **kwargs) | |
except CallerException as ce: | |
exc: Exception = ce.args[0] | |
raise exc from None | |
return inner | |
def fib(caller: Caller, n: int) -> int: | |
if n < 0: | |
caller.report_error(ValueError("n must be >= 0")) | |
return -1 | |
if n <= 1: | |
return n | |
return fib(caller, n - 2) + fib(caller, n - 1) | |
@Caller.entry_point | |
def main(caller: Caller) -> None: | |
n = -1 | |
result = fib(caller, n) | |
if caller.exceptions: | |
raise caller.last_exception | |
print(result) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
output:
notice how the exception comes from
fib
, whereas it is actually raised frommain
... or is it? x)