Last active
July 1, 2024 16:08
-
-
Save pmbaumgartner/9d5154f3f40125a010f84ef3199cb000 to your computer and use it in GitHub Desktop.
A span candidate suggester function for spaCy that suggests spans containing a digit.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Optional, Iterable, cast, List | |
from thinc.api import get_current_ops, Ops | |
from thinc.types import Ragged, Ints1d | |
from spacy.pipeline.spancat import Suggester | |
from spacy.tokens import Doc | |
from spacy.util import registry | |
@registry.misc("ngram_digits_suggester.v1") | |
def build_ngram_digits_suggester(sizes: List[int]) -> Suggester: | |
def ngram_digits_suggester( | |
docs: Iterable[Doc], *, ops: Optional[Ops] = None | |
) -> Ragged: | |
if ops is None: | |
ops = get_current_ops() | |
spans = [] | |
lengths = [] | |
for doc in docs: | |
starts = list(range(len(doc))) | |
length = 0 | |
for size in sizes: | |
if size <= len(doc): | |
starts_size = starts[: len(doc) - (size - 1)] | |
for start in starts_size: | |
end = start + size | |
if end < len(doc): | |
if any("d" in token.shape_ for token in doc[start:end]): | |
spans.append([start, end]) | |
length += 1 | |
lengths.append(length) | |
lengths_array = cast(Ints1d, ops.asarray(lengths, dtype="i")) | |
if len(spans) > 0: | |
output = Ragged(ops.asarray(spans, dtype="i"), lengths_array) | |
else: | |
output = Ragged(ops.xp.zeros((0, 0), dtype="i"), lengths_array) | |
assert output.dataXd.ndim == 2 | |
return output | |
return ngram_digits_suggester |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment