Skip to content

Instantly share code, notes, and snippets.

@wesslen
Last active June 5, 2023 17:21
Show Gist options
  • Select an option

  • Save wesslen/1271257a6954fd5c9a6e68434b4d921d to your computer and use it in GitHub Desktop.

Select an option

Save wesslen/1271257a6954fd5c9a6e68434b4d921d to your computer and use it in GitHub Desktop.
Prodigy relations validation with validate_answer callback that checks that both relations are labeled entities
# Prodigy v1.11.x; some imports will change for v1.12+
import copy
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import srsly
from spacy.language import Language
from spacy.tokens import Doc, Span, Token
from spacy.util import filter_spans
from prodigy.components.loaders import get_stream
from prodigy.core import recipe
from prodigy.models.matcher import create_matchers, parse_pattern_name
from prodigy.types import RecipeSettingsType, StreamType
from prodigy.util import get_labels, load_model, log, msg, split_string
NP_LABEL = "NP"
PATTERN_LABEL = "PATTERN"
TOKEN_LIMIT = 300
def setup_matchers(
nlp: Language,
patterns: Optional[Union[str, Path, List[Dict[str, Any]]]],
default_label: str = PATTERN_LABEL,
) -> Callable[[Doc], List[Tuple[int, int, int]]]:
if patterns is None:
patterns = []
if patterns and (isinstance(patterns, str) or isinstance(patterns, Path)):
patterns = srsly.read_jsonl(patterns)
final_patterns = []
for pattern in patterns:
if not isinstance(pattern, dict) or "pattern" not in pattern:
msg.fail(
"Invalid pattern found. Patterns should be dicts with a key"
'"pattern" and an optional "label".',
pattern,
exits=1,
)
# TODO: should we assume that the user always want to use all patterns
# and pattern matches? It's different from how the other recipes behave
# but it could lead to confusion if users forget to include pattern
# labels in --span-label
if "label" not in pattern:
pattern["label"] = default_label
final_patterns.append(pattern)
matcher, phrase_matcher, _ = create_matchers(nlp, final_patterns)
def combined_matcher(doc):
matches = matcher(doc)
phrase_matches = phrase_matcher(doc)
return matches + phrase_matches
return combined_matcher, final_patterns
@recipe(
"rel.manual",
# fmt: off
dataset=("Dataset to save annotations to", "positional", None, str),
spacy_model=("Loadable spaCy pipeline or blank:lang (e.g. blank:en)", "positional", None, str),
source=("Data to annotate (file path or '-' to read from standard input)", "positional", None, str),
loader=("Loader (guessed from file extension if not set)", "option", "lo", str),
label=("Comma-separated relation label(s) to annotate or text file with one label per line", "option", "l", get_labels),
span_label=("Comma-separated span label(s) to annotate or text file with one label per line", "option", "sl", get_labels),
patterns=("Patterns file for defining custom spans to be added", "option", "pt", str),
disable_patterns=("Patterns file for defining tokens to disable (make unselectable)", "option", "dpt", str),
add_ents=("Add entities predicted by the model", "flag", "AE", bool),
add_nps=("Add noun phrases (if noun chunks rules are available), based on tagger and parser", "flag", "AN"),
wrap=("Wrap lines in the UI by default (instead of showing tokens in one row)", "flag", "W", bool),
exclude=("Comma-separated list of dataset IDs whose annotations to exclude", "option", "e", split_string),
hide_arrow_heads=("Hide the arrow heads visually", "option", "HA", bool),
# fmt: on
)
def manual(
dataset: str,
spacy_model: Union[str, Language],
source: Union[str, Iterable[dict]] = "-",
loader: Optional[str] = None,
label: Optional[List[str]] = None,
span_label: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
patterns: Optional[Union[str, List]] = None,
disable_patterns: Optional[Union[str, List]] = None,
add_ents: bool = False,
add_nps: bool = False,
wrap: bool = False,
hide_arrow_heads: bool = False,
) -> RecipeSettingsType:
log("RECIPE: Starting recipe rel.manual", locals())
stream = get_stream(
source, None, loader, rehash=True, dedup=True, input_key="text", is_binary=False
)
nlp = load_model(spacy_model)
# fmt: off
if add_ents and "ner" not in nlp.pipe_names:
msg.warn(f"Adding entities typically requires a named entity recognizer, but no 'ner' component was found in the pipeline of model '{spacy_model}'.")
if add_nps and ("parser" not in nlp.pipe_names or "tagger" not in nlp.pipe_names):
msg.warn(f"Adding noun phrases typically requires a tagger and parser. One or both weren't found in the pipeline of model '{spacy_model}'.")
if add_nps and "noun_chunks" not in nlp.Defaults.syntax_iterators:
msg.warn(f"Adding noun phrases requires noun chunk rules in the language data, but none were for language '{nlp.lang}'. To specify your own merge rules, you can use the --merge-patterns argument.")
# fmt: on
if add_nps and span_label and NP_LABEL not in span_label:
# Add NP label if we know we need it and user hasn't set it
span_label.append(NP_LABEL)
# Set up combined token/phrase matchers with additional merge and disable patterns
matcher, patterns = setup_matchers(nlp, patterns)
disable_matcher, disable_patterns = setup_matchers(nlp, disable_patterns)
# Register token extensions to store info about the spans/tokens
Span.set_extension("type", default=None, force=True)
Span.set_extension("label", default=None, force=True)
Token.set_extension("type", default=None, force=True)
Token.set_extension("label", default=None, force=True)
Token.set_extension("disabled", default=False, force=True)
def setup_spans(
spans: List[Span], span_type: str, default_label: Optional[str] = None
) -> List[Span]:
result = []
for span in spans:
all_labels = [*span_label, NP_LABEL] if span_label else [NP_LABEL]
if span_type == "pattern":
new_label, _ = parse_pattern_name(span.label_)
else:
new_label = span.label_ or default_label or span_type.upper()
if span_label is not None and new_label not in all_labels:
continue
span._.type = span_type
span._.label = new_label
# Set token attributes so they can be targeted by disable patterns
for token in span:
token._.label = new_label
token._.type = span_type
result.append(span)
return result
def preprocess_stream(stream: StreamType) -> StreamType:
# Adding tokens manually instead of using the full add_tokens wrapper
# so we have more control over the Doc preprocessing
data_tuples = ((eg["text"], copy.deepcopy(eg)) for eg in stream)
warned = False
for doc, eg in nlp.pipe(data_tuples, as_tuples=True, batch_size=10):
# If we have pre-defined tokens in the examples, use those to
# construct a Doc manually
if "tokens" in eg:
words = [token["text"] for token in eg["tokens"]]
spaces = [token.get("ws", True) for token in eg["tokens"]]
doc = Doc(nlp.vocab, words=words, spaces=spaces)
for i, token in enumerate(eg["tokens"]):
if token.get("disabled"):
doc[i]._.disabled = True
if len(doc) >= TOKEN_LIMIT and not warned:
msg.warn(
f"Long example with {len(doc)} tokens detected. This can "
f"potentially lead to slower rendering and annotation in "
f"the web app. Consider splitting your texts into smaller "
f"chunks or sentences."
)
warned = True
matches = matcher(doc) if patterns else []
spans = []
n_skipped = 0
curr_spans = [] # "spans" already present in input data
for s in eg.get("spans", []):
c_span = doc.char_span(s["start"], s["end"], s["label"])
if c_span:
curr_spans.append(c_span)
else:
n_skipped += 1
spans.extend(setup_spans(curr_spans, "span"))
if add_nps:
spans.extend(setup_spans(doc.noun_chunks, "np", NP_LABEL))
if add_ents:
spans.extend(setup_spans(doc.ents, "ent"))
match_spans = [Span(doc, s, e, match_id) for match_id, s, e in matches]
spans.extend(setup_spans(match_spans, "pattern", PATTERN_LABEL))
all_spans = []
for span in filter_spans(spans):
span_obj = {
"text": span.text,
"start": span.start_char,
"token_start": span.start,
"token_end": span.end - 1,
"end": span.end_char,
"type": span._.type,
"label": span._.label,
}
all_spans.append(span_obj)
eg["spans"] = all_spans
if n_skipped > 0:
msg.warn(
f"Skipped {n_skipped} span(s) that were already present "
f"in the input data because the tokenization didn't match."
)
# Apply the disable patterns last, so they can use the information of the merge patterns
disable_matches = disable_matcher(doc) if disable_patterns else []
for _, start, end in disable_matches:
for token in doc[start:end]:
token._.disabled = True
tokens = []
for i, token in enumerate(doc):
token = {
"text": token.text,
"start": token.idx,
"end": token.idx + len(token.text),
"id": i,
"ws": bool(token.whitespace_),
"disabled": token._.disabled,
}
tokens.append(token)
eg["tokens"] = tokens
yield eg
def validate_answer(eg):
relations = eg.get("relations", [])
errors = []
for rel in relations:
if rel["head_span"]["label"] is None:
errors.append("Head relation is not an entity")
if rel["child_span"]["label"] is None:
errors.append("Child relation is not an entity")
if errors:
raise ValueError(" ".join(errors))
stream = preprocess_stream(stream)
return {
"view_id": "relations",
"dataset": dataset,
"stream": stream,
"exclude": exclude,
"validate_answer": validate_answer,
"config": {
"lang": nlp.lang,
"labels": label,
"relations_span_labels": span_label,
"exclude_by": "input",
"wrap_relations": wrap,
"custom_theme": {"cardMaxWidth": "90%"},
"hide_relation_arrow": hide_arrow_heads,
"auto_count_stream": True,
},
}
@wesslen
Copy link
Copy Markdown
Author

wesslen commented Jun 5, 2023

To run:

python -m prodigy rel.manual ner_rels blank:en ./data.jsonl --label relation1,relation2 --span-label entity1,entity2 --wrap -F relations_validation.py

This video shows:

  1. An annotation that validates as both the head/child of the relation are entities
  2. An invalid annotation as the relation's child is not an entity
  3. An invalid annotation as the relation's head is not an entity
New.Recording.Jun.05.2023.0119.PM.mp4

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment