Created
October 11, 2022 11:12
-
-
Save arenasys/e7df1162fee7183ac26d059b85462ef7 to your computer and use it in GitHub Desktop.
squarize images via LD infilling
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse, os, sys, glob | |
import urllib.request | |
import statistics | |
# put this script in stable-diffusion-webui and run it: python squarize.py | |
# the model will be downloaded and the squarize folder will be created. | |
# all images in squarize/input get processed and inpainted, ending up in squarize/output | |
# *stable-diffusion-webui needs to be working | |
STEPS = 50 #edit this | |
def include(path): | |
path = os.path.abspath(path) | |
sys.path.append(path) | |
def download(url, location): | |
urllib.request.urlretrieve(url, location) | |
os.makedirs("squarize/input", exist_ok=True) | |
os.makedirs("squarize/output", exist_ok=True) | |
os.makedirs("squarize/tmp", exist_ok=True) | |
include("repositories/stable-diffusion") | |
include("repositories/taming-transformers") | |
model_path = "repositories/stable-diffusion/models/ldm/inpainting_big/" | |
if not os.path.exists(model_path+"inpainting_big.ckpt"): | |
print("Downloading inpainting model (3.3GB)...") | |
download("https://heibox.uni-heidelberg.de/f/4d9ac7ea40c64582b7c9/?dl=1", model_path+"inpainting_big.ckpt") | |
from omegaconf import OmegaConf | |
from tqdm import tqdm | |
import numpy as np | |
import torch | |
from ldm.util import instantiate_from_config | |
from ldm.models.diffusion.ddim import DDIMSampler | |
from PIL import Image, ImageDraw, ImageOps, ImageFilter | |
import operator | |
ext = ["png", "jpg", "jpeg", "webp"] | |
skip_existing = False | |
def get_edge_colors(img, vertical, inset=0.01, samples=16): | |
a = (0,0,0) | |
b = (0,0,0) | |
inset = int(max(img.size[0], img.size[1])*0.01) | |
a_s = [] | |
b_s = [] | |
if vertical: | |
for y in range(0, img.size[1], img.size[1]//samples): | |
a_s += [img.getpixel((inset,y))] | |
b_s += [img.getpixel((img.size[0]-1-inset,y))] | |
else: | |
for x in range(0, img.size[0], img.size[0]//samples): | |
a_s += [img.getpixel((x,inset))] | |
b_s += [img.getpixel((x,img.size[1]-1-inset))] | |
a_v = statistics.variance([sum(i)//3 for i in a_s]) | |
b_v = statistics.variance([sum(i)//3 for i in b_s]) | |
a_t = (0,0,0) | |
b_t = (0,0,0) | |
for i in a_s: | |
a_t = tuple(map(operator.add, a_t, i)) | |
for i in b_s: | |
b_t = tuple(map(operator.add, b_t, i)) | |
a_t = tuple(map(operator.floordiv, a_t, (samples, samples, samples))) | |
b_t = tuple(map(operator.floordiv, b_t, (samples, samples, samples))) | |
return a_t, b_t, a_v, b_v | |
def contain(w, h, mw, mh): | |
r = w/h | |
dr = mw/mh | |
if r != dr: | |
if r > dr: | |
nh = int(h / w * mw) | |
if nh != h: | |
mh = nh | |
else: | |
nw = int(w / h * mh) | |
if nw != w: | |
mw = nw | |
return mw, mh | |
def position(w, h, margin=0.1, dim=512): | |
d = int(dim * (1+margin)) | |
o = int((dim-d)/2) | |
if w < h: | |
w, h = contain(w, h, dim, d) | |
oo = int((dim-w)/2) | |
return oo, o, w, h | |
else: | |
w, h = contain(w, h, d, dim) | |
oo = int((dim-h)/2) | |
return o, oo, w, h | |
def letterbox(w, h, margin=0.1, dim=512): | |
ix,iy,iw,ih = position(w,h,margin,dim) | |
i = 3 | |
if w < h: | |
return ((0, 0, ix+i, dim), (ix+iw-i, 0, ix+i, dim)) | |
else: | |
return ((0, 0, dim, iy+i), (0, iy+ih-i, dim, iy+i)) | |
def process(in_img, out_img, out_mask, margin=0.0, dim=512): | |
img = Image.open(in_img).convert('RGB') | |
iw, ih = img.size[0], img.size[1] | |
x,y,w,h = position(iw, ih, margin, dim) | |
crop = Image.new(mode='RGB',size=(dim,dim)) | |
a_v, b_v = 0, 0 | |
if w < h: | |
a,b,a_v,b_v = get_edge_colors(img, True) | |
img = img.resize((w, h)) | |
crop.paste(img, (x, y)) | |
ImageDraw.floodfill(crop, (0,0), a) | |
ImageDraw.floodfill(crop, (dim-1,0), b) | |
else: | |
a,b,a_v,b_v = get_edge_colors(img, False) | |
img = img.resize((w, h)) | |
crop.paste(img, (x, y)) | |
ImageDraw.floodfill(crop, (0,0), a) | |
ImageDraw.floodfill(crop, (0,dim-1), b) | |
crop.save(out_img) | |
mask = Image.new(mode='RGB',size=(dim,dim)) | |
mask_d = ImageDraw.Draw(mask) | |
mask_d.rectangle(((0,0), (dim, dim)), fill="black") | |
a_l, b_l = letterbox(iw, ih, margin, dim) | |
if a_v > 100: | |
x,y,w,h = a_l | |
mask_d.rectangle(((x,y), (x+w,y+h)), fill="white") | |
if b_v > 100: | |
x,y,w,h = b_l | |
mask_d.rectangle(((x,y), (x+w,y+h)), fill="white") | |
mask = mask.filter(ImageFilter.GaussianBlur(radius=1)) | |
mask.save(out_mask) | |
def do_letterbox(indir, tmpdir): | |
files = [] | |
for e in ext: | |
files += glob.glob(os.path.join(indir, f"*.{e}")) | |
if len(files) == 0: | |
print("NO DATA FOUND! PUT IMAGES IN squarize/input") | |
return | |
for f in files: | |
in_f = f | |
out_f = f.replace(indir, tmpdir) | |
mask_f = f.replace(indir + "/", tmpdir + "/mask_") | |
if not skip_existing or not (os.path.exists(out_f) and os.path.exists(mask_f)): | |
print("PROCESS", in_f) | |
process(in_f, out_f, mask_f) | |
else: | |
print("SKIP", in_f) | |
def make_batch(image, mask, device): | |
image = np.array(Image.open(image).convert("RGB")) | |
image = image.astype(np.float32)/255.0 | |
image = image[None].transpose(0,3,1,2) | |
image = torch.from_numpy(image) | |
mask = np.array(Image.open(mask).convert("L")) | |
mask = mask.astype(np.float32)/255.0 | |
mask = mask[None,None] | |
mask[mask < 0.5] = 0 | |
mask[mask >= 0.5] = 1 | |
mask = torch.from_numpy(mask) | |
masked_image = (1-mask)*image | |
batch = {"image": image, "mask": mask, "masked_image": masked_image} | |
for k in batch: | |
batch[k] = batch[k].to(device=device) | |
batch[k] = batch[k]*2.0-1.0 | |
return batch | |
def do_inpaint(tmpdir, outdir, steps): | |
images = [] | |
masks = [] | |
for e in ext: | |
masks += sorted(glob.glob(os.path.join(tmpdir, f"mask_*.{e}"))) | |
images += [x.replace("mask_", "") for x in masks] | |
print(f"Found {len(masks)} inputs.") | |
config = OmegaConf.load(model_path + "config.yaml") | |
model = instantiate_from_config(config.model) | |
model.load_state_dict(torch.load(model_path + "inpainting_big.ckpt")["state_dict"], | |
strict=False) | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
model = model.to(device) | |
sampler = DDIMSampler(model) | |
os.makedirs(outdir, exist_ok=True) | |
with torch.no_grad(): | |
with model.ema_scope(): | |
for image, mask in tqdm(zip(images, masks)): | |
outpath = os.path.join(outdir, os.path.split(image)[1]) | |
batch = make_batch(image, mask, device=device) | |
# encode masked image and concat downsampled mask | |
c = model.cond_stage_model.encode(batch["masked_image"]) | |
cc = torch.nn.functional.interpolate(batch["mask"], | |
size=c.shape[-2:]) | |
c = torch.cat((c, cc), dim=1) | |
shape = (c.shape[1]-1,)+c.shape[2:] | |
samples_ddim, _ = sampler.sample(S=steps, | |
conditioning=c, | |
batch_size=c.shape[0], | |
shape=shape, | |
verbose=False) | |
x_samples_ddim = model.decode_first_stage(samples_ddim) | |
image = torch.clamp((batch["image"]+1.0)/2.0, | |
min=0.0, max=1.0) | |
mask = torch.clamp((batch["mask"]+1.0)/2.0, | |
min=0.0, max=1.0) | |
predicted_image = torch.clamp((x_samples_ddim+1.0)/2.0, | |
min=0.0, max=1.0) | |
inpainted = (1-mask)*image+mask*predicted_image | |
inpainted = inpainted.cpu().numpy().transpose(0,2,3,1)[0]*255 | |
Image.fromarray(inpainted.astype(np.uint8)).save(outpath) | |
do_letterbox("squarize/input", "squarize/tmp") | |
do_inpaint("squarize/tmp", "squarize/output", STEPS) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment