Created
September 19, 2019 10:19
-
-
Save stefan-it/c39b63eb0043182010f2f61138751e0f to your computer and use it in GitHub Desktop.
Prediction script for PyTorch-Transformers NER
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import click | |
import numpy as np | |
import torch | |
from collections import namedtuple | |
from pytorch_transformers import BertForTokenClassification, BertTokenizer | |
from torch.nn import CrossEntropyLoss | |
from torch.utils.data import DataLoader, SequentialSampler, TensorDataset | |
from typing import List | |
from utils_ner import ( | |
convert_examples_to_features, | |
InputExample, | |
get_labels, | |
read_examples_from_file, | |
) | |
from run_ner import load_and_cache_examples | |
def build_dataset( | |
model_name_or_path: str, | |
data_dir: str, | |
max_seq_length: int, | |
tokenizer: BertTokenizer, | |
pad_token_label_id: int, | |
) -> TensorDataset: | |
args = { | |
"local_rank": -1, | |
"model_name_or_path": model_name_or_path, | |
"max_seq_length": max_seq_length, | |
"data_dir": data_dir, | |
"model_type": "bert", | |
} | |
args = namedtuple("Config", args.keys())(**args) | |
return load_and_cache_examples( | |
args=args, | |
tokenizer=tokenizer, | |
pad_token_label_id=pad_token_label_id, | |
evaluate=True, | |
) | |
def prediction( | |
dataset: TensorDataset, | |
examples: List[InputExample], | |
max_seq_length: int, | |
pad_token_label_id: int, | |
batch_size: int, | |
device: str, | |
model: BertForTokenClassification, | |
): | |
eval_sampler = SequentialSampler(dataset) | |
eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=batch_size) | |
predictions = [] | |
for batch in eval_dataloader: | |
batch = tuple(t.to(device) for t in batch) | |
with torch.no_grad(): | |
inputs = { | |
"input_ids": batch[0], | |
"attention_mask": batch[1], | |
"token_type_ids": batch[2], | |
"labels": batch[3], | |
} | |
outputs = model(**inputs) | |
_, logits = outputs[:2] | |
preds = logits.detach().cpu().numpy() | |
out_label_ids = inputs["labels"].detach().cpu().numpy() | |
preds = np.argmax(preds, axis=2) | |
label_map = {i: label for i, label in enumerate(get_labels())} | |
out_label_list = [[] for _ in range(out_label_ids.shape[0])] | |
preds_list = [[] for _ in range(out_label_ids.shape[0])] | |
for i in range(out_label_ids.shape[0]): | |
for j in range(out_label_ids.shape[1]): | |
if out_label_ids[i, j] != pad_token_label_id: | |
out_label_list[i].append(label_map[out_label_ids[i][j]]) | |
preds_list[i].append(label_map[preds[i][j]]) | |
predictions += preds_list | |
assert len(predictions) == len(examples) | |
for example, prediction in zip(examples, predictions): | |
assert len(example.words) == len(prediction) | |
assert len(example.labels) == len(prediction) | |
for word, gold_label, predicted_label in zip( | |
example.words, example.labels, prediction | |
): | |
print(f"{word} {gold_label} {predicted_label}") | |
print("") | |
@click.command() | |
@click.option("--model_name_or_path", type=str, help="Define path to fine-tuned model") | |
@click.option("--data_dir", type=str, help="Define path to data dir") | |
@click.option( | |
"--device", type=str, default="cuda", help="Defines device e.g. cpu or cuda" | |
) | |
@click.option( | |
"--batch_size", type=int, default=16, help="Defines batch size for evaluation" | |
) | |
@click.option( | |
"--max_seq_length", type=int, default=128, help="Defines max. sequence length" | |
) | |
def run_prediction(model_name_or_path, data_dir, device, batch_size, max_seq_length): | |
tokenizer = BertTokenizer.from_pretrained(model_name_or_path) | |
model = BertForTokenClassification.from_pretrained(model_name_or_path) | |
model.to(device=device) | |
pad_token_label_id = CrossEntropyLoss().ignore_index | |
examples = read_examples_from_file(data_dir, evaluate=True) | |
dataset = build_dataset( | |
model_name_or_path=model_name_or_path, | |
data_dir=data_dir, | |
max_seq_length=max_seq_length, | |
tokenizer=tokenizer, | |
pad_token_label_id=pad_token_label_id, | |
) | |
prediction( | |
dataset=dataset, | |
examples=examples, | |
max_seq_length=max_seq_length, | |
pad_token_label_id=pad_token_label_id, | |
batch_size=batch_size, | |
device=device, | |
model=model, | |
) | |
if __name__ == "__main__": | |
run_prediction() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Is it possible to obtain predictions of individual words for the run_ner task from transformers using this script?