Skip to content

Instantly share code, notes, and snippets.

@v--
Created September 16, 2024 17:48
Show Gist options
  • Save v--/3f98161b5ff770f1eec16848fc83726e to your computer and use it in GitHub Desktop.
Save v--/3f98161b5ff770f1eec16848fc83726e to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "a87cb44e-ade1-4b93-8703-22d6b8f3ad2b",
"metadata": {},
"source": [
"# Recursion prevention\n",
"\n",
"Here are two utilities for preventing accidential recursive calls where they should be forbidden. They grew out of a very peculiar problem we had at work. The solutions are not themselves elegant, but are useful for debugging and usable for safe-proofing code that can do nasty things when called recursively."
]
},
{
"cell_type": "markdown",
"id": "7a90419b-d120-484a-9131-c318356289ea",
"metadata": {},
"source": [
"## Recursive functions\n",
"\n",
"The following is a decorator that wraps a function and \"poisons\" the call stack so that it cannot be called again on deeper levels."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "7afead00-6d12-4acf-9f5e-d79c3d29e9ed",
"metadata": {},
"outputs": [],
"source": [
"import functools\n",
"import inspect\n",
"from collections.abc import Callable\n",
"from typing import Never\n",
"\n",
"\n",
"def prevent_recursion[**P, R](fun: Callable[P, R]) -> Callable[P, R]:\n",
" # We simply need a sentinel object with a unique id\n",
" sentinel = set[Never]()\n",
"\n",
" @functools.wraps(fun)\n",
" def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:\n",
" call_stack = inspect.stack()\n",
"\n",
" for frame_info in call_stack[1:]:\n",
" if frame_info.frame.f_locals.get('sentinel') is sentinel:\n",
" raise RecursionError(f'Did not expect a recursive function call for {fun!r}.')\n",
"\n",
" return fun(*args, **kwargs)\n",
"\n",
" return wrapper"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "f3c1965b-841c-4b64-af0e-9b64e4b1752c",
"metadata": {},
"outputs": [],
"source": [
"@prevent_recursion\n",
"def fibonacci(n: int) -> int:\n",
" assert n >= 0\n",
"\n",
" if n < 2:\n",
" return n\n",
"\n",
" return fibonacci(n - 1) + fibonacci(n - 2)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "166a9493-9a66-4f47-ab84-e827a972e8f0",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"fibonacci(1)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "a14afb57-2b2e-45be-9fad-acf35c48fdf5",
"metadata": {
"lines_to_next_cell": 2
},
"outputs": [
{
"ename": "RecursionError",
"evalue": "Did not expect a recursive function call for <function fibonacci at 0x7255289c7b00>.",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRecursionError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[4], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mfibonacci\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m3\u001b[39;49m\u001b[43m)\u001b[49m\n",
"Cell \u001b[0;32mIn[1], line 19\u001b[0m, in \u001b[0;36mprevent_recursion.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m frame_info\u001b[38;5;241m.\u001b[39mframe\u001b[38;5;241m.\u001b[39mf_locals\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124msentinel\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;129;01mis\u001b[39;00m sentinel:\n\u001b[1;32m 17\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRecursionError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mDid not expect a recursive function call for \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfun\u001b[38;5;132;01m!r}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m---> 19\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"Cell \u001b[0;32mIn[2], line 8\u001b[0m, in \u001b[0;36mfibonacci\u001b[0;34m(n)\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m n \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m2\u001b[39m:\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m n\n\u001b[0;32m----> 8\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfibonacci\u001b[49m\u001b[43m(\u001b[49m\u001b[43mn\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m \u001b[38;5;241m+\u001b[39m fibonacci(n \u001b[38;5;241m-\u001b[39m \u001b[38;5;241m2\u001b[39m)\n",
"Cell \u001b[0;32mIn[1], line 17\u001b[0m, in \u001b[0;36mprevent_recursion.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m frame_info \u001b[38;5;129;01min\u001b[39;00m call_stack[\u001b[38;5;241m1\u001b[39m:]:\n\u001b[1;32m 16\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m frame_info\u001b[38;5;241m.\u001b[39mframe\u001b[38;5;241m.\u001b[39mf_locals\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124msentinel\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;129;01mis\u001b[39;00m sentinel:\n\u001b[0;32m---> 17\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRecursionError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mDid not expect a recursive function call for \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfun\u001b[38;5;132;01m!r}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 19\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m fun(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
"\u001b[0;31mRecursionError\u001b[0m: Did not expect a recursive function call for <function fibonacci at 0x7255289c7b00>."
]
}
],
"source": [
"fibonacci(3)"
]
},
{
"cell_type": "markdown",
"id": "09e0e115-afad-4cd8-9ea7-e480c907f2e3",
"metadata": {},
"source": [
"## Call stack singletons\n",
"\n",
"The following mixin disallows a class to be instantiated if the call stack already has a living instance **assigned to a variable**."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "af6e9438-e185-4447-b4e9-ac1a683ed1a3",
"metadata": {},
"outputs": [],
"source": [
"import inspect\n",
"from typing import Any, Self\n",
"\n",
"\n",
"class CallStackSingletonMixin:\n",
" def __new__(cls, *args: Any, **kwargs: Any) -> Self:\n",
" call_stack = inspect.stack()\n",
"\n",
" for frame_info in call_stack[1:]:\n",
" for value in frame_info.frame.f_locals.values():\n",
" if isinstance(value, cls):\n",
" raise RecursionError(f'Only one instance of {cls!r} allowed in the call stack.')\n",
"\n",
" return super().__new__(cls)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "bd78a022-cd0f-4525-a514-0cc067536075",
"metadata": {},
"outputs": [],
"source": [
"class Test(CallStackSingletonMixin):\n",
" def __init__(self) -> None:\n",
" pass"
]
},
{
"cell_type": "markdown",
"id": "3af5f6c3-2106-4ff8-a129-1fe246a127a4",
"metadata": {},
"source": [
"The following is fine because the first instance is not accessible (directly). The capture magic ensures that no instance is assigned to a variable by Jupyter."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "9d1f6dfb-c035-4457-88b4-61f755b6879a",
"metadata": {},
"outputs": [],
"source": [
"%%capture\n",
"\n",
"Test()\n",
"Test()"
]
},
{
"cell_type": "markdown",
"id": "6634a25d-2622-40d7-be5b-9e629325aca7",
"metadata": {},
"source": [
"The following is fine because there is an instance that is not assigned directly (although ideally it should not be fine)."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "194db4bb-8c9d-4491-9672-61b6c26a7826",
"metadata": {},
"outputs": [],
"source": [
"%%capture\n",
"\n",
"test_set = {Test()}\n",
"Test()"
]
},
{
"cell_type": "markdown",
"id": "bbdbe409-1541-4848-9d0e-be915bd8922c",
"metadata": {},
"source": [
"The following is not fine because \"t\" is a living instance."
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "88ca9b24-7f23-464d-8fb2-66a78c71b23f",
"metadata": {},
"outputs": [
{
"ename": "RecursionError",
"evalue": "Only one instance of <class '__main__.Test'> allowed in the call stack.",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRecursionError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[9], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m t \u001b[38;5;241m=\u001b[39m Test()\n\u001b[0;32m----> 2\u001b[0m \u001b[43mTest\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
"Cell \u001b[0;32mIn[5], line 12\u001b[0m, in \u001b[0;36mCallStackSingletonMixin.__new__\u001b[0;34m(cls, *args, **kwargs)\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m value \u001b[38;5;129;01min\u001b[39;00m frame_info\u001b[38;5;241m.\u001b[39mframe\u001b[38;5;241m.\u001b[39mf_locals\u001b[38;5;241m.\u001b[39mvalues():\n\u001b[1;32m 11\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(value, \u001b[38;5;28mcls\u001b[39m):\n\u001b[0;32m---> 12\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRecursionError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mOnly one instance of \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mcls\u001b[39m\u001b[38;5;132;01m!r}\u001b[39;00m\u001b[38;5;124m allowed in the call stack.\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 14\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__new__\u001b[39m(\u001b[38;5;28mcls\u001b[39m)\n",
"\u001b[0;31mRecursionError\u001b[0m: Only one instance of <class '__main__.Test'> allowed in the call stack."
]
}
],
"source": [
"%%capture\n",
"\n",
"t = Test()\n",
"Test()"
]
}
],
"metadata": {
"jupytext": {
"formats": "ipynb,py:percent"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment