Last active
November 17, 2023 19:51
-
-
Save humpydonkey/905bde3156d4c0119de33c9f3b474be7 to your computer and use it in GitHub Desktop.
Extract segmentation masks via rmgb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import numpy as np | |
import PIL.Image | |
import matplotlib.pyplot as plt | |
from datasets import load_dataset | |
from rembg import remove | |
import cv2 | |
def show_mask(mask, ax, random_color=False): | |
if random_color: | |
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) | |
else: | |
color = np.array([30/255, 144/255, 1, 0.6]) | |
h, w = mask.shape[-2:] | |
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) | |
ax.imshow(mask_image) | |
kernel = np.ones((9,9),np.uint8) | |
def get_largest_component(image: np.ndarray) -> np.ndarray: | |
image = image.astype('uint8') | |
nb_components, output, stats, centroids = cv2.connectedComponentsWithStats(image, connectivity=4) | |
sizes = stats[:, -1] | |
max_label = 1 | |
max_size = sizes[1] | |
for i in range(2, nb_components): | |
if sizes[i] > max_size: | |
max_label = i | |
max_size = sizes[i] | |
result_img = np.zeros(output.shape) | |
result_img[output == max_label] = 255 | |
return result_img == 255 | |
def get_mask(img: PIL.Image.Image) -> np.ndarray: | |
"""Get the mask from the input image | |
1. remove the background of the input image | |
2. convert the output of step #1 to grayscale as a mask | |
3. get the largest connected component of step #2 as the mask | |
""" | |
out = remove(img).convert("L") | |
mask = np.array(out) | |
mask[mask > 0] = 255 | |
processed = mask | |
# processed = cv2.morphologyEx(processed, cv2.MORPH_OPEN, kernel) | |
processed = cv2.morphologyEx(processed, cv2.MORPH_CLOSE, kernel) | |
# Convert to a boolean mask | |
processed = processed > 0 | |
return get_largest_component(processed) | |
def evaluate_all(dataset): | |
# Run through all images and check the quality of the mask | |
for i, row in enumerate(dataset): | |
image = row["image"] | |
# mask = get_best_mask(image) | |
mask = get_mask(image) | |
plt.figure(figsize=(10,10)) | |
plt.title(f"Row index: {i}", fontsize=18) | |
plt.imshow(image) | |
show_mask(mask, plt.gca()) | |
# show_points(get_points(image)[0], plt.gca()) | |
# show_box(get_boxes(image), plt.gca()) | |
plt.axis('off') | |
plt.show() | |
def evaluate_one(dataset): | |
img=dataset[28]["image"] | |
mask = get_mask(img) | |
# PIL.Image.fromarray(mask) | |
plt.figure(figsize=(10,10)) | |
plt.imshow(img) | |
show_mask(mask, plt.gca()) | |
if __main__: | |
# Test on a single image | |
dataset = load_dataset("Raspberry-ai/monse-v4")["train"] | |
evaluate_one(dataset) | |
# Run extract | |
new_dataset = dataset.map(lambda row: { | |
"image": row["image"], | |
"text": row["text"], | |
"mask": PIL.Image.fromarray(get_mask(row["image"])) | |
}, batched=False) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment