Created
October 23, 2020 14:11
-
-
Save gvx/e57f0f6babdd56c6d1ef5f1787ab4c7e to your computer and use it in GitHub Desktop.
Create closures in Python without default arguments
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import dis | |
from typing import Iterable, TypeVar, Any | |
from collections.abc import Callable | |
from types import CodeType | |
F = TypeVar('F', bound=Callable) | |
def extract_mapping(names: tuple[str, ...], mapping: dict[str, int]) -> dict[int, int]: | |
return {names.index(k): v for k, v in mapping.items() if k in names} | |
LOAD_CONST = dis.opmap['LOAD_CONST'] | |
def new_opcodes(code: CodeType, global_overrides: dict[int, int], enclosing_overrides: dict[int, int]) -> Iterable[int]: | |
for opcode in dis.get_instructions(code): | |
if opcode.opname in ('LOAD_NAME', 'LOAD_GLOBAL') and opcode.arg in global_overrides: | |
yield LOAD_CONST | |
yield global_overrides[opcode.arg] | |
elif opcode.opname == 'LOAD_DEREF' and opcode.arg in enclosing_overrides: | |
yield LOAD_CONST | |
yield enclosing_overrides[opcode.arg] | |
else: | |
yield opcode.opcode | |
yield opcode.arg or 0 | |
def closure(**kwargs: Any) -> Callable[[F], F]: | |
def _closure(f: F) -> F: | |
code = f.__code__ | |
constant_indexes = {k: i for i, k in enumerate(kwargs, len(code.co_consts))} | |
global_overrides = extract_mapping(code.co_names, constant_indexes) | |
enclosing_overrides = extract_mapping(code.co_cellvars + code.co_freevars, constant_indexes) | |
if __debug__: | |
unused = {k for k, v in constant_indexes.items() if v not in global_overrides.values() and v not in enclosing_overrides.values()} | |
assert not unused, f'some variables ({", ".join(unused)}) are defined but not used in the function {f.__qualname__}' | |
f.__code__ = code.replace( | |
co_consts=code.co_consts + tuple(kwargs.values()), | |
co_code=bytes(new_opcodes(code, global_overrides, enclosing_overrides))) | |
return f | |
return _closure |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from closure_decorator import closure | |
l = [] | |
for i in range(10): | |
@closure(i=i) | |
def foo() -> None: | |
print(i) | |
l.append(foo) | |
for z in l: | |
z() | |
# prints 0 ... 9 instead of all nines! | |
# z(i=9) would have raised a type error | |
# allow **kwargs without conflicts: | |
@closure(collector_type=dict) | |
def collect_some_args(**kwargs): | |
return collector_type(kwargs) | |
print(collect_some_args(collector_type=list)) | |
# prints {'collector_type': <class 'list'>} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment