Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save PallawiSinghal/8b33a517b2dcf037680afc9c119e9b9d to your computer and use it in GitHub Desktop.
Save PallawiSinghal/8b33a517b2dcf037680afc9c119e9b9d to your computer and use it in GitHub Desktop.
# -*- coding: utf-8 -*-
import os
from detectron2.utils.logger import setup_logger
setup_logger()
from detectron2 import model_zoo
from detectron2.config import get_cfg
from detectron2.data.datasets import register_coco_instances
from detectron2.engine import DefaultTrainer
register_coco_instances("my_dataset_train", {}, "/code/detectron2/detectron2/instances_train2017.json", "/code/detectron2/detectron2/train2017")
register_coco_instances("my_dataset_val", {}, "/code/detectron2/detectron2/instances_val2017.json", "/code/detectron2/detectron2/val2017")
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml"))
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml")
cfg.DATASETS.TRAIN = ("my_dataset_train",)
cfg.DATASETS.TEST = ()
cfg.DATALOADER.NUM_WORKERS = 2
cfg.SOLVER.IMS_PER_BATCH = 1
cfg.SOLVER.BASE_LR = 0.00025
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
cfg.INPUT.MAX_SIZE_TRAIN = 1333
cfg.INPUT.MIN_SIZE_TRAIN = (1280,)
cfg.MAX_SIZE_TEST: 1333
cfg.MIN_SIZE_TEST: 1280
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
print(cfg.dump())
with open("/code/detectron2/detectron2/output/custom_mask_rcnn_X_101_32x8d_FPN_3x_my_dataset.yaml", "w") as f:
f.write(cfg.dump())
trainer = DefaultTrainer(cfg)
trainer.resume_or_load(resume=False)
trainer.train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment