Skip to content

Instantly share code, notes, and snippets.

@agoose77
Last active May 27, 2022 16:14
Show Gist options
  • Save agoose77/ec37316c72214b47ebf06b1956a72072 to your computer and use it in GitHub Desktop.
Save agoose77/ec37316c72214b47ebf06b1956a72072 to your computer and use it in GitHub Desktop.
import numba
import ast
import awkward as ak
import inspect
from numba import literal_unroll
ak.numba.register()
class Transformer(ast.NodeTransformer):
def __init__(self, fields):
super().__init__()
self._fields = fields
def visit_Call(self, node):
if not isinstance(node.func, ast.Name):
return self.generic_visit(node)
if node.func.id != "fields_of":
return self.generic_visit(node)
if len(node.args) != 1:
return self.generic_visit(node)
if not isinstance(node.args[0], ast.Name):
return self.generic_visit(node)
fields = self._fields[node.args[0].id]
return ast.Call(
func=ast.Name(id="literal_unroll", ctx=ast.Load()),
args=[ast.Tuple([ast.Constant(n) for n in fields], ast.Load())],
keywords=[],
)
# Make a function that prints the field which equals 10 for each row
def process_events(events):
for i, row in enumerate(events):
for name in fields_of(events):
if row[name] == 10:
print(i, name)
# Factory function to patch fields
def rewrite_fields(func, table):
# Build AST
node = ast.parse(inspect.getsource(func))
# Transform AST to replace `fields_of(events)` with `("this", "that")`
transform = Transformer(table)
transform.visit(node)
# Fix AST
ast.fix_missing_locations(node)
# Compile and execute code
code = compile(node, filename="<internal>", mode="exec")
ns = globals().copy()
exec(code, ns)
return ns["process_events"]
# Generate jitted function
process_events_jit = numba.njit(
rewrite_fields(
process_events,
{
"events": ("this", "that")
}
)
)
# Create demo data
events = ak.zip({"this": [1, 2, 10], "that": [10, 0, 0]})
# Run!
process_events_jit(events)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment