|
# Find images with content hidden by alpha channel |
|
# Ref https://github.com/kohya-ss/sd-scripts/issues/1269 |
|
# City96 | 2024 |
|
import os |
|
import torch |
|
import torchvision |
|
from tqdm import tqdm |
|
from torch.multiprocessing import Pool |
|
|
|
# Path to source image dataset |
|
SRC_PATH = r"Z:\booru\Test\images" |
|
# Path to copy found images to. 'None' to disable |
|
DST_PATH = r"E:\booru\Test\hidden" |
|
# Path to log found files to, 'None' to disable |
|
LOG_PATH = r"E:\booru\Test\hidden.csv" |
|
# minimum background deviation to detect |
|
MIN_STD = 0.1 |
|
# max images to check. 'None' to disable |
|
LIMIT = None |
|
|
|
def get_target_files(src_path, exts=[".png"], limit=LIMIT): |
|
""" |
|
Load list of files to check. |
|
""" |
|
paths = [] |
|
for root, _, files in os.walk(src_path): |
|
for fname in files: |
|
name, ext = os.path.splitext(fname) |
|
if ext.lower() not in exts: |
|
continue |
|
paths.append( |
|
os.path.join(root, fname) |
|
) |
|
if limit and len(paths) >= limit: |
|
return paths |
|
return paths |
|
|
|
def read_image(img_path, dims=[4]): |
|
""" |
|
Load image if it has the required number of dimensions |
|
""" |
|
try: |
|
img = torchvision.io.read_image(img_path) |
|
except: |
|
return None |
|
else: |
|
if img.shape[0] in dims: |
|
return img |
|
else: |
|
return None |
|
|
|
def check_image(img_path, dst_path=DST_PATH, threshold=MIN_STD): |
|
""" |
|
Check if background has content |
|
""" |
|
# load image & verify alpha |
|
raw = read_image(img_path) |
|
if raw is None: |
|
return None |
|
|
|
# separate layers |
|
image = raw[:3] / 255.0 |
|
alpha = raw[3:].repeat(3, 1, 1) |
|
|
|
# convert to hard edges and invert |
|
alpha = ~(alpha > 0.0) |
|
|
|
# apply to image and get background std |
|
bgdev = torch.std( |
|
image * alpha |
|
) |
|
|
|
# save if applicable |
|
if bgdev > threshold: |
|
if dst_path: |
|
fname = os.path.basename(img_path) |
|
path = os.path.join(dst_path, fname) |
|
if not os.path.isfile(path): |
|
torchvision.utils.save_image(image, path) |
|
return (img_path, float(bgdev)) |
|
return None |
|
|
|
if __name__ == "__main__": |
|
print("Parsing paths") |
|
paths = get_target_files(SRC_PATH) |
|
os.makedirs(DST_PATH, exist_ok=True) |
|
|
|
print (f"Checking {len(paths)} images") |
|
# [check_image(x) for x in paths] # debug |
|
pool = Pool(4) |
|
imap = pool.imap(check_image, paths) |
|
found = [x for x in tqdm(imap, total=len(paths)) if x] |
|
pool.close() |
|
|
|
print(f"Found {len(found)} images.") |
|
if LOG_PATH: |
|
with open(LOG_PATH, "w", encoding="utf-8") as f: |
|
f.write(f"name,conf\n") |
|
for name, conf in found: |
|
f.write(f"{name},{round(conf,4)}\n") |