Skip to content

Instantly share code, notes, and snippets.

@AdoHaha
Created July 29, 2025 17:46
Show Gist options
  • Save AdoHaha/77073795c0ebe8d90898919006b610e2 to your computer and use it in GitHub Desktop.
Save AdoHaha/77073795c0ebe8d90898919006b610e2 to your computer and use it in GitHub Desktop.
simple app to remove background using segment anything
import tkinter as tk
from tkinter import filedialog, messagebox
from PIL import Image, ImageTk
import numpy as np
import torch
from segment_anything import sam_model_registry, SamPredictor
import os
class ImageSegmenter(tk.Tk):
def __init__(self):
super().__init__()
self.title("Image Segmenter")
self.geometry("800x600")
self.image_path = None
self.original_image = None
self.image_tk = None
self.mask = None
self.predictor = None
self.mode = tk.StringVar(value="keep")
self.create_widgets()
self.load_model()
def create_widgets(self):
# Frame for buttons
button_frame = tk.Frame(self)
button_frame.pack(pady=10)
load_button = tk.Button(button_frame, text="Load Image", command=self.load_image)
load_button.pack(side=tk.LEFT, padx=5)
save_button = tk.Button(button_frame, text="Save Image", command=self.save_image)
save_button.pack(side=tk.LEFT, padx=5)
reset_button = tk.Button(button_frame, text="Reset", command=self.reset)
reset_button.pack(side=tk.LEFT, padx=5)
# Frame for radio buttons
radio_frame = tk.Frame(self)
radio_frame.pack(pady=5)
keep_radio = tk.Radiobutton(radio_frame, text="Keep", variable=self.mode, value="keep")
keep_radio.pack(side=tk.LEFT, padx=5)
remove_radio = tk.Radiobutton(radio_frame, text="Remove", variable=self.mode, value="remove")
remove_radio.pack(side=tk.LEFT, padx=5)
# Canvas for image
self.canvas = tk.Canvas(self, cursor="cross")
self.canvas.pack(fill=tk.BOTH, expand=True)
self.canvas.bind("<Button-1>", self.on_canvas_click)
def load_model(self):
# Load the SAM model
model_type = "vit_b"
checkpoint = "sam_vit_b_01ec64.pth"
if not os.path.exists(checkpoint):
messagebox.showinfo("Model Not Found", f"Please download the model checkpoint '{checkpoint}' and place it in the application directory.")
return
device = "cuda" if torch.cuda.is_available() else "cpu"
sam = sam_model_registry[model_type](checkpoint=checkpoint)
sam.to(device=device)
self.predictor = SamPredictor(sam)
messagebox.showinfo("Model Loaded", "SAM model loaded successfully.")
def load_image(self):
self.image_path = filedialog.askopenfilename()
if not self.image_path:
return
self.original_image = Image.open(self.image_path).convert("RGBA")
self.display_image(self.original_image)
self.mask = np.zeros((self.original_image.height, self.original_image.width), dtype=np.uint8)
# Set the image for the predictor
image_rgb = self.original_image.convert("RGB")
self.predictor.set_image(np.array(image_rgb))
def display_image(self, image):
self.image_tk = ImageTk.PhotoImage(image)
self.canvas.create_image(0, 0, anchor=tk.NW, image=self.image_tk)
self.canvas.config(scrollregion=self.canvas.bbox(tk.ALL))
def on_canvas_click(self, event):
if not self.predictor or not self.original_image:
return
input_point = np.array([[event.x, event.y]])
input_label = np.array([1])
masks, _, _ = self.predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=False,
)
if self.mode.get() == "keep":
self.mask = np.logical_or(self.mask, masks[0]).astype(np.uint8)
else: # remove
self.mask = np.logical_and(self.mask, ~masks[0]).astype(np.uint8)
self.update_display()
def update_display(self):
if self.original_image is None or self.mask is None:
return
# Create a transparent overlay
overlay = np.zeros((self.original_image.height, self.original_image.width, 4), dtype=np.uint8)
overlay[self.mask == 1] = [255, 0, 0, 128] # Red overlay for selected area
# Combine original image with overlay
display_image = self.original_image.copy()
display_image.paste(Image.fromarray(overlay, 'RGBA'), (0,0), Image.fromarray(overlay, 'RGBA'))
self.display_image(display_image)
def save_image(self):
if self.original_image is None or self.mask is None:
messagebox.showerror("Error", "No image or mask to save.")
return
save_path = filedialog.asksaveasfilename(defaultextension=".png", filetypes=[("PNG files", "*.png")])
if not save_path:
return
# Apply the mask to the original image
result_image = self.original_image.copy()
result_data = np.array(result_image)
result_data[:, :, 3] = self.mask * 255 # Set alpha channel based on the mask
Image.fromarray(result_data).save(save_path)
messagebox.showinfo("Success", f"Image saved to {save_path}")
def reset(self):
self.image_path = None
self.original_image = None
self.image_tk = None
self.mask = None
self.canvas.delete("all")
if self.predictor:
self.predictor.reset_image()
if __name__ == "__main__":
app = ImageSegmenter()
app.mainloop()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment