Skip to content

Instantly share code, notes, and snippets.

@arenasys
Last active October 4, 2022 22:45
Show Gist options
  • Save arenasys/fc0a9352bbbbc365f1201c97ee1e605c to your computer and use it in GitHub Desktop.
Save arenasys/fc0a9352bbbbc365f1201c97ee1e605c to your computer and use it in GitHub Desktop.
scrapes booru's into training data for waifu diffusion
#!/usr/bin/python
import os
import re
import json
import random
from PIL import Image, ImageDraw, ImageOps
from multiprocessing import Pool
import subprocess
import sys
import urllib.request
import shutil
import operator
### INSTALL
# get PIL:
# pip install pillow
#
# get gallery-dl:
# available in linux repos (AUR etc)
# on windows this script will download gallery-dl.exe automatically
GALLERY_DL_EXE_URL = "https://github.com/mikf/gallery-dl/releases/download/v1.23.1/gallery-dl.exe"
### USAGE
### DOWNLOAD
# download from URL, most search and post urls will work:
# python arenas_downloader.py "https://gelbooru.com/index.php?page=post&s=list&tags=white_panties+panties_under_pantyhose+-rating%3aexplicit"
#
# download from URL with user/pass (needed for some boorus)
# python arenas_downloader.py USER PASS "https://chan.sankakucomplex.com/?tags=black_hair+dress+polka_dot_panties+-rating%3Aexplicit&commit=Search"
### PROCESS
# process into training data:
# python arenas_downloader.py -p
# produces the img and txt folders, as needed by waifu diffusion training
### GROUPS
# group operations
# there will usually be exessive groups of images (20+) that look almost exactly the same.
# these commands let you find and delete/prune these groups (*only tested/working with sankaku)
# python arenas_downloader.py -ge
# extract groups into the groups folder
# python arenas_downloader.py -gc
# compiles the groups back into the data folder
# so run -ge, delete/prune the groups, run -gc, then run -p
### OTHER
# view help:
# python arenas_downloader.py -h
### CONFIG
#these get put at the front
IMPORTANT_TAGS = ["female", "male"]
#if you want to map tags to different values, applies to substrings
TAG_MAP = {"1girl": "one girl ", "2girls": "two girls ", "3girls": "three girls ",
"1boy": "one boy ", "2boys": "two boys ", "3boys": "three boys ",
"1futa": "one futa ", "2futa": "two futa "}
#so "1boy2girls" becomes "one boy two girls" etc
#tags that are deleted
DEAD_TAGS = ["character", "rule_63", "tagme"]
#threads to use, set to CPU count
THREADS = 12
#increasing will reduce the amount of letterboxing, but increase
# the chance of important bits being cropped (like heads)
LETTERBOX_MARGIN=0.15 # 15% of the edge can be removed
#jpg or png
SAVE_EXT = "jpg"
### CONSTANTS
CONFIG = '{\n\t"extractor":{\n\t\t"base-directory":"./",\n\t\t"filename": "{id}.{extension}",\n\t\t"directory": ["data"],\n\t\t"pixiv": {"tags": "translated"}\n\t}\n}'
HELP = """USAGE:
python arenas_downloader.py USERNAME PASSWORD \"URL\"
python arenas_downloader.py \"URL\"
python arenas_downloader.py -p
python arenas_downloader.py -ge
python arenas_downloader.py -gc
SEE TOP OF CODE FOR DETAILS
run `gallery-dl oauth:pixiv` and follow instructions to login to pixiv (normal username/password doesnt work)"""
BRACKET_REGEX = re.compile("[\s_]?\([^\)]+\)")
PATH_REGEX = re.compile(".\w+.json")
SAMPLES=16
def normalize_tag(tag):
tag = BRACKET_REGEX.sub("", tag)
tag = tag.replace(".", "")
tag = tag.replace("'", "")
tag = tag.replace("_", " ")
return tag.strip()
def booru_to_normalized(json):
booru = json["category"]
tags = []
prefix = []
if booru == "e621":
for sub in json["tags"]:
if sub == "meta" or sub == "invalid":
continue
if sub == "artist" or sub == "character":
prefix += json["tags"][sub]
continue
tags += json["tags"][sub]
elif booru == "danbooru":
tags = json["tag_string"].split()
elif booru == "pixiv":
tags = [t for t in json["tags"] if not "bookmarks" in t]
elif type(json["tags"]) == list:
tags = json["tags"]
elif type(json["tags"]) == str:
tags = json["tags"].split()
else:
print("UNSUPPORTED METADATA")
exit()
random.shuffle(tags)
#put prefix first
tags = prefix + tags
#map tags
for i in range(len(tags)):
for t in TAG_MAP:
tags[i] = tags[i].replace(t, TAG_MAP[t])
#put the important tags first
for i in IMPORTANT_TAGS:
if i in tags:
tags = [i] + [t for t in tags if t != i]
#normalize and delete tags
tags = [normalize_tag(tag) for tag in tags if not tag in DEAD_TAGS]
tag_string = ", ".join(tags)
return tag_string
def write_tags(tags, path):
with open(path, 'w') as f:
f.write(tags)
def extract_tags(json_file, txt_file):
with open(json_file, 'r') as f:
metadata = json.load(f)
write_tags(booru_to_normalized(metadata), txt_file)
def add_to_groups(groups, json, file):
id = json["id"]
group = id
if "parent_id" in json and json["parent_id"]:
group = json["parent_id"]
if not group in groups:
groups[group] = []
groups[group] = groups[group] + [file]
def extract_groups():
groups = {}
for file in os.listdir("data"):
if file.endswith(".json"):
json_file = os.path.join("data", file)
with open(json_file, 'r') as f:
metadata = json.load(f)
add_to_groups(groups, metadata, json_file)
for g in groups:
z = len(groups[g])
if z < 3:
continue
folder = f"groups/{z}/{g}"
os.makedirs(folder, exist_ok=True)
for f in groups[g]:
os.rename(f, f.replace("data", folder))
f = f.replace(".json", "")
os.rename(f, f.replace("data", folder))
print("DONE, groups is populated")
def compile_groups():
group_sizes = next(os.walk("groups"))[1]
for size in group_sizes:
size_folder = os.path.join("groups", f"{size}")
groups = next(os.walk(size_folder))[1]
for g in groups:
group_folder = os.path.join(size_folder, f"{g}")
for file in os.listdir(group_folder):
group_file = os.path.join(group_folder, file)
data_file = os.path.join("data", file)
os.rename(group_file, data_file)
print("DONE, groups is now empty")
def get_edge_colors(img, vertical, inset=3, samples=16):
a = (0,0,0)
b = (0,0,0)
if vertical:
for y in range(0, img.size[1], img.size[1]//samples):
a = tuple(map(operator.add, a, img.getpixel((inset,y))))
b = tuple(map(operator.add, b, img.getpixel((img.size[0]-1-inset,y))))
else:
for x in range(0, img.size[0], img.size[0]//samples):
a = tuple(map(operator.add, a, img.getpixel((x,inset))))
b = tuple(map(operator.add, b, img.getpixel((x,img.size[1]-1-inset))))
a = tuple(map(operator.floordiv, a, (samples, samples, samples)))
b = tuple(map(operator.floordiv, b, (samples, samples, samples)))
return a, b
def square(in_img, out_img, margin=0.1, dim=512):
img = Image.open(in_img).convert('RGB')
d = int(dim * (1+margin))
o = int((dim-d)/2)
crop = Image.new(mode='RGB',size=(dim,dim))
if img.size[0] < img.size[1]:
a,b = get_edge_colors(img, True)
img = ImageOps.pad(img, (dim, d))
oo = int((dim-img.size[0])/2)
crop.paste(img, (oo, o))
ImageDraw.floodfill(crop, (0,0), a)
ImageDraw.floodfill(crop, (dim-1,0), b)
else:
a,b = get_edge_colors(img, False)
img = ImageOps.pad(img, (d, dim))
oo = int((dim-img.size[1])/2)
crop.paste(img, (o, oo))
ImageDraw.floodfill(crop, (0,0), a)
ImageDraw.floodfill(crop, (0,dim-1), b)
crop.save(out_img)
done = 0
total = 0
def process_single(file):
global done
json_file = os.path.join("data", file)
full_file = os.path.join("data", file.replace(".json", ""))
img_file = os.path.join("img", PATH_REGEX.sub(f".{SAVE_EXT}", file))
if not os.path.isfile(full_file):
return
try:
square(full_file, img_file, LETTERBOX_MARGIN)
txt_file = img_file.replace("img", "txt").replace(f"{SAVE_EXT}", "txt")
extract_tags(json_file, txt_file)
except Exception as e:
print(e)
print(f"cannot crop {full_file} -> {img_file}, skipping...")
done += 1
if total:
print(f"[{done/total*100:.2f}%] {full_file}")
def process_all():
global total
os.makedirs("txt", exist_ok=True)
os.makedirs("img", exist_ok=True)
files = []
for file in os.listdir("data"):
if file.endswith(".json"):
files += [file]
#guess that each thread will handle an equal number of files
total = len(files)//THREADS
print("PROCESSING")
with Pool(THREADS) as p:
p.map(process_single, files)
print("DONE")
def download(url, user=None, password=None):
if not os.path.isfile("gallery-dl.conf"):
config = CONFIG
with open("gallery-dl.conf", "w") as f:
f.write(config)
if shutil.which("gallery-dl") is None:
print("gallery-dl not found")
if os.name == "nt":
print("downloading gallery-dl...")
urllib.request.urlretrieve(GALLERY_DL_EXE_URL, "gallery-dl.exe")
else:
exit()
args = ["gallery-dl", "--write-metadata", "-c", "gallery-dl.conf"]
if user and password:
args += ["-u", user, "-p", password]
args += [url]
print("DOWNLOADING")
proc = subprocess.Popen(args, shell=False)
proc.communicate()
print("DONE")
if __name__ == '__main__':
if len(sys.argv) == 1 or (len(sys.argv) == 2 and sys.argv[1] == "-h"):
print(HELP)
elif len(sys.argv) == 2 and sys.argv[1] == "-p":
process_all()
elif len(sys.argv) == 2 and sys.argv[1] == "-ge":
extract_groups()
elif len(sys.argv) == 2 and sys.argv[1] == "-gc":
compile_groups()
elif len(sys.argv) == 4:
download(sys.argv[3], sys.argv[1], sys.argv[2])
elif len(sys.argv) == 2:
download(sys.argv[1])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment