Last active
September 8, 2022 00:32
-
-
Save Christopher-Hayes/151d0180fabcc8e1eb02957088869bc0 to your computer and use it in GitHub Desktop.
Output Grid Code Block for Stable Diffusion Colab
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#@title 🌌 Run to start dreaming.{ vertical-output: true, display-mode: "form" } | |
import IPython | |
import base64 | |
from io import BytesIO | |
all_images = [] | |
# Clear sample output | |
!rm /content/stable-diffusion/outputs/txt2img-samples/samples/* | |
!rm /content/stable-diffusion/outputs/txt2img-samples/grid* | |
# Python really needs objects like JS | |
class Option: | |
def __init__(self, seed): | |
self.value = seed | |
def __str__(self): | |
return self.value | |
# Clear code output | |
clear_output() | |
### HTML ELEMENTS | |
# Create an HTML grid to put images in | |
def createGrid(): | |
display('create grid..') | |
display(IPython.display.Javascript(""" | |
var seeds = [] | |
const outputFooter = document.querySelector("#output-footer") | |
const grid = document.createElement('div') | |
grid.classList.add('stream') | |
grid.id = 'image-grid' | |
grid.style = ` | |
display: grid; | |
grid-template-columns: repeat(6, 1fr); | |
grid-gap: 10px; | |
padding: 3em 2em 5em 2em; | |
` | |
outputFooter.appendChild(grid) | |
const showSeeds = document.createElement('div') | |
showSeeds.style = ` | |
border: 2px solid grey; | |
background: #333; | |
margin: 0 0 2em 0; | |
padding: 2em 1em; | |
cursor: pointer; | |
width: 100%; | |
` | |
showSeeds.onmouseover = showSeeds.onfocus = () => {{ | |
showSeeds.style.background = '#111' | |
}} | |
showSeeds.onmouseout = showSeeds.onblur = () => {{ | |
showSeeds.style.background = '#333' | |
}} | |
showSeeds.innerText = 'Click to copy seeds: ' + seeds.join(' ') | |
showSeeds.onclick = () => {{ | |
navigator.clipboard.writeText(seeds.join(' ')).then(() => {{ | |
console.log('seeds copied to clipboard!') | |
}}, () => {{ | |
console.log('failed to copy to clipboard') | |
}}) | |
}} | |
outputFooter.appendChild(showSeeds) | |
// Listen for the "updateSeeds" event | |
window.addEventListener('updateSeeds', function (e) {{ | |
seeds = e.detail | |
showSeeds.innerText = 'Click to copy seeds: ' + seeds.join(' ') | |
}}) | |
""")) | |
# Add image to grid | |
def addToGrid(img_base64, seedStr): | |
display(IPython.display.Javascript(""" | |
const grid = document.querySelector('#image-grid') | |
const container = document.createElement('div') | |
container.style = ` | |
display: flex; | |
flex-direction: column; | |
` | |
const img = document.createElement('img') | |
img.src = `{img_base64}` | |
img.style = ` | |
width: 100%; | |
height: 100%; | |
object-fit: cover; | |
` | |
const button = document.createElement('button') | |
let active = false | |
button.style = ` | |
transition: all 80ms ease-out; | |
border: 4px solid grey; | |
border-radius: 4px; | |
filter: drop-shadow(0px 0px 0px black); | |
padding: 0; | |
cursor: pointer; | |
` | |
button.onmouseover = button.onfocus = () => {{ | |
button.style.transform = active ? 'scale(1)' : 'scale(1.2)'; | |
button.style.filter = 'drop-shadow(0px 2px 12px black)'; | |
button.style.zIndex = '10'; | |
}} | |
button.onmouseout = button.onblur = () => {{ | |
button.style.transform = active ? 'scale(0.8)' : 'scale(1)'; | |
button.style.filter = 'drop-shadow(0px 0px 0px black)'; | |
button.style.zIndex = 'unset'; | |
}} | |
button.onclick = () => {{ | |
active = !active | |
if (active) {{ | |
seeds.push('{seed}') | |
button.style.opacity = '0.5' | |
button.style.border = '4px solid #4747f5'; | |
button.style.transform = 'scale(0.8)'; | |
}} else {{ | |
seeds.splice(seeds.indexOf('{seed}'), 1) | |
button.style.border = '4px solid grey'; | |
button.style.opacity = '1' | |
button.style.transform = 'scale(1)'; | |
}} | |
// Create new custom event to share the "seeds" array object | |
var event = new CustomEvent('updateSeeds', {{detail: seeds}}) | |
window.dispatchEvent(event) | |
}} | |
button.appendChild(img) | |
container.appendChild(button) | |
grid.appendChild(container) | |
""".format(img_base64=img_base64, seed=seedStr))) | |
# Create image grid | |
createGrid() | |
### Prepare config | |
# Get iterable list from widget seedsopt string | |
widgetDict = get_widget_extractor(widget_opt) | |
# Remove hanging empty value if it ends with a comma | |
seeds = [s for s in widgetDict['seeds'].value.strip().replace(' ', ' ').split(' ') if s] | |
iterCount = widgetDict['n_iter'].value | |
widgetDict['n_iter'] = Option(1) | |
display(IPython.display.Javascript('if (typeof seeds === "undefined") { var seeds = [] } else { seeds = [] };')) | |
# Define run | |
def doRun(seedStr): | |
# Run inference | |
run(widgetDict) | |
# Update image grid | |
buffered = BytesIO() | |
all_images.pop().save(buffered, format="JPEG") | |
img_str = base64.b64encode(buffered.getvalue()) | |
img_base64 = (bytes("data:image/jpeg;base64,", encoding='utf-8') + img_str).decode("utf-8") | |
addToGrid(img_base64, seedStr) | |
# Runs | |
for baseSeed in seeds: | |
for iter in range(iterCount): | |
# Seed | |
seed = int(baseSeed) + iter | |
seedStr = str(seed) | |
# Add seed to dict manually | |
widgetDict['seed'] = Option(seed) | |
print('Seed:', seedStr) | |
doRun(seedStr) | |
# Get Seeds button at bottom | |
"""display(IPython.display.Javascript(''' | |
console.log('create get seeds button') | |
const button = document.createElement('button') | |
button.innerText = 'Get seeds' | |
button.onclick = () => {{ | |
console.log('seeds:', seeds) | |
const streams = document.querySelectorAll(".stream") | |
streams[streams.length - 1].appendChild(document.createTextNode(`Seeds: ${seeds.join(' ')}`)) | |
}} | |
const streams = document.querySelectorAll(".stream") | |
streams[streams.length - 1].appendChild(button) | |
''')) | |
for img in all_images: | |
print('show image in grid..') | |
buffered = BytesIO() | |
img.save(buffered, format="JPEG") | |
img_str = base64.b64encode(buffered.getvalue()) | |
img_base64 = (bytes("data:image/jpeg;base64,", encoding='utf-8') + img_str).decode("utf-8") | |
addToGrid(img_base64) | |
""" | |
print('Batch complete!') | |
# Clear VRAM | |
torch.cuda.empty_cache() | |
#@title Main one-time Setup Code (replaces txt2img.py from SD repo) { display-mode: "form" } | |
# Special installs | |
!pip install diffusers | |
# Slightly modified version of: https://github.com/CompVis/stable-diffusion/blob/main/scripts/txt2img.py | |
import argparse, os, sys, glob | |
import torch | |
import numpy as np | |
import datetime | |
from omegaconf import OmegaConf | |
from PIL import Image | |
from PIL.PngImagePlugin import PngInfo | |
#from tqdm.auto import tqdm, trange # NOTE: updated for notebook | |
from tqdm import tqdm, trange # NOTE: updated for notebook | |
from itertools import islice | |
from einops import rearrange | |
from torchvision.utils import make_grid | |
import time | |
import rich | |
from pytorch_lightning import seed_everything | |
from torch import autocast | |
from contextlib import contextmanager, nullcontext | |
from ldm.util import instantiate_from_config | |
from ldm.models.diffusion.ddim import DDIMSampler | |
from ldm.models.diffusion.plms import PLMSSampler | |
from scripts.txt2img import chunk, load_model_from_config | |
from IPython.display import clear_output | |
# Code to turn kwargs into Jupyter widgets | |
import ipywidgets as widgets | |
from collections import OrderedDict | |
def load_model(opt): | |
"""Seperates the loading of the model from the inference""" | |
if opt.laion400m: | |
print("Falling back to LAION 400M model...") | |
opt.config = "configs/latent-diffusion/txt2img-1p4B-eval.yaml" | |
opt.ckpt = "models/ldm/text2img-large/model.ckpt" | |
opt.outdir = "outputs/txt2img-samples-laion400m" | |
config = OmegaConf.load(f"{opt.config}") | |
model = load_model_from_config(config, f"{opt.ckpt}") | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
else: | |
print("Warning - running in CPU mode!") | |
device = torch.device("cpu") | |
model = model.to(device) | |
return model | |
all_images = [] | |
def run_inference(opt, model): | |
"""Seperates the loading of the model from the inference | |
Additionally, slightly modified to display generated images inline | |
""" | |
seed_everything(opt.seed) | |
if opt.plms: | |
sampler = PLMSSampler(model) | |
else: | |
sampler = DDIMSampler(model) | |
os.makedirs(opt.outdir, exist_ok=True) | |
outpath = opt.outdir | |
batch_size = opt.n_samples | |
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size | |
if not opt.from_file: | |
prompt = opt.prompt | |
assert prompt is not None | |
data = [batch_size * [prompt]] | |
else: | |
#print(f"reading prompts from {opt.from_file}") | |
with open(opt.from_file, "r") as f: | |
data = f.read().splitlines() | |
data = list(chunk(data, batch_size)) | |
# * Variables for saved image filenames | |
# date + time | |
datetimeStr = datetime.datetime.now().isoformat() | |
# Filename-safe prompt string | |
slugPrompt = "".join(c if c.isalnum() else "_" for c in opt.prompt) | |
sample_path = os.path.join(outpath, "samples") | |
os.makedirs(sample_path, exist_ok=True) | |
base_count = len(os.listdir(sample_path)) | |
grid_count = len(os.listdir(outpath)) - 1 | |
start_code = None | |
if opt.fixed_code: | |
start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device) | |
precision_scope = autocast if opt.precision=="autocast" else nullcontext | |
with torch.no_grad(): | |
with precision_scope("cuda"): | |
with model.ema_scope(): | |
tic = time.time() | |
all_samples = list() | |
for n in range(opt.n_iter): # trange(opt.n_iter, desc="Sampling"): | |
for prompts in data: # tqdm(data, desc="data"): | |
uc = None | |
if opt.scale != 1.0: | |
uc = model.get_learned_conditioning(batch_size * [""]) | |
if isinstance(prompts, tuple): | |
prompts = list(prompts) | |
c = model.get_learned_conditioning(prompts) | |
shape = [opt.C, opt.H // opt.f, opt.W // opt.f] | |
samples_ddim, _ = sampler.sample(S=opt.ddim_steps, | |
conditioning=c, | |
batch_size=opt.n_samples, | |
shape=shape, | |
verbose=False, | |
unconditional_guidance_scale=opt.scale, | |
unconditional_conditioning=uc, | |
eta=opt.ddim_eta, | |
x_T=start_code) | |
x_samples_ddim = model.decode_first_stage(samples_ddim) | |
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) | |
if not opt.skip_save: | |
for x_sample in x_samples_ddim: | |
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') | |
imgPath = f'{slugPrompt[:150]}_{opt.seed}_{datetimeStr}_{base_count:05}.png' | |
#Image.fromarray(x_sample.astype(np.uint8)).save( | |
# os.path.join(sample_path, imgPath)) | |
img = Image.fromarray(x_sample.astype(np.uint8)) # Image.open(imgPath) | |
all_images.append(img) | |
metadata = PngInfo() | |
metadata.add_text("artist", 'Chris Hayes') | |
metadata.add_text("copyright", 'Public Domain') | |
metadata.add_text("software", "Stable Diffusion 1.4") | |
metadata.add_text("title", opt.prompt) | |
config = f"prompt: {opt.prompt}, seed: {seed}, steps: {opt.ddim_steps}, CGS: {opt.scale}" | |
metadata.add_text("config", config) | |
img.save(os.path.join(sample_path, imgPath), pnginfo=metadata) | |
base_count += 1 | |
if not opt.skip_grid: | |
all_samples.append(x_samples_ddim) | |
if not opt.skip_grid: | |
# additionally, save as grid | |
grid = torch.stack(all_samples, 0) | |
grid = rearrange(grid, 'n b c h w -> (n b) c h w') | |
grid = make_grid(grid, nrow=n_rows) | |
# to image | |
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() | |
#Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{slugPrompt[:150]}_{opt.seed}_{datetimeStr}_{grid_count:04}.png')) | |
grid_count += 1 | |
# display | |
#if opt.display_inline: | |
#clear_output() | |
#display(Image.fromarray(grid.astype(np.uint8))) | |
toc = time.time() | |
#print(f"Your samples have been saved to: \n{outpath} \n" | |
# f" \nEnjoy.") | |
def run(opt): | |
"""If the model parameters changed, reload the model, otherwise, just do inference""" | |
#print(f"Creating image ({opt.H},{opt.W}) from prompt:\n\"{opt.prompt}\"\n") | |
# FIXME global hack | |
global last_config | |
global last_ckpt | |
global model | |
if (opt.config != last_config) or (opt.ckpt != last_ckpt): | |
model = load_model(opt) | |
# FIXME global hack | |
last_config = opt.config | |
last_ckpt = opt.ckpt | |
run_inference(opt, model) | |
# FIXME global hack | |
last_config = "" | |
last_ckpt = "" | |
####### Widget GUI code ####### | |
def get_widget_extractor(widget_dict): | |
# allows accessing after setting, this is to reduce the diff against the argparse code | |
class WidgetDict(OrderedDict): | |
def __getattr__(self,val): | |
return self[val].value | |
return WidgetDict(widget_dict) | |
# Allows long widget descriptions | |
style = {'description_width': 'initial'} | |
# Force widget width to max | |
layout = widgets.Layout(width='100%') | |
# args from argparse converted to widgets: | |
# https://github.com/CompVis/stable-diffusion/blob/main/scripts/txt2img.py#L48-L177 | |
widget_opt = OrderedDict() | |
widget_opt['outdir'] = widgets.Text( | |
layout=layout, style=style, | |
description='dir to write results to', | |
value="outputs/txt2img-samples", | |
disabled=False | |
) | |
widget_opt['skip_grid'] = widgets.Checkbox( | |
layout=layout, style=style, | |
value=False, | |
description='do not save a grid, only individual samples. Helpful when evaluating lots of samples', | |
indent=False, | |
disabled=False | |
) | |
widget_opt['skip_save'] = widgets.Checkbox( | |
layout=layout, style=style, | |
value=False, | |
description='do not save individual samples. For speed measurements.', | |
indent=False, | |
disabled=False | |
) | |
widget_opt['plms'] = widgets.Checkbox( | |
layout=layout, style=style, | |
value=False, | |
description='use plms sampling (not checked = ddim)', | |
indent=False, | |
disabled=False | |
) | |
widget_opt['laion400m'] = widgets.Checkbox( | |
layout=layout, style=style, | |
value=False, | |
description='uses the LAION400M model', | |
indent=False, | |
disabled=False | |
) | |
widget_opt['fixed_code'] = widgets.Checkbox( | |
layout=layout, style=style, | |
value=False, | |
description='if enabled, uses the same starting code across samples', | |
indent=False, | |
disabled=False | |
) | |
widget_opt['ddim_eta'] = widgets.FloatText( | |
layout=layout, style=style, | |
description='ddim eta (eta=0.0 corresponds to deterministic sampling', | |
value=0.0, | |
disabled=False | |
) | |
widget_opt['C'] = widgets.IntText( | |
layout=layout, style=style, | |
description='latent channels', | |
value=4, | |
disabled=False | |
) | |
widget_opt['f'] = widgets.IntText( | |
layout=layout, style=style, | |
description='downsampling factor', | |
value=8, | |
disabled=False | |
) | |
widget_opt['n_rows'] = widgets.IntText( | |
layout=layout, style=style, | |
description='rows in the grid (default: n_samples)', | |
value=0, | |
disabled=False | |
) | |
widget_opt['scale'] = widgets.FloatText( | |
layout=layout, style=style, | |
description='unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))', | |
value=7.5, | |
disabled=False | |
) | |
widget_opt['from_file'] = widgets.Text( | |
layout=layout, style=style, | |
description='if specified, load prompts from this file', | |
value=None, | |
disabled=False | |
) | |
widget_opt['config'] = widgets.Text( | |
layout=layout, style=style, | |
description='path to config which constructs model', | |
value="configs/stable-diffusion/v1-inference.yaml", | |
disabled=False | |
) | |
widget_opt['ckpt'] = widgets.Text( | |
layout=layout, style=style, | |
description='path to checkpoint of model', | |
value="models/ldm/stable-diffusion-v1/model.ckpt", | |
disabled=False | |
) | |
widget_opt['precision'] = widgets.Combobox( | |
layout=layout, style=style, | |
description='evaluate at this precision', | |
value="autocast", | |
options=["full", "autocast"], | |
disabled=False | |
) | |
# Extra option for the notebook | |
widget_opt['display_inline'] = widgets.Checkbox( | |
layout=layout, style=style, | |
value=True, | |
description='display output images inline (in addition to saving them)', | |
indent=False, | |
disabled=False | |
) | |
# Common | |
widget_opt['n_iter'] = widgets.IntText( | |
layout=layout, style=style, | |
description='sample this often', | |
value=1, | |
disabled=False | |
) | |
widget_opt['n_samples'] = widgets.IntText( | |
layout=layout, style=style, | |
description='how many samples to produce for each given prompt. A.k.a. batch size', | |
value=1, | |
disabled=False | |
) | |
widget_opt['H'] = widgets.IntText( | |
layout=layout, style=style, | |
description='image height, in pixel space', | |
value=512, | |
disabled=False | |
) | |
widget_opt['W'] = widgets.IntText( | |
layout=layout, style=style, | |
description='image width, in pixel space', | |
value=512, | |
disabled=False | |
) | |
widget_opt['prompt'] = widgets.Text( | |
layout=layout, style=style, | |
description='the prompt to render', | |
#value="a painting of a virus monster playing guitar", # script default | |
value="a photograph of an astronaut riding a horse", # README default | |
disabled=False | |
) | |
widget_opt['ddim_steps'] = widgets.IntText( | |
layout=layout, style=style, | |
description='number of ddim sampling steps', | |
value=50, | |
disabled=False | |
) | |
widget_opt['seeds'] = widgets.Text( | |
layout=layout, style=style, | |
description='multiple seeds for batch runs (separate with a space)', | |
value='42', | |
disabled=False | |
) | |
# Button that runs the | |
# Alternatively, you can just run the following in a new cell: | |
# run(get_widget_extractor(widget_opt)) | |
run_button = widgets.Button( | |
description='CLICK TO DREAM', | |
disabled=False, | |
button_style='', # 'success', 'info', 'warning', 'danger' or '' | |
tooltip='Click to run (settings will update automatically)', | |
icon='check' | |
) | |
run_button_out = widgets.Output() | |
# this doesn't get used | |
def on_run_button_click(b): | |
with run_button_out: | |
widgetDict = get_widget_extractor(widget_opt) | |
for seed in widgetDict['seeds'].split(','): | |
#clear_output() | |
widgetDict['seed'] = seed | |
run(widgetDict) | |
run_button.on_click(on_run_button_click) | |
# Package into box and render | |
#primary_options = ['prompt', 'outdir'] # options to put up top | |
#secondary_options = [k for k in widget_opt.keys() if k not in primary_options] # rest, ordered by insertion | |
load_options = ['config', 'ckpt'] | |
inference_options = [k for k in widget_opt.keys() if k not in load_options] # rest, ordered by insertion | |
assert all([k in inference_options + load_options for k in widget_opt.keys()]) # make sure we didn't miss any options | |
# Package into box for rendering | |
gui = widgets.VBox( | |
[widget_opt[k] for k in inference_options] + [widget_opt[k] for k in load_options] # + [run_button, run_button_out] | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment