Created
October 24, 2025 17:08
-
-
Save CodeZombie/b1752072a9bc1fe7457def40ecf259c1 to your computer and use it in GitHub Desktop.
ComfyUI "Cut Out Mask" node
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from PIL import Image, ImageChops | |
| import torch | |
| import numpy as np | |
| # TODO: Mask should be a MASK type, not IMAGE. Fix this. | |
| # TODO: This doens't handle transparency nicely. If a mask has grey area, it will add darkness to the output transparent image. Fix that. | |
| class CutMask: | |
| def __init__(self): | |
| pass | |
| @classmethod | |
| def INPUT_TYPES(cls): | |
| return { | |
| "required": { | |
| "image": ("IMAGE",), | |
| "mask": ("MASK",), | |
| "invert": ("BOOLEAN", {"default": False}) | |
| }, | |
| } | |
| RETURN_TYPES = ("IMAGE",) | |
| FUNCTION = "cut_mask" | |
| CATEGORY = "image/Jeremy Nodes" | |
| def invert_greyscale_tensor(self, image_tensor): | |
| """ | |
| Inverts the colors of a greyscale tensor image. | |
| """ | |
| if image_tensor.is_floating_point(): | |
| inverted_tensor = 1.0 - image_tensor | |
| elif image_tensor.dtype == torch.uint8: | |
| inverted_tensor = 255 - image_tensor | |
| else: | |
| raise TypeError(f"Unsupported tensor dtype for inversion: {image_tensor.dtype}. " | |
| "Expected float (e.g., torch.float32) or torch.uint8.") | |
| return inverted_tensor | |
| def cut_mask(self, image, mask, invert=False): | |
| """ | |
| Removes pixels from a base image (with transparency) that are not covered by a mask image. | |
| """ | |
| if invert: | |
| mask = self.invert_greyscale_tensor(mask) | |
| # Ensure all tensors are on the same device | |
| device = image.device | |
| mask = mask.to(device) | |
| # 1. Normalize pixel values to [0, 1] if they are in [0, 255] | |
| if image.dtype == torch.uint8: | |
| image = image.float() / 255.0 | |
| if mask.dtype == torch.uint8: | |
| mask = mask.float() / 255.0 | |
| # Ensure mask has a channel dimension if it's (1, H, W) | |
| if mask.ndim == 3: # (1, H, W) | |
| mask = mask.unsqueeze(-1) # -> (1, H, W, 1) | |
| # Ensure mask dimensions match base image's spatial dimensions | |
| # Assuming image is (B, H, W, C) and mask is (B, H, W, 1) | |
| if image.shape[1:3] != mask.shape[1:3]: | |
| # Simple resize (bilinear for float data) | |
| # Permute to (B, C, H, W) for torch.nn.functional.interpolate | |
| original_shape = image.shape | |
| image_permuted = image.permute(0, 3, 1, 2) | |
| mask_permuted = mask.permute(0, 3, 1, 2) | |
| # Resize mask to base image's H, W | |
| mask_resized = torch.nn.functional.interpolate( | |
| mask_permuted, | |
| size=(original_shape[1], original_shape[2]), | |
| mode='bilinear', | |
| align_corners=False # Set to True if you want corner pixel alignment | |
| ) | |
| # Permute back to (B, H, W, C) or (B, H, W, 1) | |
| mask = mask_resized.permute(0, 2, 3, 1) | |
| # 2. Extract the existing alpha channel | |
| # image is (B, H, W, 4) -> alpha_channel is (B, H, W, 1) | |
| alpha_channel = image[..., 3:4] # Slice to keep the channel dimension | |
| # 3. Create a binary mask from the greyscale mask | |
| # For a greyscale mask, values > 0 (or a threshold) mean "covered" | |
| # This creates a mask tensor with values 0.0 or 1.0 (or 0.0 to 1.0 based on intensity) | |
| # You might want to apply a threshold here if your mask isn't strictly binary | |
| # For example: `binary_mask = (mask > 0.5).float()` | |
| # Let's assume the mask's intensity directly corresponds to desired opacity | |
| processed_mask = mask # Already normalized to [0, 1] | |
| # 4. Combine the existing alpha channel with the processed mask | |
| # The new alpha should be the minimum of the existing alpha and the mask's intensity. | |
| # This ensures that if a pixel was already transparent (low alpha) OR | |
| # not covered by the mask (low mask intensity), it becomes transparent. | |
| new_alpha_channel = torch.min(alpha_channel, processed_mask) | |
| # 5. Replace the original alpha channel with the new one | |
| # Create a copy to avoid modifying the original tensor in place if not desired | |
| output_image_tensor = image.clone() | |
| output_image_tensor[..., 3:4] = new_alpha_channel | |
| # If original input was uint8, convert back | |
| if image.dtype == torch.uint8: | |
| output_image_tensor = (output_image_tensor * 255).to(torch.uint8) | |
| return (output_image_tensor, ) | |
| NODE_CLASS_MAPPINGS = { | |
| "CutOutMask": Cut_Out_Mask, | |
| } | |
| NODE_DISPLAY_NAME_MAPPINGS = { | |
| "CutOutMask": "Cut Out Mask", | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment