Skip to content

Instantly share code, notes, and snippets.

@asottile
Created June 16, 2016 16:11
Show Gist options
  • Save asottile/296cd1dcafc40d880426b9baafed5361 to your computer and use it in GitHub Desktop.
Save asottile/296cd1dcafc40d880426b9baafed5361 to your computer and use it in GitHub Desktop.
Check whether a yelp-cheetah template is trivially upgradable to 0.17.0
import argparse
import ast
import io
from Cheetah.compile import compile_source
from Cheetah.Template import Template
template_self_vars = {var for var in dir(Template) if not var.startswith('_')}
class TemplateVisitor(ast.NodeVisitor):
def __init__(self):
self.extends = None
self.self_attrs = set()
self.variables = set()
def visit_ClassDef(self, node):
class_attrs = [
subnode for subnode in node.body
if isinstance(subnode, ast.Assign)
]
for class_attr in class_attrs:
for target in class_attr.targets:
assert isinstance(target, ast.Name), ast.dump(target)
self.self_attrs.add(target.id)
self.generic_visit(node)
def visit_FunctionDef(self, node):
self.self_attrs.add(node.name)
self.generic_visit(node)
def visit_Call(self, node):
if (
isinstance(node.func, ast.Name) and
# Compat so you can run this on 0.16.1 or 0.17.0
(node.func.id == 'VFFSL' or node.func.id == 'VFFNS')
):
self.variables.add(node.args[0].s)
self.generic_visit(node)
def visit_ImportFrom(self, node):
if (
len(node.names) == 1 and
node.names[0].name == 'YelpCheetahTemplate' and
node.names[0].asname == 'YelpCheetahBaseClass'
):
self.extends = node.module
def _tmpl_name(modname):
return modname.replace('.', '/') + '.tmpl'
def _parse_filename(filename):
with io.open(filename) as f:
contents = f.read()
compiled_contents = compile_source(contents)
ast_obj = ast.parse(compiled_contents)
visitor = TemplateVisitor()
visitor.visit(ast_obj)
if visitor.extends == 'Cheetah.Template':
return template_self_vars | visitor.self_attrs, visitor.variables
elif visitor.extends == 'Cheetah.partial_template':
# Functions are relocated to globals so they aren't potentially self
return set(), visitor.variables - visitor.self_attrs
else:
# Process extends
parent_self_attrs, _ = _parse_filename(_tmpl_name(visitor.extends))
return visitor.self_attrs | parent_self_attrs, visitor.variables
def check_file(filename):
self_attrs, variables = _parse_filename(filename)
if variables & self_attrs:
print(filename)
print('=' * len(filename))
for variable in sorted(variables & self_attrs):
print(' - ${}'.format(variable))
return 1
elif False:
if variables:
print(filename)
print('=' * len(filename))
for variable in variables:
print(' - {}'.format(variable))
return 0
else:
return 0
def main():
parser = argparse.ArgumentParser()
parser.add_argument('filenames', nargs='*')
args = parser.parse_args()
retv = 0
for filename in args.filenames:
retv |= check_file(filename)
return retv
if __name__ == '__main__':
exit(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment