Skip to content

Instantly share code, notes, and snippets.

@Techcable
Last active July 13, 2025 02:27
Show Gist options
  • Select an option

  • Save Techcable/2d5762d23da60c4e7395358319c81526 to your computer and use it in GitHub Desktop.

Select an option

Save Techcable/2d5762d23da60c4e7395358319c81526 to your computer and use it in GitHub Desktop.
Java ASDL AST Tree Generator
import asdl
from contextlib import contextmanager
from abc import ABCMeta
from pathlib import Path
from argparse import ArgumentParser
_wrapper_types = {
"int": "Integer",
"boolean": "Boolean"
}
def get_java_type(name, optional=False, sequence=False, boxed=False):
if sequence:
return "List<{}>".format(get_java_type(name, optional=optional, boxed=True))
elif optional:
return "@Nullable {}".format(get_java_type(name, boxed=True))
elif boxed:
result = get_java_type(name)
if result in _wrapper_types:
# Box it so we can make it nullable
result = _wrapper_types[result]
return result
else:
try:
return {
"identifier": "String",
"string": "String",
"bytes": "String",
"int": "int",
"object": "Object",
"singleton": "boolean",
"constant": "Object"
}[name]
except KeyError:
return format_name("{}Node".format(name), firstUpper=True)
def format_name(name, firstUpper=False):
result = []
if firstUpper:
result.append(name[0].upper())
else:
result.append(name[0].lower())
i = iter(name[1:])
for c in i:
if c == '_':
result.append(next(i).upper())
else:
result.append(c)
return ''.join(result)
header = """/*
* Autogenerated with java_asdl.py
*/
package {package_name};
import java.util.List;
import java.util.Objects;
import javax.annotation.Nullable;
"""
class JavaPackage:
def __init__(self, name, folder):
self.folder = Path(folder)
self.name = name
@contextmanager
def create(self, name, includeHeader=True):
file = Path(self.folder, name + ".java")
file.parent.mkdir(parents=True, exist_ok=True)
with file.open('wt+') as output:
if includeHeader:
output.write(header.format(package_name=self.name))
yield JavaWriter(output)
class JavaWriter:
def __init__(self, output):
self.output = output
self.current_indentation = 0
@contextmanager
def indent(self):
self.current_indentation += 4
yield
self.current_indentation -= 4
def write(self, line=None):
if line is not None:
self.write_lines([line])
else:
self.output.write('\n')
def write_lines(self, lines):
indent = ' ' * self.current_indentation
output = self.output
for line in lines:
assert not line.isspace()
output.write(indent)
output.write(line)
output.write('\n')
def generate_body(self, className, fields, super_fields=()):
field_names = [format_name(field.name) for field in fields]
for field in fields:
if field.opt:
self.write("@Nullable")
self.write("private final {} {};".format(
get_java_type(field.type, boxed=field.opt, sequence=field.seq),
format_name(field.name)
))
self.write("public {}({}) {{".format(
className,
', '.join(format_fields(super_fields) + format_fields(fields))
))
with self.indent():
self.write("super({});".format(
', '.join(format_name(field.name) for field in super_fields)
))
for name in field_names:
self.write("this.{} = {};".format(name, name))
self.write("}")
for field in fields:
if field.opt:
self.write("@Nullable")
self.write("public {} get{}() {{".format(
get_java_type(field.type, boxed=field.opt, sequence=field.seq),
format_name(field.name, firstUpper=True)
))
with self.indent():
self.write("return this.{};".format(
format_name(field.name)
))
self.write("}")
def generate_tostring(self, name, fields):
self.write("@Override")
self.write("public String toString() {")
with self.indent():
self.write("// python-style ast.dump()")
self.write("StringBuilder builder = new StringBuilder();");
self.write('builder.append("{}(");'.format(name))
initial = True
for field in fields:
if not initial:
self.write('builder.append(", ");')
self.write('builder.append("{}=");'.format(field.name))
getter = "this.get{}()".format(format_name(field.name, firstUpper=True))
if field.seq:
self.write('builder.append("[");')
self.write('builder.append(String.join(", ", {}));'.format(getter))
self.write('builder.append("]");')
elif field.opt:
self.write('builder.append(Objects.toString({}, "None"));'.format(getter))
else:
self.write('builder.append({});'.format(getter))
initial = False
self.write('builder.append(")");')
self.write('return builder.toString();')
self.write("}")
class ASDLVisitor(asdl.VisitorBase, metaclass=ABCMeta):
def visitModule(self, mod):
for df in mod.dfns:
self.visit(df)
def visitSum(self, s, name):
for tp in s.types:
self.visit(tp)
def visitType(self, tp):
self.visit(tp.value, tp.name)
def visitProduct(self, prod, name):
for field in prod.fields:
self.visit(field)
def visitConstructor(self, cons):
for field in cons.fields:
self.visit(field)
def visitField(self, field):
pass
def is_simple_sum(s: asdl.Sum):
for constructor in s.types:
if constructor.fields:
return False
return True
def format_fields(fields, prefix=""):
field_names = [format_name(field.name) for field in fields]
field_types = [get_java_type(field.type, optional=field.opt, sequence=field.seq) for field in fields]
return tuple(map((prefix + "{} {}").format, field_types, field_names))
class JavaASTVisitorGenerator(ASDLVisitor):
def __init__(self, output: JavaWriter):
super().__init__()
self.output = output
def start(self):
self.output.write("public interface ASTVisitor<T> {")
def end(self):
self.output.write("}")
def visitSum(self, s, name):
output = self.output
if is_simple_sum(s):
self.generate_visitor(name)
else:
for cons in s.types:
self.visit(cons, get_java_type(name), s.attributes)
def visitConstructor(self, cons, superType, attributes):
self.generate_visitor(cons.name, cons.fields)
def generate_visitor(self, name, fields=()):
output = self.output
output.write("default T visit{}({} node) {{".format(
format_name(name, firstUpper=True),
get_java_type(name)
))
with output.indent():
for field in fields:
if not get_java_type(field.type).endswith("Node"):
continue # Ignore primitives
getter = "node.get{}()".format(format_name(field.name, firstUpper=True))
if field.seq:
output.write("for ({} element : {}) {{".format(
get_java_type(field.type, boxed=True),
getter
))
with output.indent():
output.write("element.visit(this);")
output.write("}")
elif field.opt:
output.write("if ({} != null) {{".format(getter))
with output.indent():
output.write("{}.visit(this);".format(getter))
output.write("}")
else:
output.write("{}.visit(this);".format(getter))
output.write("return null;")
output.write("}")
def visitProduct(self, product, name):
self.generate_visitor(name, product.fields)
class JavaASDLVisitor(ASDLVisitor):
def __init__(self, package: JavaPackage):
super().__init__()
self.package = package
def visitSum(self, s, name):
if is_simple_sum(s):
with self.package.create(get_java_type(name)) as output:
output.write("public enum {} implements ASTNode {{".format(get_java_type(name)))
with output.indent():
for cons in s.types:
output.write("{},".format(cons.name.upper()))
output.write("@Override")
output.write("public <T> T visit(ASTVisitor<T> visitor) {")
with output.indent():
output.write("return visitor.visit{}(this);".format(format_name(name, firstUpper=True)))
output.write("}")
output.write("}")
else:
typeName = get_java_type(name)
if typeName in [get_java_type(type) for type in s.types]:
# Prevent conflicts with subtypes by prepending 'Abstract' to our typeName
typeName = "Abstract" + typeName
with self.package.create(typeName) as output:
output.write("public abstract class {} implements ASTNode {{".format(
typeName
))
with output.indent():
output.generate_body(
typeName,
s.attributes
)
output.write("}")
for cons in s.types:
self.visit(cons, typeName, s.attributes)
def visitConstructor(self, cons, superType, attributes):
with self.package.create(get_java_type(cons.name)) as output:
output.write("public class {} extends {} {{".format(
get_java_type(cons.name),
superType
))
with output.indent():
output.generate_body(
get_java_type(cons.name),
cons.fields,
attributes
)
output.generate_tostring(cons.name, cons.fields + attributes)
output.write("@Override")
output.write("public <T> T visit(ASTVisitor<T> visitor) {")
with output.indent():
output.write("return visitor.visit{}(this);".format(format_name(cons.name, firstUpper=True)))
output.write("}")
output.write("}")
def visitProduct(self, product, name):
fields = tuple(product.attributes) + tuple(product.fields)
with self.package.create(get_java_type(name)) as output:
output.write("public class {} implements ASTNode {{".format(
get_java_type(name),
))
with output.indent():
output.generate_body(
get_java_type(name),
fields
)
output.generate_tostring(name, fields)
output.write("@Override")
output.write("public <T> T visit(ASTVisitor<T> visitor) {")
with output.indent():
output.write("return visitor.visit{}(this);".format(format_name(name, firstUpper=True)))
output.write("}")
output.write("}")
from sys import argv, stderr, exit
def main():
parser = ArgumentParser(description="Generates java AST nodes from an ASDL file")
parser.add_argument("--output", "-o", help="The directory to output the sources in", type=Path, default=Path("generated_ast"))
parser.add_argument("asdl_file", help="The ASDL file to use")
parser.add_argument("package", help="The package to generate the AST in")
args = parser.parse_args()
package = JavaPackage(args.package, Path(args.output, args.package.replace('.', '/')))
asdl_module = asdl.parse(args.asdl_file)
with package.create("ASTNode") as file:
file.write_lines((
"public interface ASTNode {",
"<T> T visit(ASTVisitor<T> visitor);",
"}"
))
with package.create("ASTVisitor") as output:
visitor = JavaASTVisitorGenerator(output)
visitor.start()
with output.indent():
visitor.visit(asdl_module)
visitor.end()
with package.create("package-info", includeHeader=False) as output:
output.write_lines((
"/*",
" * Autogenerated AST nodes, generated by java_asdl.py",
" */",
"package {};".format(package.name)
))
visitor = JavaASDLVisitor(package)
visitor.visit(asdl_module)
if __name__ == "__main__":
main()
@abtExp
Copy link

abtExp commented Sep 15, 2021

hi, thanks for this, i was looking for this for so long, can you please share a sample asdl (grammar) txt file if you have one?

@XiaoXiaoYi123
Copy link

hi, thanks for this, i was looking for this for so long, can you please share a sample asdl (grammar) txt file if you have one?

Me, too. Thanks very much

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment