Skip to content

Instantly share code, notes, and snippets.

@tkshnkmr
Last active September 28, 2019 14:41
Show Gist options
  • Save tkshnkmr/d1404f0f714c832d4c815a713610e09c to your computer and use it in GitHub Desktop.
Save tkshnkmr/d1404f0f714c832d4c815a713610e09c to your computer and use it in GitHub Desktop.
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
def get_model_instance_segmentation(num_classes):
# load an instance segmentation model pre-trained pre-trained on COCO
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False)
# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
return model
# 2 classes; Only target class or background
num_classes = 2
num_epochs = 10
model = get_model_instance_segmentation(num_classes)
# move model to the right device
model.to(device)
# parameters
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
len_dataloader = len(data_loader)
for epoch in range(num_epochs):
model.train()
i = 0
for imgs, annotations in data_loader:
i += 1
imgs = list(img.to(device) for img in imgs)
annotations = [{k: v.to(device) for k, v in t.items()} for t in annotations]
loss_dict = model(imgs, annotations)
losses = sum(loss for loss in loss_dict.values())
optimizer.zero_grad()
losses.backward()
optimizer.step()
print(f'Iteration: {i}/{len_dataloader}, Loss: {losses}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment