Skip to content

Instantly share code, notes, and snippets.

@CodeZombie
Created October 24, 2025 17:08
Show Gist options
  • Select an option

  • Save CodeZombie/b1752072a9bc1fe7457def40ecf259c1 to your computer and use it in GitHub Desktop.

Select an option

Save CodeZombie/b1752072a9bc1fe7457def40ecf259c1 to your computer and use it in GitHub Desktop.
ComfyUI "Cut Out Mask" node
from PIL import Image, ImageChops
import torch
import numpy as np
# TODO: Mask should be a MASK type, not IMAGE. Fix this.
# TODO: This doens't handle transparency nicely. If a mask has grey area, it will add darkness to the output transparent image. Fix that.
class CutMask:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
"mask": ("MASK",),
"invert": ("BOOLEAN", {"default": False})
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "cut_mask"
CATEGORY = "image/Jeremy Nodes"
def invert_greyscale_tensor(self, image_tensor):
"""
Inverts the colors of a greyscale tensor image.
"""
if image_tensor.is_floating_point():
inverted_tensor = 1.0 - image_tensor
elif image_tensor.dtype == torch.uint8:
inverted_tensor = 255 - image_tensor
else:
raise TypeError(f"Unsupported tensor dtype for inversion: {image_tensor.dtype}. "
"Expected float (e.g., torch.float32) or torch.uint8.")
return inverted_tensor
def cut_mask(self, image, mask, invert=False):
"""
Removes pixels from a base image (with transparency) that are not covered by a mask image.
"""
if invert:
mask = self.invert_greyscale_tensor(mask)
# Ensure all tensors are on the same device
device = image.device
mask = mask.to(device)
# 1. Normalize pixel values to [0, 1] if they are in [0, 255]
if image.dtype == torch.uint8:
image = image.float() / 255.0
if mask.dtype == torch.uint8:
mask = mask.float() / 255.0
# Ensure mask has a channel dimension if it's (1, H, W)
if mask.ndim == 3: # (1, H, W)
mask = mask.unsqueeze(-1) # -> (1, H, W, 1)
# Ensure mask dimensions match base image's spatial dimensions
# Assuming image is (B, H, W, C) and mask is (B, H, W, 1)
if image.shape[1:3] != mask.shape[1:3]:
# Simple resize (bilinear for float data)
# Permute to (B, C, H, W) for torch.nn.functional.interpolate
original_shape = image.shape
image_permuted = image.permute(0, 3, 1, 2)
mask_permuted = mask.permute(0, 3, 1, 2)
# Resize mask to base image's H, W
mask_resized = torch.nn.functional.interpolate(
mask_permuted,
size=(original_shape[1], original_shape[2]),
mode='bilinear',
align_corners=False # Set to True if you want corner pixel alignment
)
# Permute back to (B, H, W, C) or (B, H, W, 1)
mask = mask_resized.permute(0, 2, 3, 1)
# 2. Extract the existing alpha channel
# image is (B, H, W, 4) -> alpha_channel is (B, H, W, 1)
alpha_channel = image[..., 3:4] # Slice to keep the channel dimension
# 3. Create a binary mask from the greyscale mask
# For a greyscale mask, values > 0 (or a threshold) mean "covered"
# This creates a mask tensor with values 0.0 or 1.0 (or 0.0 to 1.0 based on intensity)
# You might want to apply a threshold here if your mask isn't strictly binary
# For example: `binary_mask = (mask > 0.5).float()`
# Let's assume the mask's intensity directly corresponds to desired opacity
processed_mask = mask # Already normalized to [0, 1]
# 4. Combine the existing alpha channel with the processed mask
# The new alpha should be the minimum of the existing alpha and the mask's intensity.
# This ensures that if a pixel was already transparent (low alpha) OR
# not covered by the mask (low mask intensity), it becomes transparent.
new_alpha_channel = torch.min(alpha_channel, processed_mask)
# 5. Replace the original alpha channel with the new one
# Create a copy to avoid modifying the original tensor in place if not desired
output_image_tensor = image.clone()
output_image_tensor[..., 3:4] = new_alpha_channel
# If original input was uint8, convert back
if image.dtype == torch.uint8:
output_image_tensor = (output_image_tensor * 255).to(torch.uint8)
return (output_image_tensor, )
NODE_CLASS_MAPPINGS = {
"CutOutMask": Cut_Out_Mask,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"CutOutMask": "Cut Out Mask",
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment