Last active
June 5, 2023 17:21
-
-
Save wesslen/1271257a6954fd5c9a6e68434b4d921d to your computer and use it in GitHub Desktop.
Prodigy relations validation with validate_answer callback that checks that both relations are labeled entities
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Prodigy v1.11.x; some imports will change for v1.12+ | |
| import copy | |
| from pathlib import Path | |
| from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union | |
| import srsly | |
| from spacy.language import Language | |
| from spacy.tokens import Doc, Span, Token | |
| from spacy.util import filter_spans | |
| from prodigy.components.loaders import get_stream | |
| from prodigy.core import recipe | |
| from prodigy.models.matcher import create_matchers, parse_pattern_name | |
| from prodigy.types import RecipeSettingsType, StreamType | |
| from prodigy.util import get_labels, load_model, log, msg, split_string | |
| NP_LABEL = "NP" | |
| PATTERN_LABEL = "PATTERN" | |
| TOKEN_LIMIT = 300 | |
| def setup_matchers( | |
| nlp: Language, | |
| patterns: Optional[Union[str, Path, List[Dict[str, Any]]]], | |
| default_label: str = PATTERN_LABEL, | |
| ) -> Callable[[Doc], List[Tuple[int, int, int]]]: | |
| if patterns is None: | |
| patterns = [] | |
| if patterns and (isinstance(patterns, str) or isinstance(patterns, Path)): | |
| patterns = srsly.read_jsonl(patterns) | |
| final_patterns = [] | |
| for pattern in patterns: | |
| if not isinstance(pattern, dict) or "pattern" not in pattern: | |
| msg.fail( | |
| "Invalid pattern found. Patterns should be dicts with a key" | |
| '"pattern" and an optional "label".', | |
| pattern, | |
| exits=1, | |
| ) | |
| # TODO: should we assume that the user always want to use all patterns | |
| # and pattern matches? It's different from how the other recipes behave | |
| # but it could lead to confusion if users forget to include pattern | |
| # labels in --span-label | |
| if "label" not in pattern: | |
| pattern["label"] = default_label | |
| final_patterns.append(pattern) | |
| matcher, phrase_matcher, _ = create_matchers(nlp, final_patterns) | |
| def combined_matcher(doc): | |
| matches = matcher(doc) | |
| phrase_matches = phrase_matcher(doc) | |
| return matches + phrase_matches | |
| return combined_matcher, final_patterns | |
| @recipe( | |
| "rel.manual", | |
| # fmt: off | |
| dataset=("Dataset to save annotations to", "positional", None, str), | |
| spacy_model=("Loadable spaCy pipeline or blank:lang (e.g. blank:en)", "positional", None, str), | |
| source=("Data to annotate (file path or '-' to read from standard input)", "positional", None, str), | |
| loader=("Loader (guessed from file extension if not set)", "option", "lo", str), | |
| label=("Comma-separated relation label(s) to annotate or text file with one label per line", "option", "l", get_labels), | |
| span_label=("Comma-separated span label(s) to annotate or text file with one label per line", "option", "sl", get_labels), | |
| patterns=("Patterns file for defining custom spans to be added", "option", "pt", str), | |
| disable_patterns=("Patterns file for defining tokens to disable (make unselectable)", "option", "dpt", str), | |
| add_ents=("Add entities predicted by the model", "flag", "AE", bool), | |
| add_nps=("Add noun phrases (if noun chunks rules are available), based on tagger and parser", "flag", "AN"), | |
| wrap=("Wrap lines in the UI by default (instead of showing tokens in one row)", "flag", "W", bool), | |
| exclude=("Comma-separated list of dataset IDs whose annotations to exclude", "option", "e", split_string), | |
| hide_arrow_heads=("Hide the arrow heads visually", "option", "HA", bool), | |
| # fmt: on | |
| ) | |
| def manual( | |
| dataset: str, | |
| spacy_model: Union[str, Language], | |
| source: Union[str, Iterable[dict]] = "-", | |
| loader: Optional[str] = None, | |
| label: Optional[List[str]] = None, | |
| span_label: Optional[List[str]] = None, | |
| exclude: Optional[List[str]] = None, | |
| patterns: Optional[Union[str, List]] = None, | |
| disable_patterns: Optional[Union[str, List]] = None, | |
| add_ents: bool = False, | |
| add_nps: bool = False, | |
| wrap: bool = False, | |
| hide_arrow_heads: bool = False, | |
| ) -> RecipeSettingsType: | |
| log("RECIPE: Starting recipe rel.manual", locals()) | |
| stream = get_stream( | |
| source, None, loader, rehash=True, dedup=True, input_key="text", is_binary=False | |
| ) | |
| nlp = load_model(spacy_model) | |
| # fmt: off | |
| if add_ents and "ner" not in nlp.pipe_names: | |
| msg.warn(f"Adding entities typically requires a named entity recognizer, but no 'ner' component was found in the pipeline of model '{spacy_model}'.") | |
| if add_nps and ("parser" not in nlp.pipe_names or "tagger" not in nlp.pipe_names): | |
| msg.warn(f"Adding noun phrases typically requires a tagger and parser. One or both weren't found in the pipeline of model '{spacy_model}'.") | |
| if add_nps and "noun_chunks" not in nlp.Defaults.syntax_iterators: | |
| msg.warn(f"Adding noun phrases requires noun chunk rules in the language data, but none were for language '{nlp.lang}'. To specify your own merge rules, you can use the --merge-patterns argument.") | |
| # fmt: on | |
| if add_nps and span_label and NP_LABEL not in span_label: | |
| # Add NP label if we know we need it and user hasn't set it | |
| span_label.append(NP_LABEL) | |
| # Set up combined token/phrase matchers with additional merge and disable patterns | |
| matcher, patterns = setup_matchers(nlp, patterns) | |
| disable_matcher, disable_patterns = setup_matchers(nlp, disable_patterns) | |
| # Register token extensions to store info about the spans/tokens | |
| Span.set_extension("type", default=None, force=True) | |
| Span.set_extension("label", default=None, force=True) | |
| Token.set_extension("type", default=None, force=True) | |
| Token.set_extension("label", default=None, force=True) | |
| Token.set_extension("disabled", default=False, force=True) | |
| def setup_spans( | |
| spans: List[Span], span_type: str, default_label: Optional[str] = None | |
| ) -> List[Span]: | |
| result = [] | |
| for span in spans: | |
| all_labels = [*span_label, NP_LABEL] if span_label else [NP_LABEL] | |
| if span_type == "pattern": | |
| new_label, _ = parse_pattern_name(span.label_) | |
| else: | |
| new_label = span.label_ or default_label or span_type.upper() | |
| if span_label is not None and new_label not in all_labels: | |
| continue | |
| span._.type = span_type | |
| span._.label = new_label | |
| # Set token attributes so they can be targeted by disable patterns | |
| for token in span: | |
| token._.label = new_label | |
| token._.type = span_type | |
| result.append(span) | |
| return result | |
| def preprocess_stream(stream: StreamType) -> StreamType: | |
| # Adding tokens manually instead of using the full add_tokens wrapper | |
| # so we have more control over the Doc preprocessing | |
| data_tuples = ((eg["text"], copy.deepcopy(eg)) for eg in stream) | |
| warned = False | |
| for doc, eg in nlp.pipe(data_tuples, as_tuples=True, batch_size=10): | |
| # If we have pre-defined tokens in the examples, use those to | |
| # construct a Doc manually | |
| if "tokens" in eg: | |
| words = [token["text"] for token in eg["tokens"]] | |
| spaces = [token.get("ws", True) for token in eg["tokens"]] | |
| doc = Doc(nlp.vocab, words=words, spaces=spaces) | |
| for i, token in enumerate(eg["tokens"]): | |
| if token.get("disabled"): | |
| doc[i]._.disabled = True | |
| if len(doc) >= TOKEN_LIMIT and not warned: | |
| msg.warn( | |
| f"Long example with {len(doc)} tokens detected. This can " | |
| f"potentially lead to slower rendering and annotation in " | |
| f"the web app. Consider splitting your texts into smaller " | |
| f"chunks or sentences." | |
| ) | |
| warned = True | |
| matches = matcher(doc) if patterns else [] | |
| spans = [] | |
| n_skipped = 0 | |
| curr_spans = [] # "spans" already present in input data | |
| for s in eg.get("spans", []): | |
| c_span = doc.char_span(s["start"], s["end"], s["label"]) | |
| if c_span: | |
| curr_spans.append(c_span) | |
| else: | |
| n_skipped += 1 | |
| spans.extend(setup_spans(curr_spans, "span")) | |
| if add_nps: | |
| spans.extend(setup_spans(doc.noun_chunks, "np", NP_LABEL)) | |
| if add_ents: | |
| spans.extend(setup_spans(doc.ents, "ent")) | |
| match_spans = [Span(doc, s, e, match_id) for match_id, s, e in matches] | |
| spans.extend(setup_spans(match_spans, "pattern", PATTERN_LABEL)) | |
| all_spans = [] | |
| for span in filter_spans(spans): | |
| span_obj = { | |
| "text": span.text, | |
| "start": span.start_char, | |
| "token_start": span.start, | |
| "token_end": span.end - 1, | |
| "end": span.end_char, | |
| "type": span._.type, | |
| "label": span._.label, | |
| } | |
| all_spans.append(span_obj) | |
| eg["spans"] = all_spans | |
| if n_skipped > 0: | |
| msg.warn( | |
| f"Skipped {n_skipped} span(s) that were already present " | |
| f"in the input data because the tokenization didn't match." | |
| ) | |
| # Apply the disable patterns last, so they can use the information of the merge patterns | |
| disable_matches = disable_matcher(doc) if disable_patterns else [] | |
| for _, start, end in disable_matches: | |
| for token in doc[start:end]: | |
| token._.disabled = True | |
| tokens = [] | |
| for i, token in enumerate(doc): | |
| token = { | |
| "text": token.text, | |
| "start": token.idx, | |
| "end": token.idx + len(token.text), | |
| "id": i, | |
| "ws": bool(token.whitespace_), | |
| "disabled": token._.disabled, | |
| } | |
| tokens.append(token) | |
| eg["tokens"] = tokens | |
| yield eg | |
| def validate_answer(eg): | |
| relations = eg.get("relations", []) | |
| errors = [] | |
| for rel in relations: | |
| if rel["head_span"]["label"] is None: | |
| errors.append("Head relation is not an entity") | |
| if rel["child_span"]["label"] is None: | |
| errors.append("Child relation is not an entity") | |
| if errors: | |
| raise ValueError(" ".join(errors)) | |
| stream = preprocess_stream(stream) | |
| return { | |
| "view_id": "relations", | |
| "dataset": dataset, | |
| "stream": stream, | |
| "exclude": exclude, | |
| "validate_answer": validate_answer, | |
| "config": { | |
| "lang": nlp.lang, | |
| "labels": label, | |
| "relations_span_labels": span_label, | |
| "exclude_by": "input", | |
| "wrap_relations": wrap, | |
| "custom_theme": {"cardMaxWidth": "90%"}, | |
| "hide_relation_arrow": hide_arrow_heads, | |
| "auto_count_stream": True, | |
| }, | |
| } |
Author
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
To run:
This video shows:
New.Recording.Jun.05.2023.0119.PM.mp4