Last active
June 3, 2022 01:12
-
-
Save mur6/3f7418c8bbc623052f8c3d8094092f69 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import pathlib | |
import json | |
import numpy as np | |
import cv2 | |
from detectron2.structures import BoxMode | |
from detectron2.data import MetadataCatalog, DatasetCatalog | |
def make_annotation_dict(json_dict, *, filter_category_id=None): | |
ann_list_dict = {} | |
for ann in json_dict["annotations"]: | |
#print("----#[annotation]#----") | |
#del ann["events"] | |
image_id = ann['image_id'] | |
category_id = ann['category_id'] | |
keypoints = ann['keypoints'] | |
num_keypoints = ann['num_keypoints'] | |
d = dict(category_id=category_id, keypoints=keypoints, num_keypoints=num_keypoints) | |
ann_list_dict.setdefault(image_id, []).append(d) | |
if filter_category_id is None: | |
return ann_list_dict | |
else: | |
def _filter_dict_by_category_id(): | |
for image_id, records in ann_list_dict.items(): | |
for record in records: | |
if filter_category_id == record['category_id']: | |
yield image_id, record | |
return dict(_filter_dict_by_category_id()) | |
def load_json(*, filename): | |
return json.loads(pathlib.Path(filename).read_text()) | |
def get_mathand_dicts(json_dict, *, image_dir): | |
image_dir = pathlib.Path(image_dir) | |
def _record_dict_iter(): | |
for v in json_dict["images"]: | |
record = {} | |
filename = str(image_dir / v["file_name"]) | |
height, width = cv2.imread(filename).shape[:2] | |
record["file_name"] = filename | |
record["image_id"] = v["id"] | |
record["height"] = height | |
record["width"] = width | |
print(record) | |
yield record | |
dataset_dicts = list(_record_dict_iter()) | |
return dataset_dicts | |
from pycocotools.coco import COCO | |
from typing import Any, Dict, Iterable, List, Optional | |
def _maybe_add_bbox(obj: Dict[str, Any], ann_dict: Dict[str, Any]): | |
if "bbox" not in ann_dict: | |
return | |
obj["bbox"] = ann_dict["bbox"] | |
obj["bbox_mode"] = BoxMode.XYWH_ABS | |
def _maybe_add_keypoints(obj: Dict[str, Any], ann_dict: Dict[str, Any]): | |
if "keypoints" not in ann_dict: | |
return | |
keypts = ann_dict["keypoints"] # list[int] | |
for idx, v in enumerate(keypts): | |
if idx % 3 != 2: | |
# COCO's segmentation coordinates are floating points in [0, H or W], | |
# but keypoint coordinates are integers in [0, H-1 or W-1] | |
# Therefore we assume the coordinates are "pixel indices" and | |
# add 0.5 to convert to floating point coordinates. | |
keypts[idx] = v + 0.5 | |
obj["keypoints"] = keypts | |
def _combine_images_with_annotations( | |
dataset_name: str, | |
image_root: str, | |
cat_id: str, | |
img_datas: Iterable[Dict[str, Any]], | |
ann_datas: Iterable[Iterable[Dict[str, Any]]], | |
): | |
ann_keys = ["iscrowd", "category_id"] | |
dataset_dicts = [] | |
contains_video_frame_info = False | |
for img_dict, ann_dicts in zip(img_datas, ann_datas): | |
record = {} | |
record["file_name"] = os.path.join(image_root, img_dict["file_name"]) | |
record["height"] = img_dict["height"] | |
record["width"] = img_dict["width"] | |
record["image_id"] = img_dict["id"] | |
record["dataset"] = dataset_name | |
if "frame_id" in img_dict: | |
record["frame_id"] = img_dict["frame_id"] | |
record["video_id"] = img_dict.get("vid_id", None) | |
contains_video_frame_info = True | |
objs = [] | |
for ann_dict in ann_dicts: | |
assert ann_dict["image_id"] == record["image_id"] | |
assert ann_dict.get("ignore", 0) == 0 | |
obj = {key: ann_dict[key] for key in ann_keys if key in ann_dict} | |
if obj["category_id"] == cat_id: | |
_maybe_add_bbox(obj, ann_dict) | |
#_maybe_add_segm(obj, ann_dict) | |
_maybe_add_keypoints(obj, ann_dict) | |
#_maybe_add_densepose(obj, ann_dict) | |
objs.append(obj) | |
record["annotations"] = objs | |
dataset_dicts.append(record) | |
return dataset_dicts | |
def _load_coco_api(): | |
# _load_coco_annotations | |
json_file = "data/mathand/mathand.json" | |
coco_api = COCO(json_file) | |
return coco_api | |
def get_cat_info(coco_api, category_name): | |
cats = coco_api.loadCats(coco_api.getCatIds()) | |
for cat_d in cats: | |
cat_id = cat_d['id'] | |
name = cat_d['name'] | |
keypoints = cat_d["keypoints"] | |
if category_name == name: | |
return cat_id, keypoints | |
raise ValueError(f"Not found: category_name={category_name}") | |
def coco_load_test(coco_api, cat_id, image_root, dataset_name=None): | |
#_add_categories_metadata(dataset_name, | |
# sort indices for reproducible results | |
img_ids = sorted(coco_api.imgs.keys()) | |
imgs = coco_api.loadImgs(img_ids) | |
anns = [coco_api.imgToAnns[img_id] for img_id in img_ids] | |
#_verify_annotations_have_unique_ids(annotations_json_file, anns) | |
dataset_records = _combine_images_with_annotations(dataset_name, image_root, cat_id, imgs, anns) | |
#import pprint | |
#pprint.pprint(dataset_records) | |
return dataset_records | |
def register_coco_custom_dataset(category_name): | |
coco_api = _load_coco_api() | |
cat_id, keypoints = get_cat_info(coco_api, category_name) | |
# 1. register a function which returns dicts | |
name = "mathand_train" | |
image_root = "data/mathand/images" | |
coco_load_test(coco_api, cat_id, image_root, name) | |
DatasetCatalog.register(name, lambda: coco_load_test(coco_api, cat_id, image_root, name)) | |
# 2. Optionally, add metadata about this dataset, | |
# since they might be useful in evaluation, visualization or logging | |
if cat_id == 1: | |
keypoint_flip_map = [('down-left', 'down-right')] | |
elif cat_id == 2: | |
keypoint_flip_map = [('index-finger-tip', 'index-finger-mcp')] | |
else: | |
raise ValueError(f"Illegal cat_id={cat_id}") | |
print(f"Cat={cat_id} keypoints={keypoints} keypoint_flip_map={keypoint_flip_map}") | |
MetadataCatalog.get(name).set( | |
image_root=image_root, evaluator_type="coco", keypoint_names=keypoints, keypoint_flip_map=keypoint_flip_map | |
) | |
#DatasetCatalog.register("mathand_train", dataset_func) | |
#MetadataCatalog.get("mathand_train").set(thing_classes=["mat"], keypoint_names=["a", "b"], keypoint_flip_map=[("a", "b")]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment