Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save madebyollin/f8c09233d3d8a91bd0866941e14690a0 to your computer and use it in GitHub Desktop.
Save madebyollin/f8c09233d3d8a91bd0866941e14690a0 to your computer and use it in GitHub Desktop.
Hacks for stable (non-flickery) preview demo of diffusers FLUX.1 model in jupyter notebooks
from IPython.display import HTML
def get_pred_original_sample(sched, model_output, timestep, sample):
return sample - sched.sigmas[(sched.timesteps == timestep).nonzero().item()] * model_output
# TODO: fix awful globals
prev_img_str = None
def pil_to_html(pil_img, h=IM_HEIGHT, w=IM_WIDTH):
global prev_img_str
# super complicated workaround for flickering in jupyterlab image display.
# we have to manually keep the previous image onscreen until the new image
# is fully loaded, otherwise jupyterlab will flicker during the loading interval
buffered = io.BytesIO()
pil_img.save(buffered, format="JPEG", quality=90) # adjust as needed for network quality
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
if prev_img_str is None:
prev_img_str = img_str
style = "position:absolute;left:0;top:0;display:inline-block;"
images = f'<img style="{style}; z-index:0;" width={w} height={h} src="data:image/jpeg;base64,{prev_img_str}" />'
images += f'<img style="{style}; z-index:10;" width={w} height={h} src="data:image/jpeg;base64,{img_str}" />'
prev_img_str = img_str
return HTML(f'<div style="background:black;width:{w}px;height:{h}px;">{images}</div>')
class Previewer:
def __init__(self):
self.text_handle = None
self.im_handle = None
self.reset()
def reset(self):
self.last_step_time = None
self.i = 0
@th.no_grad()
def step(self, x0_est):
global N_STEPS
tick = time.time()
# ugh
step_duration = None if self.last_step_time is None else tick - self.last_step_time
# ughhhh
x0_est = pipe._unpack_latents(x0_est, HEIGHT, WIDTH, pipe.vae_scale_factor)
dec = taesd3.decoder(x0_est).mul_(0.5).add_(0.5).clamp_(0, 1).mul_(255).round_().byte().cpu()
tock = time.time()
pbar = f"<div style='vertical-align: middle;display:inline-block;background:mediumspringgreen;height:1.5em;width:{self.i+1}em;'></div><div style='vertical-align: middle; display:inline-block;background:#404040;height:1.5em;width:{N_STEPS-self.i-1}em;margin-right:1em;'></div>"
disp_text = HTML(f"<div style='font-family:monospace;'>{pbar} FLUX.1 Schnell step <strong>{self.i + 1: 3d} / {N_STEPS}</strong> took <strong>{step_duration*1000:.1f}</strong>ms; TAEF1 previewing took <strong>{(tock-tick)*1000:.1f}</strong>ms</div>")
im = TF.to_pil_image(dec[0])
disp_im = pil_to_html(im)
if self.text_handle is None:
self.text_handle = display(disp_text, display_id=True)
self.im_handle = display(disp_im, display_id=True)
else:
self.text_handle.update(disp_text)
self.im_handle.update(disp_im)
self.last_step_time = time.time()
self.i += 1
previewer = Previewer()
def add_taesd_previewing(pipe, taesd3, previewer):
sched = pipe.scheduler
if not hasattr(sched, "_step"):
sched._step = sched.step
def step_and_preview(*args, **kwargs):
previewer.step(get_pred_original_sample(sched, *args))
return sched._step(*args, **kwargs)
sched.step = step_and_preview
# have to hook inside the scheduler to even get at the predicted result
add_taesd_previewing(pipe, taesd3, previewer)
# code to actually launch the demo animation
previewer.reset()
with torch.no_grad():
previewer.last_step_time = time.time()
images = pipe(
prompt = prompt,
width = width,
height = height,
output_type="latent",
num_inference_steps=N_STEPS,
guidance_scale=0.0,
max_sequence_length=256,
generator=generator,
).images
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment