Skip to content

Instantly share code, notes, and snippets.

@kzinmr
Last active January 15, 2021 12:03
Show Gist options
  • Save kzinmr/36e0ae5266214e132cb410bcb74b676e to your computer and use it in GitHub Desktop.
Save kzinmr/36e0ae5266214e132cb410bcb74b676e to your computer and use it in GitHub Desktop.
NER dataset processing with huggingface tokenizers==0.9.4 / transformers==4.2.1
import logging
import os
import unicodedata
from dataclasses import dataclass
from enum import Enum
from itertools import product, starmap
from pathlib import Path
from typing import Dict, List, Optional, Union
import MeCab
import requests
import textspan
from tokenizers import (
Encoding,
NormalizedString,
PreTokenizedString,
Tokenizer,
)
from tokenizers.pre_tokenizers import PreTokenizer
from torch.utils.data import Dataset
from transformers import AutoTokenizer, BatchEncoding, PreTrainedTokenizerFast
logger = logging.getLogger(__name__)
IntList = List[int]
IntListList = List[IntList]
StrList = List[str]
PAD_TOKEN_LABEL_ID = -100
PAD_TOKEN = "[PAD]"
class Split(Enum):
train = "train"
dev = "dev"
test = "test"
def download_dataset(data_dir: Union[str, Path]):
def _download_data(url, file_path):
response = requests.get(url)
if response.ok:
with open(file_path, "w") as fp:
fp.write(response.content.decode("utf8"))
return file_path
for mode in Split:
mode = mode.value
url = f"https://github.com/megagonlabs/UD_Japanese-GSD/releases/download/v2.6-NE/{mode}.bio"
file_path = os.path.join(data_dir, f"{mode}.txt")
if _download_data(url, file_path):
logger.info(f"{mode} data is successfully downloaded")
@dataclass
class SpanAnnotation:
start: int
end: int
label: str
@dataclass
class StringSpanExample:
guid: str
content: str
annotations: List[SpanAnnotation]
@dataclass
class TokenClassificationExample:
guid: str
words: StrList
labels: StrList
@dataclass
class InputFeatures:
input_ids: IntList
attention_mask: IntList
label_ids: IntList
def is_boundary_line(line: str) -> bool:
return line.startswith("-DOCSTART-") or line == "" or line == "\n"
def bio2biolu(lines: StrList, label_idx: int = -1, delimiter: str = "\t") -> StrList:
new_lines = []
n_lines = len(lines)
for i, line in enumerate(lines):
if is_boundary_line(line):
new_lines.append(line)
else:
next_iob = None
if i < n_lines - 1:
next_line = lines[i + 1].strip()
if not is_boundary_line(next_line):
next_iob = next_line.split(delimiter)[label_idx][0]
line = line.strip()
current_line_content = line.split(delimiter)
current_label = current_line_content[label_idx]
word = current_line_content[0]
tag_type = current_label[2:]
iob = current_label[0]
iob_transition = (iob, next_iob)
current_iob = iob
if iob_transition == ("B", "I"):
current_iob = "B"
elif iob_transition == ("I", "I"):
current_iob = "I"
elif iob_transition in {("B", "O"), ("B", "B"), ("B", None)}:
current_iob = "U"
elif iob_transition in {("I", "B"), ("I", "O"), ("I", None)}:
current_iob = "L"
elif iob == "O":
current_iob = "O"
else:
logger.warning(f"Invalid BIO transition: {iob_transition}")
if iob not in set("BIOLU"):
current_iob = "O"
biolu = f"{current_iob}-{tag_type}" if current_iob != "O" else "O"
new_line = f"{word}{delimiter}{biolu}"
new_lines.append(new_line)
return new_lines
def read_examples_from_file(
data_dir: str,
mode: Union[Split, str],
label_idx: int = -1,
delimiter: str = "\t",
is_bio: bool = True,
) -> List[TokenClassificationExample]:
"""
Read token-wise data like CoNLL2003 from file
"""
if isinstance(mode, Split):
mode = mode.value
file_path = os.path.join(data_dir, f"{mode}.txt")
guid_index = 1
examples = []
with open(file_path, encoding="utf-8") as f:
lines = [line for line in f]
if is_bio:
lines = bio2biolu(lines)
words = []
labels = []
for line in lines:
if is_boundary_line(line):
if words:
examples.append(
TokenClassificationExample(
guid=f"{mode}-{guid_index}", words=words, labels=labels
)
)
guid_index += 1
words = []
labels = []
else:
splits = line.strip().split(delimiter)
words.append(splits[0])
if len(splits) > 1:
labels.append(splits[label_idx])
else:
# for mode = "test"
labels.append("O")
if words:
examples.append(
TokenClassificationExample(
guid=f"{mode}-{guid_index}", words=words, labels=labels
)
)
return examples
def convert_spandata(
examples: List[TokenClassificationExample],
) -> List[StringSpanExample]:
"""
Convert token-wise data like CoNLL2003 into string-wise span data
"""
def _get_original_spans(words, text):
word_spans = []
start = 0
for w in words:
word_spans.append((start, start + len(w)))
start += len(w)
assert words == [text[s:e] for s, e in word_spans]
return word_spans
new_examples: List[StringSpanExample] = []
for example in examples:
words = example.words
text = "".join(words)
labels = example.labels
annotations: List[SpanAnnotation] = []
word_spans = _get_original_spans(words, text)
label_span = []
labeltype = ""
for span, label in zip(word_spans, labels):
if label == "O" and label_span and labeltype:
start, end = label_span[0][0], label_span[-1][-1]
annotations.append(
SpanAnnotation(start=start, end=end, label=labeltype)
)
label_span = []
elif label != "O":
labeltype = label[2:]
label_span.append(span)
if label_span and labeltype:
start, end = label_span[0][0], label_span[-1][-1]
annotations.append(SpanAnnotation(start=start, end=end, label=labeltype))
new_examples.append(
StringSpanExample(guid=example.guid, content=text, annotations=annotations)
)
return new_examples
class LabelTokenAligner:
"""
Align word-wise BIOLU-labels with subword tokens
"""
def __init__(self, labels_src: Union[str, StrList]):
if isinstance(labels_src, str):
with open(labels_path, "r") as f:
labels = [l for l in f.read().splitlines() if l and l != "O"]
else:
labels = [l for l in labels_src if l and l != "O"]
self.labels_to_id = {"O": 0}
self.ids_to_label = {0: "O"}
for i, (label, s) in enumerate(product(labels, "BILU"), 1):
l = f"{s}-{label}"
self.labels_to_id[l] = i
self.ids_to_label[i] = l
@staticmethod
def get_ids_to_label(labels_path: str) -> Dict[int, str]:
with open(labels_path, "r") as f:
labels = [l for l in f.read().splitlines() if l and l != "O"]
ids_to_label = {
i: f"{s}-{label}" for i, (label, s) in enumerate(product(labels, "BILU"), 1)
}
ids_to_label[0] = "O"
return ids_to_label
@staticmethod
def align_tokens_and_annotations_bilou(
tokenized: Encoding, annotations: List[SpanAnnotation]
) -> StrList:
"""Make word-wise BIOLU-labels aligned with given subwords
:param tokenized: output of PreTrainedTokenizerFast
:param annotations: annotations of string span format
"""
aligned_labels = ["O"] * len(
tokenized.tokens
) # Make a list to store our labels the same length as our tokens
for anno in annotations:
annotation_token_ix_set = set()
for char_ix in range(anno.start, anno.end):
token_ix = tokenized.char_to_token(char_ix)
if token_ix is not None:
annotation_token_ix_set.add(token_ix)
if len(annotation_token_ix_set) == 1:
token_ix = annotation_token_ix_set.pop()
prefix = "U"
aligned_labels[token_ix] = f"{prefix}-{anno.label}"
else:
last_token_in_anno_ix = len(annotation_token_ix_set) - 1
for num, token_ix in enumerate(sorted(annotation_token_ix_set)):
if num == 0:
prefix = "B"
elif num == last_token_in_anno_ix:
prefix = "L"
else:
prefix = "I"
aligned_labels[token_ix] = f"{prefix}-{anno.label}"
return aligned_labels
def align_labels_with_tokens(
self, tokenized_text: Encoding, annotations: List[SpanAnnotation]
) -> IntList:
# TODO: switch label encoding scheme, align_tokens_and_annotations_bio
raw_labels = self.align_tokens_and_annotations_bilou(
tokenized_text, annotations
)
return list(map(lambda x: self.labels_to_id.get(x, 0), raw_labels))
def load_pretrained_tokenizer(
tokenizer_file: str, cache_dir: Optional[str] = None
) -> PreTrainedTokenizerFast:
"""Load BertWordPieceTokenizer from tokenizer.json.
This is necessary due to the following reasons:
- BertWordPieceTokenizer cannot load from tokenizer.json via .from_file() method
- Tokenizer.from_file(tokenizer_file) cannot be used because MecabPretokenizer is not a valid native PreTokenizer.
"""
tokenizer = Tokenizer.from_file(tokenizer_file)
tokenizer.pre_tokenizer = PreTokenizer.custom(MecabPreTokenizer())
tokenizer_dir = os.path.dirname(tokenizer_file)
pt_tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained(
tokenizer_dir,
cache_dir=cache_dir,
)
# This is necessary for pt_tokenizer.save_pretrained(save_path)
pt_tokenizer._tokenizer = tokenizer # ._tokenizer
return pt_tokenizer
class PicklableTagger:
def __init__(self, mecab_option: str):
self.option = mecab_option
self.tagger = MeCab.Tagger(mecab_option)
def __getstate__(self):
return {"option": self.option}
def __setstate__(self, state):
for k, v in state.items():
setattr(self, k, v)
def __getnewargs__(self):
return (self.option,)
def __reduce_ex__(self, proto):
func = PicklableTagger
args = self.__getnewargs__()
state = self.__getstate__()
listitems = None
dictitems = None
rv = (func, args, state, listitems, dictitems)
return rv
def __call__(self, text):
return self.parse(text)
def parse(self, text):
return self.tagger.parse(text).rstrip()
class MecabPreTokenizer:
def __init__(
self,
mecab_dict_path: Optional[str] = None,
space_replacement: Optional[str] = None,
):
"""Constructs a MecabPreTokenizer for huggingface tokenizers.
- space_replacement: Character which is replaced with spaces.
You might want to use it because MeCab drop spaces by default.
This can be used to preserve spaces by replacing them with spaces later.
Special characters like '_' are used sometimes.
"""
self.space_replacement = space_replacement
mecab_option = (
f"-Owakati -d {mecab_dict_path}"
if mecab_dict_path is not None
else "-Owakati"
)
self.mecab = PicklableTagger(mecab_option)
def tokenize(self, sequence: str) -> List[str]:
text = unicodedata.normalize("NFKC", sequence)
if self.space_replacement:
text = text.replace(" ", self.space_replacement)
splits = self.mecab.parse(text).strip().split(" ")
return [x.replace(self.space_replacement, " ") for x in splits]
else:
return self.mecab.parse(text).strip().split(" ")
def custom_split(
self, i: int, normalized_string: NormalizedString
) -> List[NormalizedString]:
text = str(normalized_string)
tokens = self.tokenize(text)
tokens_spans = textspan.get_original_spans(tokens, text)
return [
normalized_string[st:ed]
for char_spans in tokens_spans
for st, ed in char_spans
]
def pre_tokenize(self, pretok: PreTokenizedString):
pretok.split(self.custom_split)
class TokenClassificationDataset(Dataset):
"""
Build feature dataset so that the model can load
"""
def __init__(
self,
examples: List[StringSpanExample],
tokenizer: PreTrainedTokenizerFast,
label_token_aligner: LabelTokenAligner,
tokens_per_batch: int = 32,
window_stride: Optional[int] = None,
):
"""tokenize_and_align_labels with long text (i.e. truncation is disabled)"""
self.features: List[InputFeatures] = []
self.examples: List[TokenClassificationExample] = []
texts: StrList = [ex.content for ex in examples]
annotations: List[List[SpanAnnotation]] = [ex.annotations for ex in examples]
if window_stride is None:
self.window_stride = tokens_per_batch
elif window_stride > tokens_per_batch:
logger.error(
"window_stride must be smaller than tokens_per_batch(max_seq_length)"
)
else:
logger.warning(
"""window_stride != tokens_per_batch:
The input data windows are overlapping. Merge the overlapping labels after processing InputFeatures.
"""
)
# tokenize text into subwords
# NOTE: add_special_tokens
tokenized_batch: BatchEncoding = tokenizer(texts, add_special_tokens=False)
encodings: List[Encoding] = tokenized_batch.encodings
# align word-wise labels with subwords
aligned_label_ids: IntListList = list(
starmap(
label_token_aligner.align_labels_with_tokens,
zip(encodings, annotations),
)
)
# perform manual padding and register features
guids: StrList = [ex.guid for ex in examples]
for guid, encoding, label_ids in zip(guids, encodings, aligned_label_ids):
seq_length = len(label_ids)
for start in range(0, seq_length, self.window_stride):
end = min(start + tokens_per_batch, seq_length)
n_padding_to_add = max(0, tokens_per_batch - end + start)
self.features.append(
InputFeatures(
input_ids=encoding.ids[start:end]
+ [tokenizer.pad_token_id] * n_padding_to_add,
label_ids=(
label_ids[start:end]
+ [PAD_TOKEN_LABEL_ID] * n_padding_to_add
),
attention_mask=(
encoding.attention_mask[start:end] + [0] * n_padding_to_add
),
)
)
subwords = encoding.tokens[start:end]
labels = [
label_token_aligner.ids_to_label[i] for i in label_ids[start:end]
]
self.examples.append(
TokenClassificationExample(guid=guid, words=subwords, labels=labels)
)
self._n_features = len(self.features)
def __len__(self):
return self._n_features
def __getitem__(self, idx) -> InputFeatures:
return self.features[idx]
@kzinmr
Copy link
Author

kzinmr commented Jan 15, 2021

 def parse_texts(tokenizer, label_token_aligner, texts, annotations, tokens_per_batch=128, window_stride=None):

    if window_stride is None:
        window_stride = tokens_per_batch
    assert window_stride <= tokens_per_batch

    tokenized_batch: BatchEncoding = tokenizer(texts, add_special_tokens=False)
    encodings: List[Encoding] = tokenized_batch.encodings

    # align word-wise labels with subwords
    aligned_label_ids: IntListList = list(
        starmap(
            label_token_aligner.align_labels_with_tokens,
            zip(encodings, annotations),
        )
    )

    # perform manual padding and register features
    for text, encoding, label_ids in zip(texts, encodings, aligned_label_ids):
        seq_length = len(label_ids)
        for start in range(0, seq_length, window_stride):
            end = min(start + tokens_per_batch, seq_length)
            n_padding_to_add = max(0, tokens_per_batch - end + start)

            offsets = encoding.offsets
            subwords = encoding.tokens[start:end]
            labels = [
                label_token_aligner.ids_to_label[i] for i in label_ids[start:end]
            ]
            
            print([f'{text[s:e]}\t{w}\t{l}' for (s,e), w, l in zip(offsets, subwords, labels)])
# dataset: List[StringSpanExample]
# tokenizer = load_pretrained_tokenizer('src/to/electra_small_wiki40b_ja_mecab_ipadic/tokenizer.json')
# label_token_aligner = LabelTokenAligner(['SOMETYPE'])
texts = list(map(lambda x: x.content, dataset))
annotations = list(map(lambda x: sorted(x.annotations, key=lambda x:x.start), dataset))
parse_texts(tokenizer, label_token_aligner, texts, annotations)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment