Created
July 29, 2025 17:46
-
-
Save AdoHaha/77073795c0ebe8d90898919006b610e2 to your computer and use it in GitHub Desktop.
simple app to remove background using segment anything
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tkinter as tk | |
from tkinter import filedialog, messagebox | |
from PIL import Image, ImageTk | |
import numpy as np | |
import torch | |
from segment_anything import sam_model_registry, SamPredictor | |
import os | |
class ImageSegmenter(tk.Tk): | |
def __init__(self): | |
super().__init__() | |
self.title("Image Segmenter") | |
self.geometry("800x600") | |
self.image_path = None | |
self.original_image = None | |
self.image_tk = None | |
self.mask = None | |
self.predictor = None | |
self.mode = tk.StringVar(value="keep") | |
self.create_widgets() | |
self.load_model() | |
def create_widgets(self): | |
# Frame for buttons | |
button_frame = tk.Frame(self) | |
button_frame.pack(pady=10) | |
load_button = tk.Button(button_frame, text="Load Image", command=self.load_image) | |
load_button.pack(side=tk.LEFT, padx=5) | |
save_button = tk.Button(button_frame, text="Save Image", command=self.save_image) | |
save_button.pack(side=tk.LEFT, padx=5) | |
reset_button = tk.Button(button_frame, text="Reset", command=self.reset) | |
reset_button.pack(side=tk.LEFT, padx=5) | |
# Frame for radio buttons | |
radio_frame = tk.Frame(self) | |
radio_frame.pack(pady=5) | |
keep_radio = tk.Radiobutton(radio_frame, text="Keep", variable=self.mode, value="keep") | |
keep_radio.pack(side=tk.LEFT, padx=5) | |
remove_radio = tk.Radiobutton(radio_frame, text="Remove", variable=self.mode, value="remove") | |
remove_radio.pack(side=tk.LEFT, padx=5) | |
# Canvas for image | |
self.canvas = tk.Canvas(self, cursor="cross") | |
self.canvas.pack(fill=tk.BOTH, expand=True) | |
self.canvas.bind("<Button-1>", self.on_canvas_click) | |
def load_model(self): | |
# Load the SAM model | |
model_type = "vit_b" | |
checkpoint = "sam_vit_b_01ec64.pth" | |
if not os.path.exists(checkpoint): | |
messagebox.showinfo("Model Not Found", f"Please download the model checkpoint '{checkpoint}' and place it in the application directory.") | |
return | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
sam = sam_model_registry[model_type](checkpoint=checkpoint) | |
sam.to(device=device) | |
self.predictor = SamPredictor(sam) | |
messagebox.showinfo("Model Loaded", "SAM model loaded successfully.") | |
def load_image(self): | |
self.image_path = filedialog.askopenfilename() | |
if not self.image_path: | |
return | |
self.original_image = Image.open(self.image_path).convert("RGBA") | |
self.display_image(self.original_image) | |
self.mask = np.zeros((self.original_image.height, self.original_image.width), dtype=np.uint8) | |
# Set the image for the predictor | |
image_rgb = self.original_image.convert("RGB") | |
self.predictor.set_image(np.array(image_rgb)) | |
def display_image(self, image): | |
self.image_tk = ImageTk.PhotoImage(image) | |
self.canvas.create_image(0, 0, anchor=tk.NW, image=self.image_tk) | |
self.canvas.config(scrollregion=self.canvas.bbox(tk.ALL)) | |
def on_canvas_click(self, event): | |
if not self.predictor or not self.original_image: | |
return | |
input_point = np.array([[event.x, event.y]]) | |
input_label = np.array([1]) | |
masks, _, _ = self.predictor.predict( | |
point_coords=input_point, | |
point_labels=input_label, | |
multimask_output=False, | |
) | |
if self.mode.get() == "keep": | |
self.mask = np.logical_or(self.mask, masks[0]).astype(np.uint8) | |
else: # remove | |
self.mask = np.logical_and(self.mask, ~masks[0]).astype(np.uint8) | |
self.update_display() | |
def update_display(self): | |
if self.original_image is None or self.mask is None: | |
return | |
# Create a transparent overlay | |
overlay = np.zeros((self.original_image.height, self.original_image.width, 4), dtype=np.uint8) | |
overlay[self.mask == 1] = [255, 0, 0, 128] # Red overlay for selected area | |
# Combine original image with overlay | |
display_image = self.original_image.copy() | |
display_image.paste(Image.fromarray(overlay, 'RGBA'), (0,0), Image.fromarray(overlay, 'RGBA')) | |
self.display_image(display_image) | |
def save_image(self): | |
if self.original_image is None or self.mask is None: | |
messagebox.showerror("Error", "No image or mask to save.") | |
return | |
save_path = filedialog.asksaveasfilename(defaultextension=".png", filetypes=[("PNG files", "*.png")]) | |
if not save_path: | |
return | |
# Apply the mask to the original image | |
result_image = self.original_image.copy() | |
result_data = np.array(result_image) | |
result_data[:, :, 3] = self.mask * 255 # Set alpha channel based on the mask | |
Image.fromarray(result_data).save(save_path) | |
messagebox.showinfo("Success", f"Image saved to {save_path}") | |
def reset(self): | |
self.image_path = None | |
self.original_image = None | |
self.image_tk = None | |
self.mask = None | |
self.canvas.delete("all") | |
if self.predictor: | |
self.predictor.reset_image() | |
if __name__ == "__main__": | |
app = ImageSegmenter() | |
app.mainloop() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment