Created
May 4, 2023 01:24
-
-
Save papuSpartan/f97c4b423352bf26e628910970b9e555 to your computer and use it in GitHub Desktop.
SAG patch for auto1111(dev) 335428c2c8139dfe07ba096a6defa75036660244
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
From cef07987cc6e1776d3c3d88691ce94c8fc2c3a0c Mon Sep 17 00:00:00 2001 | |
From: papuSpartan <[email protected]> | |
Date: Wed, 3 May 2023 20:06:27 -0500 | |
Subject: [PATCH] patch in callbacks for SAG to dev | |
--- | |
modules/script_callbacks.py | 31 ++++++++++++++++++++++++++++++- | |
modules/sd_samplers_kdiffusion.py | 9 ++++++++- | |
2 files changed, 38 insertions(+), 2 deletions(-) | |
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py | |
index 17109732..e0436a00 100644 | |
--- a/modules/script_callbacks.py | |
+++ b/modules/script_callbacks.py | |
@@ -53,7 +53,7 @@ class CFGDenoiserParams: | |
class CFGDenoisedParams: | |
- def __init__(self, x, sampling_step, total_sampling_steps): | |
+ def __init__(self, x, sampling_step, total_sampling_steps, inner_model): | |
self.x = x | |
"""Latent image representation in the process of being denoised""" | |
@@ -63,6 +63,22 @@ class CFGDenoisedParams: | |
self.total_sampling_steps = total_sampling_steps | |
"""Total number of sampling steps planned""" | |
+ self.inner_model = inner_model | |
+ """Inner model reference that is being used for denoising""" | |
+ | |
+ | |
+ | |
+class AfterCFGCallbackParams: | |
+ def __init__(self, x, sampling_step, total_sampling_steps): | |
+ self.x = x | |
+ """Latent image representation in the process of being denoised""" | |
+ | |
+ self.total_sampling_steps = total_sampling_steps | |
+ """Total number of sampling steps planned""" | |
+ | |
+ self.output_altered = False | |
+ """A flag for CFGDenoiser that indicates whether the output has been altered by the callback""" | |
+ | |
class UiTrainTabParams: | |
def __init__(self, txt2img_preview_params): | |
@@ -87,6 +103,7 @@ callback_map = dict( | |
callbacks_image_saved=[], | |
callbacks_cfg_denoiser=[], | |
callbacks_cfg_denoised=[], | |
+ callbacks_cfg_after_cfg=[], | |
callbacks_before_component=[], | |
callbacks_after_component=[], | |
callbacks_image_grid=[], | |
@@ -185,6 +202,12 @@ def cfg_denoised_callback(params: CFGDenoisedParams): | |
except Exception: | |
report_exception(c, 'cfg_denoised_callback') | |
+def cfg_after_cfg_callback(params: AfterCFGCallbackParams): | |
+ for c in callback_map['callbacks_cfg_after_cfg']: | |
+ try: | |
+ c.callback(params) | |
+ except Exception: | |
+ report_exception(c, 'cfg_after_cfg_callback') | |
def before_component_callback(component, **kwargs): | |
for c in callback_map['callbacks_before_component']: | |
@@ -331,6 +354,12 @@ def on_cfg_denoised(callback): | |
""" | |
add_callback(callback_map['callbacks_cfg_denoised'], callback) | |
+def on_cfg_after_cfg(callback): | |
+ """register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations has completed. | |
+ The callback is called with one argument: | |
+ - params: CFGDenoisedParams - parameters to be passed to the inner model and sampling state details. | |
+ """ | |
+ add_callback(callback_map['callbacks_cfg_after_cfg'], callback) | |
def on_before_component(callback): | |
"""register a function to be called before a component is created. | |
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py | |
index eb98e599..8e63859b 100644 | |
--- a/modules/sd_samplers_kdiffusion.py | |
+++ b/modules/sd_samplers_kdiffusion.py | |
@@ -9,6 +9,7 @@ from modules.shared import opts, state | |
import modules.shared as shared | |
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback | |
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback | |
+from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback | |
samplers_k_diffusion = [ | |
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}), | |
@@ -161,7 +162,7 @@ class CFGDenoiser(torch.nn.Module): | |
fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes]) | |
x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be | |
- denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps) | |
+ denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model) | |
cfg_denoised_callback(denoised_params) | |
devices.test_for_nans(x_out, "unet") | |
@@ -181,7 +182,13 @@ class CFGDenoiser(torch.nn.Module): | |
if self.mask is not None: | |
denoised = self.init_latent * self.mask + self.nmask * denoised | |
+ after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps) | |
+ cfg_after_cfg_callback(after_cfg_callback_params) | |
+ | |
+ if after_cfg_callback_params.output_altered: | |
+ denoised = after_cfg_callback_params.x | |
self.step += 1 | |
+ | |
return denoised | |
-- | |
2.40.0.windows.1 | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment