Last active
March 13, 2022 01:52
-
-
Save kzinmr/5ec980dc436f6fac6f526dc81eb722f1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dataclasses import dataclass | |
from typing import Dict, Iterable, List, Tuple | |
@dataclass | |
class ChunkSpan: | |
start: int | |
end: int | |
label: str | |
@dataclass | |
class TokenLabelPair: | |
token: str | |
label: str | |
@property | |
def text(self) -> str: | |
return self.token | |
class Sentence(Iterable[TokenLabelPair]): | |
token_labels: List[TokenLabelPair] | |
chunks: List[ChunkSpan] | |
@property | |
def text(self) -> str: | |
return "".join([token.text for token in self.token_labels]) | |
def __iter__(self): | |
for token in self.token_labels: | |
yield token | |
@staticmethod | |
def __chunk_token_labels( | |
tokens: List[TokenLabelPair], | |
) -> List[List[TokenLabelPair]]: | |
chunks = [] | |
chunk = [] | |
for token in tokens: | |
if token.label.startswith("B"): | |
if chunk: | |
chunks.append(chunk) | |
chunk = [] | |
chunk = [token] | |
elif token.label.startswith("I"): | |
chunk.append(token) | |
elif chunk: | |
chunks.append(chunk) | |
chunk = [] | |
return chunks | |
@staticmethod | |
def __chunk_span(tokens: List[TokenLabelPair]) -> List[Tuple[int, int]]: | |
pos = 0 | |
spans = [] | |
chunk_spans = [] | |
for token in tokens: | |
token_len = len(token.text) | |
span = (pos, pos + token_len) | |
pos += token_len | |
if token.label.startswith("B"): | |
# I->B | |
if len(spans) > 0: | |
chunk_spans.append((spans[0][0], spans[-1][1])) | |
spans = [] | |
spans.append(span) | |
elif token.label.startswith("I"): | |
spans.append(span) | |
elif len(spans) > 0: | |
# B|I -> O | |
chunk_spans.append((spans[0][0], spans[-1][1])) | |
spans = [] | |
return chunk_spans | |
@classmethod | |
def __build_chunks(cls, tokens: List[TokenLabelPair]) -> List[ChunkSpan]: | |
_chunks = cls.__chunk_token_labels(tokens) | |
_labels = [c_tokens[0].label for c_tokens in _chunks] | |
_spans = cls.__chunk_span(tokens) | |
return [ | |
ChunkSpan( | |
start=s, | |
end=e, | |
label=lbl.split("-")[1], | |
) | |
for lbl, (s, e) in zip(_labels, _spans) | |
] | |
@classmethod | |
def from_conll(cls, path: str, delimiter: str = "\t"): | |
# CoNLL2003 -> List[Sentence] | |
sentences: List[Sentence] = [] | |
with open(path) as fp: | |
for s in fp.read().split("\n\n"): | |
tokens: List[TokenLabelPair] = [] | |
for token in s.split("\n"): | |
line = token.split(delimiter) | |
if len(line) >= 2: | |
token_text = line[0] | |
token_label = line[-1] | |
tlp = TokenLabelPair(token_text, token_label) | |
tokens.append(tlp) | |
chunks = cls.__build_chunks(tokens) | |
sentences.append(Sentence(tokens, chunks)) | |
return sentences | |
def export_span_format(self) -> Dict: | |
# {"text": "", "spans": [{"start":0, "end":1, "label": "PERSON"}]} | |
text = self.text | |
spans = [ | |
{"start": c.start, "end": c.end, "label": c.label} for c in self.chunks | |
] | |
return {"text": text, "spans": spans} | |
if __name__ == "__main__": | |
def download_conll_data(filepath): | |
"""conllフォーマットデータのダウンロード""" | |
import requests | |
url = "https://raw.githubusercontent.com/Hironsan/IOB2Corpus/master/ja.wikipedia.conll" | |
response = requests.get(url) | |
if response.ok: | |
with open(filepath, "w") as fp: | |
fp.write(response.content.decode("utf8")) | |
return filepath | |
import json | |
if download_conll_data("."): | |
path = "./ja.wikipedia.conll" | |
sentences = Sentence.from_conll(path) | |
with open("./ja.wikipedia.jsonl", "wt") as fp: | |
for s in sentences: | |
jd = s.export_span_format() | |
js = json.dumps(jd, ensure_ascii=False) | |
fp.write(js) | |
fp.write("\n") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment