Skip to content

Instantly share code, notes, and snippets.

Created January 3, 2022 20:55
Show Gist options
  • Save manzke/0a8bec5e13564e1a3f69fbd78e734493 to your computer and use it in GitHub Desktop.
Save manzke/0a8bec5e13564e1a3f69fbd78e734493 to your computer and use it in GitHub Desktop.
Using MODNet for Background Removal
import os
import sys
import argparse
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from src.models.modnet import MODNet
def remove_background(image, matte):
# obtain predicted foreground
image = np.asarray(image)
if len(image.shape) == 2:
image = image[:, :, None]
if image.shape[2] == 1:
image = np.repeat(image, 3, axis=2)
elif image.shape[2] == 4:
image = image[:, :, 0:3]
matte = np.repeat(np.asarray(matte)[:, :, None], 3, axis=2) / 255
foreground = image * matte + np.full(image.shape, 255) * (1 - matte)
return Image.fromarray(np.uint8(foreground))
if __name__ == '__main__':
# define cmd arguments
parser = argparse.ArgumentParser()
parser.add_argument('--input-path', type=str, help='path of input images')
parser.add_argument('--output-path', type=str, help='path of output images')
parser.add_argument('--ckpt-path', type=str, help='path of pre-trained MODNet')
args = parser.parse_args()
# check input arguments
if not os.path.exists(args.input_path):
print('Cannot find input path: {0}'.format(args.input_path))
if not os.path.exists(args.output_path):
print('Cannot find output path: {0}'.format(args.output_path))
if not os.path.exists(args.ckpt_path):
print('Cannot find ckpt path: {0}'.format(args.ckpt_path))
# define hyper-parameters
ref_size = 512
# define image to tensor transform
im_transform = transforms.Compose(
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
# create MODNet and load the pre-trained ckpt
modnet = MODNet(backbone_pretrained=False)
modnet = nn.DataParallel(modnet)
if torch.cuda.is_available():
modnet = modnet.cuda()
weights = torch.load(args.ckpt_path)
weights = torch.load(args.ckpt_path, map_location=torch.device('cpu'))
# inference images
im_names = os.listdir(args.input_path)
for im_name in im_names:
print('Process image: {0}'.format(im_name))
# read image
im =, im_name))
# unify image channels to 3
im = np.asarray(im)
if len(im.shape) == 2:
im = im[:, :, None]
if im.shape[2] == 1:
im = np.repeat(im, 3, axis=2)
elif im.shape[2] == 4:
im = im[:, :, 0:3]
# convert image to PyTorch tensor
im = Image.fromarray(im)
im = im_transform(im)
# add mini-batch dim
im = im[None, :, :, :]
# resize image for input
im_b, im_c, im_h, im_w = im.shape
if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size:
if im_w >= im_h:
im_rh = ref_size
im_rw = int(im_w / im_h * ref_size)
elif im_w < im_h:
im_rw = ref_size
im_rh = int(im_h / im_w * ref_size)
im_rh = im_h
im_rw = im_w
im_rw = im_rw - im_rw % 32
im_rh = im_rh - im_rh % 32
im = F.interpolate(im, size=(im_rh, im_rw), mode='area')
# inference
_, _, matte = modnet(im.cuda() if torch.cuda.is_available() else im, True)
# resize and save matte
matte = F.interpolate(matte, size=(im_h, im_w), mode='area')
matte = matte[0][0].data.cpu().numpy()
matte_name = im_name.split('.')[0] + '.png'
foreground_name = im_name.split('.')[0] + '.foreground.png'
Image.fromarray(((matte * 255).astype('uint8')), mode='L').save(os.path.join(args.output_path, matte_name))
foreground = remove_background(, im_name)),, matte_name))), foreground_name))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment