Last active
October 17, 2023 17:35
-
-
Save napsternxg/5ebaa561e25583c135b25de981f96a68 to your computer and use it in GitHub Desktop.
NER utilities
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from IPython.display import display, HTML | |
class DisplayEntities: | |
@classmethod | |
def display(cls, texts, grouped_entities): | |
html = [] | |
html.append(cls.get_style()) | |
for text, entities in zip(texts, grouped_entities): | |
html.append(cls.show_entities(text, entities)) | |
display(HTML("".join(html))) | |
@classmethod | |
def get_style(cls): | |
return """<style> | |
span[data-tag] { | |
padding: 0.15em 0.25em; | |
margin: 0px 0.25em; | |
line-height: 1; | |
display: inline-block; | |
border-radius: 0.25em; | |
background: rgba(166, 226, 45, 0.2); | |
} | |
span[data-tag]::after { | |
box-sizing: border-box; | |
content: attr(data-tag) ':' attr(data-score); | |
font-size: 0.6em; | |
line-height: 1; | |
padding: 0.35em; | |
border-radius: 0.35em; | |
text-transform: uppercase; | |
display: inline-block; | |
vertical-align: middle; | |
margin: 0px 0px 0.1rem 0.5rem; | |
background: rgb(166, 226, 45); | |
} | |
span[data-tag="QTY"] { | |
background: rgb(230, 126, 193, 0.2); | |
} | |
span[data-tag="QTY"]::after { | |
background: rgb(230, 126, 193); | |
} | |
span[data-tag="NAME"] { | |
background: rgba(166, 226, 45, 0.2); | |
} | |
span[data-tag="NAME"]::after { | |
background: rgb(166, 226, 45); | |
} | |
span[data-tag="COMMENT"] { | |
background: rgb(94, 211, 229, 0.2); | |
} | |
span[data-tag="COMMENT"]::after { | |
background: rgb(94, 211, 229); | |
} | |
span[data-tag="UNIT"] { | |
background: rgb(229, 164, 94, 0.2); | |
} | |
span[data-tag="UNIT"]::after { | |
background: rgb(229, 164, 94); | |
} | |
</style> | |
""" | |
@classmethod | |
def show_entities(cls, text, entities): | |
html = [] | |
html.append("<div class='annotations'>") | |
prev_end = 0 | |
for e in entities: | |
# if e["start"] != prev_end: | |
html.append(text[prev_end : e["start"]]) | |
if e["entity_group"] != "O": | |
html.append( | |
f"<span data-tag='{e['entity_group']}' data-score={e['score']:.2f}>" | |
) | |
html.append(text[e["start"] : e["end"]]) | |
if e["entity_group"] != "O": | |
html.append("</span>") | |
prev_end = e["end"] | |
html.append(text[prev_end : e["start"]]) | |
html.append("</div>") | |
return "".join(html) | |
def show_predictions(texts): | |
outputs, grouped_outputs = predict(texts, pred_trainer, token_classifier) | |
DisplayEntities.display(texts, grouped_outputs) | |
return outputs, grouped_outputs | |
if __name__ == "__main__": | |
texts = ["2 tablespoons unsalted butter, softened", "2 tablespoons sugar"] | |
outputs, grouped_outputs = show_predictions(texts) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from scipy.special import softmax | |
from transformers import DataCollatorForTokenClassification | |
import torch | |
from collections import defaultdict | |
from transformers import Trainer, TrainingArguments | |
def post_process(preds): | |
out_dict = defaultdict(list) | |
for p in preds: | |
g = p["entity_group"] | |
if out_dict[g] and out_dict[g][-1]["end"] == p["start"]: | |
p_new = out_dict[g][-1] | |
p_new["word"] = f"{p_new['word']}{p['word']}" | |
p_new["end"] = p["end"] | |
p_new["score"] = max(p_new["score"], p["score"]) | |
out_dict[g][-1] = p_new | |
else: | |
out_dict[g].append(p.copy()) | |
return out_dict | |
def get_predictions(predictions, labels, tokens): | |
predictions = softmax(predictions, axis=-1) | |
scores = np.max(predictions, axis=-1) | |
predictions = np.argmax(predictions, axis=-1) | |
# true_predictions, true_labels, true_scores = [], [], [] | |
outputs = [] | |
for prediction, label, score, sent_tokens in tqdm(zip(predictions, labels, scores, tokens)): | |
outputs.append([]) | |
tid = 0 | |
for p, l, s in zip(prediction, label, score): | |
if l == -100: | |
continue | |
t = sent_tokens[tid] | |
o = dict(t, pred=label_list[p], label=label_list[l], score=s) | |
outputs[-1].append(o) | |
tid += 1 | |
return outputs | |
def predict(texts, pred_trainer, token_classifier): | |
texts = [tokenizer.backend_tokenizer.normalizer.normalize_str(t) for t in texts] | |
# print(texts) | |
tokens = [ | |
[ | |
{"word": x[0], "start": x[1][0], "end": x[1][1]} | |
for x in tokenizer.backend_tokenizer.pre_tokenizer.pre_tokenize_str(t) | |
] | |
for t in texts | |
] | |
# print(tokens) | |
batch = [{"tokens": [w["word"] for w in t], "label": [0] * len(t)} for t in tokens] | |
prediction_output = pred_trainer.predict(batch) | |
# print(prediction_output.predictions.shape) | |
outputs = get_predictions( | |
prediction_output.predictions, prediction_output.label_ids, tokens | |
) | |
grouped_outputs = [ | |
token_classifier.group_entities([dict(t, entity=t["pred"]) for t in pred]) | |
for pred in tqdm(outputs) | |
] | |
return outputs, grouped_outputs | |
def get_aligned_tags(ner_tags, label_mask, ner_tag_id, mask_label_id=-100): | |
aligned_tags = [ | |
ner_tags[i - 1] if mask else mask_label_id | |
for i, mask in zip(ner_tag_id, label_mask) | |
] | |
return aligned_tags | |
class NERDataCollator(DataCollatorForTokenClassification): | |
def __init__(self, label_col, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.label_col = label_col | |
def create_batch(self, batch): | |
# print(batch) | |
inputs = self.tokenizer.batch_encode_plus( | |
batch["tokens"], | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
is_split_into_words=True, | |
return_special_tokens_mask=True, | |
return_length=True, | |
return_offsets_mapping=True, | |
) | |
# Start offest for first subword in tokens is 0 | |
label_mask = ( | |
(inputs["special_tokens_mask"] == 0) | |
& (inputs["offset_mapping"][:, :, 0] == 0) | |
).int() | |
ner_tag_id = label_mask.cumsum(axis=-1) | |
labels = None | |
if self.label_col in batch: | |
labels = torch.LongTensor( | |
[ | |
get_aligned_tags(ner_tags, mask, tag_ids, mask_label_id=-100) | |
for ner_tags, mask, tag_ids in zip( | |
batch[self.label_col], label_mask, ner_tag_id | |
) | |
] | |
) | |
return inputs, labels | |
def torch_call(self, features): | |
batch = {k: [] for k in features[0].keys()} | |
for b in features: | |
for k in batch: | |
batch[k].append(b[k]) | |
inputs, labels = self.create_batch(batch) | |
inputs = { | |
k: inputs[k] for k in ["input_ids", "token_type_ids", "attention_mask"] | |
} | |
if labels is not None: | |
inputs["labels"] = labels | |
return inputs | |
def get_pred_trainer(model): | |
training_args = TrainingArguments( | |
output_dir="./ner_model", | |
per_device_train_batch_size=32, | |
per_device_eval_batch_size=32, | |
do_train=False, | |
do_eval=True, | |
dataloader_num_workers=10, | |
remove_unused_columns=False, | |
) | |
pred_trainer = Trainer( | |
model=model, | |
args=training_args, | |
data_collator=NERDataCollator(label_col=label_col, tokenizer=tokenizer), | |
tokenizer=tokenizer, | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment