Last active
February 24, 2020 19:32
-
-
Save ivyleavedtoadflax/cc5b324c44a68c1ccc5c16d4bd960b7c to your computer and use it in GitHub Desktop.
Find overlapping spans in prodigy documents
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf-8 | |
import itertools | |
from itertools import groupby | |
from operator import itemgetter | |
from pprint import PrettyPrinter | |
import plac | |
from deep_reference_parser.io import read_jsonl, write_jsonl | |
def find_overlaps(docs_path): | |
manually_checked = read_jsonl(docs_path) | |
class spans: | |
def __init__(self, docs, labels=["author", "title", "year"]): | |
self.authors = [] | |
self.titles = [] | |
self.years = [] | |
self.labels = labels | |
self.docs = docs | |
def get_ranges(self, doc, label): | |
out = list( | |
itertools.chain.from_iterable( | |
[ | |
range(i["start"], i["end"]) | |
for i in doc["spans"] | |
if i["label"] == label | |
] | |
) | |
) | |
assert len(out) == len(set(out)) | |
return set(out) | |
def split_nonconsecutive(self, data): | |
return [ | |
list(map(itemgetter(1), g)) | |
for k, g in groupby(enumerate(data), lambda x: x[0] - x[1]) | |
] | |
def get_text(self, doc, range, context=0): | |
return doc["text"][min(range) - context : max(range) + context] | |
def run(self): | |
out = {} | |
for doc in self.docs: | |
self.authors = self.get_ranges(doc, "author") | |
self.titles = self.get_ranges(doc, "title") | |
self.years = self.get_ranges(doc, "year") | |
at = self.authors & self.titles | |
ay = self.authors & self.years | |
ty = self.titles & self.years | |
overlaps = at | ay | ty | |
if overlaps: | |
split_overlaps = self.split_nonconsecutive(overlaps) | |
doc_overlaps = [] | |
for overlap in split_overlaps: | |
doc_overlaps.append( | |
{ | |
"overlaps": self.get_text(doc, overlap), | |
"overlaps_with_context": self.get_text( | |
doc, overlap, 10 | |
), | |
} | |
) | |
out[doc["_input_hash"]] = doc_overlaps | |
return out | |
foo = spans(manually_checked) | |
out = foo.run() | |
pp = PrettyPrinter(indent=4) | |
pp.pprint(out) | |
if __name__ == "__main__": | |
plac.call(find_overlaps) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment