Skip to content

Instantly share code, notes, and snippets.

@kzinmr
Last active March 13, 2022 01:52
Show Gist options
  • Save kzinmr/5ec980dc436f6fac6f526dc81eb722f1 to your computer and use it in GitHub Desktop.
Save kzinmr/5ec980dc436f6fac6f526dc81eb722f1 to your computer and use it in GitHub Desktop.
from dataclasses import dataclass
from typing import Dict, Iterable, List, Tuple
@dataclass
class ChunkSpan:
start: int
end: int
label: str
@dataclass
class TokenLabelPair:
token: str
label: str
@property
def text(self) -> str:
return self.token
class Sentence(Iterable[TokenLabelPair]):
token_labels: List[TokenLabelPair]
chunks: List[ChunkSpan]
@property
def text(self) -> str:
return "".join([token.text for token in self.token_labels])
def __iter__(self):
for token in self.token_labels:
yield token
@staticmethod
def __chunk_token_labels(
tokens: List[TokenLabelPair],
) -> List[List[TokenLabelPair]]:
chunks = []
chunk = []
for token in tokens:
if token.label.startswith("B"):
if chunk:
chunks.append(chunk)
chunk = []
chunk = [token]
elif token.label.startswith("I"):
chunk.append(token)
elif chunk:
chunks.append(chunk)
chunk = []
return chunks
@staticmethod
def __chunk_span(tokens: List[TokenLabelPair]) -> List[Tuple[int, int]]:
pos = 0
spans = []
chunk_spans = []
for token in tokens:
token_len = len(token.text)
span = (pos, pos + token_len)
pos += token_len
if token.label.startswith("B"):
# I->B
if len(spans) > 0:
chunk_spans.append((spans[0][0], spans[-1][1]))
spans = []
spans.append(span)
elif token.label.startswith("I"):
spans.append(span)
elif len(spans) > 0:
# B|I -> O
chunk_spans.append((spans[0][0], spans[-1][1]))
spans = []
return chunk_spans
@classmethod
def __build_chunks(cls, tokens: List[TokenLabelPair]) -> List[ChunkSpan]:
_chunks = cls.__chunk_token_labels(tokens)
_labels = [c_tokens[0].label for c_tokens in _chunks]
_spans = cls.__chunk_span(tokens)
return [
ChunkSpan(
start=s,
end=e,
label=lbl.split("-")[1],
)
for lbl, (s, e) in zip(_labels, _spans)
]
@classmethod
def from_conll(cls, path: str, delimiter: str = "\t"):
# CoNLL2003 -> List[Sentence]
sentences: List[Sentence] = []
with open(path) as fp:
for s in fp.read().split("\n\n"):
tokens: List[TokenLabelPair] = []
for token in s.split("\n"):
line = token.split(delimiter)
if len(line) >= 2:
token_text = line[0]
token_label = line[-1]
tlp = TokenLabelPair(token_text, token_label)
tokens.append(tlp)
chunks = cls.__build_chunks(tokens)
sentences.append(Sentence(tokens, chunks))
return sentences
def export_span_format(self) -> Dict:
# {"text": "", "spans": [{"start":0, "end":1, "label": "PERSON"}]}
text = self.text
spans = [
{"start": c.start, "end": c.end, "label": c.label} for c in self.chunks
]
return {"text": text, "spans": spans}
if __name__ == "__main__":
def download_conll_data(filepath):
"""conllフォーマットデータのダウンロード"""
import requests
url = "https://raw.githubusercontent.com/Hironsan/IOB2Corpus/master/ja.wikipedia.conll"
response = requests.get(url)
if response.ok:
with open(filepath, "w") as fp:
fp.write(response.content.decode("utf8"))
return filepath
import json
if download_conll_data("."):
path = "./ja.wikipedia.conll"
sentences = Sentence.from_conll(path)
with open("./ja.wikipedia.jsonl", "wt") as fp:
for s in sentences:
jd = s.export_span_format()
js = json.dumps(jd, ensure_ascii=False)
fp.write(js)
fp.write("\n")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment