Last active
January 15, 2021 12:03
-
-
Save kzinmr/36e0ae5266214e132cb410bcb74b676e to your computer and use it in GitHub Desktop.
NER dataset processing with huggingface tokenizers==0.9.4 / transformers==4.2.1
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import os | |
import unicodedata | |
from dataclasses import dataclass | |
from enum import Enum | |
from itertools import product, starmap | |
from pathlib import Path | |
from typing import Dict, List, Optional, Union | |
import MeCab | |
import requests | |
import textspan | |
from tokenizers import ( | |
Encoding, | |
NormalizedString, | |
PreTokenizedString, | |
Tokenizer, | |
) | |
from tokenizers.pre_tokenizers import PreTokenizer | |
from torch.utils.data import Dataset | |
from transformers import AutoTokenizer, BatchEncoding, PreTrainedTokenizerFast | |
logger = logging.getLogger(__name__) | |
IntList = List[int] | |
IntListList = List[IntList] | |
StrList = List[str] | |
PAD_TOKEN_LABEL_ID = -100 | |
PAD_TOKEN = "[PAD]" | |
class Split(Enum): | |
train = "train" | |
dev = "dev" | |
test = "test" | |
def download_dataset(data_dir: Union[str, Path]): | |
def _download_data(url, file_path): | |
response = requests.get(url) | |
if response.ok: | |
with open(file_path, "w") as fp: | |
fp.write(response.content.decode("utf8")) | |
return file_path | |
for mode in Split: | |
mode = mode.value | |
url = f"https://github.com/megagonlabs/UD_Japanese-GSD/releases/download/v2.6-NE/{mode}.bio" | |
file_path = os.path.join(data_dir, f"{mode}.txt") | |
if _download_data(url, file_path): | |
logger.info(f"{mode} data is successfully downloaded") | |
@dataclass | |
class SpanAnnotation: | |
start: int | |
end: int | |
label: str | |
@dataclass | |
class StringSpanExample: | |
guid: str | |
content: str | |
annotations: List[SpanAnnotation] | |
@dataclass | |
class TokenClassificationExample: | |
guid: str | |
words: StrList | |
labels: StrList | |
@dataclass | |
class InputFeatures: | |
input_ids: IntList | |
attention_mask: IntList | |
label_ids: IntList | |
def is_boundary_line(line: str) -> bool: | |
return line.startswith("-DOCSTART-") or line == "" or line == "\n" | |
def bio2biolu(lines: StrList, label_idx: int = -1, delimiter: str = "\t") -> StrList: | |
new_lines = [] | |
n_lines = len(lines) | |
for i, line in enumerate(lines): | |
if is_boundary_line(line): | |
new_lines.append(line) | |
else: | |
next_iob = None | |
if i < n_lines - 1: | |
next_line = lines[i + 1].strip() | |
if not is_boundary_line(next_line): | |
next_iob = next_line.split(delimiter)[label_idx][0] | |
line = line.strip() | |
current_line_content = line.split(delimiter) | |
current_label = current_line_content[label_idx] | |
word = current_line_content[0] | |
tag_type = current_label[2:] | |
iob = current_label[0] | |
iob_transition = (iob, next_iob) | |
current_iob = iob | |
if iob_transition == ("B", "I"): | |
current_iob = "B" | |
elif iob_transition == ("I", "I"): | |
current_iob = "I" | |
elif iob_transition in {("B", "O"), ("B", "B"), ("B", None)}: | |
current_iob = "U" | |
elif iob_transition in {("I", "B"), ("I", "O"), ("I", None)}: | |
current_iob = "L" | |
elif iob == "O": | |
current_iob = "O" | |
else: | |
logger.warning(f"Invalid BIO transition: {iob_transition}") | |
if iob not in set("BIOLU"): | |
current_iob = "O" | |
biolu = f"{current_iob}-{tag_type}" if current_iob != "O" else "O" | |
new_line = f"{word}{delimiter}{biolu}" | |
new_lines.append(new_line) | |
return new_lines | |
def read_examples_from_file( | |
data_dir: str, | |
mode: Union[Split, str], | |
label_idx: int = -1, | |
delimiter: str = "\t", | |
is_bio: bool = True, | |
) -> List[TokenClassificationExample]: | |
""" | |
Read token-wise data like CoNLL2003 from file | |
""" | |
if isinstance(mode, Split): | |
mode = mode.value | |
file_path = os.path.join(data_dir, f"{mode}.txt") | |
guid_index = 1 | |
examples = [] | |
with open(file_path, encoding="utf-8") as f: | |
lines = [line for line in f] | |
if is_bio: | |
lines = bio2biolu(lines) | |
words = [] | |
labels = [] | |
for line in lines: | |
if is_boundary_line(line): | |
if words: | |
examples.append( | |
TokenClassificationExample( | |
guid=f"{mode}-{guid_index}", words=words, labels=labels | |
) | |
) | |
guid_index += 1 | |
words = [] | |
labels = [] | |
else: | |
splits = line.strip().split(delimiter) | |
words.append(splits[0]) | |
if len(splits) > 1: | |
labels.append(splits[label_idx]) | |
else: | |
# for mode = "test" | |
labels.append("O") | |
if words: | |
examples.append( | |
TokenClassificationExample( | |
guid=f"{mode}-{guid_index}", words=words, labels=labels | |
) | |
) | |
return examples | |
def convert_spandata( | |
examples: List[TokenClassificationExample], | |
) -> List[StringSpanExample]: | |
""" | |
Convert token-wise data like CoNLL2003 into string-wise span data | |
""" | |
def _get_original_spans(words, text): | |
word_spans = [] | |
start = 0 | |
for w in words: | |
word_spans.append((start, start + len(w))) | |
start += len(w) | |
assert words == [text[s:e] for s, e in word_spans] | |
return word_spans | |
new_examples: List[StringSpanExample] = [] | |
for example in examples: | |
words = example.words | |
text = "".join(words) | |
labels = example.labels | |
annotations: List[SpanAnnotation] = [] | |
word_spans = _get_original_spans(words, text) | |
label_span = [] | |
labeltype = "" | |
for span, label in zip(word_spans, labels): | |
if label == "O" and label_span and labeltype: | |
start, end = label_span[0][0], label_span[-1][-1] | |
annotations.append( | |
SpanAnnotation(start=start, end=end, label=labeltype) | |
) | |
label_span = [] | |
elif label != "O": | |
labeltype = label[2:] | |
label_span.append(span) | |
if label_span and labeltype: | |
start, end = label_span[0][0], label_span[-1][-1] | |
annotations.append(SpanAnnotation(start=start, end=end, label=labeltype)) | |
new_examples.append( | |
StringSpanExample(guid=example.guid, content=text, annotations=annotations) | |
) | |
return new_examples | |
class LabelTokenAligner: | |
""" | |
Align word-wise BIOLU-labels with subword tokens | |
""" | |
def __init__(self, labels_src: Union[str, StrList]): | |
if isinstance(labels_src, str): | |
with open(labels_path, "r") as f: | |
labels = [l for l in f.read().splitlines() if l and l != "O"] | |
else: | |
labels = [l for l in labels_src if l and l != "O"] | |
self.labels_to_id = {"O": 0} | |
self.ids_to_label = {0: "O"} | |
for i, (label, s) in enumerate(product(labels, "BILU"), 1): | |
l = f"{s}-{label}" | |
self.labels_to_id[l] = i | |
self.ids_to_label[i] = l | |
@staticmethod | |
def get_ids_to_label(labels_path: str) -> Dict[int, str]: | |
with open(labels_path, "r") as f: | |
labels = [l for l in f.read().splitlines() if l and l != "O"] | |
ids_to_label = { | |
i: f"{s}-{label}" for i, (label, s) in enumerate(product(labels, "BILU"), 1) | |
} | |
ids_to_label[0] = "O" | |
return ids_to_label | |
@staticmethod | |
def align_tokens_and_annotations_bilou( | |
tokenized: Encoding, annotations: List[SpanAnnotation] | |
) -> StrList: | |
"""Make word-wise BIOLU-labels aligned with given subwords | |
:param tokenized: output of PreTrainedTokenizerFast | |
:param annotations: annotations of string span format | |
""" | |
aligned_labels = ["O"] * len( | |
tokenized.tokens | |
) # Make a list to store our labels the same length as our tokens | |
for anno in annotations: | |
annotation_token_ix_set = set() | |
for char_ix in range(anno.start, anno.end): | |
token_ix = tokenized.char_to_token(char_ix) | |
if token_ix is not None: | |
annotation_token_ix_set.add(token_ix) | |
if len(annotation_token_ix_set) == 1: | |
token_ix = annotation_token_ix_set.pop() | |
prefix = "U" | |
aligned_labels[token_ix] = f"{prefix}-{anno.label}" | |
else: | |
last_token_in_anno_ix = len(annotation_token_ix_set) - 1 | |
for num, token_ix in enumerate(sorted(annotation_token_ix_set)): | |
if num == 0: | |
prefix = "B" | |
elif num == last_token_in_anno_ix: | |
prefix = "L" | |
else: | |
prefix = "I" | |
aligned_labels[token_ix] = f"{prefix}-{anno.label}" | |
return aligned_labels | |
def align_labels_with_tokens( | |
self, tokenized_text: Encoding, annotations: List[SpanAnnotation] | |
) -> IntList: | |
# TODO: switch label encoding scheme, align_tokens_and_annotations_bio | |
raw_labels = self.align_tokens_and_annotations_bilou( | |
tokenized_text, annotations | |
) | |
return list(map(lambda x: self.labels_to_id.get(x, 0), raw_labels)) | |
def load_pretrained_tokenizer( | |
tokenizer_file: str, cache_dir: Optional[str] = None | |
) -> PreTrainedTokenizerFast: | |
"""Load BertWordPieceTokenizer from tokenizer.json. | |
This is necessary due to the following reasons: | |
- BertWordPieceTokenizer cannot load from tokenizer.json via .from_file() method | |
- Tokenizer.from_file(tokenizer_file) cannot be used because MecabPretokenizer is not a valid native PreTokenizer. | |
""" | |
tokenizer = Tokenizer.from_file(tokenizer_file) | |
tokenizer.pre_tokenizer = PreTokenizer.custom(MecabPreTokenizer()) | |
tokenizer_dir = os.path.dirname(tokenizer_file) | |
pt_tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained( | |
tokenizer_dir, | |
cache_dir=cache_dir, | |
) | |
# This is necessary for pt_tokenizer.save_pretrained(save_path) | |
pt_tokenizer._tokenizer = tokenizer # ._tokenizer | |
return pt_tokenizer | |
class PicklableTagger: | |
def __init__(self, mecab_option: str): | |
self.option = mecab_option | |
self.tagger = MeCab.Tagger(mecab_option) | |
def __getstate__(self): | |
return {"option": self.option} | |
def __setstate__(self, state): | |
for k, v in state.items(): | |
setattr(self, k, v) | |
def __getnewargs__(self): | |
return (self.option,) | |
def __reduce_ex__(self, proto): | |
func = PicklableTagger | |
args = self.__getnewargs__() | |
state = self.__getstate__() | |
listitems = None | |
dictitems = None | |
rv = (func, args, state, listitems, dictitems) | |
return rv | |
def __call__(self, text): | |
return self.parse(text) | |
def parse(self, text): | |
return self.tagger.parse(text).rstrip() | |
class MecabPreTokenizer: | |
def __init__( | |
self, | |
mecab_dict_path: Optional[str] = None, | |
space_replacement: Optional[str] = None, | |
): | |
"""Constructs a MecabPreTokenizer for huggingface tokenizers. | |
- space_replacement: Character which is replaced with spaces. | |
You might want to use it because MeCab drop spaces by default. | |
This can be used to preserve spaces by replacing them with spaces later. | |
Special characters like '_' are used sometimes. | |
""" | |
self.space_replacement = space_replacement | |
mecab_option = ( | |
f"-Owakati -d {mecab_dict_path}" | |
if mecab_dict_path is not None | |
else "-Owakati" | |
) | |
self.mecab = PicklableTagger(mecab_option) | |
def tokenize(self, sequence: str) -> List[str]: | |
text = unicodedata.normalize("NFKC", sequence) | |
if self.space_replacement: | |
text = text.replace(" ", self.space_replacement) | |
splits = self.mecab.parse(text).strip().split(" ") | |
return [x.replace(self.space_replacement, " ") for x in splits] | |
else: | |
return self.mecab.parse(text).strip().split(" ") | |
def custom_split( | |
self, i: int, normalized_string: NormalizedString | |
) -> List[NormalizedString]: | |
text = str(normalized_string) | |
tokens = self.tokenize(text) | |
tokens_spans = textspan.get_original_spans(tokens, text) | |
return [ | |
normalized_string[st:ed] | |
for char_spans in tokens_spans | |
for st, ed in char_spans | |
] | |
def pre_tokenize(self, pretok: PreTokenizedString): | |
pretok.split(self.custom_split) | |
class TokenClassificationDataset(Dataset): | |
""" | |
Build feature dataset so that the model can load | |
""" | |
def __init__( | |
self, | |
examples: List[StringSpanExample], | |
tokenizer: PreTrainedTokenizerFast, | |
label_token_aligner: LabelTokenAligner, | |
tokens_per_batch: int = 32, | |
window_stride: Optional[int] = None, | |
): | |
"""tokenize_and_align_labels with long text (i.e. truncation is disabled)""" | |
self.features: List[InputFeatures] = [] | |
self.examples: List[TokenClassificationExample] = [] | |
texts: StrList = [ex.content for ex in examples] | |
annotations: List[List[SpanAnnotation]] = [ex.annotations for ex in examples] | |
if window_stride is None: | |
self.window_stride = tokens_per_batch | |
elif window_stride > tokens_per_batch: | |
logger.error( | |
"window_stride must be smaller than tokens_per_batch(max_seq_length)" | |
) | |
else: | |
logger.warning( | |
"""window_stride != tokens_per_batch: | |
The input data windows are overlapping. Merge the overlapping labels after processing InputFeatures. | |
""" | |
) | |
# tokenize text into subwords | |
# NOTE: add_special_tokens | |
tokenized_batch: BatchEncoding = tokenizer(texts, add_special_tokens=False) | |
encodings: List[Encoding] = tokenized_batch.encodings | |
# align word-wise labels with subwords | |
aligned_label_ids: IntListList = list( | |
starmap( | |
label_token_aligner.align_labels_with_tokens, | |
zip(encodings, annotations), | |
) | |
) | |
# perform manual padding and register features | |
guids: StrList = [ex.guid for ex in examples] | |
for guid, encoding, label_ids in zip(guids, encodings, aligned_label_ids): | |
seq_length = len(label_ids) | |
for start in range(0, seq_length, self.window_stride): | |
end = min(start + tokens_per_batch, seq_length) | |
n_padding_to_add = max(0, tokens_per_batch - end + start) | |
self.features.append( | |
InputFeatures( | |
input_ids=encoding.ids[start:end] | |
+ [tokenizer.pad_token_id] * n_padding_to_add, | |
label_ids=( | |
label_ids[start:end] | |
+ [PAD_TOKEN_LABEL_ID] * n_padding_to_add | |
), | |
attention_mask=( | |
encoding.attention_mask[start:end] + [0] * n_padding_to_add | |
), | |
) | |
) | |
subwords = encoding.tokens[start:end] | |
labels = [ | |
label_token_aligner.ids_to_label[i] for i in label_ids[start:end] | |
] | |
self.examples.append( | |
TokenClassificationExample(guid=guid, words=subwords, labels=labels) | |
) | |
self._n_features = len(self.features) | |
def __len__(self): | |
return self._n_features | |
def __getitem__(self, idx) -> InputFeatures: | |
return self.features[idx] |
Author
kzinmr
commented
Jan 15, 2021
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment