Created
January 9, 2020 14:34
-
-
Save Slater-Victoroff/f89350b80ef80fc2c733b6167a8115cf to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Collection of methods built to assist in data augmentation for extraction datasets | |
""" | |
from ast import literal_eval | |
import json | |
import random | |
from collections import defaultdict | |
from functools import partial | |
from typing import Iterable, Dict, Callable | |
import pandas as pd | |
from BlueJet.sources.local import TrainingSet | |
def random_swap(text, options): | |
# Can't swap if you have nothing to swap with | |
# so fallback to our only valid option | |
options = tuple(options - {text}) or (text,) | |
return random.choice(options) | |
def _get_swap_lambdas(df, source_col, target_col, **kwargs): | |
options = defaultdict(set) | |
for _, row in df.iterrows(): | |
for label in row[target_col]: | |
if not label.get("text"): | |
label["text"] = row[source_col][label["start"]:label["end"]] | |
options[label["label"]] |= {label["text"]} | |
# Yes, we are ignoring x on purpose | |
return {key: partial(random_swap, options=options[key]) for key in options} | |
def _get_noop(df, source_col, target_col, **kwargs): | |
options = set(label["label"] for labels in df[target_col] for label in labels) | |
return {key: lambda x : x for key in options} | |
class TokenAugmentor: | |
""" | |
Only applies for extraction models. | |
TokenAugmentor only replaces the tagged tokens and updates labels. Context augmentation | |
will be handled in a separate Augmentor. | |
Example: | |
df = pd.read_csv("<source>.csv") | |
test = TokenAugmentor({"swap": 3, "original": 1}) | |
test.augment(df, "text", "question_#", "<destination>.csv") | |
""" | |
strategies = { | |
"swap": _get_swap_lambdas, | |
"original": _get_noop | |
} | |
def __init__(self, strategy:dict): | |
for key in strategy: | |
if key.lower() not in self.strategies: | |
raise ValueError("%s not an available strategy" % key) | |
self.strategy = strategy | |
def augment(self, df, source_col:str, target_col:str, results_file:str, **kwargs): | |
df[target_col] = df[target_col].apply(literal_eval) | |
swappers = {} | |
for key in self.strategy: | |
swappers[key] = self.strategies[key](df, source_col, target_col, **kwargs) | |
augmented_data = {source_col: [], target_col: []} | |
for _, row in df.iterrows(): | |
new_sources, new_targets = self._augment_docs(row[source_col], row[target_col], swappers) | |
augmented_data[source_col].extend(new_sources) | |
augmented_data[target_col].extend(new_targets) | |
df = pd.DataFrame(augmented_data) | |
df.to_csv(open(results_file, "w")) | |
def _augment_docs(self, source:str, target:Iterable[dict], swappers:Dict[str, Callable], ): | |
new_sources = [] | |
new_targets = [] | |
for key, value in self.strategy.items(): | |
for _ in range(value): | |
try: | |
new_source, new_target = self._augment_doc(source, target, swappers[key]) | |
new_sources.append(new_source) | |
new_targets.append(json.dumps(new_target)) | |
except NotImplementedError: | |
continue | |
return new_sources, new_targets | |
def _augment_doc(self, source:str, target:Iterable[dict], swapper:Dict[str, Callable]): | |
offset = 0 | |
new_source = source | |
new_target = [] | |
last_end = 0 | |
for entry in sorted(target, key=lambda x: x["start"]): | |
original_value = entry["text"] | |
new_value = swapper[entry["label"]](entry["text"]) | |
# Checking for overlap | |
if new_target and (last_end > entry["start"]): | |
raise NotImplementedError( | |
"Overlapping labels are not yet supported" | |
) | |
last_end = entry["end"] | |
new_start = entry["start"] + offset | |
new_end = entry["end"] + offset | |
new_source = new_source[:new_start] + new_value + new_source[new_end:] | |
offset += (len(new_value) - len(original_value)) | |
final_end = entry["end"] + offset | |
new_target.append({ | |
"start": new_start, | |
"end": final_end, | |
"label": entry["label"], | |
"text": new_value | |
}) | |
# Ensure that mapping was accomplished successfully. Else error. | |
for new_entry in new_target: | |
assert(new_source[new_entry["start"]: new_entry["end"]] == new_entry["text"]) | |
return new_source, new_target | |
def convert_format(source_file:str, source_col:str, target_col:str): | |
""" | |
Convert files from standard Teach Export format to standard BlueJet format | |
""" | |
def rewrite_labels(row): | |
text = row[source_col] | |
old_labels = json.loads(row[target_col]) | |
new_labels = [] | |
for old_label in old_labels: | |
extracted_text = text[old_label["startOffset"]: old_label["endOffset"]] | |
new_label = { | |
"label": old_label["label"], | |
"start": old_label["startOffset"], | |
"end": old_label["endOffset"], | |
"text": extracted_text | |
} | |
new_labels.append(new_label) | |
return new_labels | |
df = pd.read_csv(open(source_file)) | |
new_label_col = [] | |
for _, row in df.iterrows(): | |
new_label_col.append(rewrite_labels(row)) | |
df[target_col] = new_label_col | |
df.to_csv(open(source_file, "w")) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment