Created
September 16, 2024 17:48
-
-
Save v--/3f98161b5ff770f1eec16848fc83726e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "a87cb44e-ade1-4b93-8703-22d6b8f3ad2b", | |
"metadata": {}, | |
"source": [ | |
"# Recursion prevention\n", | |
"\n", | |
"Here are two utilities for preventing accidential recursive calls where they should be forbidden. They grew out of a very peculiar problem we had at work. The solutions are not themselves elegant, but are useful for debugging and usable for safe-proofing code that can do nasty things when called recursively." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "7a90419b-d120-484a-9131-c318356289ea", | |
"metadata": {}, | |
"source": [ | |
"## Recursive functions\n", | |
"\n", | |
"The following is a decorator that wraps a function and \"poisons\" the call stack so that it cannot be called again on deeper levels." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "7afead00-6d12-4acf-9f5e-d79c3d29e9ed", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import functools\n", | |
"import inspect\n", | |
"from collections.abc import Callable\n", | |
"from typing import Never\n", | |
"\n", | |
"\n", | |
"def prevent_recursion[**P, R](fun: Callable[P, R]) -> Callable[P, R]:\n", | |
" # We simply need a sentinel object with a unique id\n", | |
" sentinel = set[Never]()\n", | |
"\n", | |
" @functools.wraps(fun)\n", | |
" def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:\n", | |
" call_stack = inspect.stack()\n", | |
"\n", | |
" for frame_info in call_stack[1:]:\n", | |
" if frame_info.frame.f_locals.get('sentinel') is sentinel:\n", | |
" raise RecursionError(f'Did not expect a recursive function call for {fun!r}.')\n", | |
"\n", | |
" return fun(*args, **kwargs)\n", | |
"\n", | |
" return wrapper" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "f3c1965b-841c-4b64-af0e-9b64e4b1752c", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"@prevent_recursion\n", | |
"def fibonacci(n: int) -> int:\n", | |
" assert n >= 0\n", | |
"\n", | |
" if n < 2:\n", | |
" return n\n", | |
"\n", | |
" return fibonacci(n - 1) + fibonacci(n - 2)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "166a9493-9a66-4f47-ab84-e827a972e8f0", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"1" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"fibonacci(1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "a14afb57-2b2e-45be-9fad-acf35c48fdf5", | |
"metadata": { | |
"lines_to_next_cell": 2 | |
}, | |
"outputs": [ | |
{ | |
"ename": "RecursionError", | |
"evalue": "Did not expect a recursive function call for <function fibonacci at 0x7255289c7b00>.", | |
"output_type": "error", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mRecursionError\u001b[0m Traceback (most recent call last)", | |
"Cell \u001b[0;32mIn[4], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mfibonacci\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m3\u001b[39;49m\u001b[43m)\u001b[49m\n", | |
"Cell \u001b[0;32mIn[1], line 19\u001b[0m, in \u001b[0;36mprevent_recursion.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m frame_info\u001b[38;5;241m.\u001b[39mframe\u001b[38;5;241m.\u001b[39mf_locals\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124msentinel\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;129;01mis\u001b[39;00m sentinel:\n\u001b[1;32m 17\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRecursionError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mDid not expect a recursive function call for \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfun\u001b[38;5;132;01m!r}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m---> 19\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", | |
"Cell \u001b[0;32mIn[2], line 8\u001b[0m, in \u001b[0;36mfibonacci\u001b[0;34m(n)\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m n \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m2\u001b[39m:\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m n\n\u001b[0;32m----> 8\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfibonacci\u001b[49m\u001b[43m(\u001b[49m\u001b[43mn\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m \u001b[38;5;241m+\u001b[39m fibonacci(n \u001b[38;5;241m-\u001b[39m \u001b[38;5;241m2\u001b[39m)\n", | |
"Cell \u001b[0;32mIn[1], line 17\u001b[0m, in \u001b[0;36mprevent_recursion.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m frame_info \u001b[38;5;129;01min\u001b[39;00m call_stack[\u001b[38;5;241m1\u001b[39m:]:\n\u001b[1;32m 16\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m frame_info\u001b[38;5;241m.\u001b[39mframe\u001b[38;5;241m.\u001b[39mf_locals\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124msentinel\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;129;01mis\u001b[39;00m sentinel:\n\u001b[0;32m---> 17\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRecursionError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mDid not expect a recursive function call for \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfun\u001b[38;5;132;01m!r}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 19\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m fun(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n", | |
"\u001b[0;31mRecursionError\u001b[0m: Did not expect a recursive function call for <function fibonacci at 0x7255289c7b00>." | |
] | |
} | |
], | |
"source": [ | |
"fibonacci(3)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "09e0e115-afad-4cd8-9ea7-e480c907f2e3", | |
"metadata": {}, | |
"source": [ | |
"## Call stack singletons\n", | |
"\n", | |
"The following mixin disallows a class to be instantiated if the call stack already has a living instance **assigned to a variable**." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "af6e9438-e185-4447-b4e9-ac1a683ed1a3", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import inspect\n", | |
"from typing import Any, Self\n", | |
"\n", | |
"\n", | |
"class CallStackSingletonMixin:\n", | |
" def __new__(cls, *args: Any, **kwargs: Any) -> Self:\n", | |
" call_stack = inspect.stack()\n", | |
"\n", | |
" for frame_info in call_stack[1:]:\n", | |
" for value in frame_info.frame.f_locals.values():\n", | |
" if isinstance(value, cls):\n", | |
" raise RecursionError(f'Only one instance of {cls!r} allowed in the call stack.')\n", | |
"\n", | |
" return super().__new__(cls)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "bd78a022-cd0f-4525-a514-0cc067536075", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Test(CallStackSingletonMixin):\n", | |
" def __init__(self) -> None:\n", | |
" pass" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "3af5f6c3-2106-4ff8-a129-1fe246a127a4", | |
"metadata": {}, | |
"source": [ | |
"The following is fine because the first instance is not accessible (directly). The capture magic ensures that no instance is assigned to a variable by Jupyter." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "9d1f6dfb-c035-4457-88b4-61f755b6879a", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%%capture\n", | |
"\n", | |
"Test()\n", | |
"Test()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "6634a25d-2622-40d7-be5b-9e629325aca7", | |
"metadata": {}, | |
"source": [ | |
"The following is fine because there is an instance that is not assigned directly (although ideally it should not be fine)." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "194db4bb-8c9d-4491-9672-61b6c26a7826", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%%capture\n", | |
"\n", | |
"test_set = {Test()}\n", | |
"Test()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "bbdbe409-1541-4848-9d0e-be915bd8922c", | |
"metadata": {}, | |
"source": [ | |
"The following is not fine because \"t\" is a living instance." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "88ca9b24-7f23-464d-8fb2-66a78c71b23f", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"ename": "RecursionError", | |
"evalue": "Only one instance of <class '__main__.Test'> allowed in the call stack.", | |
"output_type": "error", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mRecursionError\u001b[0m Traceback (most recent call last)", | |
"Cell \u001b[0;32mIn[9], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m t \u001b[38;5;241m=\u001b[39m Test()\n\u001b[0;32m----> 2\u001b[0m \u001b[43mTest\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", | |
"Cell \u001b[0;32mIn[5], line 12\u001b[0m, in \u001b[0;36mCallStackSingletonMixin.__new__\u001b[0;34m(cls, *args, **kwargs)\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m value \u001b[38;5;129;01min\u001b[39;00m frame_info\u001b[38;5;241m.\u001b[39mframe\u001b[38;5;241m.\u001b[39mf_locals\u001b[38;5;241m.\u001b[39mvalues():\n\u001b[1;32m 11\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(value, \u001b[38;5;28mcls\u001b[39m):\n\u001b[0;32m---> 12\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRecursionError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mOnly one instance of \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mcls\u001b[39m\u001b[38;5;132;01m!r}\u001b[39;00m\u001b[38;5;124m allowed in the call stack.\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 14\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__new__\u001b[39m(\u001b[38;5;28mcls\u001b[39m)\n", | |
"\u001b[0;31mRecursionError\u001b[0m: Only one instance of <class '__main__.Test'> allowed in the call stack." | |
] | |
} | |
], | |
"source": [ | |
"%%capture\n", | |
"\n", | |
"t = Test()\n", | |
"Test()" | |
] | |
} | |
], | |
"metadata": { | |
"jupytext": { | |
"formats": "ipynb,py:percent" | |
}, | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.12.3" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment