Created
August 31, 2024 13:53
-
-
Save madebyollin/f8c09233d3d8a91bd0866941e14690a0 to your computer and use it in GitHub Desktop.
Hacks for stable (non-flickery) preview demo of diffusers FLUX.1 model in jupyter notebooks
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from IPython.display import HTML | |
def get_pred_original_sample(sched, model_output, timestep, sample): | |
return sample - sched.sigmas[(sched.timesteps == timestep).nonzero().item()] * model_output | |
# TODO: fix awful globals | |
prev_img_str = None | |
def pil_to_html(pil_img, h=IM_HEIGHT, w=IM_WIDTH): | |
global prev_img_str | |
# super complicated workaround for flickering in jupyterlab image display. | |
# we have to manually keep the previous image onscreen until the new image | |
# is fully loaded, otherwise jupyterlab will flicker during the loading interval | |
buffered = io.BytesIO() | |
pil_img.save(buffered, format="JPEG", quality=90) # adjust as needed for network quality | |
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
if prev_img_str is None: | |
prev_img_str = img_str | |
style = "position:absolute;left:0;top:0;display:inline-block;" | |
images = f'<img style="{style}; z-index:0;" width={w} height={h} src="data:image/jpeg;base64,{prev_img_str}" />' | |
images += f'<img style="{style}; z-index:10;" width={w} height={h} src="data:image/jpeg;base64,{img_str}" />' | |
prev_img_str = img_str | |
return HTML(f'<div style="background:black;width:{w}px;height:{h}px;">{images}</div>') | |
class Previewer: | |
def __init__(self): | |
self.text_handle = None | |
self.im_handle = None | |
self.reset() | |
def reset(self): | |
self.last_step_time = None | |
self.i = 0 | |
@th.no_grad() | |
def step(self, x0_est): | |
global N_STEPS | |
tick = time.time() | |
# ugh | |
step_duration = None if self.last_step_time is None else tick - self.last_step_time | |
# ughhhh | |
x0_est = pipe._unpack_latents(x0_est, HEIGHT, WIDTH, pipe.vae_scale_factor) | |
dec = taesd3.decoder(x0_est).mul_(0.5).add_(0.5).clamp_(0, 1).mul_(255).round_().byte().cpu() | |
tock = time.time() | |
pbar = f"<div style='vertical-align: middle;display:inline-block;background:mediumspringgreen;height:1.5em;width:{self.i+1}em;'></div><div style='vertical-align: middle; display:inline-block;background:#404040;height:1.5em;width:{N_STEPS-self.i-1}em;margin-right:1em;'></div>" | |
disp_text = HTML(f"<div style='font-family:monospace;'>{pbar} FLUX.1 Schnell step <strong>{self.i + 1: 3d} / {N_STEPS}</strong> took <strong>{step_duration*1000:.1f}</strong>ms; TAEF1 previewing took <strong>{(tock-tick)*1000:.1f}</strong>ms</div>") | |
im = TF.to_pil_image(dec[0]) | |
disp_im = pil_to_html(im) | |
if self.text_handle is None: | |
self.text_handle = display(disp_text, display_id=True) | |
self.im_handle = display(disp_im, display_id=True) | |
else: | |
self.text_handle.update(disp_text) | |
self.im_handle.update(disp_im) | |
self.last_step_time = time.time() | |
self.i += 1 | |
previewer = Previewer() | |
def add_taesd_previewing(pipe, taesd3, previewer): | |
sched = pipe.scheduler | |
if not hasattr(sched, "_step"): | |
sched._step = sched.step | |
def step_and_preview(*args, **kwargs): | |
previewer.step(get_pred_original_sample(sched, *args)) | |
return sched._step(*args, **kwargs) | |
sched.step = step_and_preview | |
# have to hook inside the scheduler to even get at the predicted result | |
add_taesd_previewing(pipe, taesd3, previewer) | |
# code to actually launch the demo animation | |
previewer.reset() | |
with torch.no_grad(): | |
previewer.last_step_time = time.time() | |
images = pipe( | |
prompt = prompt, | |
width = width, | |
height = height, | |
output_type="latent", | |
num_inference_steps=N_STEPS, | |
guidance_scale=0.0, | |
max_sequence_length=256, | |
generator=generator, | |
).images |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment