Skip to content

Instantly share code, notes, and snippets.

@napsternxg
Last active October 17, 2023 17:35
Show Gist options
  • Save napsternxg/5ebaa561e25583c135b25de981f96a68 to your computer and use it in GitHub Desktop.
Save napsternxg/5ebaa561e25583c135b25de981f96a68 to your computer and use it in GitHub Desktop.
NER utilities
from IPython.display import display, HTML
class DisplayEntities:
@classmethod
def display(cls, texts, grouped_entities):
html = []
html.append(cls.get_style())
for text, entities in zip(texts, grouped_entities):
html.append(cls.show_entities(text, entities))
display(HTML("".join(html)))
@classmethod
def get_style(cls):
return """<style>
span[data-tag] {
padding: 0.15em 0.25em;
margin: 0px 0.25em;
line-height: 1;
display: inline-block;
border-radius: 0.25em;
background: rgba(166, 226, 45, 0.2);
}
span[data-tag]::after {
box-sizing: border-box;
content: attr(data-tag) ':' attr(data-score);
font-size: 0.6em;
line-height: 1;
padding: 0.35em;
border-radius: 0.35em;
text-transform: uppercase;
display: inline-block;
vertical-align: middle;
margin: 0px 0px 0.1rem 0.5rem;
background: rgb(166, 226, 45);
}
span[data-tag="QTY"] {
background: rgb(230, 126, 193, 0.2);
}
span[data-tag="QTY"]::after {
background: rgb(230, 126, 193);
}
span[data-tag="NAME"] {
background: rgba(166, 226, 45, 0.2);
}
span[data-tag="NAME"]::after {
background: rgb(166, 226, 45);
}
span[data-tag="COMMENT"] {
background: rgb(94, 211, 229, 0.2);
}
span[data-tag="COMMENT"]::after {
background: rgb(94, 211, 229);
}
span[data-tag="UNIT"] {
background: rgb(229, 164, 94, 0.2);
}
span[data-tag="UNIT"]::after {
background: rgb(229, 164, 94);
}
</style>
"""
@classmethod
def show_entities(cls, text, entities):
html = []
html.append("<div class='annotations'>")
prev_end = 0
for e in entities:
# if e["start"] != prev_end:
html.append(text[prev_end : e["start"]])
if e["entity_group"] != "O":
html.append(
f"<span data-tag='{e['entity_group']}' data-score={e['score']:.2f}>"
)
html.append(text[e["start"] : e["end"]])
if e["entity_group"] != "O":
html.append("</span>")
prev_end = e["end"]
html.append(text[prev_end : e["start"]])
html.append("</div>")
return "".join(html)
def show_predictions(texts):
outputs, grouped_outputs = predict(texts, pred_trainer, token_classifier)
DisplayEntities.display(texts, grouped_outputs)
return outputs, grouped_outputs
if __name__ == "__main__":
texts = ["2 tablespoons unsalted butter, softened", "2 tablespoons sugar"]
outputs, grouped_outputs = show_predictions(texts)
import numpy as np
from scipy.special import softmax
from transformers import DataCollatorForTokenClassification
import torch
from collections import defaultdict
from transformers import Trainer, TrainingArguments
def post_process(preds):
out_dict = defaultdict(list)
for p in preds:
g = p["entity_group"]
if out_dict[g] and out_dict[g][-1]["end"] == p["start"]:
p_new = out_dict[g][-1]
p_new["word"] = f"{p_new['word']}{p['word']}"
p_new["end"] = p["end"]
p_new["score"] = max(p_new["score"], p["score"])
out_dict[g][-1] = p_new
else:
out_dict[g].append(p.copy())
return out_dict
def get_predictions(predictions, labels, tokens):
predictions = softmax(predictions, axis=-1)
scores = np.max(predictions, axis=-1)
predictions = np.argmax(predictions, axis=-1)
# true_predictions, true_labels, true_scores = [], [], []
outputs = []
for prediction, label, score, sent_tokens in tqdm(zip(predictions, labels, scores, tokens)):
outputs.append([])
tid = 0
for p, l, s in zip(prediction, label, score):
if l == -100:
continue
t = sent_tokens[tid]
o = dict(t, pred=label_list[p], label=label_list[l], score=s)
outputs[-1].append(o)
tid += 1
return outputs
def predict(texts, pred_trainer, token_classifier):
texts = [tokenizer.backend_tokenizer.normalizer.normalize_str(t) for t in texts]
# print(texts)
tokens = [
[
{"word": x[0], "start": x[1][0], "end": x[1][1]}
for x in tokenizer.backend_tokenizer.pre_tokenizer.pre_tokenize_str(t)
]
for t in texts
]
# print(tokens)
batch = [{"tokens": [w["word"] for w in t], "label": [0] * len(t)} for t in tokens]
prediction_output = pred_trainer.predict(batch)
# print(prediction_output.predictions.shape)
outputs = get_predictions(
prediction_output.predictions, prediction_output.label_ids, tokens
)
grouped_outputs = [
token_classifier.group_entities([dict(t, entity=t["pred"]) for t in pred])
for pred in tqdm(outputs)
]
return outputs, grouped_outputs
def get_aligned_tags(ner_tags, label_mask, ner_tag_id, mask_label_id=-100):
aligned_tags = [
ner_tags[i - 1] if mask else mask_label_id
for i, mask in zip(ner_tag_id, label_mask)
]
return aligned_tags
class NERDataCollator(DataCollatorForTokenClassification):
def __init__(self, label_col, *args, **kwargs):
super().__init__(*args, **kwargs)
self.label_col = label_col
def create_batch(self, batch):
# print(batch)
inputs = self.tokenizer.batch_encode_plus(
batch["tokens"],
return_tensors="pt",
padding=True,
truncation=True,
is_split_into_words=True,
return_special_tokens_mask=True,
return_length=True,
return_offsets_mapping=True,
)
# Start offest for first subword in tokens is 0
label_mask = (
(inputs["special_tokens_mask"] == 0)
& (inputs["offset_mapping"][:, :, 0] == 0)
).int()
ner_tag_id = label_mask.cumsum(axis=-1)
labels = None
if self.label_col in batch:
labels = torch.LongTensor(
[
get_aligned_tags(ner_tags, mask, tag_ids, mask_label_id=-100)
for ner_tags, mask, tag_ids in zip(
batch[self.label_col], label_mask, ner_tag_id
)
]
)
return inputs, labels
def torch_call(self, features):
batch = {k: [] for k in features[0].keys()}
for b in features:
for k in batch:
batch[k].append(b[k])
inputs, labels = self.create_batch(batch)
inputs = {
k: inputs[k] for k in ["input_ids", "token_type_ids", "attention_mask"]
}
if labels is not None:
inputs["labels"] = labels
return inputs
def get_pred_trainer(model):
training_args = TrainingArguments(
output_dir="./ner_model",
per_device_train_batch_size=32,
per_device_eval_batch_size=32,
do_train=False,
do_eval=True,
dataloader_num_workers=10,
remove_unused_columns=False,
)
pred_trainer = Trainer(
model=model,
args=training_args,
data_collator=NERDataCollator(label_col=label_col, tokenizer=tokenizer),
tokenizer=tokenizer,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment