Last active
February 13, 2022 08:14
-
-
Save kzinmr/2bdfab7531cc4aa23c4b104f6fb9941f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from dataclasses import dataclass | |
from typing import Optional | |
import fugashi | |
import unidic_lite | |
@dataclass | |
class Token: | |
text: str | |
start: int | |
end: int | |
@dataclass | |
class ChunkSpan: | |
start: int | |
end: int | |
label: str | |
@dataclass | |
class TokenLabelPair: | |
token: str | |
label: str | |
class MeCabTokenizer: | |
def __init__(self): | |
# mecab_option = "-Owakati" | |
# self.wakati = MeCab.Tagger(mecab_option) | |
dic_dir = unidic_lite.DICDIR | |
mecabrc = os.path.join(dic_dir, "mecabrc") | |
mecab_option = "-d {} -r {} ".format(dic_dir, mecabrc) | |
self.mecab = fugashi.GenericTagger(mecab_option) | |
def tokenize(self, text: str) -> list[str]: | |
# return self.mecab.parse(text).strip().split(" ") | |
return self.mecab(text) | |
def tokenize_with_alignment(self, text: str) -> list[Token]: | |
token_surfaces = [word.surface for word in self.mecab(text)] | |
tokens = [] | |
_cursor = 0 | |
for token in token_surfaces: | |
start = text.index(token, _cursor) | |
end = start + len(token) | |
tokens.append(Token(token, start, end)) | |
_cursor = end | |
return tokens | |
class Span2TokenConverter: | |
def __init__(self): | |
# tokenizer | |
self.tokenizer = MeCabTokenizer() | |
@staticmethod | |
def _get_chunk_span( | |
query_span: tuple[int, int], superspans: list[tuple[int, int]] | |
) -> Optional[tuple[int, int]]: | |
"""トークンを包摂するチャンクについて、トークンの文字列スパンを包摂するチャンクのスパンを返す. | |
NOTE: 一つのチャンクはトークン境界を跨がないと想定. | |
""" | |
for superspan in superspans: | |
if query_span[0] >= superspan[0] and query_span[1] <= superspan[1]: | |
return superspan | |
return None | |
@classmethod | |
def _get_token2label_map( | |
cls, spans_of_tokens: list[tuple[int, int]], spans_of_chunks: list[ChunkSpan] | |
) -> dict[tuple[int, int], str]: | |
"""トークンの文字列スパンから、トークンを包摂するチャンクのラベルへのマップを構成.""" | |
span_tuples = [(span.start, span.end) for span in spans_of_chunks] | |
_span2label = {(span.start, span.end): span.label for span in spans_of_chunks} | |
tokenspan2tagtype: dict[tuple[int, int], str] = {} | |
for original_token_span in spans_of_tokens: | |
chunk_span = cls._get_chunk_span(original_token_span, span_tuples) | |
if chunk_span is not None: | |
tokenspan2tagtype[original_token_span] = _span2label[chunk_span] | |
return tokenspan2tagtype | |
@staticmethod | |
def _get_labels_per_tokens(spans_of_tokens: list[tuple[int, int]], tokenspan2tagtype: dict[tuple[int, int], str]) -> list[str]: | |
"""トークン列に対応するラベル列をトークンスパンからラベルへのマップを基に構成""" | |
label = "O" | |
token_labels: list[str] = [] | |
for token_span in spans_of_tokens: | |
if token_span in tokenspan2tagtype: | |
tagtype = tokenspan2tagtype[token_span] | |
if label == "O": | |
label = f"B-{tagtype}" | |
else: | |
label = f"I-{tagtype}" | |
else: | |
label = "O" | |
token_labels.append(label) | |
return token_labels | |
@classmethod | |
def get_token_labels( | |
cls, tokens: list[Token], spans_of_chunks: list[ChunkSpan] | |
) -> list[TokenLabelPair]: | |
""" 文字列スパンとトークンスパンから、トークン-ラベルペアを得る. | |
""" | |
spans_of_tokens = [(token.start, token.end) for token in tokens] | |
tokenspan2label = cls._get_token2label_map(spans_of_tokens, spans_of_chunks) | |
labels_per_tokens = cls._get_labels_per_tokens(spans_of_tokens, tokenspan2label) | |
token_labels = [ | |
TokenLabelPair(token.text, label) for token, label in zip(tokens, labels_per_tokens) | |
] | |
return token_labels | |
def get_token_label_pairs( | |
self, text: str, spans_of_chunks: list[ChunkSpan] | |
) -> list[TokenLabelPair]: | |
""" | |
文字列チャンクスパン+元テキスト -> トークン-BIOラベルのペアを返すパイプライン | |
- 文字列チャンクスパン: [(0, 2, "PERSON")], 元テキスト: "太郎の家" | |
- トークン-BIOラベル: [("太郎", "B-PERSON"), ("の", "O"), ("家", "O")] | |
""" | |
# 1. tokenize | |
tokens = self.tokenizer.tokenize_with_alignment(text) | |
# 2. get token-label pairs from spans of chunks and spans of tokens | |
token_labels = self.get_token_labels(tokens, spans_of_chunks) | |
return token_labels |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
old version.