Last active
February 21, 2025 16:02
-
-
Save catboxanon/69ce64e0389fa803d26dc59bb444af53 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gradio as gr | |
import numpy as np | |
import torch | |
from PIL import Image | |
from scipy.ndimage import gaussian_filter | |
from skimage.transform import resize | |
import modules.scripts as scripts | |
from modules import devices, script_callbacks, shared | |
from modules.processing import StableDiffusionProcessingTxt2Img | |
from modules.script_callbacks import ExtraNoiseParams | |
from modules.scripts import AlwaysVisible | |
NAME = "Extra Noise Mask" | |
MASK = None | |
BLUR_RADIUS = 0 | |
class Script(scripts.Script): | |
def title(self): | |
return NAME | |
def ui(self, _): | |
with gr.Group(): | |
with gr.Accordion(NAME, open=False): | |
enabled = gr.Checkbox(label="Enabled", value=False) | |
return_mask = gr.Checkbox(label="Return mask", value=False) | |
blur_radius = gr.Slider(label="Blur radius", value=0, minimum=0, maximum=32, step=1) | |
canvas = gr.Image( | |
image_mode="RGB", | |
source='upload', | |
tool='sketch', | |
type='numpy', | |
height=768, | |
show_label=False, | |
show_download_button=False, | |
interactive=True, | |
brush_color=shared.opts.data.get('img2img_inpaint_mask_brush_color', '#FFFFFF'), # type: ignore | |
brush_radius=128, | |
) | |
return [ | |
enabled, | |
return_mask, | |
canvas, | |
blur_radius, | |
] | |
def show(self, is_img2img): | |
if not is_img2img: | |
return AlwaysVisible | |
def process(self, | |
p: StableDiffusionProcessingTxt2Img, | |
enabled: bool, | |
return_mask: bool, | |
canvas: dict, | |
blur_radius: float, | |
): | |
global MASK, BLUR_RADIUS | |
if enabled: | |
MASK = canvas["mask"] # type: ignore | |
p._extra_noise_mask = canvas["mask"] # type: ignore | |
BLUR_RADIUS = blur_radius | |
if not MASK.any(): | |
return | |
p.extra_generation_params.update({ | |
f'{NAME} enabled': enabled, | |
f'{NAME} blur radius': blur_radius, | |
}) | |
else: | |
MASK = None | |
BLUR_RADIUS = 0 | |
def postprocess(self, p, processed, *args): | |
if hasattr(p, '_extra_noise_mask') and args[1]: | |
processed.images.extend([Image.fromarray(p._extra_noise_mask)]) | |
def on_extra_noise(params: ExtraNoiseParams): | |
global MASK, BLUR_RADIUS | |
noise = params.noise | |
if MASK is not None and MASK.any(): | |
MASK = resize(MASK, (noise.shape[2], noise.shape[3])) | |
MASK = gaussian_filter(MASK, sigma=BLUR_RADIUS) | |
MASK = MASK.mean(axis=-1) | |
MASK = torch.from_numpy(MASK).unsqueeze(0).unsqueeze(0).repeat(1,4,1,1).to(devices.device) | |
params.noise = noise * MASK | |
script_callbacks.on_extra_noise(on_extra_noise) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment