Created
December 29, 2021 08:20
-
-
Save bendangnuksung/af1a8a00cc4ca6bdc4929e0ffd991eed to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from detectron2.utils.logger import setup_logger | |
from glob import glob | |
setup_logger() | |
import copy | |
from detectron2.evaluation import COCOEvaluator, inference_on_dataset | |
from detectron2.data import build_detection_test_loader | |
from detectron2 import model_zoo | |
from detectron2.config import get_cfg | |
from detectron2.config import CfgNode as CN | |
from detectron2.data.datasets import register_coco_instances | |
from detectron2.engine import DefaultTrainer | |
from detectron2.evaluation import COCOEvaluator | |
from detectron2.evaluation.evaluator import DatasetEvaluator | |
from detectron2.data import MetadataCatalog, DatasetCatalog | |
import os | |
import pycocotools.mask as mask_util | |
import numpy as np | |
import cv2 | |
from detectron2.data import detection_utils as utils | |
import torch | |
import detectron2.data.transforms as T | |
from fastcore.all import * | |
from ensemble_boxes import * | |
from detectron2.utils.logger import setup_logger | |
from detectron2.data import build_detection_train_loader | |
INPUT_MASK = 'polygon' | |
# INPUT_MASK = 'bitmask' | |
MIN_PIXELS = [175, 75, 75] | |
CLASS_CONFIDENCE_THRESHOLDS = [.56, .71, .27] | |
IOU_TH = 0.4 | |
# import some common detectron2 utilities | |
# from detectron2.engine import DefaultPredictor | |
# from detectron2.utils.visualizer import Visualizer | |
# from detectron2.data import MetadataCatalog, DatasetCatalog | |
# Taken from https://www.kaggle.com/theoviel/competition-metric-map-iou | |
def precision_at(threshold, iou): | |
matches = iou > threshold | |
if matches.shape == (0,): | |
matches = np.array([[]]) | |
true_positives = np.sum(matches, axis=1) == 1 # Correct objects | |
false_positives = np.sum(matches, axis=0) == 0 # Missed objects | |
false_negatives = np.sum(matches, axis=1) == 0 # Extra objects | |
return np.sum(true_positives), np.sum(false_positives), np.sum(false_negatives) | |
def score_old(pred, targ): | |
pred_masks = pred['instances'].pred_masks.cpu().numpy() | |
enc_preds = [mask_util.encode(np.asarray(p, order='F')) for p in pred_masks] | |
enc_targs = list(map(lambda x:x['segmentation'], targ)) | |
ious = mask_util.iou(enc_preds, enc_targs, [0]*len(enc_targs)) | |
prec = [] | |
for t in np.arange(0.5, 1.0, 0.05): | |
tp, fp, fn = precision_at(t, ious) | |
p = tp / (tp + fp + fn) | |
prec.append(p) | |
return np.mean(prec) | |
def convert_contours_to_masks(contours, image_size=(520,704)): | |
masks = [] | |
for c in contours: | |
c = np.asarray(c) | |
c = np.squeeze(c) | |
c = c.reshape(-1,2) | |
mask = np.zeros(image_size, np.uint8) | |
mask = cv2.drawContours(mask, [c], -1, 255, -1) | |
mask = mask.astype(np.bool) | |
masks.append(mask) | |
return np.array(masks) | |
def nms_predictions(classes, scores, bboxes, masks, | |
iou_th, shape=(520, 704)): | |
# print(len(classes)) | |
he, wd = shape[0], shape[1] | |
boxes_list = [[x[0] / wd, x[1] / he, x[2] / wd, x[3] / he] | |
for x in bboxes] | |
scores_list = [x for x in scores] | |
labels_list = [x for x in classes] | |
nms_bboxes, nms_scores, nms_classes = nms( | |
boxes=[boxes_list], | |
scores=[scores_list], | |
labels=[labels_list], | |
weights=None, | |
iou_thr=iou_th | |
) | |
nms_masks = [] | |
for s in nms_scores: | |
nms_masks.append(masks[scores.index(s)]) | |
return nms_classes, nms_scores, nms_masks | |
def get_pred_class_mode(pred_classes, wanted_pred_class_list=[0,1,2], default_pred_class=0): | |
pc = pred_classes.detach().cpu().numpy() | |
unique, counts = np.unique(pc, return_counts=True) | |
classes = dict(zip(unique, counts)) | |
classes = dict(sorted(classes.items(), key=lambda item: item[1], reverse=True)) | |
for cls in classes.keys(): | |
if cls in wanted_pred_class_list: | |
return cls | |
return default_pred_class | |
def iou_and_min_pixel_process(pred): | |
# pred_class = torch.mode(pred['instances'].pred_classes)[0] | |
pred_class = get_pred_class_mode(pred['instances'].pred_classes) | |
take = pred['instances'].scores >= CLASS_CONFIDENCE_THRESHOLDS[pred_class] | |
pred_masks = pred['instances'].pred_masks[take].cpu().numpy() | |
classes = pred['instances'].pred_classes[take].cpu().numpy().tolist() | |
scores = pred['instances'].scores[take].cpu().numpy().tolist() | |
bboxes = pred['instances'].pred_boxes[take].tensor.cpu().numpy().tolist() | |
if len(classes): | |
classes, scores, pred_masks = nms_predictions(classes, scores, bboxes, pred_masks, iou_th=IOU_TH) | |
final_masks = [] | |
# used = np.zeros(im.shape[:2], dtype=int) | |
used = np.zeros([520, 704], dtype=int) | |
for mask in pred_masks: | |
mask = mask * (1-used) | |
if mask.sum() >= MIN_PIXELS[pred_class]: # skip predictions with small area | |
final_masks.append(mask) | |
return np.array(final_masks).astype(np.bool) | |
def score(pred, targ_ann, coco_annotation_segmentation='rle'): | |
#print("*"*80) | |
#print(targ) | |
#print("*"*80) | |
# pred_masks = pred['instances'].pred_masks.cpu().numpy() | |
pred_masks = iou_and_min_pixel_process(pred) | |
#enc_preds = [mask_util.encode(np.asarray(p, order='F')) for p in pred_masks] | |
targ = {'annotations': targ_ann} | |
#enc_targs = list(map(lambda x:x['segmentation'], targ['annotations'])) | |
#enc_targs = convert_contours_to_masks(enc_targs) | |
#enc_targs = [mask_util.encode(np.asarray(p, order='F')) for p in enc_targs] | |
#ious = mask_util.iou(enc_preds, enc_targs, [0]*len(enc_targs)) | |
if coco_annotation_segmentation == 'rle': | |
# if annotation in rle | |
enc_preds = [mask_util.encode(np.asarray(p, order='F')) for p in pred_masks] | |
enc_targs = list(map(lambda x:x['segmentation'], targ['annotations'])) | |
else: | |
# if annotation in polygon | |
enc_preds = [mask_util.encode(np.asarray(p, order='F')) for p in pred_masks] | |
enc_targs = list(map(lambda x:x['segmentation'], targ['annotations'])) | |
enc_targs = convert_contours_to_masks(enc_targs) | |
enc_targs = [mask_util.encode(np.asarray(p, order='F')) for p in enc_targs] | |
ious = mask_util.iou(enc_preds, enc_targs, [0]*len(enc_targs)) | |
prec = [] | |
for t in np.arange(0.5, 1.0, 0.05): | |
tp, fp, fn = precision_at(t, ious) | |
p = tp / (tp + fp + fn) | |
prec.append(p) | |
return np.mean(prec) | |
class MAPIOUEvaluator(DatasetEvaluator): | |
def __init__(self, dataset_name): | |
dataset_dicts = DatasetCatalog.get(dataset_name) | |
self.annotations_cache = {item['image_id']:item['annotations'] for item in dataset_dicts} | |
def reset(self): | |
self.scores = [] | |
def process(self, inputs, outputs): | |
for inp, out in zip(inputs, outputs): | |
if len(out['instances']) == 0: | |
self.scores.append(0) | |
else: | |
targ = self.annotations_cache[inp['image_id']] | |
self.scores.append(score(out, targ)) | |
def evaluate(self): | |
return {"MaP IoU": np.mean(self.scores)} | |
SIZE = 704 | |
HEIGHT = 520 | |
WIDTH = 704 | |
import random | |
# def get_random_transform_list(height, width): | |
def get_random_transform_list(): | |
transform_list = [ | |
T.RandomFlip(prob=0.5, horizontal=False, vertical=True), | |
T.RandomFlip(prob=0.5, horizontal=True, vertical=False), | |
] | |
if random.randint(0, 1): | |
if random.randint(0, 3) == 0: | |
transform_list.append(T.RandomRotation(angle=[1.0, 180.0])) | |
if random.randint(0, 3) == 0: | |
r_crop_size_height = random.randint(HEIGHT - 120, HEIGHT - 60) | |
r_crop_size_width = random.randint(WIDTH - 140, WIDTH - 80) | |
transform_list.append(T.RandomCrop("absolute", (r_crop_size_height, r_crop_size_width))) | |
if random.randint(0, 3) == 0: | |
r_n = random.randint(0, 3) | |
if r_n == 0: | |
transform_list.append(T.RandomBrightness(0.8, 1.2)) | |
if r_n == 1: | |
transform_list.append(T.RandomContrast(0.9, 1.2)) | |
if r_n == 2: | |
transform_list.append(T.RandomSaturation(0.9, 1.1)) | |
if r_n == 3: | |
transform_list.append(T.RandomLighting(0.9)) | |
if random.randint(0, 3) == 0: | |
transform_list.append(T.ResizeScale(min_scale=0.5, max_scale=0.85, target_height=HEIGHT, target_width=WIDTH)) | |
transform_list.append(T.Resize([HEIGHT, WIDTH])) | |
return transform_list | |
def custom_mapper(dataset_dict): | |
# global INPUT_MASK | |
# print("*"*60) | |
# print(dataset_dict.keys()) | |
dataset_dict = copy.deepcopy(dataset_dict) | |
image = utils.read_image(dataset_dict["file_name"], format="BGR") | |
transform_list = get_random_transform_list() | |
image, transforms = T.apply_transform_gens(transform_list, image) | |
dataset_dict["image"] = torch.as_tensor(image.transpose(2, 0, 1).astype("float32")) | |
# annos = [ | |
# utils.transform_instance_annotations(obj, transforms, image.shape[:2]) | |
# for obj in dataset_dict.pop("annotations") | |
# if obj.get("iscrowd", 0) == 0 | |
# ] | |
annos = [] | |
for obj in dataset_dict["annotations"]: | |
ann = utils.transform_instance_annotations(obj, transforms, image.shape[:2]) | |
if obj.get("iscrowd", 0) == 0 and ann["segmentation"].max() > 0: | |
annos.append(ann) | |
instances = utils.annotations_to_instances(annos, image.shape[:2], mask_format='bitmask') | |
dataset_dict["instances"] = utils.filter_empty_instances(instances) | |
# print(dataset_dict["instances"]) | |
# exit() | |
return dataset_dict | |
class CocoTrainer(DefaultTrainer): | |
#@classmethod | |
#def build_evaluator(cls, cfg, dataset_name, output_folder=None): | |
# if output_folder is None: | |
# os.makedirs("coco_eval", exist_ok=True) | |
# output_folder = "coco_eval" | |
#return COCOEvaluator(dataset_name, cfg, False, output_folder) | |
# @classmethod | |
# def build_train_loader(cls, cfg): | |
# return build_detection_train_loader(cfg, mapper=custom_mapper) | |
@classmethod | |
def build_evaluator(cls, cfg, dataset_name, output_folder=None): | |
return MAPIOUEvaluator(dataset_name) | |
def setup_config(args): | |
global INPUT_MASK | |
n_images = len(glob(os.path.join(args['train_image'], 'train/*.png'))) | |
steps_per_epoch = n_images // eval(args['batch']) | |
cfg = get_cfg() | |
try: | |
cfg.merge_from_file(model_zoo.get_config_file(args['config_file'])) | |
except: | |
cfg.merge_from_file(args['config_file']) | |
PERIOD = 400 | |
cfg.DATASETS.TRAIN = ("car_parts_train",) | |
cfg.DATASETS.TEST = ("car_parts_val",) | |
cfg.DATASETS.VALIDATION = ("car_parts_val",) | |
cfg.TEST.EVAL_PERIOD = PERIOD | |
#cfg.TEST.DETECTIONS_PER_IMAGE = 100 # added | |
cfg.DATALOADER.NUM_WORKERS = eval(args['worker']) | |
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(args['config_file']) | |
cfg.SOLVER.IMS_PER_BATCH = eval(args['batch']) | |
cfg.SOLVER.CHECKPOINT_PERIOD = PERIOD | |
cfg.SOLVER.BASE_LR = eval(args['learning_rate']) | |
cfg.SOLVER.MAX_ITER = (steps_per_epoch * eval(args['epoch'])) | |
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = eval(args['size']) | |
cfg.MODEL.ROI_HEADS.NUM_CLASSES = eval(args['classes']) # only has one class (ballon) | |
cfg.MODEL.MASK_ON = True | |
cfg.OUTPUT_DIR = args['output'] | |
cfg.INPUT.RANDOM_FLIP = "horizontal" | |
# cfg.INPUT.RANDOM_FLIP = "vertical" | |
cfg.INPUT.CROP = CN({"ENABLED": True}) | |
# cfg.INPUT.CROP.TYPE = "absolute" | |
# cfg.INPUT.CROP.SIZE = [460, 620] | |
cfg.INPUT.CROP.TYPE = "relative" | |
cfg.INPUT.CROP.SIZE = [0.8, 0.8] | |
# cfg.INPUT.RANDOM_FLIP = "vertical" | |
# cfg.INPUT.CROP = CN({'ENABLED': True, 'TYPE': 'relative_range', 'SIZE': [0.8, 0.75]}) | |
# size change | |
# cfg.INPUT.MAX_SIZE_TEST = SIZE | |
# cfg.INPUT.MAX_SIZE_TRAIN = SIZE | |
# cfg.INPUT.MIN_SIZE_TEST = SIZE | |
# cfg.INPUT.MIN_SIZE_TRAIN = (SIZE) | |
setup_logger(args['output']) | |
cfg.INPUT.MASK_FORMAT = 'bitmask' | |
if cfg.INPUT.MASK_FORMAT == 'bitmask': | |
INPUT_MASK = 'rle' | |
######################################################################### | |
######################################################################### | |
# pretrained_model_path = '/camcom/ben/models/casx152/c10/4/top/model_0036499.pth' | |
# cfg.MODEL.WEIGHTS = pretrained_model_path | |
######################################################################### | |
######################################################################### | |
if True: | |
cfg.SOLVER.WARMUP_ITERS = 750 | |
# cfg.SOLVER.WEIGHT_DECAY = 0.0002 | |
cfg.SOLVER.MAX_ITER = 75000 | |
# cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 256 | |
######################################################################### | |
######################################################################### | |
######################################################################### | |
######################################################################### | |
# finetune existing model | |
# if True: | |
# cfg.SOLVER.BASE_LR = 0.000005 | |
# cfg.SOLVER.WARMUP_ITERS = 0 | |
# cfg.SOLVER.WEIGHT_DECAY = 0.0000005 | |
# cfg.MODEL.WEIGHTS = "/camcom/ben/models/casx152/c10/4/top/model_0036249.pth" | |
# cfg.TEST.EVAL_PERIOD = 200 | |
# cfg.SOLVER.CHECKPOINT_PERIOD = 200 | |
######################################################################### | |
######################################################################### | |
######################################################################### | |
######################################################################### | |
# extras | |
# cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128 | |
#cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 | |
# cfg.SOLVER.BASE_LR = 0.006 | |
return cfg | |
def train(args): | |
register_coco_instances("car_parts_train", {}, args['train_json'], args['train_image']) | |
register_coco_instances("car_parts_val", {}, args['val_json'], args['val_image']) | |
cfg = setup_config(args) | |
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True) | |
trainer = CocoTrainer(cfg) | |
resume = eval(args['resume']) | |
trainer.resume_or_load(resume=resume) | |
print("*" * 20) | |
print(cfg) | |
print("*" * 20) | |
trainer.train() | |
print("*"*50) | |
print("Finished Training") | |
print("*" * 50) | |
evaluator = COCOEvaluator("car_parts_val", cfg, False, output_dir="./output/") | |
val_loader = build_detection_test_loader(cfg, "car_parts_val") | |
print(inference_on_dataset(trainer.model, val_loader, evaluator)) | |
if __name__ == '__main__': | |
import argparse | |
parser = argparse.ArgumentParser(description="Detectron2 training parameters") | |
parser.add_argument("-tj", "--train_json", help="Path to train.json path") | |
parser.add_argument("-ti", "--train_image", help="Path to train images path") | |
parser.add_argument("-vj", "--val_json", help="Path to train.json path") | |
parser.add_argument("-vi", "--val_image", help="Path to train images path") | |
parser.add_argument("-e", "--epoch", help="Number of Epochs", default='100') | |
parser.add_argument("-b", "--batch", help="Number of batch images", default='6') | |
parser.add_argument("-lr", "--learning_rate", help="Number of batch images", default='0.0005') | |
parser.add_argument("-s", "--size", help="batch image size", default='512') | |
parser.add_argument("-c", "--classes", help="number of class labels", default='3') | |
parser.add_argument("-r", "--resume", help="resume from last training", default='False') | |
parser.add_argument("-w", "--worker", help="Number of workers", default='6') | |
parser.add_argument("-o", "--output", help="Output Directory path", default='./slawek') | |
parser.add_argument("-cfg", "--config_file", help="Model config YAML file", default='COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml') | |
args = parser.parse_args() | |
args = vars(args) | |
train(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment