Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save millawell/ea77a08da80d1dd181af4122891dde2e to your computer and use it in GitHub Desktop.
Save millawell/ea77a08da80d1dd181af4122891dde2e to your computer and use it in GitHub Desktop.
import typer
from pathlib import Path
from functools import partial
from tqdm import tqdm
from lxml import etree
from spacy.tokens import DocBin
from spacy.util import get_lang_class
def get_namespaces():
return {
"pos":"http:///de/tudarmstadt/ukp/dkpro/core/api/lexmorph/type/pos.ecore",
"tcas":"http:///uima/tcas.ecore",
"xmi":"http://www.omg.org/XMI",
"cas":"http:///uima/cas.ecore",
"tweet":"http:///de/tudarmstadt/ukp/dkpro/core/api/lexmorph/type/pos/tweet.ecore",
"morph":"http:///de/tudarmstadt/ukp/dkpro/core/api/lexmorph/type/morph.ecore",
"dependency":"http:///de/tudarmstadt/ukp/dkpro/core/api/syntax/type/dependency.ecore",
"type5":"http:///de/tudarmstadt/ukp/dkpro/core/api/semantics/type.ecore",
"type8":"http:///de/tudarmstadt/ukp/dkpro/core/api/transform/type.ecore",
"type7":"http:///de/tudarmstadt/ukp/dkpro/core/api/syntax/type.ecore",
"type2":"http:///de/tudarmstadt/ukp/dkpro/core/api/metadata/type.ecore",
"type9":"http:///org/dkpro/core/api/xml/type.ecore",
"type3":"http:///de/tudarmstadt/ukp/dkpro/core/api/ner/type.ecore",
"type4":"http:///de/tudarmstadt/ukp/dkpro/core/api/segmentation/type.ecore",
"type":"http:///de/tudarmstadt/ukp/dkpro/core/api/coref/type.ecore",
"type6":"http:///de/tudarmstadt/ukp/dkpro/core/api/structure/type.ecore",
"constituent":"http:///de/tudarmstadt/ukp/dkpro/core/api/syntax/type/constituent.ecore",
"chunk":"http:///de/tudarmstadt/ukp/dkpro/core/api/syntax/type/chunk.ecore",
}
def get_tag_mapping():
return {
"LOC":"LOC",
"PER": "PER",
"ORG": "ORG",
"LOCderiv":"LOC",
"PERderiv": "PER",
"ORGderiv": "ORG",
}
def load_xmi(xmi_input):
"""
Loads an XMI file and returns the root element and the namespaces.
"""
with open(xmi_input, "rb") as fin:
xml_str = fin.read()
tree = etree.fromstring(xml_str)
return tree
def get_lemma(token_el, tree, ns):
if (lemma_id := token_el.get("lemma")) is None:
return None
lemma = tree.xpath(f"//type4:Lemma[@xmi:id='{lemma_id}']",namespaces=ns)
if len(lemma) == 0:
return None
lemma = lemma[0]
return lemma.attrib.get("value")
def get_pos(token_el, tree, ns):
if (pos_id := token_el.get("pos")) is None:
return None
pos = tree.xpath(f"//pos:POS[@xmi:id='{pos_id}']",namespaces=ns)
if len(pos) == 0:
return None
pos = pos[0]
return pos.attrib.get("coarseValue")
def get_morph(token_el, tree, ns):
if (morph_id := token_el.get("morph")) is None:
return None
morph = tree.xpath(f"//morph:MorphologicalFeatures[@xmi:id='{morph_id}']",namespaces=ns)
if len(morph) == 0:
return None
morph = morph[0]
return morph.attrib.get("value")
def parse_sentences(tree, ns):
sentence_elements = tree.xpath(
'type4:Sentence',
namespaces=ns
)
sentences = []
for sentence_el in tqdm(sentence_elements, desc="parse sentences"):
begin_ind=int(sentence_el.attrib['begin'])
end_ind=int(sentence_el.attrib['end'])
sentences.append({
"start_char": begin_ind,
"end_char": end_ind,
"tag": "sentence",
"attrib": {}
})
return sentences
def parse_tokens(tree, ns):
token_elements = tree.xpath(
'type4:Token',
namespaces=ns
)
tokens = []
for token_el in tqdm(token_elements, desc="parse tokens"):
begin_ind=int(token_el.attrib['begin'])
end_ind=int(token_el.attrib['end'])
tokens.append({
"start_char": begin_ind,
"end_char": end_ind,
"tag": "token",
"attrib": {}
})
lemma=get_lemma(token_el, tree, ns)
pos=get_pos(token_el, tree, ns)
morph=get_morph(token_el, tree, ns)
if lemma is not None:
tokens[-1]["attrib"]["lemma"] = lemma
if pos is not None:
tokens[-1]["attrib"]["pos"] = pos
if morph is not None:
tokens[-1]["attrib"]["morph"] = morph
return tokens
def parse_entities(tree, ns):
entity_elements = tree.xpath(
"type3:NamedEntity",
namespaces=ns
)
tag_map = get_tag_mapping()
entities = []
for entity in tqdm(entity_elements, desc="parse entities"):
if 'begin' in entity.attrib and 'end' in entity.attrib and 'value' in entity.attrib:
entities.append({
"start_char": int(entity.attrib['begin']),
"end_char": int(entity.attrib['end']),
"tag": tag_map[entity.attrib['value']],
"attrib": {}
})
if 'identifier' in entity.attrib:
entities[-1]["attrib"] = {
"ref": entity.attrib['identifier']
}
return entities
def add_annotations(so, view, new_elems):
for elem in tqdm(new_elems, desc="add annotations"):
start_ind = view.get_table_pos(elem["start_char"])
end_ind = view.get_table_pos(elem["end_char"]-1)+1
so.add_inline(
begin=start_ind,
end=end_ind,
tag=elem["tag"],
depth=None,
attrib=elem["attrib"],
)
def chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i:i + n]
def convert(export_path: str, n_sents:int, lang:str):
lang = get_lang_class(lang)
nlp = lang()
export_path = Path(export_path)
assert export_path.exists()
xmi_path = Path.cwd() / 'corpus' / 'xmi'
xmi_files = [f for f in xmi_path.iterdir() if f.suffix == ".xmi"]
for xmi_file in xmi_files:
db = DocBin() # create a DocBin object
xmi_tree = load_xmi(xmi_file)
xmi_namespaces = get_namespaces()
raw_string = xmi_tree.xpath(
"cas:Sofa",
namespaces=xmi_namespaces
)[0].attrib["sofaString"]
entities = parse_entities(xmi_tree, xmi_namespaces)
sentences = parse_sentences(xmi_tree, xmi_namespaces)
for chunk in chunks(sentences, n_sents):
c_begin = chunk[0]["start_char"]
c_end = chunk[-1]['end_char']
c_str = raw_string[c_begin:c_end]
doc = nlp.make_doc(c_str) # create doc object from text
ents = []
for entity in entities:
if entity["start_char"] >= c_begin and entity["end_char"] <= c_end:
if 'ref' in entity["attrib"]:
add_fun = partial(doc.char_span, kb_id=entity['attrib'].get('ref', None))
else:
add_fun = doc.char_span
span = add_fun(
entity['start_char']-c_begin,
entity['end_char']-c_begin,
label=entity['tag'],
alignment_mode="expand"
)
ents.append(span)
# ents = [ent for ent in ents if ent.kb_id != 0]
ents = get_non_overlapping_intervals(ents)
doc.ents = ents
db.add(doc)
db.to_disk(str(f'./corpus/converted/{xmi_file.stem}.spacy'))
def get_non_overlapping_intervals(list_of_intervals):
"""
Given a list of intervals, returns a list of non-overlapping intervals.
"""
def do_overlap(interval_a, interval_b):
interval_a = interval_a.start_char, interval_a.end_char
interval_b = interval_b.start_char, interval_b.end_char
# make sure that they are sorted
interval_a, interval_b = sorted([interval_a, interval_b], key=lambda x: x[0])
return (
interval_a[0] < interval_b[1]
and interval_b[0] < interval_a[1]
)
chosen = set()
for interval_a in list_of_intervals:
if not any(do_overlap(interval_a, interval_b) for interval_b in chosen):
chosen.add(interval_a)
return list(chosen)
if __name__ == "__main__":
typer.run(convert)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment