Created
September 20, 2023 08:56
-
-
Save tomaarsen/bfe7515aee0c71cd1f2f874f68c36570 to your computer and use it in GitHub Desktop.
SpanMarker handler.py for Inference Endpoints
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Any, Dict, List | |
from span_marker import SpanMarkerModel | |
class EndpointHandler: | |
def __init__(self, model_id: str) -> None: | |
self.model = SpanMarkerModel.from_pretrained(model_id) | |
# Try to place it on CUDA, do nothing if it fails | |
self.model.try_cuda() | |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: | |
""" | |
Args: | |
data (Dict[str, Any]): | |
a dictionary with the "inputs" key corresponding to a string containing some text | |
Return: | |
A List[Dict[str, Any]]:. The object returned should be like [{"entity_group": "XXX", "word": "some word", "start": 3, "end": 6, "score": 0.82}] containing : | |
- "entity_group": A string representing what the entity is. | |
- "word": A rubstring of the original string that was detected as an entity. | |
- "start": the offset within `input` leading to `answer`. context[start:stop] == word | |
- "end": the ending offset within `input` leading to `answer`. context[start:stop] === word | |
- "score": A score between 0 and 1 describing how confident the model is for this entity. | |
""" | |
return [ | |
{ | |
"entity_group": entity["label"], | |
"word": entity["span"], | |
"start": entity["char_start_index"], | |
"end": entity["char_end_index"], | |
"score": entity["score"], | |
} | |
for entity in self.model.predict(data["inputs"]) | |
] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment