Created
June 15, 2024 23:19
-
-
Save madebyollin/97f290982f9847f45d69676eabb8b9b6 to your computer and use it in GitHub Desktop.
A quick hacked version of the sd3 gradio UI that has live previews (via TAESD3)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import gradio as gr | |
import numpy as np | |
import random | |
import torch | |
from diffusers import ( | |
StableDiffusion3Pipeline, | |
SD3Transformer2DModel, | |
FlowMatchEulerDiscreteScheduler, | |
AutoencoderTiny, | |
) | |
from typing import Any, Callable, Dict, List, Optional, Union | |
# import spaces | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
dtype = torch.float16 | |
repo = "stabilityai/stable-diffusion-3-medium-diffusers" | |
pipe = StableDiffusion3Pipeline.from_pretrained(repo, torch_dtype=torch.float16).to( | |
device | |
) | |
taesd3 = ( | |
AutoencoderTiny.from_pretrained("madebyollin/taesd3", torch_dtype=torch.float16) | |
.half() | |
.eval() | |
.requires_grad_(False) | |
.to(device) | |
) | |
taesd3.decoder.layers = torch.compile( | |
taesd3.decoder.layers, | |
fullgraph=True, | |
dynamic=False, | |
mode="max-autotune-no-cudagraphs", | |
) | |
MAX_SEED = np.iinfo(np.int32).max | |
MAX_IMAGE_SIZE = 1344 | |
def get_pred_original_sample(sched, model_output, timestep, sample): | |
return ( | |
sample | |
- sched.sigmas[(sched.timesteps == timestep).nonzero().item()] * model_output | |
) | |
def retrieve_timesteps( | |
scheduler, | |
num_inference_steps: Optional[int] = None, | |
device: Optional[Union[str, torch.device]] = None, | |
timesteps: Optional[List[int]] = None, | |
sigmas: Optional[List[float]] = None, | |
**kwargs, | |
): | |
""" | |
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles | |
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. | |
Args: | |
scheduler (`SchedulerMixin`): | |
The scheduler to get timesteps from. | |
num_inference_steps (`int`): | |
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` | |
must be `None`. | |
device (`str` or `torch.device`, *optional*): | |
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. | |
timesteps (`List[int]`, *optional*): | |
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, | |
`num_inference_steps` and `sigmas` must be `None`. | |
sigmas (`List[float]`, *optional*): | |
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, | |
`num_inference_steps` and `timesteps` must be `None`. | |
Returns: | |
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the | |
second element is the number of inference steps. | |
""" | |
if timesteps is not None and sigmas is not None: | |
raise ValueError( | |
"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" | |
) | |
if timesteps is not None: | |
accepts_timesteps = "timesteps" in set( | |
inspect.signature(scheduler.set_timesteps).parameters.keys() | |
) | |
if not accepts_timesteps: | |
raise ValueError( | |
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" | |
f" timestep schedules. Please check whether you are using the correct scheduler." | |
) | |
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) | |
timesteps = scheduler.timesteps | |
num_inference_steps = len(timesteps) | |
elif sigmas is not None: | |
accept_sigmas = "sigmas" in set( | |
inspect.signature(scheduler.set_timesteps).parameters.keys() | |
) | |
if not accept_sigmas: | |
raise ValueError( | |
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" | |
f" sigmas schedules. Please check whether you are using the correct scheduler." | |
) | |
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) | |
timesteps = scheduler.timesteps | |
num_inference_steps = len(timesteps) | |
else: | |
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) | |
timesteps = scheduler.timesteps | |
return timesteps, num_inference_steps | |
@torch.no_grad() | |
def sd3_pipe_call_that_returns_an_iterable_of_images( | |
self, | |
prompt: Union[str, List[str]] = None, | |
prompt_2: Optional[Union[str, List[str]]] = None, | |
prompt_3: Optional[Union[str, List[str]]] = None, | |
height: Optional[int] = None, | |
width: Optional[int] = None, | |
num_inference_steps: int = 28, | |
timesteps: List[int] = None, | |
guidance_scale: float = 7.0, | |
negative_prompt: Optional[Union[str, List[str]]] = None, | |
negative_prompt_2: Optional[Union[str, List[str]]] = None, | |
negative_prompt_3: Optional[Union[str, List[str]]] = None, | |
num_images_per_prompt: Optional[int] = 1, | |
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
latents: Optional[torch.FloatTensor] = None, | |
prompt_embeds: Optional[torch.FloatTensor] = None, | |
negative_prompt_embeds: Optional[torch.FloatTensor] = None, | |
pooled_prompt_embeds: Optional[torch.FloatTensor] = None, | |
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, | |
output_type: Optional[str] = "pil", | |
return_dict: bool = True, | |
joint_attention_kwargs: Optional[Dict[str, Any]] = None, | |
clip_skip: Optional[int] = None, | |
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, | |
callback_on_step_end_tensor_inputs: List[str] = ["latents"], | |
): | |
height = height or self.default_sample_size * self.vae_scale_factor | |
width = width or self.default_sample_size * self.vae_scale_factor | |
# 1. Check inputs. Raise error if not correct | |
self.check_inputs( | |
prompt, | |
prompt_2, | |
prompt_3, | |
height, | |
width, | |
negative_prompt=negative_prompt, | |
negative_prompt_2=negative_prompt_2, | |
negative_prompt_3=negative_prompt_3, | |
prompt_embeds=prompt_embeds, | |
negative_prompt_embeds=negative_prompt_embeds, | |
pooled_prompt_embeds=pooled_prompt_embeds, | |
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, | |
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, | |
) | |
self._guidance_scale = guidance_scale | |
self._clip_skip = clip_skip | |
self._joint_attention_kwargs = joint_attention_kwargs | |
self._interrupt = False | |
# 2. Define call parameters | |
if prompt is not None and isinstance(prompt, str): | |
batch_size = 1 | |
elif prompt is not None and isinstance(prompt, list): | |
batch_size = len(prompt) | |
else: | |
batch_size = prompt_embeds.shape[0] | |
device = self._execution_device | |
( | |
prompt_embeds, | |
negative_prompt_embeds, | |
pooled_prompt_embeds, | |
negative_pooled_prompt_embeds, | |
) = self.encode_prompt( | |
prompt=prompt, | |
prompt_2=prompt_2, | |
prompt_3=prompt_3, | |
negative_prompt=negative_prompt, | |
negative_prompt_2=negative_prompt_2, | |
negative_prompt_3=negative_prompt_3, | |
do_classifier_free_guidance=self.do_classifier_free_guidance, | |
prompt_embeds=prompt_embeds, | |
negative_prompt_embeds=negative_prompt_embeds, | |
pooled_prompt_embeds=pooled_prompt_embeds, | |
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, | |
device=device, | |
clip_skip=self.clip_skip, | |
num_images_per_prompt=num_images_per_prompt, | |
) | |
if self.do_classifier_free_guidance: | |
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) | |
pooled_prompt_embeds = torch.cat( | |
[negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0 | |
) | |
# 4. Prepare timesteps | |
timesteps, num_inference_steps = retrieve_timesteps( | |
self.scheduler, num_inference_steps, device, timesteps | |
) | |
num_warmup_steps = max( | |
len(timesteps) - num_inference_steps * self.scheduler.order, 0 | |
) | |
self._num_timesteps = len(timesteps) | |
# 5. Prepare latent variables | |
num_channels_latents = self.transformer.config.in_channels | |
latents = self.prepare_latents( | |
batch_size * num_images_per_prompt, | |
num_channels_latents, | |
height, | |
width, | |
prompt_embeds.dtype, | |
device, | |
generator, | |
latents, | |
) | |
# 6. Denoising loop | |
# with self.progress_bar(total=num_inference_steps) as progress_bar: | |
if True: | |
for i, t in enumerate(timesteps): | |
if self.interrupt: | |
continue | |
# expand the latents if we are doing classifier free guidance | |
latent_model_input = ( | |
torch.cat([latents] * 2) | |
if self.do_classifier_free_guidance | |
else latents | |
) | |
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML | |
timestep = t.expand(latent_model_input.shape[0]) | |
noise_pred = self.transformer( | |
hidden_states=latent_model_input, | |
timestep=timestep, | |
encoder_hidden_states=prompt_embeds, | |
pooled_projections=pooled_prompt_embeds, | |
joint_attention_kwargs=self.joint_attention_kwargs, | |
return_dict=False, | |
)[0] | |
# perform guidance | |
if self.do_classifier_free_guidance: | |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
noise_pred = noise_pred_uncond + self.guidance_scale * ( | |
noise_pred_text - noise_pred_uncond | |
) | |
# compute the previous noisy sample x_t -> x_t-1 | |
latents_dtype = latents.dtype | |
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] | |
x0_pred = get_pred_original_sample(self.scheduler, noise_pred, t, latents) | |
yield self.image_processor.postprocess(taesd3.decode(x0_pred)[0])[0] | |
# if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): | |
# progress_bar.update() | |
# | |
yield self.image_processor.postprocess( | |
self.vae.decode( | |
(latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor, | |
return_dict=False, | |
)[0] | |
)[0] | |
# @spaces.GPU | |
def infer( | |
prompt, | |
negative_prompt, | |
seed, | |
randomize_seed, | |
width, | |
height, | |
guidance_scale, | |
num_inference_steps, | |
progress=gr.Progress(track_tqdm=True), | |
): | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
generator = torch.Generator().manual_seed(seed) | |
yield from sd3_pipe_call_that_returns_an_iterable_of_images( | |
pipe, | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
width=width, | |
height=height, | |
generator=generator, | |
) | |
examples = [ | |
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", | |
"An astronaut riding a green horse", | |
"A delicious ceviche cheesecake slice", | |
] | |
css = """ | |
#col-container { | |
margin: 0 auto; | |
max-width: 580px; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
with gr.Column(elem_id="col-container"): | |
gr.Markdown( | |
f""" | |
# Demo [Stable Diffusion 3 Medium](https://huggingface.co/stabilityai/stable-diffusion-3-medium) | |
Learn more about the [Stable Diffusion 3 series](https://stability.ai/news/stable-diffusion-3). Try on [Stability AI API](https://platform.stability.ai/docs/api-reference#tag/Generate/paths/~1v2beta~1stable-image~1generate~1sd3/post), [Stable Assistant](https://stability.ai/stable-assistant), or on Discord via [Stable Artisan](https://stability.ai/stable-artisan). Run locally with [ComfyUI](https://github.com/comfyanonymous/ComfyUI) or [diffusers](https://github.com/huggingface/diffusers) | |
""" | |
) | |
with gr.Row(): | |
prompt = gr.Text( | |
label="Prompt", | |
show_label=False, | |
max_lines=1, | |
placeholder="Enter your prompt", | |
container=False, | |
) | |
run_button = gr.Button("Run", scale=0) | |
result = gr.Image(label="Result", show_label=False) | |
with gr.Accordion("Advanced Settings", open=False): | |
negative_prompt = gr.Text( | |
label="Negative prompt", | |
max_lines=1, | |
placeholder="Enter a negative prompt", | |
) | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1, | |
value=0, | |
) | |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
with gr.Row(): | |
width = gr.Slider( | |
label="Width", | |
minimum=256, | |
maximum=MAX_IMAGE_SIZE, | |
step=64, | |
value=1024, | |
) | |
height = gr.Slider( | |
label="Height", | |
minimum=256, | |
maximum=MAX_IMAGE_SIZE, | |
step=64, | |
value=1024, | |
) | |
with gr.Row(): | |
guidance_scale = gr.Slider( | |
label="Guidance scale", | |
minimum=0.0, | |
maximum=10.0, | |
step=0.1, | |
value=5.0, | |
) | |
num_inference_steps = gr.Slider( | |
label="Number of inference steps", | |
minimum=1, | |
maximum=50, | |
step=1, | |
value=28, | |
) | |
gr.Examples(examples=examples, inputs=[prompt]) | |
gr.on( | |
triggers=[run_button.click, prompt.submit, negative_prompt.submit], | |
fn=infer, | |
inputs=[ | |
prompt, | |
negative_prompt, | |
seed, | |
randomize_seed, | |
width, | |
height, | |
guidance_scale, | |
num_inference_steps, | |
], | |
outputs=result, | |
) | |
demo.launch(share=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment