Last active
December 4, 2024 01:54
-
-
Save JarbasAl/ffa6cbb677af861a1382107acfd30e26 to your computer and use it in GitHub Desktop.
extract triples with spacy and https://spacy.io/universe/project/coreferee
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Tuple, Dict, List | |
import spacy | |
from spacy.cli import download | |
from spacy.tokens import Token | |
class DependencyParser: | |
def __init__(self): | |
self.NEGATION = {"no", "not", "n't", "never", "none"} | |
self.SUBJECTS = {"nsubj", "nsubjpass", "csubj", "csubjpass", "agent", "expl"} | |
self.OBJECTS = {"dobj", "dative", "attr", "oprd", "pobj"} | |
def get_subs_from_conjunctions(self, subs: List[Token]) -> List[Token]: | |
more_subs = [] | |
for sub in subs: | |
if "and" in {tok.lower_ for tok in sub.rights}: | |
more_subs.extend([tok for tok in sub.rights if tok.dep_ in self.SUBJECTS or tok.pos_ == "NOUN"]) | |
more_subs.extend(self.get_subs_from_conjunctions( | |
[tok for tok in sub.rights if tok.dep_ in self.SUBJECTS or tok.pos_ == "NOUN"])) | |
return more_subs | |
def get_objs_from_conjunctions(self, objs: List[Token]) -> List[Token]: | |
more_objs = [] | |
for obj in objs: | |
if "and" in {tok.lower_ for tok in obj.rights}: | |
more_objs.extend([tok for tok in obj.rights if tok.dep_ in self.OBJECTS or tok.pos_ == "NOUN"]) | |
more_objs.extend(self.get_objs_from_conjunctions(more_objs)) | |
return more_objs | |
def is_negated(self, tok: Token) -> bool: | |
return any(dep.lower_ in self.NEGATION for dep in tok.children) | |
def find_svos(self, tokens: List[Token]) -> List[Tuple[str, str, str]]: | |
svos = [] | |
verbs = [tok for tok in tokens if tok.pos_ in ["VERB", "AUX"]] | |
for v in verbs: | |
subs, verb_negated = self.get_all_subs(v) | |
if not subs: | |
continue | |
_, objs = self.get_all_objs(v) | |
# Handle copular constructions | |
if v.lemma_ in {"be"}: | |
cop_svos = self.handle_copular_constructions(v, subs) | |
if cop_svos: | |
svos.extend(cop_svos) | |
continue | |
# General cases | |
for sub in subs: | |
objs = [tok for tok in v.rights if tok.dep_ in self.OBJECTS] | |
objs.extend(self.get_objs_from_conjunctions(objs)) | |
for obj in objs: | |
obj_negated = self.is_negated(obj) | |
svos.append(( | |
sub.lower_, | |
"!" + v.lemma_ if verb_negated or obj_negated else v.lemma_, | |
obj.lower_ | |
)) | |
return svos | |
def get_all_subs(self, v: Token) -> Tuple[List[Token], bool]: | |
verb_negated = self.is_negated(v) | |
subs = [tok for tok in v.lefts if tok.dep_ in self.SUBJECTS] | |
if subs: | |
subs.extend(self.get_subs_from_conjunctions(subs)) | |
return subs, verb_negated | |
def get_all_objs(self, v: Token) -> Tuple[Token, List[Token]]: | |
objs = [tok for tok in v.rights if tok.dep_ in self.OBJECTS] | |
objs.extend(self.get_objs_from_conjunctions(objs)) | |
return v, objs | |
def handle_copular_constructions(self, v: Token, subs: List[Token]) -> List[Tuple[str, str, str]]: | |
svos = [] | |
for sub in subs: | |
objs = [tok for tok in v.rights if tok.dep_ in {"attr", "acomp", "pobj"}] | |
for obj in objs: | |
svos.append((sub.lower_, v.lemma_, obj.lower_)) | |
return svos | |
class TriplesExtractor: | |
"""Extract semantic triples for knowledge graph construction.""" | |
def __init__(self, model="en_core_web_trf", solve_coref=True) -> None: | |
"""Load spaCy model.""" | |
if not spacy.util.is_package(model): | |
download(model) | |
if model == "en_core_web_trf": | |
# EXTRA MODEL ALSO NEEDED | |
if not spacy.util.is_package("en_core_web_lg"): | |
download("en_core_web_lg") | |
self.nlp = spacy.load(model) | |
self.coref = solve_coref | |
if solve_coref: | |
try: | |
self.nlp.add_pipe("coreferee") | |
except: | |
print("WARNING: coreference resolution not available") | |
def extract_preps(self, doc): | |
triples = [] | |
for ent in doc.ents: | |
if len(ent) > 1: # Check if it's a multi-word entity | |
ent_text = " ".join([t.text for t in ent]) | |
else: | |
ent_text = ent.text | |
preps = [prep for prep in ent.root.head.children if prep.dep_ == "prep"] | |
for prep in preps: | |
for child in prep.children: | |
if len(ent) > 1: # Ensure multi-word entities are captured as a single unit | |
triples.append((ent_text, "{} {}".format(ent.root.head, prep), child.text)) | |
else: | |
triples.append((ent.text, "{} {}".format(ent.root.head, prep), child.text)) | |
return triples | |
def solve_corefs(self, text: str) -> str: | |
doc = self.nlp(text) | |
mapping: Dict[int, str] = {} | |
for chain in doc._.coref_chains: | |
plural = any(len(mention) > 1 for mention in chain.mentions) | |
if plural: | |
continue | |
ctoks = [] | |
for m in chain.mentions: | |
# filter pronouns from candidate replacements | |
ctoks += [doc[i] for i in m.token_indexes if doc[i].pos_ in ['NOUN', 'PROPN']] | |
if not ctoks: | |
continue | |
# pick the longest PROPER NOUN token | |
propers = [tok for tok in ctoks if tok.pos_ == 'PROPN'] | |
if propers: | |
resolve_tok = max(propers, key=lambda k: len(k.text)) | |
else: | |
# let's just pick the longest NOUN token | |
resolve_tok = max(ctoks, key=lambda k: len(k.text)) | |
for mention in chain.mentions: | |
idx = mention[0] | |
if resolve_tok.text == doc[idx].text: | |
continue | |
print(idx, doc[idx], "->", resolve_tok) | |
mapping[idx] = resolve_tok.text | |
for chain in doc._.coref_chains: | |
plural = any(len(mention) > 1 for mention in chain.mentions) | |
if plural: | |
m = max(chain.mentions, key=len) | |
joint_str = " and ".join([mapping.get(i) or doc[i].text for i in m]) | |
for mention in chain: | |
if len(mention) == 1: | |
idx = mention[0] | |
mapping[idx] = joint_str | |
print(idx, doc[idx], "->", joint_str) | |
tokens = [mapping.get(idx, t.text) | |
for idx, t in enumerate(doc)] | |
return " ".join(tokens) | |
def semantic_triples(self, documents: List[str]) -> Dict[int, List[Tuple[str, str, str]]]: | |
"""Extract semantic triples from a list of documents.""" | |
parser = DependencyParser() | |
output_dict = {} | |
for idx, text in enumerate(documents): | |
if text: | |
if self.coref: | |
text = self.solve_corefs(text) | |
doc = self.nlp(text) | |
svo_lst = parser.find_svos(doc) + self.extract_preps(doc) | |
output_dict[idx] = svo_lst | |
return output_dict | |
if __name__ == "__main__": | |
test = [ | |
"Miro has a dog.", | |
"Miro is a software developer.", | |
"beer is nice", | |
"Mike is a nice guy", | |
"Chris was an asshole", | |
"Apple was founded in Cupertino in the year 1981.", | |
"Barrack Obama was born in Hawaii. He was president of the United States and lived in the White House.", | |
"He was very busy with his work, Peter had had enough of it. He and his wife decided they needed a holiday. They travelled to Spain because they loved the country very much." | |
] | |
# coref resolution | |
# 7 He -> Obama | |
# | |
# 13 it -> work | |
# 15 He -> Peter | |
# 17 his -> Peter | |
# 33 country -> Spain | |
# 20 they -> Peter and wife | |
# 25 They -> Peter and wife | |
# 30 they -> Peter and wife | |
extractor = TriplesExtractor() | |
triples = extractor.semantic_triples(test) | |
from pprint import pprint | |
# TODO - obama is missing "lived in the White House" | |
pprint(triples) | |
# 1: [('miro', 'have', 'dog')], | |
# 2: [('miro', 'be', 'developer')], | |
# 3: [('beer', 'be', 'nice')], | |
# 4: [('mike', 'be', 'guy')], | |
# 5: [('chris', 'be', 'asshole')], | |
# 6: [('Apple', 'founded in', 'Cupertino'), ('Apple', 'founded in', 'year')], | |
# 7: [('obama', 'be', 'president'), ('Barrack Obama', 'born in', 'Hawaii')], | |
# 8: [('he', 'be', 'busy'), | |
# ('peter', 'have', 'enough'), | |
# ('peter', 'need', 'holiday'), | |
# ('wife', 'need', 'holiday'), | |
# ('peter', 'love', 'spain'), | |
# ('wife', 'love', 'spain'), | |
# ('Peter', 'travelled to', 'Spain')]} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment