Created
March 4, 2022 12:27
-
-
Save harshildarji/7933ec58c266b58ca8522baac2c8b789 to your computer and use it in GitHub Desktop.
Generate NER predictions
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import pickle | |
import warnings | |
from functools import reduce | |
from operator import add | |
import numpy as np | |
import pandas as pd | |
import torch | |
from tqdm import tqdm | |
from transformers import BertForTokenClassification, logging | |
warnings.simplefilter(action="ignore", category=Warning) | |
logging.set_verbosity(logging.ERROR) | |
device = torch.device("cpu") | |
def analyze(text): | |
tokenized_sentence = tokenizer.encode(text) | |
input_ids = torch.tensor([tokenized_sentence]) | |
with torch.no_grad(): | |
output = model(input_ids) | |
label_indices = np.argmax(output[0].numpy(), axis=2) | |
tokens = tokenizer.convert_ids_to_tokens(input_ids.numpy()[0]) | |
new_tokens, new_labels = [], [] | |
for token, label_idx in zip(tokens, label_indices[0]): | |
if token.startswith("##"): | |
new_tokens[-1] = new_tokens[-1] + token[2:] | |
else: | |
new_labels.append(tag_values[label_idx]) | |
new_tokens.append(token) | |
to_remove = [] | |
for idx in range(len(new_tokens)): | |
if new_tokens[idx] == "." and new_labels[idx] != "O": | |
new_tokens[idx - 1] += "." | |
to_remove.append(idx) | |
new_tokens = [token for idx, token in enumerate(new_tokens) if idx not in to_remove] | |
new_labels = [label for idx, label in enumerate(new_labels) if idx not in to_remove] | |
return new_tokens, new_labels | |
def chunks(lst, n): | |
for i in range(0, len(lst), n): | |
yield lst[i : i + n] | |
if __name__ == "__main__": | |
print("[+] Reading data") | |
data = pd.read_csv("../metadata.csv") | |
tenor = data["tenor"].dropna().reset_index(drop=True).str.split("|").tolist() | |
tenor = reduce(add, tenor) | |
tenor = set(filter(None, tenor)) | |
print("[+] Downloading model, tag_values, and tokenizer") | |
os.system( | |
"wget -q https://www.dropbox.com/s/vos8pqwmlbqe0wf/model.pt https://www.dropbox.com/s/u2oojgmmprt0a9d/tag_values.pkl https://www.dropbox.com/s/uj15pab78emefoq/tokenizer.pkl" | |
) | |
tokenizer = pickle.load(open("tokenizer.pkl", "rb")) | |
tag_values = pickle.load(open("tag_values.pkl", "rb")) | |
model = BertForTokenClassification.from_pretrained( | |
"bert-base-german-cased", | |
num_labels=len(tag_values), | |
output_attentions=False, | |
output_hidden_states=False, | |
) | |
model.load_state_dict(torch.load("model.pt", map_location=device)) | |
print("[+] NER annotation") | |
conll = open(f"tenor.conll", "a+") | |
for datum in tqdm(tenor): | |
if len(datum) > 512: | |
tokens, labels = [], [] | |
chunked = list(chunks(datum, 512)) | |
for c in chunked: | |
ts, ls = analyze(c) | |
tokens.extend(ts) | |
labels.extend(ls) | |
else: | |
tokens, labels = analyze(datum) | |
for token, label in zip(tokens, labels): | |
if token == "[CLS]" or token == "[SEP]": | |
continue | |
line = f"{token} {label}\n" | |
conll.write(line) | |
conll.write("\n") | |
conll.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment