Last active
October 26, 2018 16:16
-
-
Save mrdrozdov/f7d2c0182a253f98fb5da1b8b06f16c1 to your computer and use it in GitHub Desktop.
useful tree methods
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
A tree is a nested tuple representation. | |
tree = ('A', ('B', 'C')) | |
It can also be represented as a set of spans. | |
tree_spans = get_spans(tree) # {(0, 3), (1, 3)} | |
Sometimes you need to convert a string into a tree. | |
tree_str = '( A ( B C ) )' | |
tokens, transitions = convert_binary_bracketing(tree_str) | |
tree = build_tree(tokens, transitions) | |
Or a tree into a string. | |
tree_str = tree_to_string(tree) | |
You can get just the tokens from a tree string. | |
tokens = to_tokens(tree_str) | |
Or you can pretty-print a tree. | |
print(print_tree(tree)) | |
""" | |
from nltk.tree import Tree | |
from nltk.treeprettyprinter import TreePrettyPrinter | |
def tree_to_string(parse): | |
if not isinstance(parse, (list, tuple)): | |
return parse | |
if len(parse) == 1: | |
return parse[0] | |
else: | |
return '( ' + tree_to_string(parse[0]) + ' ' + tree_to_string(parse[1]) + ' )' | |
def tree_to_nltk_string(s, symbol='|'): | |
if not isinstance(s, (list, tuple)): | |
s = s.replace(')', 'RP').replace('(', 'LP') | |
return '({} {})'.format(symbol, s) | |
return '({} {} {})'.format( | |
symbol, | |
tree_to_nltk_string(s[0], symbol), | |
tree_to_nltk_string(s[1], symbol), | |
) | |
def print_tree(s): | |
if isinstance(s, (list, tuple)): | |
tree_string = tree_to_nltk_string(s) | |
tree = Tree.fromstring(tree_string) | |
out = TreePrettyPrinter(tree).text() | |
return out | |
def to_tokens(parse): | |
return [x for x in parse.split() if x != '(' and x != ')'] | |
def to_indexed_contituents(parse): | |
sp = parse.split() | |
if len(sp) == 1: | |
return set([(0, 1)]) | |
backpointers = [] | |
indexed_constituents = set() | |
word_index = 0 | |
for index, token in enumerate(sp): | |
if token == '(': | |
backpointers.append(word_index) | |
elif token == ')': | |
start = backpointers.pop() | |
end = word_index | |
constituent = (start, end) | |
indexed_constituents.add(constituent) | |
else: | |
word_index += 1 | |
return indexed_constituents | |
def get_spans(tree): | |
def helper(tr, idx=0): | |
if isinstance(tr, (str, int)): | |
return 1, [] | |
left, left_spans = helper(tr[0], idx=idx) | |
right, right_spans = helper(tr[1], idx=idx+left) | |
span = [(idx, idx + left + right)] | |
spans = span + left_spans + right_spans | |
return left + right, spans | |
_, spans = helper(tree) | |
return spans | |
def build_tree(tokens, transitions): | |
stack = [] | |
buf = tokens[::-1] | |
for t in transitions: | |
if t == 0: | |
stack.append(buf.pop()) | |
elif t == 1: | |
right = stack.pop() | |
left = stack.pop() | |
stack.append((left, right)) | |
assert len(stack) == 1 | |
return stack[0] | |
def convert_binary_bracketing(parse): | |
transitions = [] | |
tokens = [] | |
for word in parse.split(' '): | |
if word[0] != "(": | |
if word == ")": | |
transitions.append(1) | |
else: | |
tokens.append(word) | |
transitions.append(0) | |
return tokens, transitions | |
def example_f1(c1, c2): | |
""" | |
Compute unlabeled f1. | |
""" | |
prec = float(len(c1.intersection(c2))) / len(c2) | |
return prec # For strictly binary trees, P = R = F1 | |
def get_spans(tree): | |
""" | |
Convert the tree representation to the span representation. | |
""" | |
def helper(tr, idx=0): | |
if isinstance(tr, (str, int)): | |
return 1, [] | |
left, left_spans = helper(tr[0], idx=idx) | |
right, right_spans = helper(tr[1], idx=idx+left) | |
span = [(idx, idx + left + right)] | |
spans = span + left_spans + right_spans | |
return left + right, spans | |
_, spans = helper(tree) | |
return spans |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment