Skip to content

Instantly share code, notes, and snippets.

@minhlab
Created July 15, 2017 07:41
Show Gist options
  • Save minhlab/37ee262409f71769218a4a3c0fe55250 to your computer and use it in GitHub Desktop.
Save minhlab/37ee262409f71769218a4a3c0fe55250 to your computer and use it in GitHub Desktop.
Use this code to count events in OntoNotes 5.0
import os
import re
from collections import defaultdict
import sys
import ptb
ontonotes_root = '/Users/cumeo/corpora/ontonotes-release-5.0'
ontonotes_en = os.path.join(ontonotes_root, 'data/files/data/english/annotations')
assert os.path.exists(ontonotes_root), 'please link/put ontonotes into data directory'
def load_coref(content):
chains = defaultdict(list)
sent = 0
for line in content.split('\n'):
if line.strip() and not re.match(r'</?DOC|</?TEXT', line):
last_tokens = []
brackets = []
curr_index = 0
parts = re.findall(r'<COREF\s[^>]+>|</COREF\s*>|[^\s<>]+', line)
for s in parts:
m = re.match(r'<COREF\s+ID="([^"]+)"', s)
if m: # open bracket
brackets.append((m.group(1), curr_index))
last_tokens = []
elif re.match(r'</COREF', s):
refid, start_index = brackets.pop()
chains[refid].append((sent, start_index, curr_index, tuple(last_tokens)))
else:
s = (s.replace('-AMP-', '&').replace('-LAB-', '<')
.replace('-RAB-', '>').replace(r'\*', '*'))
last_tokens.append(s)
curr_index += 1
# print(sent, ' --> ', tokens)
# print(sent, len(dep_content), line)
sent += 1
return chains
class OntonotesDocument:
def __init__(self, base_path):
self.base_path = base_path
self._chains = {}
coref_path = self.base_path + '.coref'
if os.path.exists(coref_path):
self._trees = ptb.load_trees(self.base_path + '.parse')
# print(coref_path)
with open(coref_path) as f:
try:
self._chains = load_coref(f.read())
#print(self._chains)
except:
sys.stderr.write('Error at document: %s. Ignored.\n' %coref_path)
else:
self._chains, self._coref = {}, {}
def iter_docs(dir_path=ontonotes_en):
for root, _, fnames in os.walk(dir_path):
for fname in fnames:
if re.search(r'\.parse', fname):
yield OntonotesDocument(os.path.join(root, fname[:-6]))
if __name__ == '__main__':
doc_count = 0
event_doc_count = 0
event_count = 0
all_chain_count = 0
event_chain_count = 0
for doc in iter_docs():
if os.path.exists(doc.base_path + '.coref'):
event_found_in_doc = False
for chain in doc._chains.values():
event_found_in_chain = False
for sent, start_index, end_index, tokens in chain:
#print(start_index, end_index, ' '.join(tokens))
pos = doc._trees[sent].terminals[start_index].label()
if start_index+1 == end_index and pos[0] == 'V':
#print(doc._trees[sent].terminals[start_index])
assert (doc._trees[sent].terminals[start_index][0],) == tokens
event_found_in_chain = True
event_found_in_doc = True
event_count += 1
if event_found_in_chain: event_chain_count += 1
all_chain_count += 1
if event_found_in_doc: event_doc_count += 1
doc_count += 1
print('#documents with coreference annotation: %d' %doc_count)
print('#documents with events: %d' %event_doc_count)
print('#events: %d' %event_count)
print('#all chains: %d' %all_chain_count)
print('#chains with events: %d' %event_chain_count)
from nltk.tree import ParentedTree
import os
root_dir = 'data/ptb'
# assert os.path.exists(root_dir), 'please link/put PENN TreeBank into data directory'
def _assign_token_ids(root, n):
assert not isinstance(root, str)
if len(root) == 1 and isinstance(root[0], str):
if root.label() == '-NONE-':
root.token_id = None
else:
root.token_id = n
n += 1
else:
for child in root:
n = _assign_token_ids(child, n)
return n
def _find_terminals(root):
assert not isinstance(root, str)
target_list = []
if len(root) == 1 and isinstance(root[0], str):
target_list.append(root)
else:
for child in root:
target_list.extend(_find_terminals(child))
root.terminals = target_list
return target_list
def _tree_from_string(s):
root = ParentedTree.fromstring(s)
_find_terminals(root)
_assign_token_ids(root, 0)
return root
def _iter_syntax_trees(path):
with open(path, 'rt') as f:
parenthese_count = 0
buf = []
line = f.readline()
while line != '':
if parenthese_count == 0:
s = ' '.join(buf).strip()
if s: yield s
del buf[:]
buf.append(line)
parenthese_count += sum(1 for c in line if c == '(')
parenthese_count -= sum(1 for c in line if c == ')')
line = f.readline()
s = ' '.join(buf).strip()
if s: yield s
def load_trees(path):
trees = []
for s in _iter_syntax_trees(path):
root = _tree_from_string(s)
trees.append(root)
return trees
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment