Skip to content

Instantly share code, notes, and snippets.

@arenasys
Created October 11, 2022 11:12
Show Gist options
  • Save arenasys/e7df1162fee7183ac26d059b85462ef7 to your computer and use it in GitHub Desktop.
Save arenasys/e7df1162fee7183ac26d059b85462ef7 to your computer and use it in GitHub Desktop.
squarize images via LD infilling
import argparse, os, sys, glob
import urllib.request
import statistics
# put this script in stable-diffusion-webui and run it: python squarize.py
# the model will be downloaded and the squarize folder will be created.
# all images in squarize/input get processed and inpainted, ending up in squarize/output
# *stable-diffusion-webui needs to be working
STEPS = 50 #edit this
def include(path):
path = os.path.abspath(path)
sys.path.append(path)
def download(url, location):
urllib.request.urlretrieve(url, location)
os.makedirs("squarize/input", exist_ok=True)
os.makedirs("squarize/output", exist_ok=True)
os.makedirs("squarize/tmp", exist_ok=True)
include("repositories/stable-diffusion")
include("repositories/taming-transformers")
model_path = "repositories/stable-diffusion/models/ldm/inpainting_big/"
if not os.path.exists(model_path+"inpainting_big.ckpt"):
print("Downloading inpainting model (3.3GB)...")
download("https://heibox.uni-heidelberg.de/f/4d9ac7ea40c64582b7c9/?dl=1", model_path+"inpainting_big.ckpt")
from omegaconf import OmegaConf
from tqdm import tqdm
import numpy as np
import torch
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from PIL import Image, ImageDraw, ImageOps, ImageFilter
import operator
ext = ["png", "jpg", "jpeg", "webp"]
skip_existing = False
def get_edge_colors(img, vertical, inset=0.01, samples=16):
a = (0,0,0)
b = (0,0,0)
inset = int(max(img.size[0], img.size[1])*0.01)
a_s = []
b_s = []
if vertical:
for y in range(0, img.size[1], img.size[1]//samples):
a_s += [img.getpixel((inset,y))]
b_s += [img.getpixel((img.size[0]-1-inset,y))]
else:
for x in range(0, img.size[0], img.size[0]//samples):
a_s += [img.getpixel((x,inset))]
b_s += [img.getpixel((x,img.size[1]-1-inset))]
a_v = statistics.variance([sum(i)//3 for i in a_s])
b_v = statistics.variance([sum(i)//3 for i in b_s])
a_t = (0,0,0)
b_t = (0,0,0)
for i in a_s:
a_t = tuple(map(operator.add, a_t, i))
for i in b_s:
b_t = tuple(map(operator.add, b_t, i))
a_t = tuple(map(operator.floordiv, a_t, (samples, samples, samples)))
b_t = tuple(map(operator.floordiv, b_t, (samples, samples, samples)))
return a_t, b_t, a_v, b_v
def contain(w, h, mw, mh):
r = w/h
dr = mw/mh
if r != dr:
if r > dr:
nh = int(h / w * mw)
if nh != h:
mh = nh
else:
nw = int(w / h * mh)
if nw != w:
mw = nw
return mw, mh
def position(w, h, margin=0.1, dim=512):
d = int(dim * (1+margin))
o = int((dim-d)/2)
if w < h:
w, h = contain(w, h, dim, d)
oo = int((dim-w)/2)
return oo, o, w, h
else:
w, h = contain(w, h, d, dim)
oo = int((dim-h)/2)
return o, oo, w, h
def letterbox(w, h, margin=0.1, dim=512):
ix,iy,iw,ih = position(w,h,margin,dim)
i = 3
if w < h:
return ((0, 0, ix+i, dim), (ix+iw-i, 0, ix+i, dim))
else:
return ((0, 0, dim, iy+i), (0, iy+ih-i, dim, iy+i))
def process(in_img, out_img, out_mask, margin=0.0, dim=512):
img = Image.open(in_img).convert('RGB')
iw, ih = img.size[0], img.size[1]
x,y,w,h = position(iw, ih, margin, dim)
crop = Image.new(mode='RGB',size=(dim,dim))
a_v, b_v = 0, 0
if w < h:
a,b,a_v,b_v = get_edge_colors(img, True)
img = img.resize((w, h))
crop.paste(img, (x, y))
ImageDraw.floodfill(crop, (0,0), a)
ImageDraw.floodfill(crop, (dim-1,0), b)
else:
a,b,a_v,b_v = get_edge_colors(img, False)
img = img.resize((w, h))
crop.paste(img, (x, y))
ImageDraw.floodfill(crop, (0,0), a)
ImageDraw.floodfill(crop, (0,dim-1), b)
crop.save(out_img)
mask = Image.new(mode='RGB',size=(dim,dim))
mask_d = ImageDraw.Draw(mask)
mask_d.rectangle(((0,0), (dim, dim)), fill="black")
a_l, b_l = letterbox(iw, ih, margin, dim)
if a_v > 100:
x,y,w,h = a_l
mask_d.rectangle(((x,y), (x+w,y+h)), fill="white")
if b_v > 100:
x,y,w,h = b_l
mask_d.rectangle(((x,y), (x+w,y+h)), fill="white")
mask = mask.filter(ImageFilter.GaussianBlur(radius=1))
mask.save(out_mask)
def do_letterbox(indir, tmpdir):
files = []
for e in ext:
files += glob.glob(os.path.join(indir, f"*.{e}"))
if len(files) == 0:
print("NO DATA FOUND! PUT IMAGES IN squarize/input")
return
for f in files:
in_f = f
out_f = f.replace(indir, tmpdir)
mask_f = f.replace(indir + "/", tmpdir + "/mask_")
if not skip_existing or not (os.path.exists(out_f) and os.path.exists(mask_f)):
print("PROCESS", in_f)
process(in_f, out_f, mask_f)
else:
print("SKIP", in_f)
def make_batch(image, mask, device):
image = np.array(Image.open(image).convert("RGB"))
image = image.astype(np.float32)/255.0
image = image[None].transpose(0,3,1,2)
image = torch.from_numpy(image)
mask = np.array(Image.open(mask).convert("L"))
mask = mask.astype(np.float32)/255.0
mask = mask[None,None]
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)
masked_image = (1-mask)*image
batch = {"image": image, "mask": mask, "masked_image": masked_image}
for k in batch:
batch[k] = batch[k].to(device=device)
batch[k] = batch[k]*2.0-1.0
return batch
def do_inpaint(tmpdir, outdir, steps):
images = []
masks = []
for e in ext:
masks += sorted(glob.glob(os.path.join(tmpdir, f"mask_*.{e}")))
images += [x.replace("mask_", "") for x in masks]
print(f"Found {len(masks)} inputs.")
config = OmegaConf.load(model_path + "config.yaml")
model = instantiate_from_config(config.model)
model.load_state_dict(torch.load(model_path + "inpainting_big.ckpt")["state_dict"],
strict=False)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
sampler = DDIMSampler(model)
os.makedirs(outdir, exist_ok=True)
with torch.no_grad():
with model.ema_scope():
for image, mask in tqdm(zip(images, masks)):
outpath = os.path.join(outdir, os.path.split(image)[1])
batch = make_batch(image, mask, device=device)
# encode masked image and concat downsampled mask
c = model.cond_stage_model.encode(batch["masked_image"])
cc = torch.nn.functional.interpolate(batch["mask"],
size=c.shape[-2:])
c = torch.cat((c, cc), dim=1)
shape = (c.shape[1]-1,)+c.shape[2:]
samples_ddim, _ = sampler.sample(S=steps,
conditioning=c,
batch_size=c.shape[0],
shape=shape,
verbose=False)
x_samples_ddim = model.decode_first_stage(samples_ddim)
image = torch.clamp((batch["image"]+1.0)/2.0,
min=0.0, max=1.0)
mask = torch.clamp((batch["mask"]+1.0)/2.0,
min=0.0, max=1.0)
predicted_image = torch.clamp((x_samples_ddim+1.0)/2.0,
min=0.0, max=1.0)
inpainted = (1-mask)*image+mask*predicted_image
inpainted = inpainted.cpu().numpy().transpose(0,2,3,1)[0]*255
Image.fromarray(inpainted.astype(np.uint8)).save(outpath)
do_letterbox("squarize/input", "squarize/tmp")
do_inpaint("squarize/tmp", "squarize/output", STEPS)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment