Last active
August 21, 2020 14:41
-
-
Save podgorskiy/4d527321007868c730fa24c0f468a7d1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
import contextlib | |
import datetime | |
import io | |
import json | |
import logging | |
import numpy as np | |
import os | |
import pycocotools.mask as mask_util | |
from fvcore.common.file_io import PathManager, file_lock | |
from fvcore.common.timer import Timer | |
from PIL import Image | |
import pickle | |
import itertools | |
import imageio | |
import numpy as np | |
from PIL import Image | |
import PIL | |
from matplotlib import pyplot as plt | |
import random | |
from detectron2.structures import Boxes, BoxMode, PolygonMasks | |
from .. import DatasetCatalog, MetadataCatalog | |
""" | |
This file contains functions to parse COCO-format annotations into dicts in "Detectron2 format". | |
""" | |
logger = logging.getLogger(__name__) | |
__all__ = ["load_artf", "convert_to_coco_json"] | |
def load_artf(save_pickle, image_root, filter, split, name): | |
timer = Timer() | |
json_file = PathManager.get_local_path(save_pickle) | |
if timer.seconds() > 1: | |
logger.info("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds())) | |
if name is not None: | |
meta = MetadataCatalog.get(name) | |
cat_ids = [1] | |
paths = [] | |
for dirName, subdirList, fileList in os.walk(image_root): | |
paths += [os.path.relpath(os.path.join(dirName, x), image_root) for x in fileList if | |
x.endswith('.jpg') or x.endswith('.jpeg') or x.endswith('.png')] | |
paths.sort() | |
annotation = {} | |
with open(save_pickle, 'rb') as f: | |
annotation = pickle.load(f) | |
if '__banned' not in annotation.keys(): | |
annotation['__banned'] = set() | |
dirs = [x for x in os.listdir(image_root) if os.path.isdir(os.path.join(image_root, x))] | |
counts = {x: 0 for x in dirs} | |
counts_real = {x: len(os.listdir(os.path.join(image_root, x))) for x in dirs} | |
for x in dirs: | |
for p in annotation.keys(): | |
if x in p: | |
counts[x] += 1 | |
for i, c in counts.items(): | |
print("{}:{}:{}".format(i, c, counts_real[i])) | |
to_delete = [] | |
for p, x in annotation.items(): | |
if len(x) == 0 and p != '__banned': | |
to_delete.append(p) | |
for p in to_delete: | |
del annotation[p] | |
del annotation['__banned'] | |
data = [] | |
img_sizes = {} | |
img_scales = {} | |
for id, (filename, x) in enumerate(list(annotation.items())): | |
img = Image.open(os.path.join(image_root, '../library', filename)) | |
w, h = img.size | |
scale = min(1600 / w, 1600 / h) | |
scale = max(max(800 / w, 800 / h), scale) | |
if scale > 1.0: | |
scale = 1.0 | |
img_sizes[filename] = (int(w * scale), int(h * scale)) | |
img_scales[filename] = scale | |
rw = -1 | |
rh = -1 | |
if os.path.exists(os.path.join(image_root, filename)): | |
rimg = Image.open(os.path.join(image_root, filename)) | |
rw, rh = rimg.size | |
if not os.path.exists(os.path.join(image_root, filename)) or rw != int(w * scale) or rh != int(h * scale): | |
img = imageio.imread(os.path.join(image_root, '../library', filename)) | |
img = Image.fromarray(img) | |
if scale < 1.0: | |
print(img.size, "Resizing to: ", (int(w * scale), int(h * scale))) | |
img = img.resize((int(w * scale), int(h * scale)), PIL.Image.LANCZOS) | |
os.makedirs(os.path.dirname(os.path.join(image_root, filename)), exist_ok=True) | |
imageio.imwrite(os.path.join(image_root, filename), img) | |
data.append((id, filename, x)) | |
print("Data length: %d" % len(data)) | |
random.seed(0) | |
random.shuffle(data) | |
data_train = data[:int(len(data) * 0.8)] | |
data_test = data[int(len(data) * 0.8):] | |
if split == 'train': | |
dataset = data_train | |
elif split == 'test': | |
dataset = data_test | |
else: | |
assert False, split | |
print("Data size: %d" % len(annotation.items())) | |
dataset_dicts = [] | |
ann_keys = ["iscrowd", "bbox", "keypoints", "category_id"] | |
num_instances_without_valid_segmentation = 0 | |
def make_keypoints(filename, points): | |
assert len(points) == 4 | |
scale = img_scales[filename] | |
return list(itertools.chain.from_iterable([[int(x * scale), int(y * scale), 2] for x, y in points])) | |
def get_box(filename, points): | |
lm = np.array(points) | |
# Calculate auxiliary vectors. | |
eye_left = lm[0] | |
eye_right = lm[1] | |
eye_avg = (eye_left + eye_right) * 0.5 | |
eye_to_eye = eye_right - eye_left | |
mouth_left = lm[2] | |
mouth_right = lm[3] | |
mouth_avg = (mouth_left + mouth_right) * 0.5 | |
eye_to_mouth = mouth_avg - eye_avg | |
# Choose oriented crop rectangle. | |
x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] | |
x /= np.hypot(*x) | |
x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) * 0.8 | |
y = np.flipud(x) * [-1, 1] | |
c = eye_avg + eye_to_mouth * 0.1 | |
quad = np.stack([c - x - y * 0.8, c - x + y * 1.2, c + x + y * 1.2, c + x - y * 0.8]) | |
qsize = np.hypot(*x) * 2 | |
crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), | |
int(np.ceil(max(quad[:, 1])))) | |
scale = img_scales[filename] | |
return [int(x * scale) for x in[crop[0], crop[1], crop[2] - crop[0], crop[3] - crop[1]]] | |
for id, filename, points in dataset: | |
record = {} | |
record["file_name"] = os.path.join(image_root, filename) | |
record["height"] = img_sizes[filename][1] | |
record["width"] = img_sizes[filename][0] | |
image_id = record["image_id"] = id | |
n = 4 | |
anno_list = x = [points[i:i + n] for i in range(0, len(points), n)] | |
objs = [] | |
for anno in anno_list: | |
if len(anno) != 4: | |
print('Error with %s' % filename) | |
continue | |
obj = { | |
"bbox": get_box(filename, anno), | |
"keypoints": make_keypoints(filename, anno), | |
"category_id": 1, | |
"bbox_mode": BoxMode.XYWH_ABS | |
} | |
objs.append(obj) | |
record["annotations"] = objs | |
dataset_dicts.append(record) | |
return dataset_dicts | |
def convert_to_coco_dict(dataset_name): | |
""" | |
Convert an instance detection/segmentation or keypoint detection dataset | |
in detectron2's standard format into COCO json format. | |
Generic dataset description can be found here: | |
https://detectron2.readthedocs.io/tutorials/datasets.html#register-a-dataset | |
COCO data format description can be found here: | |
http://cocodataset.org/#format-data | |
Args: | |
dataset_name (str): | |
name of the source dataset | |
Must be registered in DatastCatalog and in detectron2's standard format. | |
Must have corresponding metadata "thing_classes" | |
Returns: | |
coco_dict: serializable dict in COCO json format | |
""" | |
dataset_dicts = DatasetCatalog.get(dataset_name) | |
metadata = MetadataCatalog.get(dataset_name) | |
# unmap the category mapping ids for COCO | |
if hasattr(metadata, "thing_dataset_id_to_contiguous_id"): | |
reverse_id_mapping = {v: k for k, v in metadata.thing_dataset_id_to_contiguous_id.items()} | |
reverse_id_mapper = lambda contiguous_id: reverse_id_mapping[contiguous_id] # noqa | |
else: | |
reverse_id_mapper = lambda contiguous_id: contiguous_id # noqa | |
categories = [ | |
{"id": reverse_id_mapper(id), "name": name} | |
for id, name in enumerate(metadata.thing_classes) | |
] | |
logger.info("Converting dataset dicts into COCO format") | |
coco_images = [] | |
coco_annotations = [] | |
for image_id, image_dict in enumerate(dataset_dicts): | |
coco_image = { | |
"id": image_dict.get("image_id", image_id), | |
"width": image_dict["width"], | |
"height": image_dict["height"], | |
"file_name": image_dict["file_name"], | |
} | |
coco_images.append(coco_image) | |
anns_per_image = image_dict.get("annotations", []) | |
for annotation in anns_per_image: | |
# create a new dict with only COCO fields | |
coco_annotation = {} | |
# COCO requirement: XYWH box format | |
bbox = annotation["bbox"] | |
bbox_mode = annotation["bbox_mode"] | |
bbox = BoxMode.convert(bbox, bbox_mode, BoxMode.XYWH_ABS) | |
# COCO requirement: instance area | |
if "segmentation" in annotation: | |
# Computing areas for instances by counting the pixels | |
segmentation = annotation["segmentation"] | |
# TODO: check segmentation type: RLE, BinaryMask or Polygon | |
if isinstance(segmentation, list): | |
polygons = PolygonMasks([segmentation]) | |
area = polygons.area()[0].item() | |
elif isinstance(segmentation, dict): # RLE | |
area = mask_util.area(segmentation).item() | |
else: | |
raise TypeError(f"Unknown segmentation type {type(segmentation)}!") | |
else: | |
# Computing areas using bounding boxes | |
bbox_xy = BoxMode.convert(bbox, BoxMode.XYWH_ABS, BoxMode.XYXY_ABS) | |
area = Boxes([bbox_xy]).area()[0].item() | |
if "keypoints" in annotation: | |
keypoints = annotation["keypoints"] # list[int] | |
for idx, v in enumerate(keypoints): | |
if idx % 3 != 2: | |
# COCO's segmentation coordinates are floating points in [0, H or W], | |
# but keypoint coordinates are integers in [0, H-1 or W-1] | |
# For COCO format consistency we substract 0.5 | |
# https://github.com/facebookresearch/detectron2/pull/175#issuecomment-551202163 | |
keypoints[idx] = v - 0.5 | |
if "num_keypoints" in annotation: | |
num_keypoints = annotation["num_keypoints"] | |
else: | |
num_keypoints = sum(kp > 0 for kp in keypoints[2::3]) | |
# COCO requirement: | |
# linking annotations to images | |
# "id" field must start with 1 | |
coco_annotation["id"] = len(coco_annotations) + 1 | |
coco_annotation["image_id"] = coco_image["id"] | |
coco_annotation["bbox"] = [round(float(x), 3) for x in bbox] | |
coco_annotation["area"] = float(area) | |
coco_annotation["iscrowd"] = annotation.get("iscrowd", 0) | |
coco_annotation["category_id"] = reverse_id_mapper(annotation["category_id"]) | |
# Add optional fields | |
if "keypoints" in annotation: | |
coco_annotation["keypoints"] = keypoints | |
coco_annotation["num_keypoints"] = num_keypoints | |
if "segmentation" in annotation: | |
seg = coco_annotation["segmentation"] = annotation["segmentation"] | |
if isinstance(seg, dict): # RLE | |
counts = seg["counts"] | |
if not isinstance(counts, str): | |
# make it json-serializable | |
seg["counts"] = counts.decode("ascii") | |
coco_annotations.append(coco_annotation) | |
logger.info( | |
"Conversion finished, " | |
f"#images: {len(coco_images)}, #annotations: {len(coco_annotations)}" | |
) | |
info = { | |
"date_created": str(datetime.datetime.now()), | |
"description": "Automatically generated COCO json file for Detectron2.", | |
} | |
coco_dict = {"info": info, "images": coco_images, "categories": categories, "licenses": None} | |
if len(coco_annotations) > 0: | |
coco_dict["annotations"] = coco_annotations | |
return coco_dict | |
def convert_to_coco_json(dataset_name, output_file, allow_cached=True): | |
""" | |
Converts dataset into COCO format and saves it to a json file. | |
dataset_name must be registered in DatasetCatalog and in detectron2's standard format. | |
Args: | |
dataset_name: | |
reference from the config file to the catalogs | |
must be registered in DatasetCatalog and in detectron2's standard format | |
output_file: path of json file that will be saved to | |
allow_cached: if json file is already present then skip conversion | |
""" | |
# TODO: The dataset or the conversion script *may* change, | |
# a checksum would be useful for validating the cached data | |
PathManager.mkdirs(os.path.dirname(output_file)) | |
with file_lock(output_file): | |
if PathManager.exists(output_file) and allow_cached: | |
logger.warning( | |
f"Using previously cached COCO format annotations at '{output_file}'. " | |
"You need to clear the cache file if your dataset has been modified." | |
) | |
else: | |
logger.info(f"Converting annotations of dataset '{dataset_name}' to COCO format ...)") | |
coco_dict = convert_to_coco_dict(dataset_name) | |
logger.info(f"Caching COCO format annotations at '{output_file}' ...") | |
with PathManager.open(output_file, "w") as f: | |
json.dump(coco_dict, f) | |
if __name__ == "__main__": | |
""" | |
Test the COCO json dataset loader. | |
Usage: | |
python -m detectron2.data.datasets.coco \ | |
path/to/json path/to/image_root dataset_name | |
"dataset_name" can be "coco_2014_minival_100", or other | |
pre-registered ones | |
""" | |
# from detectron2.utils.logger import setup_logger | |
# from detectron2.utils.visualizer import Visualizer | |
# import detectron2.data.datasets # noqa # add pre-defined metadata | |
# import sys | |
# | |
# logger = setup_logger(name=__name__) | |
# assert sys.argv[3] in DatasetCatalog.list() | |
# meta = MetadataCatalog.get(sys.argv[3]) | |
# | |
# dicts = load_artf(save_pickle, image_root, filter, split, name) | |
# logger.info("Done loading {} samples.".format(len(dicts))) | |
# | |
# dirname = "coco-data-vis" | |
# os.makedirs(dirname, exist_ok=True) | |
# for d in dicts: | |
# img = np.array(Image.open(d["file_name"])) | |
# visualizer = Visualizer(img, metadata=meta) | |
# vis = visualizer.draw_dataset_dict(d) | |
# fpath = os.path.join(dirname, os.path.basename(d["file_name"])) | |
# vis.save(fpath) | |
################################### | |
# Other file that does registration | |
################################### | |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
import copy | |
import os | |
from detectron2.data import DatasetCatalog, MetadataCatalog | |
from .artf import load_artf | |
__all__ = ["register_artf_instances"] | |
def register_artf_instances(name, metadata, root, filter, split): | |
assert isinstance(name, str), name | |
assert isinstance(root, (str, os.PathLike)), root | |
save_pickle = os.path.join(root, 'data', 'art_face_save.pth') | |
image_root = os.path.join(root, 'data', 'resized_lib') | |
assert os.path.exists(save_pickle), save_pickle | |
assert os.path.exists(image_root), image_root | |
assert os.path.isfile(save_pickle), save_pickle | |
assert os.path.isdir(image_root), image_root | |
# 1. register a function which returns dicts | |
DatasetCatalog.register(name, lambda: load_artf(save_pickle, image_root, filter, split, name)) | |
# 2. Optionally, add metadata about this dataset, | |
# since they might be useful in evaluation, visualization or logging | |
MetadataCatalog.get(name).set( | |
pth_file=save_pickle, image_root=image_root, evaluator_type="coco", **metadata | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment