Created
April 10, 2015 19:17
-
-
Save tjvr/b61eb61ac6689d070b7d to your computer and use it in GitHub Desktop.
Lately: an Earley parser in Python 3
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import itertools | |
import regex | |
# Lexer results | |
class Token: | |
def __init__(self, kind, value=None): | |
self.kind = kind | |
self.value = value | |
def __repr__(self): | |
r = "Token({}".format(repr(self.kind)) | |
if self.value is not None: r += ", {}".format(repr(self.value)) | |
r += ")" | |
return r | |
def __eq__(self, other): | |
return ( | |
isinstance(other, Token) and | |
self.kind == other.kind and | |
self.value == other.value | |
) | |
def __ne__(self, other): | |
return not self == other | |
def match(self, token): | |
return ( | |
isinstance(token, Token) and | |
self.kind == token.kind and | |
(self.value is None or self.value == token.value) | |
) | |
def transform(self, token): | |
return token.value or token.kind | |
if self.value is None: | |
if token.value is not None: | |
return token.value | |
return Ugly(token.kind) | |
def stringify(self): | |
return "«{}»".format(self.kind) | |
class Ugly: | |
def __init__(self, value): | |
self.value = value | |
def __repr__(self): | |
return "<{}>".format(str(self.value)) | |
class RawToken(Token): | |
def transform(self, token): | |
return token | |
# Lexer | |
class Terminal: | |
def __init__(self, kind, pattern=None): #, process=None): | |
self.kind = kind | |
self.pattern = getattr(pattern, "pattern", pattern) | |
# assert process is None or callable(process) | |
# self.process = process | |
self._re = regex.compile(pattern) | |
if self._re.match(""): | |
raise ValueError("pattern must not match empty string") | |
if kind == "": | |
raise ValueError("kind must not be empty") | |
self.groups = self._re.groups | |
def __repr__(self): | |
r = "Terminal({}, {}".format(repr(self.kind), repr(self.pattern)) | |
# if self.process: r += ", {}".format(repr(self.process)) | |
r += ")" | |
return r | |
def token(self, value): | |
# if self.process: | |
# value = self.process(value) | |
return Token(self.kind, value) | |
class Literal(Terminal): | |
def __init__(self, value): #, process=None): | |
Terminal.__init__(self, value, regex.escape(value)) #, process) | |
def __repr__(self): | |
r = "Literal({}".format(repr(self.kind)) | |
# if self.process: r += ", {}".format(repr(self.process)) | |
r += ")" | |
return r | |
def token(self, value): | |
# if self.process: | |
# value = self.process(value) | |
return Token(self.kind) | |
class Lexer: | |
def __init__(self, terminals): | |
self.terminals = terminals | |
self.master_regex = regex.compile( | |
"|".join("({})".format(t.pattern) for t in terminals) | |
) | |
self.group_indexes = [] | |
index = 0 | |
for terminal in terminals: | |
if terminal.groups: | |
terminal._index = index + 1 | |
else: | |
terminal._index = index | |
index += terminal.groups + 1 | |
self._terminals_by_kind = {} | |
for terminal in self.terminals: | |
kind = terminal.kind | |
if kind not in self._terminals_by_kind: | |
self._terminals_by_kind[kind] = [] | |
self._terminals_by_kind[kind].append(terminal) | |
def __repr__(self): | |
return "Lexer([\n{}])".format( | |
"".join(" {},\n".format(repr(t)) for t in self.terminals) | |
) | |
def tokenize(self, source): | |
return list(self.generate_tokens(source)) | |
def generate_tokens(self, source): | |
remaining = str(source) | |
regex = self.master_regex | |
terminals = self.terminals | |
group_indexes = self.group_indexes | |
while remaining: | |
m = regex.match(remaining) | |
if not m or m.end() == 0: | |
raise TokenizeError(remaining) | |
groups = m.groups() | |
index = 0 | |
for terminal in terminals: | |
index = terminal._index | |
if groups[index] is not None: | |
# if terminal.groups: XXX | |
# value = groups[index + 1] | |
# else: | |
value = groups[index] | |
token = terminal.token(value) | |
if terminal.kind is not None: | |
yield token | |
break | |
remaining = remaining[m.end():] | |
def _is_valid_token(self, token): | |
for terminal in self._terminals_by_kind[kind]: | |
if terminal._re.fullmatch(token.value): | |
return token.value | |
def _build(self, grammar_ways): | |
for tokens in grammar_ways: | |
for token in tokens: | |
if not self._is_valid_token(token): | |
break | |
else: | |
likely_separators = (" ",) | |
for sep in likely_separators: | |
if self._is_valid_token(Token(None, sep)): | |
break | |
else: | |
sep = "" | |
yield sep.join(t.value for t in tokens) | |
class TokenizeError(Exception): | |
"""Raised when lexing encounters an invalid character sequence.""" | |
pass | |
# Grammar | |
def is_rule_name(s): | |
"""Returns True if a symbol represents a rule name.""" | |
return isinstance(s, str) | |
class Rule: | |
"""A set of symbols making up a nonterminal.""" | |
def __init__(self, name, symbols, process): | |
"""Name used by other rules to refer to this nonterminal.""" | |
self.name = name | |
assert name is not None | |
"""List of terminals and nonterminals (rule names).""" | |
self.symbols = symbols | |
"""Function to run on the resulting list of tokens and node values.""" | |
self.process = process | |
assert callable(process) | |
def __repr__(self): | |
return "Rule({}, {}, {})".format( | |
repr(self.name), | |
repr(self.symbols), | |
repr(self.process), | |
) | |
def stringify(self): | |
return "{} → {} {{% {} %}}".format( | |
self.name, | |
" ".join((s.stringify() if hasattr(s, "stringify") else s) | |
for s in self.symbols), | |
repr(self.process), | |
) | |
class Grammar: | |
"""A collection of rules defining a language.""" | |
def __init__(self, rules): | |
self.toplevel = rules[0].name if rules else None | |
self._rules = [] | |
self._rules_by_name = {} | |
for rule in rules: | |
self.add_rule(rule) | |
def __repr__(self): | |
return "Grammar([\n{}])".format( | |
"".join(" {},\n".format(repr(rule)) for rule in self._rules) | |
) | |
def stringify(self): | |
return "\n".join(("" if r is None else r.stringify()) | |
for r in self._rules) | |
def add_rule(self, rule): | |
if rule is not None: | |
name = rule.name | |
if name not in self._rules_by_name: | |
self._rules_by_name[name] = [] | |
self._rules_by_name[name].append(rule) | |
self._rules.append(rule) | |
def add_rules(self, rules): | |
for rule in rules: | |
self.add_rule(rule) | |
def remove_rule(self, rule): | |
assert rule is not None | |
name = rule.name | |
self._rules_by_name[name].remove(rule) | |
if not self._rules_by_name[name]: | |
self._rules_by_name.pop(name) | |
self._rules.remove(rule) | |
def remove_rules(self, rules): | |
for rule in rules: | |
self.remove_rule(rule) | |
@property | |
def rules(self): | |
return list(r for r in self._rules if r is not None) | |
def copy(self): | |
g = Grammar() | |
g._rules_by_name = self._rules_by_name.copy() | |
g._rules = list(self._rules) | |
g.toplevel = self.toplevel | |
# Magic parser | |
@classmethod | |
def from_bnf(self, source): | |
return parse_bnf(source) # TODO | |
# Sugar | |
def _parser(self, tokens): | |
p = Parser(self) | |
p.feed(tokens) | |
p.finish() | |
return p | |
def parse(self, tokens): | |
return self._parser(tokens).result() | |
def all_parses(self, tokens): | |
return list(self._parser(tokens).results()) | |
# Magic builder | |
def build(self, lexer, value): | |
ways = self._build(self.toplevel, value) | |
return lexer._build(ways) | |
def _build(self, rule_name, value): | |
if isinstance(rule_name, Token): | |
if not isinstance(value, str): | |
yield [Token(rule_name.kind, value)] | |
return | |
rules = self._rules_by_name[rule_name] | |
for rule in rules: | |
try: | |
parts = rule.process.build(value) | |
except AttributeError: | |
continue | |
ways = [[]] | |
for symbol, part in zip(rule.symbols, parts): | |
uh = list(self._build(symbol, part)) | |
new_ways = [] | |
for tokens in ways: | |
for x in uh: | |
new_ways.append(tokens + x) | |
ways = new_ways | |
for tokens in ways: | |
yield "".join(tokens) | |
# Parser | |
class State: | |
"""A rule with a starting point in the input stream.""" | |
def __init__(self, rule, origin, position=0, value=None): | |
self.rule = rule | |
"""The start index of this state in the input stream.""" | |
self.origin = origin | |
"""The index reached so far in `rule.symbols`.""" | |
self.position = position | |
"""List of resulting values and tokens.""" | |
self.value = value or [] | |
self.is_complete = (position == len(rule.symbols)) | |
self.expect = None if self.is_complete else rule.symbols[position] | |
def __repr__(self): | |
words = [self.rule.name, "→"] | |
for index, symbol in enumerate(self.rule.symbols): | |
if index == self.position: | |
words.append("•") | |
if is_rule_name(symbol): | |
words.append(symbol) | |
# elif isinstance(symbol, Matcher): | |
# words.append(repr(symbol.kind)) | |
else: | |
words.append(repr(symbol)) | |
if self.is_complete: | |
words.append("•") | |
words.append(" <{}>".format(id(self))) | |
return " ".join(words) | |
def consume(self, value): | |
"""Return new state after consuming one token or nonterminal. | |
Called when advancing and completing, respectively. | |
""" | |
new_value = list(self.value) | |
new_value.append(value) | |
s = State(self.rule, self.origin, self.position + 1, new_value) | |
if s.is_complete: | |
try: | |
s.value = s.rule.process(*s.value) | |
except Exception as e: | |
message = ("Error processing rule {}\nat {}\nwith value {}" | |
.format(repr(s.rule), repr(s), repr(s.value))) | |
raise ProcessError(message) from e | |
return s | |
class Parser: | |
def __init__(self, grammar): | |
self.grammar = grammar | |
self.table = [] | |
self.tokens = [] | |
self.line_number = None | |
def feed(self, tokens): | |
table = self.table | |
grammar = self.grammar | |
nullable_rules = {} | |
for name, rules in grammar._rules_by_name.items(): | |
for r in rules: | |
if r.symbols: | |
continue | |
if name not in nullable_rules: | |
nullable_rules[name] = [] | |
nullable_rules[name].append(r) | |
start_index = len(table) # The index where feed() started | |
if start_index == 0: | |
rules = grammar._rules_by_name[grammar.toplevel] | |
column = list(map(lambda rule: State(rule, origin=0), rules)) | |
table.append(column) | |
else: # continue parsing | |
column = self.table[-1] | |
for index, token in enumerate(tokens): | |
new_column = [] | |
try: | |
self.line_number = token.line_number | |
except AttributeError: | |
pass | |
predicted_rules = set() | |
if index == 0: | |
predicted_rules.add(grammar.toplevel) | |
for state in column: # nb. column grows while iterating | |
if not state.is_complete: | |
expect = state.expect # state.rule.symbols[state.position] | |
if is_rule_name(expect): | |
# Predict: add component states | |
if expect not in predicted_rules: | |
rules = grammar._rules_by_name.get(expect, []) | |
if not rules: | |
self.do_raise(UndefinedRuleError(expect)) | |
for rule in rules: | |
column.append(State(rule, origin=index)) | |
predicted_rules.add(expect) | |
# Magical completion to fix nullables | |
rules = nullable_rules.get(expect, []) | |
for rule in rules: | |
value = rule.process() | |
column.append(state.consume(value)) | |
else: | |
if isinstance(token, EOFToken): continue | |
# Advance: consume token | |
if expect.match(token): | |
value = expect.transform(token) | |
new_column.append(state.consume(value)) | |
else: | |
# Complete: progress earlier states | |
name, value = state.rule.name, state.value | |
old_column = table[state.origin] # ✝ this will break!! | |
for other in old_column: | |
expect = other.expect | |
if is_rule_name(expect) and expect == name: | |
new_state = other.consume(value) | |
column.append(new_state) | |
if isinstance(token, EOFToken): | |
assert not new_column | |
return | |
if not new_column: | |
message = "Unexpected {}".format(repr(token)) | |
self.do_raise(ParseError(message)) | |
#print("\n".join(repr(s) for s in column)) | |
#print() | |
self.tokens.append(token) | |
self.table.append(new_column) | |
# hold on, if this column isn't already appended to the table | |
# then completing indirectly-nullable things will bug out | |
# because we'll complete a state in the same column as its origin | |
# and line ✝ will fail | |
column = new_column | |
def finish(self): | |
self.feed([EOFToken()]) | |
def do_raise(self, exception): | |
"""Raise exception, adding parser debug information.""" | |
index = len(self.table) | |
exception.index = index | |
args = list(exception.args) | |
args[0] = "{} at {}".format(args[0], index) | |
if self.line_number: | |
args[0] += " on line {}".format(self.line_number) | |
exception.args = tuple(args) | |
raise exception | |
def result(self): | |
gen_results = self.results() | |
result = next(gen_results) | |
try: | |
next(gen_results) | |
except StopIteration: | |
return result | |
raise AmbiguityError() | |
def results(self): | |
toplevel = self.grammar.toplevel | |
column = self.table[-1] | |
count = 0 | |
for state in column: | |
if not state.is_complete: continue | |
if state.origin > 0: continue | |
if state.rule.name == toplevel: | |
yield state.value | |
count += 1 | |
if not count: | |
self.do_raise(ParseError("No complete parses")) | |
class EOFToken: | |
def __repr__(self): | |
return "EOFToken()" | |
class ParseError(Exception): | |
"""Raised when parsing results in a syntax error.""" | |
pass | |
class AmbiguityError(Exception): | |
"""Raised by result() when there is more than one result.""" | |
pass | |
class ProcessError(Exception): | |
"""Raised when a rule processor function raises an error.""" | |
pass | |
class UndefinedRuleError(Exception): | |
"""Raised when parsing encounters a rule that isn't defined.""" | |
pass | |
from . import util | |
from . import mixfix |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from . import Lexer, Pattern, Literal, Newline | |
from . import Rule, Grammar | |
from .util import * | |
l = Lexer([ | |
Newline("nl", '\n'), | |
Pattern(None, r' +'), | |
Pattern("python_code", "{% *(.?([^%][^}])+) *%}"), | |
Literal('::='), | |
Literal('|'), | |
Pattern("nonterminal", r'\<([^> ]+)\>'), | |
Literal("["), Literal("]"), # optional | |
Literal("{"), Literal("}"), # repeated | |
Pattern("terminal_char", r'"(.)"'), | |
Pattern("symbol", r'([^> ]+)'), | |
]) | |
def name_rules(name, _, rules): | |
return [Rule(name, s, p) for (s, p) in rules] | |
BNF_grammar = Grammar([ | |
Rule("grammar", ["rules"], Grammar), | |
Rule("rules", ["rules", "newline", "rule"], extend2), | |
Rule("rules", ["rule"], identity), # nb. rule is a *list*! | |
Rule("rule", ["name", "::=", "alternatives"], name_rules), | |
Rule("name", ["symbol"], literal), | |
Rule("alternatives", ["alternatives", "OR", "alt"], push2), | |
Rule("alternatives", ["alt"], box), | |
Rule("OR", [l.token("|")], ignore), | |
Rule("OR", ["newline", "|"], ignore), | |
Rule("alt", ["sequence", "python"], tupleify), | |
Rule("alt", ["sequence"], lambda s: (s, tupleify)), | |
Rule("sequence", ["sequence", "item"], push), | |
Rule("sequence", ["item"], box), | |
Rule("item", [l.token("nonterminal")], literal), | |
Rule("item", [l.token("terminal_char")], literal), | |
Rule("python", ["python_code"], eval), | |
Rule("newline", ["newline", l.token("nl")], ignore), | |
Rule("newline", [l.token("nl")], ignore), | |
]) | |
BNF_grammar.add_rules(BNF_lexer.grammar_rules()) | |
def parse_bnf(source): | |
source = source.strip() | |
tokens = BNF_lexer.tokenize(source) | |
grammar = BNF_grammar.parse(tokens) | |
return grammar | |
# http://cui.unige.ch/db-research/Enseignement/analyseinfo/AboutBNF.html | |
bootstrap = """ | |
syntax ::= { rule } | |
rule ::= identifier "::=" expression | |
expression ::= term { "|" term } | |
term ::= factor { factor } | |
factor ::= identifier | | |
quoted_symbol | | |
"(" expression ")" | # ??? | |
"[" expression "]" | # optional | |
"{" expression "}" # repetitive | |
identifier ::= letter { letter | digit } | |
quoted_symbol ::= '"' { any_character } '"' | |
""" | |
__all__ = ["parse_bnf"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""For mixfix expression parsing! | |
Usage: | |
>>> l, g = lately.mixfix.make(""" | |
... | |
... _ "↔" _ | |
... | |
... _ "→" _ | |
... | |
... _ "∧" _ | |
... | |
... _ "∨" _ | |
... | |
... "¬" _ | |
... | |
... "(" _ ")" | |
... /[A-Z]+/ | |
... | |
... """) | |
>>> n = g.parse(l.tokenize('¬P ∧ Q ∨ R')) | |
>>> print(n.repr_p()) | |
(¬('P') ∧ ('Q' ∨ 'R')) | |
""" | |
import regex | |
from string import ascii_uppercase | |
from . import Lexer, Terminal, Literal | |
from . import Rule, Grammar, Token | |
from .util import * | |
def aux(index): | |
r = '' | |
while index > 0: | |
r = ascii_uppercase[(index - 1) % 26] + r | |
index = (index - 1) // 26 | |
return r | |
class Slot: | |
def __init__(self, _): | |
pass | |
def __repr__(self): | |
return "Slot()" | |
mixfix_lexer = Lexer([ | |
Terminal(None, r' +'), | |
Terminal("terminal", r'"((?:\\[\'"\\]|[^"])*)"'), | |
Terminal("regex", r'/((?:\\[/\\]|[^/])*)/'), | |
#Terminal("label", r'{[a-z]+}'), | |
#Terminal("name", r'[A-Z]+'), | |
Literal("("), Literal(")"), Literal("*"), | |
Literal("_"), | |
Literal("\n"), | |
Terminal("terminal", r'[^ \n]+'), | |
]) | |
mixfix_grammar = Grammar([ | |
Rule("file", ["level"], box), | |
Rule("file", ["file", "NLs", "level"], push2), | |
Rule("NLs", [Token("\n"), Token("\n")], ignore), | |
Rule("NLs", ["NLs", Token("\n")], ignore), | |
Rule("level", ["line"], box), | |
Rule("level", ["level", Token("\n"), "line"], push2), | |
Rule("line", ["symbols"], identity), | |
Rule("line", ["terminal"], identity), | |
Rule("terminal", [Token("regex")], regex.compile), | |
Rule("symbols", ["s"], box), | |
Rule("symbols", ["symbols", "s"], push), | |
Rule("s", [Token("_")], Slot), | |
Rule("s", ["s", Token("*")], postfix("*")), | |
Rule("s", ["s", Token("+")], postfix("+")), | |
Rule("s", [Token("terminal")], Literal), | |
Rule("s", [Token("("), "symbols", Token(")")], brackets), | |
]) | |
class mixfix(Builder): | |
"""Create a function f(*args) which wraps `cls`.""" | |
def __init__(self, kind, slots): | |
summary = "".join(("_" if is_slot else "X") for is_slot in slots) | |
self.style = None | |
if summary == "_X_": | |
self.style = 'infix' | |
elif summary == "X_": | |
self.style == 'prefix' | |
elif summary == "_X": | |
self.style == 'postfix' | |
Builder.__init__(self, kind) | |
self.slots = slots | |
def __repr__(self): | |
return "<mixfix>" #({}, {})>".format(repr(self.kind), repr(self.slots)) | |
def __call__(self, *values): | |
assert len(values) == len(self.slots) | |
args = [] | |
for arg, is_slot in zip(values, self.slots): | |
if is_slot: | |
args.append(arg) | |
#if len(args) == 1 and isinstance(args[0], list): | |
# args = args[0] | |
return self.cls(*args) | |
def build(self, node): | |
assert self.kind | |
args = list(node.args) | |
values = [self.kind] | |
for is_slot in self.slots: | |
if is_slot: | |
values.append(Token("_")) | |
return values | |
def make(text): | |
tokens = mixfix_lexer.tokenize(text.strip()) | |
levels = mixfix_grammar.parse(tokens) | |
num_terminals = 0 | |
rules = [] | |
terminals = [ | |
Terminal(None, r' +'), | |
] | |
toplevel = None | |
for level, symbols_or_terminal in enumerate(levels, 1): | |
level_name = "{}".format(level) | |
toplevel = toplevel or level_name | |
nextlevel = "{}".format(level + 1) | |
nextlevel_used = False | |
highest_aux_index = 0 | |
def build_symbols(name, thing, sub=False): | |
nonlocal highest_aux_index | |
if len(thing) == 1 and isinstance(thing[0], list): | |
thing = thing[0] | |
node_words = [] | |
node_slots = [] | |
for sym in thing: | |
is_word = isinstance(sym, Literal) | |
node_slots.append(not is_word) | |
if is_word: | |
node_words.append(sym.kind) | |
#else: | |
# node_words.append("_") | |
if sub and node_slots == [True]: | |
builder = identity | |
elif sub and len([is_slot for is_slot in node_slots if is_slot]): | |
builder = nth(node_slots.index(True)) | |
else: | |
builder = mixfix(" ".join(node_words), node_slots) | |
is_infix = (node_slots == [True, False, True]) | |
is_prefix = (node_slots == [False, True]) | |
symbols = [] | |
aux_rules = [] | |
for index, sym in enumerate(thing): | |
if isinstance(sym, Literal): | |
terminals.append(sym) | |
sym = Token(sym.kind) | |
elif isinstance(sym, Slot): | |
if is_infix: | |
if index == 2: | |
sym = nextlevel | |
nextlevel_used = True | |
else: | |
sym = level_name | |
elif is_prefix: | |
if index == 1: | |
sym = nextlevel | |
nextlevel_used = True | |
else: | |
sym = level_name | |
else: | |
if index == 0 or index == len(thing) - 1: | |
sym = level_name | |
else: | |
sym = toplevel | |
elif isinstance(sym, Node): | |
highest_aux_index += 1 | |
w = aux(highest_aux_index) | |
item_name = "{}-{}".format(level_name, w) | |
if sym.name == "*": | |
aux_name = "{}_star".format(item_name) | |
aux_rules.append(Rule(aux_name, [], empty)) | |
elif sym.name == "+": | |
aux_name = "{}_plus".format(item_name) | |
aux_rules.append(Rule(aux_name, [item_name], box)) | |
aux_rules.append(Rule(aux_name, [aux_name, item_name], push)) | |
assert len(sym.args) == 1 | |
aux_rules.extend( | |
build_symbols(item_name, sym.args, sub=True) | |
) | |
sym = aux_name | |
elif isinstance(sym, list): | |
highest_aux_index += 1 | |
w = aux(highest_aux_index) | |
aux_name = "{}-{}".format(level_name, w) | |
aux_rules.extend( | |
build_symbols(aux_name, sym, sub=True) | |
) | |
sym = aux_name | |
symbols.append(sym) | |
return [Rule(name, symbols, builder)] + aux_rules | |
for thing in symbols_or_terminal: | |
if hasattr(thing, 'subn'): # Regex | |
num_terminals += 1 | |
kind = "t{}".format(num_terminals) | |
terminals.append(Terminal(kind, thing.pattern)) | |
rules.append(Rule(level_name, [Token(kind)], identity)) | |
else: | |
sym_rules = build_symbols(level_name, thing) | |
rules.extend(sym_rules) | |
if len(sym_rules) > 1: | |
rules.append(None) | |
if level < len(levels): | |
rules.append(Rule(level_name, [nextlevel], identity)) | |
rules.append(None) | |
rules.append(None) | |
l = Lexer(terminals) | |
g = Grammar(rules) | |
if nextlevel_used and not g.get(nextlevel): | |
g.add_rule(Rule(nextlevel, [level_name], identity)) | |
return l, g |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Useful functions for constructing parse trees.""" | |
from functools import partial | |
from pprint import pformat | |
class PrettyFunction: | |
"""Magic to make the repr for the helper functions more readable.""" | |
def __init__(self, pretty, func): | |
self.pretty = pretty | |
self.func = func | |
def __call__(self, *args): | |
return self.func(*args) | |
def __repr__(self): | |
return self.pretty | |
def pretty(func): | |
pretty = "{}.{}".format(func.__module__, func.__name__) | |
return PrettyFunction(pretty, func) | |
def pretty_wrapper(func): | |
def wrapped(*args): | |
pretty = "{}.{}({})".format( | |
func.__module__, | |
func.__name__, | |
", ".join(map(repr, args)), | |
) | |
f = func(*args) | |
p = PrettyFunction(pretty, f) | |
p.from_func = func | |
p.from_args = args | |
return p | |
wrapped.func = func | |
return wrapped | |
# ----------------------------------------------------------------------------- | |
# Lexer helpers | |
class int_or_float: | |
__name__ = "int_or_float" | |
def __call__(self, x): | |
try: | |
return int(x) | |
except ValueError: | |
return float(x) | |
def build(self, x): | |
return [str(x)] | |
int_or_float = pretty(int_or_float()) | |
# ----------------------------------------------------------------------------- | |
# Grammar helpers | |
def nth(index, name): | |
def f(*args): | |
return args[index] | |
pretty = "{}.{}".format(f.__module__, name) | |
return PrettyFunction(pretty, f) | |
nth_words = ["first", "second", "third", "fourth", "fifth"] | |
for index, name in enumerate(nth_words): | |
locals()[name] = nth(index, name) | |
del nth, nth_words, index, name | |
@pretty_wrapper | |
def nth(index): | |
def f(*args): | |
return args[index] | |
return f | |
class identity: | |
def __call__(self, x): | |
return x | |
def build(self, x): | |
return [x] | |
def __repr__(self): | |
return "{}.{}".format(self.__module__, "identity") | |
identity = identity() | |
@pretty | |
def tupleify(*args): | |
"""Returns argument-tuple.""" | |
return tuple(args) | |
@pretty | |
def ignore(*args): | |
return None | |
@pretty_wrapper | |
def constant(value): | |
def f(_): | |
return value | |
return f | |
# Lists | |
@pretty | |
def empty(): | |
"""Returns empty list""" | |
return [] | |
@pretty | |
def box(x): | |
"""Returns single-item list""" | |
return [x] | |
@pretty | |
def push(l, item): | |
"""Push item onto end of list l""" | |
l = list(l) | |
l.append(item) | |
return l | |
@pretty | |
def push2(l, _, item): | |
"""Push item onto end of list l. Useful when _ is a separator""" | |
l = list(l) | |
l.append(item) | |
return l | |
@pretty | |
def cons(item, l): | |
"""Add item to beginning of list l.""" | |
l = list(l) | |
l.insert(0, item) | |
return l | |
@pretty | |
def cons2(item, _, l): | |
"""Add item to beginning of list l. Useful when _ is a separator""" | |
l = list(l) | |
l.insert(0, item) | |
return l | |
@pretty | |
def cons_and_push(front, l, back): | |
"""Add `front` to beginning and `back` to end of list l.""" | |
l = list(l) | |
l.insert(0, front) | |
l.append(back) | |
return l | |
@pretty | |
def extend(l, right): | |
"""Concatenate `l` and `right` together.""" | |
l = list(l) | |
l.extend(right) | |
return l | |
@pretty | |
def extend2(l, _, right): | |
"""Concatenate `l` and `right` together. Useful when _ is a separator""" | |
l = list(l) | |
l.extend(right) | |
return l | |
# Syntax | |
@pretty | |
def brackets(_, x, _2): | |
"""Returns second argument.""" | |
return x | |
# AST helpers | |
def indent(text): | |
return text.replace("\n", "\n ") | |
def pretty_join(sep, *words, fmt="\n {}\n"): | |
max_width = max(map(len, words)) | |
if max_width > 40 or any("\n" in w for w in words): | |
sep = "{}\n".format(sep.strip()) | |
return fmt.format(indent(sep.join(words))) | |
return sep.join(words) | |
class Node: | |
"""For representing parse trees. | |
Usage:: | |
from functools import partial | |
Plus = partial(Node, "Plus") | |
""" | |
def __init__(self, name, *args, style=None): | |
self.name = name | |
self.args = list(args) | |
self.style = style | |
def __repr__(self): | |
return "Node({}{})".format( | |
repr(self.name), | |
"".join(", {}".format(repr(a)) for a in self.args), | |
) | |
# TODO magical grammar-based pretty repr | |
def repr_g(self, grammar, name=None): | |
for rule in grammar.rules: | |
process = rule.process | |
if not hasattr(process, "build"): | |
continue | |
if process.kind == self.name: | |
return process.build(self) | |
#func = getattr(process, "from_func", None) | |
#style = { | |
# infix.func: 'infix', | |
# prefix.func: 'prefix', | |
# postfix.func: 'postfix', | |
#}.get(func) | |
#if not style: | |
# continue | |
#args = process.from_args | |
#print(style, args) | |
def repr_p(self): | |
"""Return extra-readable function-style pretty representation. | |
Uses style hints provided by the infix/prefix/postfix helper functions. | |
Looks like: 3 * 4 + f(1, 2, …) | |
""" | |
def rec(a, func=pformat): | |
return a.repr_p() if hasattr(a, "repr_p") else func(a) | |
name = rec(self.name, str) | |
first = rec(self.args[0]) | |
if self.style == 'infix': | |
assert len(self.args) == 2 | |
second = rec(self.args[1]) | |
return "({})".format(pretty_join(" ", first, name, second)) | |
elif self.style == 'prefix': | |
assert len(self.args) == 1 | |
return "({})".format(pretty_join(" ", name, first)) | |
elif self.style == 'postfix': | |
assert len(self.args) == 1 | |
return "({})".format(pretty_join(" ", first, name)) | |
else: | |
return self.repr_f() | |
def repr_f(self): | |
"""Return function-style pretty representation. | |
Looks like: +(*(3, 4), f(1, 2, …)) | |
""" | |
def rec(a, func=pformat): | |
return a.repr_f() if hasattr(a, "repr_f") else func(a) | |
return "{}({})".format( | |
rec(self.name, str), | |
pretty_join(", ", *(rec(x) for x in self.args)), | |
) | |
def repr_s(self): | |
"""Return S-expression pretty representation. | |
Looks like: (+ (* 3 4) (f 1 2 …)) | |
""" | |
def rec(a, func=pformat): | |
return a.repr_s() if hasattr(a, "repr_s") else func(a) | |
return "({})".format(pretty_join(" ", | |
rec(self.name, str), | |
*(rec(x) for x in self.args), | |
fmt="{}\n" | |
)) | |
def repr_rpn(self): | |
"""Return Reverse Polish Notation pretty representaion. | |
Looks like: 3 4 * 1 2 … f + | |
""" | |
def rec(a): | |
return a.repr_rpn() if hasattr(a, "repr_rpn") else pformat(a) | |
return pretty_join(" ", | |
*([rec(x) + " " for x in self.args] + [self.name]) | |
) | |
class Builder: | |
"""Function for processing grammar rule to build parse tree. | |
Also has a `build` attribute, to do the reverse operation. Which is kind of | |
an experiment... | |
""" | |
def __init__(self, kind): | |
if callable(kind): | |
self.cls = kind | |
else: | |
self.kind = kind | |
self.cls = partial(Node, kind, style=self.style) | |
@pretty_wrapper | |
class prefix(Builder): | |
"""Create a function f(_, right) which wraps `cls`.""" | |
style = 'prefix' | |
def __call__(self, _, right): | |
return self.cls(right) | |
def build(self, node): | |
assert self.kind | |
return [self.kind, node.args[0]] | |
@pretty_wrapper | |
class infix(Builder): | |
"""Create a function f(left, _, right) which wraps `cls`.""" | |
style = 'infix' | |
def __call__(self, left, _, right): | |
return self.cls(left, right) | |
def build(self, node): | |
assert self.kind | |
return [node.args[0], self.kind, node.args[1]] | |
@pretty_wrapper | |
class postfix(Builder): | |
"""Create a function f(left, _) which wraps `cls`.""" | |
style = 'postfix' | |
def __call__(self, left, _): | |
return self.cls(left) | |
def build(self, node): | |
assert self.kind | |
return [node.args[0], self.kind] | |
@pretty_wrapper | |
def fcall(cls): | |
"""Create a function f(name, _, args, _) which wraps `cls(name, *args)`.""" | |
assert callable(cls) | |
def f(name, _, args, _2): | |
return cls(name, *args) | |
return f | |
# Token-specific | |
@pretty | |
def value(token): | |
return token.value | |
@pretty_wrapper | |
def with_value(func): | |
def f(token): | |
return func(token.value) | |
return f |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
example