Skip to content

Instantly share code, notes, and snippets.

@yonatanzunger
Created August 11, 2022 00:22
Show Gist options
  • Save yonatanzunger/12bca358330f3889f24fd35f7c556997 to your computer and use it in GitHub Desktop.
Save yonatanzunger/12bca358330f3889f24fd35f7c556997 to your computer and use it in GitHub Desktop.
Python Codegen Example: Main code
import dis
import io
from types import CodeType, FunctionType
from typing import Any, Callable, List, NamedTuple, Tuple
# Opcodes
_LOAD_FAST = dis.opname.index('LOAD_FAST')
_LOAD_CONST = dis.opname.index('LOAD_CONST')
_COMPARE_OP = dis.opname.index('COMPARE_OP')
_JUMP_IF_FALSE_OR_POP = dis.opname.index('JUMP_IF_FALSE_OR_POP')
_RETURN_VALUE = dis.opname.index('RETURN_VALUE')
# A Comparison represents a single comparison of the form 'args[variableIndex] OP testValue'. The
# operation should be the index of the corresponding operation in dis.cmp_op; e.g.,
# dis.cmp_op.index('==') = 2.
class Comparison(NamedTuple):
variableIndex: int
comparisonOp: int
testValue: Any
def makeFastComparator(
numArguments: int,
comparisons: List[Comparison],
filename: str = '__generated-code__',
funcName: str = '_compiled_CompareKey',
) -> Callable:
"""Generate a function which will evaluate a bunch of comparisons very efficiently.
This function will return a function which takes (numArguments) arguments, and which is
equivalent to
def funcName(x1, x2, ... xN):
return (x[i1] OP1 test1) and (x[i2] OP2 test2) and ...
where (i1, OP1, and test1) ... are the tuples passed in 'comparisons'. So for example,
makeFastComparator(2, [(0, dist.cmp_op.index('=='), b'1234'),
(1, dist.cmp_op.index('<=', b'3333')])
returns a function equivalent to
def myComparator(arg1, arg2):
return arg1 == b'1234' and arg2 <= b'3333'
The resulting function is significantly more efficient than the function you could generate
naively using for loops, etc.; in fact, it is bytecode-identical to what would be produced
from the 'and' statement described above. It can therefore be used in an innermost loop and
be really fast.
DEBUGGING NOTES:
- The resulting function will have its __name__ and __module__ values set to the funcName and
filename arguments to this function, respectively. Importantly, this will affect how the
function shows up in profiles, so you can set this if you want different return values from
makeFastComparator to show up as different profiling items.
- This function will have fake "line numbers" in it which will show up in tracebacks if
something goes wrong. Line numbers 1...N correspond to evaluating the corresponding comparison
operation; line number N+1 is the final 'return' statement. This can help you debug problems
if e.g. you passed a bogus argument to the function.
"""
constants: Tuple[Any, ...]
if not comparisons:
# Simple case: If there are no comparisons, this is just the function that returns True.
constants = (None, True)
bytecode = bytes([_LOAD_CONST, 1, _RETURN_VALUE, 0])
lnotab = bytes([0, 1])
else:
# We're going to generate bytecode for each comparison (index, operation, value) that looks
# like
# 1: LOAD_FAST <index1>
# LOAD_CONST <value1>
# COMPARE_OP <operation1>
# JUMP_IF_FALSE_OR_POP <return>
# 2: LOAD_FAST <index2>
# LOAD_CONST <value2>
# COMPARE_OP <operation2>
# .... repeated for each value except the last one ...
# N: LOAD_FAST <indexN>
# LOAD_CONST <valueN>
# COMPARE_OP <operationN>
# N+1: RETURN_VALUE
#
# Here <return> is the offset of the RETURN_VALUE instruction, which (since all
# instructions are exactly two bytes) = 8 * (ncmps - 1) + 6 = 8 * ncmps - 2. The numbers
# on the left are line numbers, so that you can debug the output more easily!
constants = (None, *(compare[2] for compare in comparisons))
COMPARISON_LENGTH = 4 * 2
writer = io.BytesIO()
returnAddress = COMPARISON_LENGTH * len(comparisons) - 2
if returnAddress > 255:
raise RuntimeError(
'Hmm. Handling jumps of more than 255 bytes would require more '
'intelligent code than we\'ve written here.'
)
# Helpers for writing out the table of line numbers. This format is documented in
# https://github.com/python/cpython/blob/master/Objects/lnotab_notes.txt
# but we're just using its simplest form. incrementLine() basically says "the point
# where we're about to write is the next line of code."
linenoWriter = io.BytesIO()
lastLineStart = 0
def incrementLine() -> None:
nonlocal lastLineStart
currentPos = writer.tell()
linenoWriter.write(bytes([currentPos - lastLineStart, 1]))
lastLineStart = currentPos
for opNum, comparison in enumerate(comparisons):
index, cmpOp, _ = comparison
if index < 0 or index >= numArguments:
raise ValueError(
f'Got bad index {index} for comparison request with only '
f'{numArguments} arguments!'
)
# NB dis.cmp_op[-1] is the constant 'BAD', which is invalid!
if cmpOp < 0 or cmpOp >= len(dis.cmp_op) - 1:
raise ValueError(f'Got bad comparison operation {cmpOp}')
incrementLine()
writer.write(bytes([_LOAD_FAST, index, _LOAD_CONST, opNum + 1, _COMPARE_OP, cmpOp]))
if opNum != len(comparisons) - 1:
writer.write(bytes([_JUMP_IF_FALSE_OR_POP, returnAddress]))
else:
incrementLine()
writer.write(bytes([_RETURN_VALUE, 0]))
bytecode = writer.getvalue()
lnotab = linenoWriter.getvalue()
# The CodeType class isn't properly documented, but its call syntax is defined by the function
# code_new() in https://github.com/python/cpython/blob/master/Objects/codeobject.c . That file
# also documents what legal values of the flags are. Note that the arguments are just the values
# of the fields of a code object: co_argcount, co_kwonlyargcount, co_nlocals, co_stacksize,
# co_flags, co_code, co_consts, co_names, co_varnames, co_filename, co_name, co_firstlineno,
# co_lnotab, and two optional arguments, co_freevars and co_cellvars.
# Some important but non-obvious reminders:
# - Arguments count as locals, so if co_nlocals < co_argcount + co_kwonlyargcount, very
# surprising things will happen to you.
# - co_varnames should have co_nlocals elements in it, or various debug operations may fail.
# - Many things in the CPython codebase assume that co_constants[0] is None.
# You can find useful tips for instantiating CodeType objects at
# https://stackoverflow.com/questions/16064409/how-to-create-a-code-object-in-python
code = CodeType(
numArguments, # Normal arguments
0, # kw-only arguments
numArguments, # local variables
2, # stack size
0, # flags
bytecode,
constants,
tuple(), # global variable names used; none.
tuple(f'arg{i}' for i in range(numArguments)), # names for our locals for debug
filename,
funcName,
0, # firstlineno
lnotab, # line number table
)
# The arguments to FunctionType are the CodeType object and the dict of globals which may be
# used by this function, which should have the same keys as the 'co_names' argument to CodeType.
# mypy doesn't seem to understand the correct arguments to this, though.
return FunctionType(code, {})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment