Skip to content

Instantly share code, notes, and snippets.

@mrdrozdov
Last active October 26, 2018 16:16
Show Gist options
  • Save mrdrozdov/f7d2c0182a253f98fb5da1b8b06f16c1 to your computer and use it in GitHub Desktop.
Save mrdrozdov/f7d2c0182a253f98fb5da1b8b06f16c1 to your computer and use it in GitHub Desktop.
useful tree methods
"""
A tree is a nested tuple representation.
tree = ('A', ('B', 'C'))
It can also be represented as a set of spans.
tree_spans = get_spans(tree) # {(0, 3), (1, 3)}
Sometimes you need to convert a string into a tree.
tree_str = '( A ( B C ) )'
tokens, transitions = convert_binary_bracketing(tree_str)
tree = build_tree(tokens, transitions)
Or a tree into a string.
tree_str = tree_to_string(tree)
You can get just the tokens from a tree string.
tokens = to_tokens(tree_str)
Or you can pretty-print a tree.
print(print_tree(tree))
"""
from nltk.tree import Tree
from nltk.treeprettyprinter import TreePrettyPrinter
def tree_to_string(parse):
if not isinstance(parse, (list, tuple)):
return parse
if len(parse) == 1:
return parse[0]
else:
return '( ' + tree_to_string(parse[0]) + ' ' + tree_to_string(parse[1]) + ' )'
def tree_to_nltk_string(s, symbol='|'):
if not isinstance(s, (list, tuple)):
s = s.replace(')', 'RP').replace('(', 'LP')
return '({} {})'.format(symbol, s)
return '({} {} {})'.format(
symbol,
tree_to_nltk_string(s[0], symbol),
tree_to_nltk_string(s[1], symbol),
)
def print_tree(s):
if isinstance(s, (list, tuple)):
tree_string = tree_to_nltk_string(s)
tree = Tree.fromstring(tree_string)
out = TreePrettyPrinter(tree).text()
return out
def to_tokens(parse):
return [x for x in parse.split() if x != '(' and x != ')']
def to_indexed_contituents(parse):
sp = parse.split()
if len(sp) == 1:
return set([(0, 1)])
backpointers = []
indexed_constituents = set()
word_index = 0
for index, token in enumerate(sp):
if token == '(':
backpointers.append(word_index)
elif token == ')':
start = backpointers.pop()
end = word_index
constituent = (start, end)
indexed_constituents.add(constituent)
else:
word_index += 1
return indexed_constituents
def get_spans(tree):
def helper(tr, idx=0):
if isinstance(tr, (str, int)):
return 1, []
left, left_spans = helper(tr[0], idx=idx)
right, right_spans = helper(tr[1], idx=idx+left)
span = [(idx, idx + left + right)]
spans = span + left_spans + right_spans
return left + right, spans
_, spans = helper(tree)
return spans
def build_tree(tokens, transitions):
stack = []
buf = tokens[::-1]
for t in transitions:
if t == 0:
stack.append(buf.pop())
elif t == 1:
right = stack.pop()
left = stack.pop()
stack.append((left, right))
assert len(stack) == 1
return stack[0]
def convert_binary_bracketing(parse):
transitions = []
tokens = []
for word in parse.split(' '):
if word[0] != "(":
if word == ")":
transitions.append(1)
else:
tokens.append(word)
transitions.append(0)
return tokens, transitions
def example_f1(c1, c2):
"""
Compute unlabeled f1.
"""
prec = float(len(c1.intersection(c2))) / len(c2)
return prec # For strictly binary trees, P = R = F1
def get_spans(tree):
"""
Convert the tree representation to the span representation.
"""
def helper(tr, idx=0):
if isinstance(tr, (str, int)):
return 1, []
left, left_spans = helper(tr[0], idx=idx)
right, right_spans = helper(tr[1], idx=idx+left)
span = [(idx, idx + left + right)]
spans = span + left_spans + right_spans
return left + right, spans
_, spans = helper(tree)
return spans
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment