Skip to content

Instantly share code, notes, and snippets.

@kzinmr
Last active February 5, 2021 17:15
Show Gist options
  • Save kzinmr/57ccb70169f1b8939f93faf2bff8ee68 to your computer and use it in GitHub Desktop.
Save kzinmr/57ccb70169f1b8939f93faf2bff8ee68 to your computer and use it in GitHub Desktop.
# 会社名サンプラー
import os
import pickle
import random
from dataclasses import dataclass
from typing import List
import numpy as np
def seed_everything(seed: int = 1234):
"""乱数固定"""
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
@dataclass
class SpanAnnotation:
start: int
end: int
label: str
@dataclass
class StringSpanExample:
guid: str
content: str
annotations: List[SpanAnnotation]
class CompanySampler:
def __init__(self, sampler_data_path):
with open(sampler_data_path, "rb") as fp:
data = pickle.load(fp)
prefix2companies, suffix2companies = (
data["prefix2companies"],
data["suffix2companies"],
)
self.prefix_keys = list(prefix2companies.keys())
self.n_prefix_keys = len(self.prefix_keys)
self.suffix_keys = list(suffix2companies.keys())
self.n_suffix_keys = len(self.suffix_keys)
self.prefix2companies = prefix2companies
self.prefix2n_companies = {k: len(v) for k, v in prefix2companies.items()}
self.suffix2companies = suffix2companies
self.suffix2n_companies = {k: len(v) for k, v in suffix2companies.items()}
def sample(self, n_sample):
for _ in range(n_sample):
if np.random.uniform(low=0.0, high=1.0) > 0.5:
# prefix
key = self.prefix_keys[random.randint(0, self.n_prefix_keys - 1)]
yield self.prefix2companies[key][
random.randint(0, self.prefix2n_companies[key] - 1)
]
else:
# suffix
key = self.suffix_keys[random.randint(0, self.n_suffix_keys - 1)]
yield self.suffix2companies[key][
random.randint(0, self.suffix2n_companies[key] - 1)
]
def sample_with_pattern_coverage(self, min_n_sample):
for key in self.prefix_keys:
companies = self.prefix2companies[key]
n_companies = self.prefix2n_companies[key]
for _ in range(min_n_sample):
yield companies[random.randint(0, n_companies - 1)]
for key in self.suffix_keys:
companies = self.suffix2companies[key]
n_companies = self.suffix2n_companies[key]
for _ in range(min_n_sample):
yield companies[random.randint(0, n_companies - 1)]
@staticmethod
def _replace_example(
example: StringSpanExample,
new_target: str,
anno: SpanAnnotation,
anno_idx: int,
new_idx: int = 0,
) -> StringSpanExample:
"""
Replace anno with new_target, then update other annotations
annos = list(enumerate(example.annotations))
anno_idx, anno = random.choice(annos)
_replace_example(example, new_target, anno, anno_idx)
"""
new_len = len(new_target)
text = example.content
new_text = f"{text[:anno.start]}{new_target}{text[anno.end:]}"
assert new_text[anno.start : anno.start + new_len] == new_target
new_anno = SpanAnnotation(
start=anno.start, end=anno.start + new_len, label=anno.label
)
# update annotations wrt new_anno
org_len = anno.end - anno.start
offset = new_len - org_len
new_annotations = []
for org_anno in example.annotations:
if org_anno.start < anno.start:
updated = SpanAnnotation(
start=org_anno.start, end=org_anno.end, label=org_anno.label
)
elif org_anno.start > anno.start:
updated = SpanAnnotation(
start=org_anno.start + offset,
end=org_anno.end + offset,
label=org_anno.label,
)
else:
updated = new_anno
new_annotations.append(updated)
new_example = StringSpanExample(
guid=f"{example.guid}-ex{anno_idx}-sample{new_idx}",
content=new_text,
annotations=new_annotations,
)
return new_example
def extend_example(
self, example: StringSpanExample, n_samples: int = 1
) -> List[StringSpanExample]:
"""各事例の各アノテーションに対して、サンプルした企業名で置換した事例を生成"""
new_examples = []
for i, anno in enumerate(example.annotations):
for j, new_target in enumerate(self.sample(n_samples)):
new_example = self._replace_example(example, new_target, anno, i, j)
new_examples.append(new_example)
return new_examples
def cover_target_examples(
self,
contexts: List[StringSpanExample],
min_n_sample: int = 10,
n_samples: int = 1,
) -> List[StringSpanExample]:
"""各サンプルパターンを一定数取得し、ランダムな文脈で置換配置する"""
companies = list(self.sample_with_pattern_coverage(min_n_sample * n_samples))
companies_batch = [
companies[i : i + n_samples] for i in range(0, len(companies), n_samples)
]
target_examples = random.sample(contexts, n_samples)
n_samples = min(n_samples, len(contexts))
augmented_examples = []
for companies in companies_batch:
for ex, c in zip(target_examples, companies):
anno_idx, anno = random.choice(list(enumerate(ex.annotations)))
new_ex = self._replace_example(ex, c, anno, anno_idx)
augmented_examples.append(new_ex)
return augmented_examples
@kzinmr
Copy link
Author

kzinmr commented Feb 5, 2021

class CompanySamplerV2:
    def __init__(
        self,
        sampler_data_path
    ):
        with open(sampler_data_path, 'rb') as fp:
            data = pickle.load(fp)
            self.prefix2companies = {k: [v['replaced'] for v in vs] for k, vs in data['prefix2companies'].items()}
            self.suffix2companies = {k: [v['replaced'] for v in vs] for k, vs in data['suffix2companies'].items()}
            self.affix2companies = {k: [v['replaced'] for v in vs] for k, vs in data['affix2companies'].items()}

            self.prefix_keys = list(self.prefix2companies.keys())
            self.suffix_keys = list(self.suffix2companies.keys())
            self.affix_keys = list(self.affix2companies.keys())
            self.n_prefix_keys = len(self.prefix_keys)
            self.n_suffix_keys = len(self.suffix_keys)
            self.n_affix_keys = len(self.affix_keys)
            self.prefix2n_companies = {k: len(v) for k, v in self.prefix2companies.items()}
            self.suffix2n_companies = {k: len(v) for k, v in self.suffix2companies.items()}
            self.affix2n_companies = {k: len(v) for k, v in self.affix2companies.items()}

    def sample(self, n_sample):
        for _ in range(n_sample):
            if np.random.uniform(low=0.0, high=1.0) < 0.33:
                # prefix
                key = self.prefix_keys[random.randint(0, self.n_prefix_keys - 1)]
                yield self.prefix2companies[key][
                    random.randint(0, self.prefix2n_companies[key] - 1)
                ]
            elif np.random.uniform(low=0.0, high=1.0) > 0.67:
                # suffix
                key = self.suffix_keys[random.randint(0, self.n_suffix_keys - 1)]
                yield self.suffix2companies[key][
                    random.randint(0, self.suffix2n_companies[key] - 1)
                ]
            else:
                # affix
                key = self.affix_keys[random.randint(0, self.n_affix_keys - 1)]
                yield self.affix2companies[key][
                    random.randint(0, self.affix2n_companies[key] - 1)
                ]

    def sample_with_pattern_coverage(self, min_n_sample):
        for key in self.prefix_keys:
            companies = self.prefix2companies[key]
            n_companies = self.prefix2n_companies[key]
            for _ in range(min_n_sample):
                yield companies[random.randint(0, n_companies - 1)]

        for key in self.suffix_keys:
            companies = self.suffix2companies[key]
            n_companies = self.suffix2n_companies[key]
            for _ in range(min_n_sample):
                yield companies[random.randint(0, n_companies - 1)]

        for key in self.affix_keys:
            companies = self.affix2companies[key]
            n_companies = self.affix2n_companies[key]
            for _ in range(min_n_sample):
                yield companies[random.randint(0, n_companies - 1)]

    @staticmethod
    def _replace_example(
        example: StringSpanExample,
        new_target: str,
        anno: SpanAnnotation,
        anno_idx: int,
        new_idx: int = 0,
    ) -> StringSpanExample:
        """
        Replace anno with new_target, then update other annotations
        annos = list(enumerate(example.annotations))
        anno_idx, anno = random.choice(annos)
        _replace_example(example, new_target, anno, anno_idx)
        """
        new_len = len(new_target)
        text = example.content
        new_text = f"{text[:anno.start]}{new_target}{text[anno.end:]}"
        assert new_text[anno.start : anno.start + new_len] == new_target
        new_anno = SpanAnnotation(
            start=anno.start, end=anno.start + new_len, label=anno.label
        )

        # update annotations wrt new_anno
        org_len = anno.end - anno.start
        offset = new_len - org_len
        new_annotations = []
        for org_anno in example.annotations:
            if org_anno.start < anno.start:
                updated = SpanAnnotation(
                    start=org_anno.start, end=org_anno.end, label=org_anno.label
                )
            elif org_anno.start > anno.start:
                updated = SpanAnnotation(
                    start=org_anno.start + offset,
                    end=org_anno.end + offset,
                    label=org_anno.label,
                )
            else:
                updated = new_anno
            new_annotations.append(updated)

        new_example = StringSpanExample(
            guid=f"{example.guid}-ex{anno_idx}-sample{new_idx}",
            content=new_text,
            annotations=new_annotations,
        )
        return new_example

    def extend_example(
        self, example: StringSpanExample, n_samples: int = 1
    ) -> List[StringSpanExample]:
        """各事例の各アノテーションに対して、サンプルした企業名で置換した事例を生成"""
        new_examples = []
        for i, anno in enumerate(example.annotations):
            for j, new_target in enumerate(self.sample(n_samples)):
                new_example = self._replace_example(example, new_target, anno, i, j)
                new_examples.append(new_example)
        return new_examples

    def cover_target_examples(
        self,
        contexts: List[StringSpanExample],
        min_n_sample: int = 10,
        n_samples: int = 1,
    ) -> List[StringSpanExample]:
        """各サンプルパターンを一定数取得し、ランダムな文脈で置換配置する"""
        companies = list(self.sample_with_pattern_coverage(min_n_sample * n_samples))
        companies_batch = [
            companies[i : i + n_samples] for i in range(0, len(companies), n_samples)
        ]
        target_examples = random.sample(contexts, n_samples)
        n_samples = min(n_samples, len(contexts))
        augmented_examples = []
        for companies in companies_batch:
            for ex, c in zip(target_examples, companies):
                anno_idx, anno = random.choice(list(enumerate(ex.annotations)))
                new_ex = self._replace_example(ex, c, anno, anno_idx)
                augmented_examples.append(new_ex)
        return augmented_examples

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment