Created
July 27, 2025 21:50
-
-
Save marduk191/3f7a59a05bac0242dabcc9bd57133cab to your computer and use it in GitHub Desktop.
comfyui upscaler
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from PIL import Image, ImageFilter, ImageEnhance | |
| import math | |
| import comfy.model_management as mm | |
| import comfy.utils | |
| import comfy.samplers | |
| import folder_paths | |
| class UltraMaxUpscaler: | |
| @classmethod | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "image": ("IMAGE",), | |
| "model": ("MODEL",), | |
| "positive": ("CONDITIONING",), | |
| "negative": ("CONDITIONING",), | |
| "vae": ("VAE",), | |
| "upscale_factor": ("FLOAT", {"default": 2.0, "min": 1.0, "max": 6.0, "step": 0.1}), | |
| "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), | |
| "steps": ("INT", {"default": 20, "min": 1, "max": 100}), | |
| "cfg": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 30.0, "step": 0.1}), | |
| "sampler_name": (comfy.samplers.KSampler.SAMPLERS,), | |
| "scheduler": (comfy.samplers.KSampler.SCHEDULERS,), | |
| "denoise_strength": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), | |
| "tile_size": ("INT", {"default": 768, "min": 256, "max": 2048, "step": 64}), | |
| "tile_overlap": ("INT", {"default": 64, "min": 32, "max": 256, "step": 16}), | |
| "detail_boost": ("FLOAT", {"default": 1.3, "min": 0.5, "max": 3.0, "step": 0.1}), | |
| "edge_enhance": ("FLOAT", {"default": 1.2, "min": 0.0, "max": 3.0, "step": 0.1}), | |
| "noise_reduction": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.05}), | |
| "speed_mode": (["balanced", "fast", "quality"], {"default": "balanced"}), | |
| "frequency_separation": ("BOOLEAN", {"default": True}), | |
| } | |
| } | |
| RETURN_TYPES = ("IMAGE",) | |
| FUNCTION = "upscale" | |
| CATEGORY = "image/upscaling" | |
| def __init__(self): | |
| self.device = mm.get_torch_device() | |
| def ensure_bhwc_format(self, tensor): | |
| """Ensure tensor is in BHWC format (ComfyUI standard)""" | |
| if len(tensor.shape) == 3: | |
| tensor = tensor.unsqueeze(0) # Add batch dimension | |
| if len(tensor.shape) == 4: | |
| if tensor.shape[1] == 3: # BCHW format | |
| tensor = tensor.permute(0, 2, 3, 1) # Convert to BHWC | |
| return tensor | |
| def ensure_bchw_format(self, tensor): | |
| """Ensure tensor is in BCHW format (for PyTorch operations)""" | |
| if len(tensor.shape) == 3: | |
| tensor = tensor.unsqueeze(0) # Add batch dimension | |
| if len(tensor.shape) == 4: | |
| if tensor.shape[-1] == 3: # BHWC format | |
| tensor = tensor.permute(0, 3, 1, 2) # Convert to BCHW | |
| elif tensor.shape[1] == 3: # Already BCHW | |
| pass # No conversion needed | |
| else: | |
| raise ValueError(f"Cannot convert tensor with shape {tensor.shape} to BCHW format") | |
| return tensor | |
| def gaussian_blur_fast(self, tensor, sigma): | |
| """Fast Gaussian blur using separable convolution""" | |
| tensor = self.ensure_bchw_format(tensor) | |
| kernel_size = int(2 * math.ceil(2 * sigma) + 1) | |
| kernel_size = kernel_size if kernel_size % 2 == 1 else kernel_size + 1 | |
| # Create 1D Gaussian kernel | |
| x = torch.arange(kernel_size, dtype=torch.float32, device=self.device) | |
| x = x - kernel_size // 2 | |
| kernel_1d = torch.exp(-0.5 * (x / sigma) ** 2) | |
| kernel_1d = kernel_1d / kernel_1d.sum() | |
| # Apply separable convolution | |
| padding = kernel_size // 2 | |
| # Horizontal blur | |
| kernel_h = kernel_1d.view(1, 1, 1, -1).repeat(tensor.shape[1], 1, 1, 1) | |
| blurred = F.conv2d(tensor, kernel_h, padding=(0, padding), groups=tensor.shape[1]) | |
| # Vertical blur | |
| kernel_v = kernel_1d.view(1, 1, -1, 1).repeat(tensor.shape[1], 1, 1, 1) | |
| blurred = F.conv2d(blurred, kernel_v, padding=(padding, 0), groups=tensor.shape[1]) | |
| return blurred | |
| def extract_edges(self, tensor): | |
| """Extract edge information using Sobel operator""" | |
| tensor = self.ensure_bchw_format(tensor) | |
| # Convert to grayscale | |
| gray = torch.mean(tensor, dim=1, keepdim=True) | |
| # Sobel kernels | |
| sobel_x = torch.tensor([[[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]]], | |
| dtype=torch.float32, device=self.device) | |
| sobel_y = torch.tensor([[[[-1, -2, -1], [0, 0, 0], [1, 2, 1]]]], | |
| dtype=torch.float32, device=self.device) | |
| # Apply Sobel filters | |
| edge_x = F.conv2d(gray, sobel_x, padding=1) | |
| edge_y = F.conv2d(gray, sobel_y, padding=1) | |
| # Combine edges | |
| edges = torch.sqrt(edge_x ** 2 + edge_y ** 2) | |
| edges = edges.repeat(1, tensor.shape[1], 1, 1) # Expand to RGB | |
| return edges | |
| def frequency_separation_upscale(self, tensor, scale_factor): | |
| """Upscale using frequency separation for better detail preservation""" | |
| original_format_bhwc = tensor.shape[-1] == 3 | |
| tensor = self.ensure_bchw_format(tensor) | |
| # Separate frequencies | |
| low_freq = self.gaussian_blur_fast(tensor, sigma=1.0) | |
| high_freq = tensor - low_freq | |
| # Upscale low frequency with bicubic (smooth areas) | |
| low_upscaled = F.interpolate( | |
| low_freq, | |
| scale_factor=scale_factor, | |
| mode='bicubic', | |
| align_corners=False, | |
| antialias=True | |
| ) | |
| # Upscale high frequency with bilinear (preserve sharp details) | |
| high_upscaled = F.interpolate( | |
| high_freq, | |
| scale_factor=scale_factor, | |
| mode='bilinear', | |
| align_corners=False | |
| ) | |
| # Enhance high frequency details slightly | |
| high_upscaled = high_upscaled * 1.1 | |
| # Combine frequencies | |
| result = low_upscaled + high_upscaled | |
| result = torch.clamp(result, 0, 1) | |
| # Convert back to original format if needed | |
| if original_format_bhwc: | |
| result = self.ensure_bhwc_format(result) | |
| return result | |
| def adaptive_sharpen(self, tensor, strength, edge_mask=None): | |
| """Apply adaptive sharpening based on edge information""" | |
| original_format_bhwc = tensor.shape[-1] == 3 | |
| tensor = self.ensure_bchw_format(tensor) | |
| # Create unsharp mask | |
| blurred = self.gaussian_blur_fast(tensor, sigma=0.8) | |
| detail = tensor - blurred | |
| # Apply edge-aware enhancement | |
| if edge_mask is not None: | |
| edge_mask = self.ensure_bchw_format(edge_mask) | |
| if edge_mask.shape[2:] != tensor.shape[2:]: | |
| edge_mask = F.interpolate(edge_mask, size=tensor.shape[2:], mode='bilinear') | |
| # Normalize edge mask | |
| edge_strength = torch.abs(edge_mask) | |
| edge_strength = edge_strength / (edge_strength.max() + 1e-8) | |
| # Apply stronger enhancement in edge areas | |
| strength_map = 1.0 + edge_strength * (strength - 1.0) | |
| detail = detail * strength_map | |
| else: | |
| detail = detail * strength | |
| # Apply enhancement | |
| enhanced = tensor + detail | |
| enhanced = torch.clamp(enhanced, 0, 1) | |
| # Convert back to original format if needed | |
| if original_format_bhwc: | |
| enhanced = self.ensure_bhwc_format(enhanced) | |
| return enhanced | |
| def calculate_tiles(self, width, height, tile_size, overlap): | |
| """Calculate optimal tile positions""" | |
| tiles = [] | |
| step_size = tile_size - overlap | |
| y = 0 | |
| while y < height: | |
| x = 0 | |
| while x < width: | |
| # Calculate actual tile dimensions | |
| tile_w = min(tile_size, width - x) | |
| tile_h = min(tile_size, height - y) | |
| # Only add tiles with valid dimensions | |
| if tile_w > 64 and tile_h > 64: | |
| tiles.append((x, y, tile_w, tile_h)) | |
| x += step_size | |
| if x >= width: | |
| break | |
| y += step_size | |
| if y >= height: | |
| break | |
| return tiles | |
| def process_tile(self, model, positive, negative, vae, tile_tensor, | |
| seed, steps, cfg, sampler_name, scheduler, denoise_strength): | |
| """Process a single tile through the diffusion model""" | |
| # Debug: Print input tensor shape | |
| print(f"Input tile tensor shape: {tile_tensor.shape}") | |
| # Ensure tile is in BHWC format for VAE (ComfyUI standard) | |
| tile_bhwc = self.ensure_bhwc_format(tile_tensor) | |
| print(f"Tile tensor in BHWC format: {tile_bhwc.shape}") | |
| # Validate dimensions - should be BHWC with 3 channels | |
| if len(tile_bhwc.shape) != 4 or tile_bhwc.shape[-1] != 3: | |
| raise ValueError(f"Expected BHWC format with 3 channels, got shape: {tile_bhwc.shape}") | |
| # Encode to latent space (VAE expects BHWC format) | |
| with torch.no_grad(): | |
| latent_samples = vae.encode(tile_bhwc[:,:,:,:3]) | |
| # Create latent dictionary in ComfyUI format | |
| latent_dict = {"samples": latent_samples} | |
| # Use ComfyUI's common_ksampler function (import it properly) | |
| import comfy.sample | |
| import comfy.utils | |
| import latent_preview | |
| # Prepare noise | |
| if denoise_strength > 0: | |
| batch_inds = None | |
| noise = comfy.sample.prepare_noise(latent_samples, seed, batch_inds) | |
| else: | |
| noise = torch.zeros_like(latent_samples) | |
| # Sample using ComfyUI's sampling system | |
| callback = latent_preview.prepare_callback(model, steps) | |
| disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED | |
| samples = comfy.sample.sample( | |
| model, noise, steps, cfg, sampler_name, scheduler, | |
| positive, negative, latent_samples, | |
| denoise=denoise_strength, | |
| disable_noise=False, | |
| start_step=None, | |
| last_step=None, | |
| force_full_denoise=True, | |
| noise_mask=None, | |
| callback=callback, | |
| disable_pbar=disable_pbar, | |
| seed=seed | |
| ) | |
| # Decode back to image space | |
| with torch.no_grad(): | |
| decoded = vae.decode(samples) | |
| # Ensure decoded tensor is on correct device | |
| decoded = decoded.to(self.device) | |
| return decoded | |
| def blend_tiles(self, tiles_data, final_width, final_height, overlap): | |
| """Blend tiles with smooth feathering""" | |
| result = torch.zeros((1, final_height, final_width, 3), device=self.device) | |
| weight_map = torch.zeros((1, final_height, final_width, 1), device=self.device) | |
| for (x, y, w, h), tile_data in tiles_data: | |
| # Ensure tile is in BHWC format and on correct device | |
| tile_data = self.ensure_bhwc_format(tile_data) | |
| tile_data = tile_data.to(self.device) # Move to GPU | |
| # Create feathering mask | |
| mask = torch.ones((1, h, w, 1), device=self.device) | |
| # Apply feathering | |
| fade_size = min(overlap // 2, min(w, h) // 4) | |
| if fade_size > 1: | |
| # Create smooth transitions on edges | |
| for i in range(fade_size): | |
| alpha = i / fade_size | |
| # Top edge | |
| if y > 0: | |
| mask[:, i, :, :] *= alpha | |
| # Bottom edge | |
| if y + h < final_height: | |
| mask[:, h-1-i, :, :] *= alpha | |
| # Left edge | |
| if x > 0: | |
| mask[:, :, i, :] *= alpha | |
| # Right edge | |
| if x + w < final_width: | |
| mask[:, :, w-1-i, :] *= alpha | |
| # Blend tile into result | |
| result[:, y:y+h, x:x+w, :] += tile_data * mask | |
| weight_map[:, y:y+h, x:x+w, :] += mask | |
| # Normalize by weights | |
| weight_map = torch.clamp(weight_map, min=1e-8) | |
| result = result / weight_map | |
| return result | |
| def upscale(self, image, model, positive, negative, vae, upscale_factor, | |
| seed, steps, cfg, sampler_name, scheduler, denoise_strength, | |
| tile_size, tile_overlap, detail_boost, edge_enhance, | |
| noise_reduction, speed_mode, frequency_separation): | |
| # Ensure input is on correct device and format | |
| image = image.to(self.device) | |
| image = self.ensure_bhwc_format(image) | |
| batch, height, width, channels = image.shape | |
| target_width = int(width * upscale_factor) | |
| target_height = int(height * upscale_factor) | |
| print(f"Upscaling from {width}x{height} to {target_width}x{target_height}") | |
| # Adjust parameters based on speed mode | |
| if speed_mode == "fast": | |
| steps = max(10, steps // 2) | |
| tile_size = min(tile_size, 512) | |
| elif speed_mode == "quality": | |
| steps = min(steps + 10, 50) | |
| tile_size = max(tile_size, 768) | |
| # Initial upscale | |
| if frequency_separation: | |
| upscaled = self.frequency_separation_upscale(image, upscale_factor) | |
| else: | |
| image_bchw = self.ensure_bchw_format(image) | |
| upscaled_bchw = F.interpolate( | |
| image_bchw, | |
| size=(target_height, target_width), | |
| mode='bicubic', | |
| align_corners=False, | |
| antialias=True | |
| ) | |
| upscaled = self.ensure_bhwc_format(upscaled_bchw) | |
| print(f"Initial upscale complete: {upscaled.shape}") | |
| # Extract edge information for enhancement | |
| edge_mask = None | |
| if edge_enhance > 0: | |
| edge_mask = self.extract_edges(upscaled) | |
| # Calculate tiles for processing | |
| tiles = self.calculate_tiles(target_width, target_height, tile_size, tile_overlap) | |
| print(f"Processing {len(tiles)} tiles") | |
| # Process tiles if denoising is enabled | |
| if denoise_strength > 0 and len(tiles) > 0: | |
| processed_tiles = [] | |
| print(f"Upscaled tensor shape before tile processing: {upscaled.shape}") | |
| for i, (x, y, w, h) in enumerate(tiles): | |
| print(f"Processing tile {i+1}/{len(tiles)}: {w}x{h} at ({x},{y})") | |
| # Extract tile | |
| tile = upscaled[:, y:y+h, x:x+w, :] | |
| print(f"Extracted tile shape: {tile.shape}") | |
| try: | |
| # Process tile through diffusion model | |
| processed_tile = self.process_tile( | |
| model, positive, negative, vae, tile, | |
| seed + i, steps, cfg, sampler_name, scheduler, denoise_strength | |
| ) | |
| # Convert to BHWC format | |
| processed_tile = self.ensure_bhwc_format(processed_tile) | |
| processed_tiles.append(((x, y, w, h), processed_tile)) | |
| except Exception as e: | |
| print(f"Error processing tile {i}: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| # Use original tile if processing fails | |
| processed_tiles.append(((x, y, w, h), tile)) | |
| # Blend processed tiles | |
| if processed_tiles: | |
| upscaled = self.blend_tiles(processed_tiles, target_width, target_height, tile_overlap) | |
| # Apply final enhancements | |
| if detail_boost != 1.0: | |
| upscaled = self.adaptive_sharpen(upscaled, detail_boost, edge_mask) | |
| # Noise reduction | |
| if noise_reduction > 0: | |
| upscaled_bchw = self.ensure_bchw_format(upscaled) | |
| smoothed = self.gaussian_blur_fast(upscaled_bchw, sigma=0.5) | |
| upscaled_bchw = (1 - noise_reduction) * upscaled_bchw + noise_reduction * smoothed | |
| upscaled = self.ensure_bhwc_format(upscaled_bchw) | |
| # Final clamp and format | |
| upscaled = torch.clamp(upscaled, 0, 1) | |
| print("Upscaling complete!") | |
| return (upscaled,) | |
| # Node mappings for ComfyUI | |
| NODE_CLASS_MAPPINGS = { | |
| "UltraMaxUpscaler": UltraMaxUpscaler | |
| } | |
| NODE_DISPLAY_NAME_MAPPINGS = { | |
| "UltraMaxUpscaler": "Ultra Max Upscaler" | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment