Last active
August 29, 2015 14:01
-
-
Save arne-cl/bc13bdfe7d460d4a228e to your computer and use it in GitHub Desktop.
prints a rhetorical structure tree (RS3) as table
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
# #!/usr/bin/python | |
# Titel: rst.py | |
# Discription: prints a rst tree as table | |
# Lizenz: GPLv3 | |
# Author: Andre Herzog | |
# vers.: 0.1c | |
# Date: 26.03.2014 | |
import sys | |
import os | |
import csv | |
from xml.etree import cElementTree as et | |
from collections import deque | |
from collections import OrderedDict | |
from prettytable import PrettyTable | |
class Node(object): | |
'''This class represens a single rst node.''' | |
def __init__(self, parent=u"", rel=u"span", child=[], | |
status=u"nuc", start=-1, end=-1, node=True): | |
'''The init of the node object requires an id only. | |
The init default is a nucleus node (not a leaf).''' | |
self.start = start + 1 # first token of the segment (position) | |
self.end = end # last token ot the segment (position) | |
self.status = status # nucleus or satellite | |
self.node = node # node or leaf (True/False) | |
self.parent = parent # id of the parent node | |
self.child = child # id list of the children | |
self.relation = rel # relation name (between node and parent) | |
def pprint(self): | |
'''Pretty prints a node as a small table.''' | |
x = PrettyTable(["start", "end", "status", "node", "parent", "child", | |
"relation"]) | |
x.align["RELATION"] = "l" # Left align | |
x.add_row([self.start, self.end, self.status, self.node, self.parent, | |
self.child, self.relation]) | |
print x.get_string(sortby="start") | |
class RstTree(OrderedDict): | |
'''This class represens a rst tree.''' | |
def __init__(self, path="", nuc=[]): | |
'''Init with list of root elements.''' | |
super(RstTree, self).__init__() | |
self.token = [] # list of token | |
self.root = [] # list of root node ids | |
if nuc: | |
self._nuc = nuc # list with nucleus relnames (multinuc) | |
else: | |
self._nuc = [u'span', u'conjunction', u'contrast', u'disjunction', | |
u'joint', u'list', u'restatement-rm', u'sequence'] | |
if path: | |
self.loadTree(path) | |
def __getAttrib(self, e, name): | |
'''Gets attribute by name, if this elem has not such attribute, it | |
returns an empty string.''' | |
if name in e.attrib: | |
return e.attrib[name] | |
return "" | |
def __addAdditionalInfos(self): | |
'''Make implicit information explicit. (Ugly like the rst format): | |
* adds the children to the nodes | |
* calculates the spans of the nodes''' | |
# Add children to the nodes | |
for ID in self: | |
if not self[ID].node: | |
child = ID | |
for i in self.traverseToRoot(self[child].parent): | |
if not child in self[i].child: | |
self[i].child.append(child) | |
child = i | |
# calculate the spans of the nodes | |
for ID in self: | |
child = ID | |
for i in self.traverseToRoot(self[ID].parent): | |
if self[child].relation in self._nuc: | |
start_tmp = [self[child].start, self[i].start] | |
end_tmp = [self[child].end, self[i].end] | |
for c in self[child].child: | |
start_tmp.append(self[c].start) | |
end_tmp.append(self[c].end) | |
self[i].start = min(start_tmp) | |
self[i].end = max(end_tmp) | |
child = i | |
def loadTree(self, path): | |
'''Loads a rst tree from a given rs3.xml file into the class.''' | |
xml = et.ElementTree(file=path) | |
count = 0 | |
for e in xml.iter(): | |
if e.tag == "segment": | |
parent = self.__getAttrib(e, "parent") | |
relname = self.__getAttrib(e, "relname") | |
if relname in self._nuc: | |
status = "nuc" | |
else: | |
status = "sat" | |
if not parent: | |
self.root.append(e.attrib['id']) | |
# get text and split it by whitespace | |
txt = e.text.strip().split() | |
self.token.extend(txt) | |
txt_len = len(txt) | |
start = count | |
count += txt_len | |
self.update({e.attrib['id']: Node(parent, | |
relname, [], status, | |
start, count, False)}) | |
elif e.tag == "group": | |
parent = self.__getAttrib(e, "parent") | |
relname = self.__getAttrib(e, "relname") | |
if relname in self._nuc: | |
status = "nuc" | |
else: | |
status = "sat" | |
if not parent: | |
self.root.append(e.attrib['id']) | |
self.update({e.attrib['id']: Node(parent, | |
relname, [], status, | |
count + 1, -1, True)}) | |
self.__addAdditionalInfos() | |
def getRelation(self, node): | |
'''Returns the relation between multi-nuc nodes.''' | |
rel = [] | |
if not self[node].relation == "span": | |
return [self[node].relation] | |
for child in self[node].child: | |
if not self[child].relation == "span": | |
rel.append(self[child].relation) | |
return rel | |
def traverseToRoot(self, ID): | |
'''Traverses form the given node ID to the hightest node of the rst | |
tree by following the nucelei. (bottom-up)''' | |
parent = deque() | |
last = current = 0 | |
if ID: | |
parent.append((current, ID)) | |
while parent and self[ID].parent: | |
current, ID = parent.popleft() | |
while last and last >= current: | |
last -= 1 | |
if ID: | |
yield ID | |
parent.extendleft([(current + 1, self[ID].parent)]) | |
last = current | |
def traverse(self, ID): | |
'''Traverses form the given node ID to the lowest child nodes of the | |
rst tree. (top-down)''' | |
children = deque() | |
path = list() | |
last = current = 0 | |
children.append((current, ID)) | |
while children: | |
print ID, self[ID].child | |
current, ID = children.popleft() | |
while last and last >= current: | |
path.pop() | |
last -= 1 | |
path.append(ID) | |
yield current, path | |
children.extendleft([(current + 1, | |
node) for node in self[ID].child]) | |
last = current | |
def getAllDescendents(self, ID): | |
'''Gets all descendents of a node.''' | |
for node in self.traverse(self[ID], lambda x: None): | |
yield node | |
def printSatNucTable(self): | |
'''Prints a satelite-nucleus-relation table''' | |
x = PrettyTable(["S-ID", "S-START", "S-END", "N-ID", "N-START", | |
"N-END", "RELATION"]) | |
x.align["RELATION"] = "l" # Left align | |
for r in rst: | |
if rst[r].status == "nuc": | |
if rst[r].child: | |
for s in rst[r].child: | |
if rst[s].relation != "span": | |
n = rst[s].parent | |
# normal nucs | |
if not rst[s].relation in rst._nuc: | |
x.add_row([s, rst[s].start, rst[s].end, | |
n, rst[n].start, rst[n].end, | |
rst[s].relation]) | |
# multi nucs with children | |
elif rst[r].relation != "span": | |
x.add_row(["-", "-", "-", | |
r, rst[r].start, rst[r].end, | |
rst[r].relation + " (" | |
+ rst[r].parent + ")"]) | |
else: | |
# multi nucs without children | |
x.add_row(["-", "-", "-", r, rst[r].start, rst[r].end, | |
rst[r].relation + " (" + rst[r].parent + ")"]) | |
print x.get_string(sortby="N-END") | |
def writeToCsv(self, path): | |
'''Writes the satelite-nucleus-relation table to a tab separated | |
format.''' | |
with open(path, 'wb') as csvfile: | |
w = csv.writer(csvfile, delimiter='\t', | |
quotechar='|', quoting=csv.QUOTE_MINIMAL) | |
w.writerow(["S-ID", "S-START", "S-END", "N-ID", "N-START", | |
"N-END", "RELATION"]) | |
for r in rst: | |
if rst[r].status == "nuc": | |
if rst[r].child: | |
for s in rst[r].child: | |
if rst[s].relation != "span": | |
n = rst[s].parent | |
# normal nucs | |
if not rst[s].relation in rst._nuc: | |
w.writerow([s, rst[s].start, rst[s].end, | |
n, rst[n].start, rst[n].end, | |
rst[s].relation]) | |
# multi nucs with children | |
elif rst[r].relation != "span": | |
w.writerow(["-", "-", "-", | |
r, rst[r].start, rst[r].end, | |
rst[r].relation + " (" | |
+ rst[r].parent + ")"]) | |
else: | |
# multi nucs without children | |
w.writerow(["-", "-", "-", r, rst[r].start, | |
rst[r].end, rst[r].relation + | |
" (" + rst[r].parent + ")"]) | |
# Main ####################################################################### | |
if __name__ == "__main__": | |
# TODO A better run part. | |
if len(sys.argv) <= 1: | |
print "Usage: rst.py <path-to-rst-file> [<path-to-csv-output-file>]" | |
sys.exit(0) | |
elif len(sys.argv) > 2: | |
rst_path = sys.argv[1] | |
csv_path = sys.argv[2] | |
else: | |
rst_path = sys.argv[1] | |
csv_path = "" | |
if os.path.isfile(rst_path): | |
rst = RstTree(rst_path) | |
if csv_path: | |
rst.writeToCsv(csv_path) | |
else: | |
rst.printSatNucTable() | |
else: | |
print "Rst file not exists." | |
sys.exit(0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment