Created
January 14, 2024 18:24
-
-
Save morrisalp/2c8d150b6187a3bf50b4a89695da78e1 to your computer and use it in GitHub Desktop.
Diffusers SDXL pipeline with gradients (overriding no_grad), tested with diffusers v25.0. Use output_type="latent" when calling pipeline to get latents with gradients.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from diffusers import StableDiffusionXLPipeline | |
class CustomPipeline(StableDiffusionXLPipeline): | |
@classmethod | |
def from_pretrained(cls, *args, **kwargs): | |
self = super().from_pretrained(*args, **kwargs) | |
assert self.watermark is None # watermarking currently not supported | |
def postprocess_no_grad(image, *args, **kwargs): | |
return self.image_processor.__class__.postprocess( | |
self.image_processor, image.detach(), *args, **kwargs) | |
self.image_processor.postprocess = postprocess_no_grad | |
return self | |
def __call__(self, *args, **kwargs): | |
return super().__call__.__wrapped__(self, *args, **kwargs) | |
# ^ __wrapped__: removes @torch.no_grad decorator |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment