Created
August 31, 2022 10:39
-
-
Save td2sk/96f4454f5cb9f67905d134864aa68e19 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""make variations of input image""" | |
import argparse | |
import glob | |
import os | |
import sys | |
import time | |
from contextlib import nullcontext | |
from itertools import islice | |
import numpy as np | |
import PIL | |
import torch | |
from einops import rearrange, repeat | |
from ldm.models.diffusion.ddim import DDIMSampler | |
from ldm.models.diffusion.plms import PLMSSampler | |
from ldm.util import instantiate_from_config | |
from omegaconf import OmegaConf | |
from PIL import Image | |
from pytorch_lightning import seed_everything | |
from torch import autocast | |
from torchvision.utils import make_grid | |
from tqdm import tqdm, trange | |
def chunk(it, size): | |
it = iter(it) | |
return iter(lambda: tuple(islice(it, size)), ()) | |
def load_model_from_config(config, ckpt, verbose=False): | |
print(f"Loading model from {ckpt}") | |
pl_sd = torch.load(ckpt, map_location="cpu") | |
if "global_step" in pl_sd: | |
print(f"Global Step: {pl_sd['global_step']}") | |
sd = pl_sd["state_dict"] | |
model = instantiate_from_config(config.model) | |
m, u = model.load_state_dict(sd, strict=False) | |
if len(m) > 0 and verbose: | |
print("missing keys:") | |
print(m) | |
if len(u) > 0 and verbose: | |
print("unexpected keys:") | |
print(u) | |
model.cuda() | |
model.eval() | |
return model | |
def load_img(path): | |
image = Image.open(path).convert("RGB") | |
w, h = image.size | |
print(f"loaded input image of size ({w}, {h}) from {path}") | |
# resize to integer multiple of 32 | |
w, h = map(lambda x: x - x % 32, (w, h)) | |
image = image.resize((w, h), resample=PIL.Image.LANCZOS) | |
image = np.array(image).astype(np.float32) / 255.0 | |
image = image[None].transpose(0, 3, 1, 2) | |
image = torch.from_numpy(image) | |
return 2.*image - 1. | |
def load_model(config="configs/stable-diffusion/v1-inference.yaml", ckpt="models/ldm/stable-diffusion-v1/model.ckpt"): | |
config = OmegaConf.load(f"{config}") | |
model = load_model_from_config(config, f"{ckpt}") | |
model = model.to(torch.float16) | |
return model | |
def main(model, prompt: str, init_img: str, skip_grid=True, ddim_steps=50, plms=True, fixed_code=False, ddim_eta=0.0, n_iter=1, C=4, f=8, n_samples=1, n_rows=0, scale=5.0, strength=0.75, from_file=None, seed=42, precision="autocast", prompt_correction=[]): | |
seed_everything(seed) | |
device = torch.device( | |
"cuda") if torch.cuda.is_available() else torch.device("cpu") | |
model = model.to(device) | |
if plms: | |
raise NotImplementedError("PLMS sampler not (yet) supported") | |
sampler = PLMSSampler(model) | |
else: | |
sampler = DDIMSampler(model) | |
batch_size = n_samples | |
n_rows = n_rows if n_rows > 0 else batch_size | |
if not from_file: | |
prompt = prompt | |
assert prompt is not None | |
data = [batch_size * [prompt]] | |
else: | |
print(f"reading prompts from {from_file}") | |
with open(from_file, "r") as f: | |
data = f.read().splitlines() | |
data = list(chunk(data, batch_size)) | |
assert os.path.isfile(init_img) | |
init_image = load_img(init_img).to(torch.float16).to(device) | |
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) | |
init_latent = model.get_first_stage_encoding( | |
model.encode_first_stage(init_image)) # move to latent space | |
sampler.make_schedule(ddim_num_steps=ddim_steps, | |
ddim_eta=ddim_eta, verbose=False) | |
assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]' | |
t_enc = int(strength * ddim_steps) | |
print(f"target t_enc is {t_enc} steps") | |
images = [] | |
precision_scope = autocast if precision == "autocast" else nullcontext | |
with torch.no_grad(): | |
with precision_scope("cuda"): | |
with model.ema_scope(): | |
tic = time.time() | |
all_samples = list() | |
for n in trange(n_iter, desc="Sampling"): | |
for prompts in tqdm(data, desc="data"): | |
uc = None | |
if scale != 1.0: | |
uc = model.get_learned_conditioning( | |
batch_size * [""]) | |
if isinstance(prompts, tuple): | |
prompts = list(prompts) | |
c = model.get_learned_conditioning(prompts) | |
for pw in prompt_correction: | |
pw = pw.split('::') | |
p, weight = pw[:-1], float(pw[-1]) | |
c += weight * \ | |
model.get_learned_conditioning(list(p)) | |
# encode (scaled latent) | |
z_enc = sampler.stochastic_encode( | |
init_latent, torch.tensor([t_enc]*batch_size).to(device)) | |
# decode it | |
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=scale, | |
unconditional_conditioning=uc,) | |
x_samples = model.decode_first_stage(samples) | |
x_samples = torch.clamp( | |
(x_samples + 1.0) / 2.0, min=0.0, max=1.0) | |
for x_sample in x_samples: | |
x_sample = 255. * \ | |
rearrange(x_sample.cpu().numpy(), | |
'c h w -> h w c') | |
images.append(Image.fromarray( | |
x_sample.astype(np.uint8))) | |
toc = time.time() | |
return images | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment