Last active
February 5, 2021 17:15
-
-
Save kzinmr/57ccb70169f1b8939f93faf2bff8ee68 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# 会社名サンプラー | |
import os | |
import pickle | |
import random | |
from dataclasses import dataclass | |
from typing import List | |
import numpy as np | |
def seed_everything(seed: int = 1234): | |
"""乱数固定""" | |
random.seed(seed) | |
os.environ["PYTHONHASHSEED"] = str(seed) | |
np.random.seed(seed) | |
@dataclass | |
class SpanAnnotation: | |
start: int | |
end: int | |
label: str | |
@dataclass | |
class StringSpanExample: | |
guid: str | |
content: str | |
annotations: List[SpanAnnotation] | |
class CompanySampler: | |
def __init__(self, sampler_data_path): | |
with open(sampler_data_path, "rb") as fp: | |
data = pickle.load(fp) | |
prefix2companies, suffix2companies = ( | |
data["prefix2companies"], | |
data["suffix2companies"], | |
) | |
self.prefix_keys = list(prefix2companies.keys()) | |
self.n_prefix_keys = len(self.prefix_keys) | |
self.suffix_keys = list(suffix2companies.keys()) | |
self.n_suffix_keys = len(self.suffix_keys) | |
self.prefix2companies = prefix2companies | |
self.prefix2n_companies = {k: len(v) for k, v in prefix2companies.items()} | |
self.suffix2companies = suffix2companies | |
self.suffix2n_companies = {k: len(v) for k, v in suffix2companies.items()} | |
def sample(self, n_sample): | |
for _ in range(n_sample): | |
if np.random.uniform(low=0.0, high=1.0) > 0.5: | |
# prefix | |
key = self.prefix_keys[random.randint(0, self.n_prefix_keys - 1)] | |
yield self.prefix2companies[key][ | |
random.randint(0, self.prefix2n_companies[key] - 1) | |
] | |
else: | |
# suffix | |
key = self.suffix_keys[random.randint(0, self.n_suffix_keys - 1)] | |
yield self.suffix2companies[key][ | |
random.randint(0, self.suffix2n_companies[key] - 1) | |
] | |
def sample_with_pattern_coverage(self, min_n_sample): | |
for key in self.prefix_keys: | |
companies = self.prefix2companies[key] | |
n_companies = self.prefix2n_companies[key] | |
for _ in range(min_n_sample): | |
yield companies[random.randint(0, n_companies - 1)] | |
for key in self.suffix_keys: | |
companies = self.suffix2companies[key] | |
n_companies = self.suffix2n_companies[key] | |
for _ in range(min_n_sample): | |
yield companies[random.randint(0, n_companies - 1)] | |
@staticmethod | |
def _replace_example( | |
example: StringSpanExample, | |
new_target: str, | |
anno: SpanAnnotation, | |
anno_idx: int, | |
new_idx: int = 0, | |
) -> StringSpanExample: | |
""" | |
Replace anno with new_target, then update other annotations | |
annos = list(enumerate(example.annotations)) | |
anno_idx, anno = random.choice(annos) | |
_replace_example(example, new_target, anno, anno_idx) | |
""" | |
new_len = len(new_target) | |
text = example.content | |
new_text = f"{text[:anno.start]}{new_target}{text[anno.end:]}" | |
assert new_text[anno.start : anno.start + new_len] == new_target | |
new_anno = SpanAnnotation( | |
start=anno.start, end=anno.start + new_len, label=anno.label | |
) | |
# update annotations wrt new_anno | |
org_len = anno.end - anno.start | |
offset = new_len - org_len | |
new_annotations = [] | |
for org_anno in example.annotations: | |
if org_anno.start < anno.start: | |
updated = SpanAnnotation( | |
start=org_anno.start, end=org_anno.end, label=org_anno.label | |
) | |
elif org_anno.start > anno.start: | |
updated = SpanAnnotation( | |
start=org_anno.start + offset, | |
end=org_anno.end + offset, | |
label=org_anno.label, | |
) | |
else: | |
updated = new_anno | |
new_annotations.append(updated) | |
new_example = StringSpanExample( | |
guid=f"{example.guid}-ex{anno_idx}-sample{new_idx}", | |
content=new_text, | |
annotations=new_annotations, | |
) | |
return new_example | |
def extend_example( | |
self, example: StringSpanExample, n_samples: int = 1 | |
) -> List[StringSpanExample]: | |
"""各事例の各アノテーションに対して、サンプルした企業名で置換した事例を生成""" | |
new_examples = [] | |
for i, anno in enumerate(example.annotations): | |
for j, new_target in enumerate(self.sample(n_samples)): | |
new_example = self._replace_example(example, new_target, anno, i, j) | |
new_examples.append(new_example) | |
return new_examples | |
def cover_target_examples( | |
self, | |
contexts: List[StringSpanExample], | |
min_n_sample: int = 10, | |
n_samples: int = 1, | |
) -> List[StringSpanExample]: | |
"""各サンプルパターンを一定数取得し、ランダムな文脈で置換配置する""" | |
companies = list(self.sample_with_pattern_coverage(min_n_sample * n_samples)) | |
companies_batch = [ | |
companies[i : i + n_samples] for i in range(0, len(companies), n_samples) | |
] | |
target_examples = random.sample(contexts, n_samples) | |
n_samples = min(n_samples, len(contexts)) | |
augmented_examples = [] | |
for companies in companies_batch: | |
for ex, c in zip(target_examples, companies): | |
anno_idx, anno = random.choice(list(enumerate(ex.annotations))) | |
new_ex = self._replace_example(ex, c, anno, anno_idx) | |
augmented_examples.append(new_ex) | |
return augmented_examples |
Author
kzinmr
commented
Feb 5, 2021
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment