Created
May 4, 2021 15:07
-
-
Save kervel/c16d57c0c6eadf0b4497f545da5b43d2 to your computer and use it in GitHub Desktop.
preprocess VOC data
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Script that executes the first stage of the pipeline for VOC data (filtering) | |
""" | |
import argparse | |
import xml.etree.ElementTree as ET | |
import os | |
import random | |
DEFAULT_CONFIG = { | |
'validation': [], | |
'training': [], | |
'training_split': 0.8 | |
} | |
def get_attributes(obj: ET.Element): | |
d = {} | |
for a in obj.find("attributes").findall("attribute"): | |
a_key = a.find("name").text | |
a_val = a.find("value").text | |
d[a_key] = a_val | |
return d | |
def process_annotation(ann_file, labels, img_dir, out_dir, dumponly): | |
""" Process an annotation, filtering out unwanted labels, modifying/adding values where | |
needed, and storying it in the suitable output directory | |
Args: | |
ann_file: path to input annotation file | |
labels: Set of labels you want to keep in the annotation | |
img_dir: path to directory where corresponding image is stored (needs to be saved in the | |
annotation itself) | |
out_dir: path to directory where the processed annotation needs to be saved | |
Returns: n/a | |
""" | |
try: | |
tree = ET.parse(ann_file) | |
except: | |
print(f"0 failed parsing {ann_file}") | |
return | |
root = tree.getroot() | |
# find all 'objects' that need to be filtered -- if all objects need to be filtered, | |
# return and skip this annotation | |
objects = root.findall("object") | |
for o in objects: | |
nm = o.find("name") | |
attrs = get_attributes(o) | |
if 'Type' in attrs: | |
nm.text = nm.text + attrs['Type'] | |
if(dumponly): | |
print(nm.text) | |
if dumponly: | |
return | |
objects = root.findall("object") | |
if len(labels) > 0: | |
filtered = list( | |
filter(lambda x: x.find("name").text not in labels, objects) | |
) # objects that need to be removed from the tree | |
if len(objects) == len(filtered): | |
return | |
for ann_object in filtered: | |
root.remove(ann_object) | |
# img file relative to root data dir | |
img_file = os.path.join(img_dir, root.find("filename").text) | |
# if the image doesn't exist, return (stop processing) | |
if not os.path.isfile(os.path.join(data_root, img_file)): | |
print(f"0 image does not exist {ann_file} --> try to find {img_file}") | |
return | |
# otherwise the image does exist -- modify the annotation so it contains a path to this image | |
# the path to the image has to be relative to the root directory | |
ann_id = os.path.basename(ann_file) | |
# change annotation image directory | |
folder = root.find("folder") | |
folder.text = "" | |
root.find("filename").text = img_file | |
print(f"1 writing file {ann_file}") | |
# save annotation to output directory | |
tree.write(os.path.join(out_dir, ann_id)) | |
return | |
def preprocess_voc_data(root_data_dir, labels, out_dir, dumponly): | |
""" preprocess a VOC data set | |
Args: | |
labels: labels to filter (or no labels if no filtering necessary) | |
root_data_dir: path to the base data directory | |
(e.g. raw/dataset) | |
out_dir: path to output directory, relative to root data directory | |
(e.g. processed/dataset) | |
""" | |
voc_images = os.path.join(root_data_dir, "JPEGImages") | |
annotation_dir = os.path.join(root_data_dir, "Annotations") | |
# create output directories | |
train_anns_dir = os.path.join(out_dir, "train_annotations") | |
val_anns_dir = os.path.join(out_dir, "val_annotations") | |
test_anns_dir = os.path.join(out_dir, "test_annotations") | |
if not dumponly: | |
os.makedirs(train_anns_dir, exist_ok=True) | |
os.makedirs(val_anns_dir, exist_ok=True) | |
os.makedirs(test_anns_dir, exist_ok=True) | |
# iterate over all annotation files, process them | |
# annotation dir can contain both subdirectories and annotation files | |
for file in os.listdir(annotation_dir): | |
file = os.path.join(annotation_dir, file) | |
if os.path.isdir(file): | |
for ann_file in os.listdir(file): | |
ann_file = os.path.join(file, ann_file) | |
rand = random.random() | |
ann_dst = train_anns_dir if rand < DEFAULT_CONFIG["training_split"] else val_anns_dir | |
process_annotation(ann_file, labels, voc_images, ann_dst, dumponly) | |
else: | |
ann_file = os.path.join(annotation_dir, file) | |
rand = random.random() | |
ann_dst = train_anns_dir if rand < DEFAULT_CONFIG["training_split"] else val_anns_dir | |
process_annotation(ann_file, labels, voc_images, ann_dst, dumponly) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--root_data_dir", | |
help="root data directy -- any specified file/directory path is relative to this root " | |
"data directory", | |
required=True, | |
) | |
parser.add_argument("--output_dir", required=True) | |
parser.add_argument("--target_labels", required=True, nargs="*") | |
parser.add_argument("--dumponly", help="just dump all labels you found") | |
args = parser.parse_args() | |
data_root = os.path.abspath(args.root_data_dir) | |
target_labels = args.target_labels | |
preprocess_voc_data( | |
root_data_dir=data_root, | |
labels=target_labels, | |
out_dir=os.path.join(data_root, args.output_dir), | |
dumponly = args.dumponly | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment