Last active
May 29, 2024 07:32
-
-
Save LukeAI/6af4984c79a7534c9c1330958545367c to your computer and use it in GitHub Desktop.
How to process a dir of images with SAM and save visualisations of their masks
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
from __future__ import annotations | |
import os | |
from pathlib import Path | |
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor | |
import cv2 | |
import numpy as np | |
import torch | |
from tqdm import tqdm | |
# config | |
in_dir = 'my_images' | |
out_dir = 'segmented' | |
sam_model = "vit_l" | |
sam_check = "sam_vit_l_0b3195.pth" | |
#sam_model = "vit_h" | |
#sam_check = "sam_vit_h_4b8939.pth" | |
#sam_model = "vit_b" | |
#sam_check = "sam_vit_b_01ec64.pth" | |
device="cuda" | |
transparency = 0.3 | |
max_masks = 300 | |
# sam generator params | |
points_per_batch=64 | |
points_per_side=64 | |
pred_iou_thresh=0.86 | |
stability_score_thresh=0.92 | |
crop_n_layers=1 | |
crop_n_points_downscale_factor=2 | |
min_mask_region_area=100 | |
# list of random colors | |
colors = [] | |
for i in range(max_masks): | |
colors.append(np.random.random((3))) | |
def draw_segmentation(anns): | |
if len(anns) == 0: | |
return | |
h, w = anns[0]['segmentation'].shape | |
image = np.zeros((h, w, 3), dtype=np.float64) | |
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) | |
no_masks = min(len(sorted_anns), max_masks) | |
for i in range(no_masks): | |
# true/false segmentation | |
seg = sorted_anns[i]['segmentation'] | |
# set this segmentation a random color | |
image[seg] = colors[i] | |
return image | |
def process_image(img_path, out_path, mask_generator): | |
image = cv2.imread(img_path) | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
# mask generator wants the default uint8 image | |
masks = mask_generator.generate(image) | |
# convert to float64 | |
image = image.astype(np.float64) / 255 | |
seg = draw_segmentation(masks) | |
# add segmentation image on top of original image | |
image += transparency * seg | |
# convert back to uint8 for display/save | |
image = (255 * image).astype(np.uint8) | |
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) | |
# cv2.imshow("my img", image) | |
# cv2.waitKey(-1) | |
cv2.imwrite(out_path, image) | |
if __name__ == "__main__": | |
# make sure output dir exists | |
if not os.path.exists(out_dir): | |
os.makedirs(out_dirs) | |
# load SAM model + create mask generator | |
sam = sam_model_registry[sam_model](checkpoint=sam_check) | |
sam.to(device=device) | |
sam = torch.compile(sam) | |
mask_generator = SamAutomaticMaskGenerator(sam, | |
points_per_side=points_per_side, | |
pred_iou_thresh=pred_iou_thresh, | |
stability_score_thresh=stability_score_thresh, | |
crop_n_layers=crop_n_layers, | |
crop_n_points_downscale_factor=crop_n_points_downscale_factor, | |
min_mask_region_area=min_mask_region_area) | |
# process input directory | |
for img in tqdm(os.listdir(in_dir)): | |
# change extension of output image to .png | |
out_img = Path(img).stem + ".png" | |
out_img = os.path.join(out_dir, out_img) | |
# if we can read/decode this file as an image | |
in_img = os.path.join(in_dir, img) | |
if cv2.haveImageReader(in_img): | |
process_image(in_img, out_img, mask_generator) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment