Skip to content

Instantly share code, notes, and snippets.

@marduk191
Created July 27, 2025 21:50
Show Gist options
  • Select an option

  • Save marduk191/3f7a59a05bac0242dabcc9bd57133cab to your computer and use it in GitHub Desktop.

Select an option

Save marduk191/3f7a59a05bac0242dabcc9bd57133cab to your computer and use it in GitHub Desktop.
comfyui upscaler
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image, ImageFilter, ImageEnhance
import math
import comfy.model_management as mm
import comfy.utils
import comfy.samplers
import folder_paths
class UltraMaxUpscaler:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"model": ("MODEL",),
"positive": ("CONDITIONING",),
"negative": ("CONDITIONING",),
"vae": ("VAE",),
"upscale_factor": ("FLOAT", {"default": 2.0, "min": 1.0, "max": 6.0, "step": 0.1}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"steps": ("INT", {"default": 20, "min": 1, "max": 100}),
"cfg": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 30.0, "step": 0.1}),
"sampler_name": (comfy.samplers.KSampler.SAMPLERS,),
"scheduler": (comfy.samplers.KSampler.SCHEDULERS,),
"denoise_strength": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
"tile_size": ("INT", {"default": 768, "min": 256, "max": 2048, "step": 64}),
"tile_overlap": ("INT", {"default": 64, "min": 32, "max": 256, "step": 16}),
"detail_boost": ("FLOAT", {"default": 1.3, "min": 0.5, "max": 3.0, "step": 0.1}),
"edge_enhance": ("FLOAT", {"default": 1.2, "min": 0.0, "max": 3.0, "step": 0.1}),
"noise_reduction": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.05}),
"speed_mode": (["balanced", "fast", "quality"], {"default": "balanced"}),
"frequency_separation": ("BOOLEAN", {"default": True}),
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "upscale"
CATEGORY = "image/upscaling"
def __init__(self):
self.device = mm.get_torch_device()
def ensure_bhwc_format(self, tensor):
"""Ensure tensor is in BHWC format (ComfyUI standard)"""
if len(tensor.shape) == 3:
tensor = tensor.unsqueeze(0) # Add batch dimension
if len(tensor.shape) == 4:
if tensor.shape[1] == 3: # BCHW format
tensor = tensor.permute(0, 2, 3, 1) # Convert to BHWC
return tensor
def ensure_bchw_format(self, tensor):
"""Ensure tensor is in BCHW format (for PyTorch operations)"""
if len(tensor.shape) == 3:
tensor = tensor.unsqueeze(0) # Add batch dimension
if len(tensor.shape) == 4:
if tensor.shape[-1] == 3: # BHWC format
tensor = tensor.permute(0, 3, 1, 2) # Convert to BCHW
elif tensor.shape[1] == 3: # Already BCHW
pass # No conversion needed
else:
raise ValueError(f"Cannot convert tensor with shape {tensor.shape} to BCHW format")
return tensor
def gaussian_blur_fast(self, tensor, sigma):
"""Fast Gaussian blur using separable convolution"""
tensor = self.ensure_bchw_format(tensor)
kernel_size = int(2 * math.ceil(2 * sigma) + 1)
kernel_size = kernel_size if kernel_size % 2 == 1 else kernel_size + 1
# Create 1D Gaussian kernel
x = torch.arange(kernel_size, dtype=torch.float32, device=self.device)
x = x - kernel_size // 2
kernel_1d = torch.exp(-0.5 * (x / sigma) ** 2)
kernel_1d = kernel_1d / kernel_1d.sum()
# Apply separable convolution
padding = kernel_size // 2
# Horizontal blur
kernel_h = kernel_1d.view(1, 1, 1, -1).repeat(tensor.shape[1], 1, 1, 1)
blurred = F.conv2d(tensor, kernel_h, padding=(0, padding), groups=tensor.shape[1])
# Vertical blur
kernel_v = kernel_1d.view(1, 1, -1, 1).repeat(tensor.shape[1], 1, 1, 1)
blurred = F.conv2d(blurred, kernel_v, padding=(padding, 0), groups=tensor.shape[1])
return blurred
def extract_edges(self, tensor):
"""Extract edge information using Sobel operator"""
tensor = self.ensure_bchw_format(tensor)
# Convert to grayscale
gray = torch.mean(tensor, dim=1, keepdim=True)
# Sobel kernels
sobel_x = torch.tensor([[[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]]],
dtype=torch.float32, device=self.device)
sobel_y = torch.tensor([[[[-1, -2, -1], [0, 0, 0], [1, 2, 1]]]],
dtype=torch.float32, device=self.device)
# Apply Sobel filters
edge_x = F.conv2d(gray, sobel_x, padding=1)
edge_y = F.conv2d(gray, sobel_y, padding=1)
# Combine edges
edges = torch.sqrt(edge_x ** 2 + edge_y ** 2)
edges = edges.repeat(1, tensor.shape[1], 1, 1) # Expand to RGB
return edges
def frequency_separation_upscale(self, tensor, scale_factor):
"""Upscale using frequency separation for better detail preservation"""
original_format_bhwc = tensor.shape[-1] == 3
tensor = self.ensure_bchw_format(tensor)
# Separate frequencies
low_freq = self.gaussian_blur_fast(tensor, sigma=1.0)
high_freq = tensor - low_freq
# Upscale low frequency with bicubic (smooth areas)
low_upscaled = F.interpolate(
low_freq,
scale_factor=scale_factor,
mode='bicubic',
align_corners=False,
antialias=True
)
# Upscale high frequency with bilinear (preserve sharp details)
high_upscaled = F.interpolate(
high_freq,
scale_factor=scale_factor,
mode='bilinear',
align_corners=False
)
# Enhance high frequency details slightly
high_upscaled = high_upscaled * 1.1
# Combine frequencies
result = low_upscaled + high_upscaled
result = torch.clamp(result, 0, 1)
# Convert back to original format if needed
if original_format_bhwc:
result = self.ensure_bhwc_format(result)
return result
def adaptive_sharpen(self, tensor, strength, edge_mask=None):
"""Apply adaptive sharpening based on edge information"""
original_format_bhwc = tensor.shape[-1] == 3
tensor = self.ensure_bchw_format(tensor)
# Create unsharp mask
blurred = self.gaussian_blur_fast(tensor, sigma=0.8)
detail = tensor - blurred
# Apply edge-aware enhancement
if edge_mask is not None:
edge_mask = self.ensure_bchw_format(edge_mask)
if edge_mask.shape[2:] != tensor.shape[2:]:
edge_mask = F.interpolate(edge_mask, size=tensor.shape[2:], mode='bilinear')
# Normalize edge mask
edge_strength = torch.abs(edge_mask)
edge_strength = edge_strength / (edge_strength.max() + 1e-8)
# Apply stronger enhancement in edge areas
strength_map = 1.0 + edge_strength * (strength - 1.0)
detail = detail * strength_map
else:
detail = detail * strength
# Apply enhancement
enhanced = tensor + detail
enhanced = torch.clamp(enhanced, 0, 1)
# Convert back to original format if needed
if original_format_bhwc:
enhanced = self.ensure_bhwc_format(enhanced)
return enhanced
def calculate_tiles(self, width, height, tile_size, overlap):
"""Calculate optimal tile positions"""
tiles = []
step_size = tile_size - overlap
y = 0
while y < height:
x = 0
while x < width:
# Calculate actual tile dimensions
tile_w = min(tile_size, width - x)
tile_h = min(tile_size, height - y)
# Only add tiles with valid dimensions
if tile_w > 64 and tile_h > 64:
tiles.append((x, y, tile_w, tile_h))
x += step_size
if x >= width:
break
y += step_size
if y >= height:
break
return tiles
def process_tile(self, model, positive, negative, vae, tile_tensor,
seed, steps, cfg, sampler_name, scheduler, denoise_strength):
"""Process a single tile through the diffusion model"""
# Debug: Print input tensor shape
print(f"Input tile tensor shape: {tile_tensor.shape}")
# Ensure tile is in BHWC format for VAE (ComfyUI standard)
tile_bhwc = self.ensure_bhwc_format(tile_tensor)
print(f"Tile tensor in BHWC format: {tile_bhwc.shape}")
# Validate dimensions - should be BHWC with 3 channels
if len(tile_bhwc.shape) != 4 or tile_bhwc.shape[-1] != 3:
raise ValueError(f"Expected BHWC format with 3 channels, got shape: {tile_bhwc.shape}")
# Encode to latent space (VAE expects BHWC format)
with torch.no_grad():
latent_samples = vae.encode(tile_bhwc[:,:,:,:3])
# Create latent dictionary in ComfyUI format
latent_dict = {"samples": latent_samples}
# Use ComfyUI's common_ksampler function (import it properly)
import comfy.sample
import comfy.utils
import latent_preview
# Prepare noise
if denoise_strength > 0:
batch_inds = None
noise = comfy.sample.prepare_noise(latent_samples, seed, batch_inds)
else:
noise = torch.zeros_like(latent_samples)
# Sample using ComfyUI's sampling system
callback = latent_preview.prepare_callback(model, steps)
disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
samples = comfy.sample.sample(
model, noise, steps, cfg, sampler_name, scheduler,
positive, negative, latent_samples,
denoise=denoise_strength,
disable_noise=False,
start_step=None,
last_step=None,
force_full_denoise=True,
noise_mask=None,
callback=callback,
disable_pbar=disable_pbar,
seed=seed
)
# Decode back to image space
with torch.no_grad():
decoded = vae.decode(samples)
# Ensure decoded tensor is on correct device
decoded = decoded.to(self.device)
return decoded
def blend_tiles(self, tiles_data, final_width, final_height, overlap):
"""Blend tiles with smooth feathering"""
result = torch.zeros((1, final_height, final_width, 3), device=self.device)
weight_map = torch.zeros((1, final_height, final_width, 1), device=self.device)
for (x, y, w, h), tile_data in tiles_data:
# Ensure tile is in BHWC format and on correct device
tile_data = self.ensure_bhwc_format(tile_data)
tile_data = tile_data.to(self.device) # Move to GPU
# Create feathering mask
mask = torch.ones((1, h, w, 1), device=self.device)
# Apply feathering
fade_size = min(overlap // 2, min(w, h) // 4)
if fade_size > 1:
# Create smooth transitions on edges
for i in range(fade_size):
alpha = i / fade_size
# Top edge
if y > 0:
mask[:, i, :, :] *= alpha
# Bottom edge
if y + h < final_height:
mask[:, h-1-i, :, :] *= alpha
# Left edge
if x > 0:
mask[:, :, i, :] *= alpha
# Right edge
if x + w < final_width:
mask[:, :, w-1-i, :] *= alpha
# Blend tile into result
result[:, y:y+h, x:x+w, :] += tile_data * mask
weight_map[:, y:y+h, x:x+w, :] += mask
# Normalize by weights
weight_map = torch.clamp(weight_map, min=1e-8)
result = result / weight_map
return result
def upscale(self, image, model, positive, negative, vae, upscale_factor,
seed, steps, cfg, sampler_name, scheduler, denoise_strength,
tile_size, tile_overlap, detail_boost, edge_enhance,
noise_reduction, speed_mode, frequency_separation):
# Ensure input is on correct device and format
image = image.to(self.device)
image = self.ensure_bhwc_format(image)
batch, height, width, channels = image.shape
target_width = int(width * upscale_factor)
target_height = int(height * upscale_factor)
print(f"Upscaling from {width}x{height} to {target_width}x{target_height}")
# Adjust parameters based on speed mode
if speed_mode == "fast":
steps = max(10, steps // 2)
tile_size = min(tile_size, 512)
elif speed_mode == "quality":
steps = min(steps + 10, 50)
tile_size = max(tile_size, 768)
# Initial upscale
if frequency_separation:
upscaled = self.frequency_separation_upscale(image, upscale_factor)
else:
image_bchw = self.ensure_bchw_format(image)
upscaled_bchw = F.interpolate(
image_bchw,
size=(target_height, target_width),
mode='bicubic',
align_corners=False,
antialias=True
)
upscaled = self.ensure_bhwc_format(upscaled_bchw)
print(f"Initial upscale complete: {upscaled.shape}")
# Extract edge information for enhancement
edge_mask = None
if edge_enhance > 0:
edge_mask = self.extract_edges(upscaled)
# Calculate tiles for processing
tiles = self.calculate_tiles(target_width, target_height, tile_size, tile_overlap)
print(f"Processing {len(tiles)} tiles")
# Process tiles if denoising is enabled
if denoise_strength > 0 and len(tiles) > 0:
processed_tiles = []
print(f"Upscaled tensor shape before tile processing: {upscaled.shape}")
for i, (x, y, w, h) in enumerate(tiles):
print(f"Processing tile {i+1}/{len(tiles)}: {w}x{h} at ({x},{y})")
# Extract tile
tile = upscaled[:, y:y+h, x:x+w, :]
print(f"Extracted tile shape: {tile.shape}")
try:
# Process tile through diffusion model
processed_tile = self.process_tile(
model, positive, negative, vae, tile,
seed + i, steps, cfg, sampler_name, scheduler, denoise_strength
)
# Convert to BHWC format
processed_tile = self.ensure_bhwc_format(processed_tile)
processed_tiles.append(((x, y, w, h), processed_tile))
except Exception as e:
print(f"Error processing tile {i}: {e}")
import traceback
traceback.print_exc()
# Use original tile if processing fails
processed_tiles.append(((x, y, w, h), tile))
# Blend processed tiles
if processed_tiles:
upscaled = self.blend_tiles(processed_tiles, target_width, target_height, tile_overlap)
# Apply final enhancements
if detail_boost != 1.0:
upscaled = self.adaptive_sharpen(upscaled, detail_boost, edge_mask)
# Noise reduction
if noise_reduction > 0:
upscaled_bchw = self.ensure_bchw_format(upscaled)
smoothed = self.gaussian_blur_fast(upscaled_bchw, sigma=0.5)
upscaled_bchw = (1 - noise_reduction) * upscaled_bchw + noise_reduction * smoothed
upscaled = self.ensure_bhwc_format(upscaled_bchw)
# Final clamp and format
upscaled = torch.clamp(upscaled, 0, 1)
print("Upscaling complete!")
return (upscaled,)
# Node mappings for ComfyUI
NODE_CLASS_MAPPINGS = {
"UltraMaxUpscaler": UltraMaxUpscaler
}
NODE_DISPLAY_NAME_MAPPINGS = {
"UltraMaxUpscaler": "Ultra Max Upscaler"
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment