Created
October 26, 2017 13:52
-
-
Save csm10495/6f267f2b07871bf0fd57b35b1448be06 to your computer and use it in GitHub Desktop.
Play around with Python's Abstract Syntax Tree, to see if we can say that a function will call a specific function (and what the args to that call would be).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
Brief: | |
ast_test.py - Quick and dirty test code of the Python abstract syntax tree. | |
The big thing is the ability to (sort of) see if a function calls another | |
specific function. | |
License: | |
MIT License - 2017 | |
Author(s): | |
Charles Machalow | |
''' | |
import ast | |
import inspect | |
import textwrap | |
def getSyntaxTree(func): | |
''' | |
returns a list of abstract syntax tree nodes for a given function. | |
''' | |
try: | |
return list(ast.walk(ast.parse(textwrap.dedent(inspect.getsource(func))))) | |
except TypeError as ex: | |
if 'builtin' in str(ex): | |
return None # can't decode a builtin... that is fine | |
raise ex | |
def keywordsToDict(keywords): | |
''' | |
takes a list of keyword objects and attempts to convert them into | |
a dictionary of key->value | |
''' | |
d = {} | |
for k in keywords: | |
try: | |
d[k.arg] = ast.literal_eval(k.value) | |
except ValueError: | |
d[k.arg] = k.value.id, None # variable name? | |
''' | |
if hasattr(k.value, 'n'): | |
d[k.arg] = k.value.n | |
elif hasattr(k.value, 's'): | |
d[k.arg] = k.value.s | |
elif hasattr(k.value, 'func'): | |
d[k.arg] = k.value.func.id + "(%s)" % (argsAndKeywordsToString(k.value.keywords, k.value.args)) | |
elif hasattr(k.value, 'keys') and hasattr(k.value, 'values'): | |
d[k.arg] = "<dict>" | |
elif hasattr(k.value, 'elts'): | |
d[k.arg] = '<list>' | |
elif hasattr(k.value, 'id'): | |
d[k.arg] = k.value.id # named variable? | |
else: | |
raise NotImplementedError("Huh?") | |
''' | |
return d | |
def argsAndKeywordsToString(keywords, theArgs): | |
''' | |
takes the keywords and args for a function call and converts them to a string. | |
''' | |
s = '' | |
for i in theArgs: | |
s+= str(ast.literal_eval(i)) + ', ' | |
kDict = keywordsToDict(keywords) | |
for key, value in kDict.items(): | |
if isinstance(value, str): | |
value = '\"%s\"' % value # add quotes | |
elif isinstance(value, tuple): | |
value = value[0] # ignore the others in the tuple | |
# the tuple means that this is a variable name not string | |
s += ('%s=%s, ' % (key, value)) | |
return s.rstrip(', ') | |
def getThisCallThatArgs(this, that, objs=None): | |
''' | |
check if this function calls that one. | |
objs is a list of common objects to scan through | |
Returns the list of args/kwargs as strings | |
''' | |
syntaxTree = getSyntaxTree(this) | |
retList = [] | |
if syntaxTree: | |
for thing in syntaxTree: | |
#if hasattr(thing, 'name') and thing.name == that.__name__: | |
# return True | |
if hasattr(thing, 'func') and hasattr(thing.func, 'id') and thing.func.id == that.__name__: | |
# thing is a Call ? | |
args = thing.args | |
kwargs = thing.keywords | |
retList.append(argsAndKeywordsToString(kwargs, args)) | |
elif hasattr(thing, 'func'): | |
if objs: | |
for obj in objs: | |
if hasattr(obj, thing.func.id): | |
retList += getThisCallThatArgs(getattr(obj, thing.func.id), that, objs) | |
else: | |
try: | |
retList += getThisCallThatArgs(eval(thing.func.id), that) | |
except NameError: | |
pass | |
return retList | |
def thisCallThat(this, that, objs=None, quiet=False): | |
''' | |
check if this function calls that one. | |
objs is a list of common objects to scan through | |
Returns the True if this calls that. | |
''' | |
retList = getThisCallThatArgs(this, that, objs) | |
if not quiet: | |
for i in retList: | |
print ("%s might call %s(%s)" % (this.__name__, that.__name__, i)) | |
return len(retList) > 0 | |
############# | |
# Test Code # | |
############# | |
def lowLevelFunction(id): | |
return True | |
def wrapper(id): | |
if False: | |
s = set() | |
return lowLevelFunction(id=s) | |
else: | |
return lowLevelFunction(id=[1,2,3]) | |
if __name__ == '__main__': | |
assert thisCallThat(wrapper, lowLevelFunction, quiet=True) | |
############# | |
# Test Code # | |
############# |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment