Created
December 30, 2023 20:24
-
-
Save ritwikraha/804b6d9b835c6ee16a0ccf4fd5cffe95 to your computer and use it in GitHub Desktop.
Gradio Pixel Selector Utility
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gradio as gr | |
import numpy as np | |
import torch | |
from PIL import Image | |
''' | |
TODOs: | |
- Fetch the SAM model | |
- Fetch the inpainting model | |
- Initialize the pipeline | |
- Create the mask_generator from a SAM or other similar model | |
- Create relevant functions for inpainting | |
Reference: Abhishek Thakur's YouTube Video: https://www.youtube.com/watch?v=CERvlvUvVEI&t=764s | |
''' | |
# Initialize a Gradio demo for pixel selection | |
with gr.Blocks() as demo: | |
# Display a title using Markdown | |
gr.Markdown("# Pixel Selector using Gradio") | |
# Define a state to store selected pixels | |
selected_pixels = gr.State([]) | |
# Create a row for image inputs | |
with gr.Row(): | |
input_img = gr.Image(label="Input") | |
mask_img = gr.Image(label="Mask", interactive=False) | |
seg_img = gr.Image(label="Segmentation", interactive=False) | |
output_img = gr.Image(label="Output", interactive=False) | |
# Create a row for text input | |
with gr.Row(): | |
prompt_text = gr.Textbox(lines=1, label="Prompt") | |
# Create a row for buttons | |
with gr.Row(): | |
submit_btn = gr.Button("Submit") | |
clear_btn = gr.Button("Clear") | |
# Define a function to generate a mask based on selected pixels | |
def generate_mask(image, selected_pixels, event: gr.SelectData): | |
selected_pixels.append(event.index) | |
predictor.set_image(image) | |
input_point = np.array(selected_pixels) | |
input_label = np.ones(input_point.shape[0]) | |
mask, _, _ = predictor.predict( | |
point_coords=input_point, | |
point_labels=input_label, | |
multimask_output=False, | |
) | |
# Clear torch cache | |
torch.cuda.empty_cache() | |
mask = Image.fromarray(mask[0, :, :]) | |
segmentations = mask_generator.generate(image) | |
boolean_masks = [s["segmentation"] for s in segmentations] | |
final_segmentation = np.zeros((boolean_masks[0].shape[0], boolean_masks[0].shape[1], 3), dtype=np.uint8) | |
torch.cuda.empty_cache() | |
return mask, final_segmentation | |
# Define a function to clear all selections and inputs | |
def clear_selection(selected_pixels, input_img, mask_img, seg_img, output_img, prompt_text): | |
selected_pixels = [] | |
img = None | |
mask = None | |
seg = None | |
out = None | |
prompt = "" | |
neg_prompt = "" | |
return img, mask, seg, out, prompt, neg_prompt | |
# Launch the Gradio demo | |
if __name__ == "__main__": | |
demo.launch() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment