Last active
October 4, 2022 22:45
-
-
Save arenasys/fc0a9352bbbbc365f1201c97ee1e605c to your computer and use it in GitHub Desktop.
scrapes booru's into training data for waifu diffusion
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python | |
import os | |
import re | |
import json | |
import random | |
from PIL import Image, ImageDraw, ImageOps | |
from multiprocessing import Pool | |
import subprocess | |
import sys | |
import urllib.request | |
import shutil | |
import operator | |
### INSTALL | |
# get PIL: | |
# pip install pillow | |
# | |
# get gallery-dl: | |
# available in linux repos (AUR etc) | |
# on windows this script will download gallery-dl.exe automatically | |
GALLERY_DL_EXE_URL = "https://github.com/mikf/gallery-dl/releases/download/v1.23.1/gallery-dl.exe" | |
### USAGE | |
### DOWNLOAD | |
# download from URL, most search and post urls will work: | |
# python arenas_downloader.py "https://gelbooru.com/index.php?page=post&s=list&tags=white_panties+panties_under_pantyhose+-rating%3aexplicit" | |
# | |
# download from URL with user/pass (needed for some boorus) | |
# python arenas_downloader.py USER PASS "https://chan.sankakucomplex.com/?tags=black_hair+dress+polka_dot_panties+-rating%3Aexplicit&commit=Search" | |
### PROCESS | |
# process into training data: | |
# python arenas_downloader.py -p | |
# produces the img and txt folders, as needed by waifu diffusion training | |
### GROUPS | |
# group operations | |
# there will usually be exessive groups of images (20+) that look almost exactly the same. | |
# these commands let you find and delete/prune these groups (*only tested/working with sankaku) | |
# python arenas_downloader.py -ge | |
# extract groups into the groups folder | |
# python arenas_downloader.py -gc | |
# compiles the groups back into the data folder | |
# so run -ge, delete/prune the groups, run -gc, then run -p | |
### OTHER | |
# view help: | |
# python arenas_downloader.py -h | |
### CONFIG | |
#these get put at the front | |
IMPORTANT_TAGS = ["female", "male"] | |
#if you want to map tags to different values, applies to substrings | |
TAG_MAP = {"1girl": "one girl ", "2girls": "two girls ", "3girls": "three girls ", | |
"1boy": "one boy ", "2boys": "two boys ", "3boys": "three boys ", | |
"1futa": "one futa ", "2futa": "two futa "} | |
#so "1boy2girls" becomes "one boy two girls" etc | |
#tags that are deleted | |
DEAD_TAGS = ["character", "rule_63", "tagme"] | |
#threads to use, set to CPU count | |
THREADS = 12 | |
#increasing will reduce the amount of letterboxing, but increase | |
# the chance of important bits being cropped (like heads) | |
LETTERBOX_MARGIN=0.15 # 15% of the edge can be removed | |
#jpg or png | |
SAVE_EXT = "jpg" | |
### CONSTANTS | |
CONFIG = '{\n\t"extractor":{\n\t\t"base-directory":"./",\n\t\t"filename": "{id}.{extension}",\n\t\t"directory": ["data"],\n\t\t"pixiv": {"tags": "translated"}\n\t}\n}' | |
HELP = """USAGE: | |
python arenas_downloader.py USERNAME PASSWORD \"URL\" | |
python arenas_downloader.py \"URL\" | |
python arenas_downloader.py -p | |
python arenas_downloader.py -ge | |
python arenas_downloader.py -gc | |
SEE TOP OF CODE FOR DETAILS | |
run `gallery-dl oauth:pixiv` and follow instructions to login to pixiv (normal username/password doesnt work)""" | |
BRACKET_REGEX = re.compile("[\s_]?\([^\)]+\)") | |
PATH_REGEX = re.compile(".\w+.json") | |
SAMPLES=16 | |
def normalize_tag(tag): | |
tag = BRACKET_REGEX.sub("", tag) | |
tag = tag.replace(".", "") | |
tag = tag.replace("'", "") | |
tag = tag.replace("_", " ") | |
return tag.strip() | |
def booru_to_normalized(json): | |
booru = json["category"] | |
tags = [] | |
prefix = [] | |
if booru == "e621": | |
for sub in json["tags"]: | |
if sub == "meta" or sub == "invalid": | |
continue | |
if sub == "artist" or sub == "character": | |
prefix += json["tags"][sub] | |
continue | |
tags += json["tags"][sub] | |
elif booru == "danbooru": | |
tags = json["tag_string"].split() | |
elif booru == "pixiv": | |
tags = [t for t in json["tags"] if not "bookmarks" in t] | |
elif type(json["tags"]) == list: | |
tags = json["tags"] | |
elif type(json["tags"]) == str: | |
tags = json["tags"].split() | |
else: | |
print("UNSUPPORTED METADATA") | |
exit() | |
random.shuffle(tags) | |
#put prefix first | |
tags = prefix + tags | |
#map tags | |
for i in range(len(tags)): | |
for t in TAG_MAP: | |
tags[i] = tags[i].replace(t, TAG_MAP[t]) | |
#put the important tags first | |
for i in IMPORTANT_TAGS: | |
if i in tags: | |
tags = [i] + [t for t in tags if t != i] | |
#normalize and delete tags | |
tags = [normalize_tag(tag) for tag in tags if not tag in DEAD_TAGS] | |
tag_string = ", ".join(tags) | |
return tag_string | |
def write_tags(tags, path): | |
with open(path, 'w') as f: | |
f.write(tags) | |
def extract_tags(json_file, txt_file): | |
with open(json_file, 'r') as f: | |
metadata = json.load(f) | |
write_tags(booru_to_normalized(metadata), txt_file) | |
def add_to_groups(groups, json, file): | |
id = json["id"] | |
group = id | |
if "parent_id" in json and json["parent_id"]: | |
group = json["parent_id"] | |
if not group in groups: | |
groups[group] = [] | |
groups[group] = groups[group] + [file] | |
def extract_groups(): | |
groups = {} | |
for file in os.listdir("data"): | |
if file.endswith(".json"): | |
json_file = os.path.join("data", file) | |
with open(json_file, 'r') as f: | |
metadata = json.load(f) | |
add_to_groups(groups, metadata, json_file) | |
for g in groups: | |
z = len(groups[g]) | |
if z < 3: | |
continue | |
folder = f"groups/{z}/{g}" | |
os.makedirs(folder, exist_ok=True) | |
for f in groups[g]: | |
os.rename(f, f.replace("data", folder)) | |
f = f.replace(".json", "") | |
os.rename(f, f.replace("data", folder)) | |
print("DONE, groups is populated") | |
def compile_groups(): | |
group_sizes = next(os.walk("groups"))[1] | |
for size in group_sizes: | |
size_folder = os.path.join("groups", f"{size}") | |
groups = next(os.walk(size_folder))[1] | |
for g in groups: | |
group_folder = os.path.join(size_folder, f"{g}") | |
for file in os.listdir(group_folder): | |
group_file = os.path.join(group_folder, file) | |
data_file = os.path.join("data", file) | |
os.rename(group_file, data_file) | |
print("DONE, groups is now empty") | |
def get_edge_colors(img, vertical, inset=3, samples=16): | |
a = (0,0,0) | |
b = (0,0,0) | |
if vertical: | |
for y in range(0, img.size[1], img.size[1]//samples): | |
a = tuple(map(operator.add, a, img.getpixel((inset,y)))) | |
b = tuple(map(operator.add, b, img.getpixel((img.size[0]-1-inset,y)))) | |
else: | |
for x in range(0, img.size[0], img.size[0]//samples): | |
a = tuple(map(operator.add, a, img.getpixel((x,inset)))) | |
b = tuple(map(operator.add, b, img.getpixel((x,img.size[1]-1-inset)))) | |
a = tuple(map(operator.floordiv, a, (samples, samples, samples))) | |
b = tuple(map(operator.floordiv, b, (samples, samples, samples))) | |
return a, b | |
def square(in_img, out_img, margin=0.1, dim=512): | |
img = Image.open(in_img).convert('RGB') | |
d = int(dim * (1+margin)) | |
o = int((dim-d)/2) | |
crop = Image.new(mode='RGB',size=(dim,dim)) | |
if img.size[0] < img.size[1]: | |
a,b = get_edge_colors(img, True) | |
img = ImageOps.pad(img, (dim, d)) | |
oo = int((dim-img.size[0])/2) | |
crop.paste(img, (oo, o)) | |
ImageDraw.floodfill(crop, (0,0), a) | |
ImageDraw.floodfill(crop, (dim-1,0), b) | |
else: | |
a,b = get_edge_colors(img, False) | |
img = ImageOps.pad(img, (d, dim)) | |
oo = int((dim-img.size[1])/2) | |
crop.paste(img, (o, oo)) | |
ImageDraw.floodfill(crop, (0,0), a) | |
ImageDraw.floodfill(crop, (0,dim-1), b) | |
crop.save(out_img) | |
done = 0 | |
total = 0 | |
def process_single(file): | |
global done | |
json_file = os.path.join("data", file) | |
full_file = os.path.join("data", file.replace(".json", "")) | |
img_file = os.path.join("img", PATH_REGEX.sub(f".{SAVE_EXT}", file)) | |
if not os.path.isfile(full_file): | |
return | |
try: | |
square(full_file, img_file, LETTERBOX_MARGIN) | |
txt_file = img_file.replace("img", "txt").replace(f"{SAVE_EXT}", "txt") | |
extract_tags(json_file, txt_file) | |
except Exception as e: | |
print(e) | |
print(f"cannot crop {full_file} -> {img_file}, skipping...") | |
done += 1 | |
if total: | |
print(f"[{done/total*100:.2f}%] {full_file}") | |
def process_all(): | |
global total | |
os.makedirs("txt", exist_ok=True) | |
os.makedirs("img", exist_ok=True) | |
files = [] | |
for file in os.listdir("data"): | |
if file.endswith(".json"): | |
files += [file] | |
#guess that each thread will handle an equal number of files | |
total = len(files)//THREADS | |
print("PROCESSING") | |
with Pool(THREADS) as p: | |
p.map(process_single, files) | |
print("DONE") | |
def download(url, user=None, password=None): | |
if not os.path.isfile("gallery-dl.conf"): | |
config = CONFIG | |
with open("gallery-dl.conf", "w") as f: | |
f.write(config) | |
if shutil.which("gallery-dl") is None: | |
print("gallery-dl not found") | |
if os.name == "nt": | |
print("downloading gallery-dl...") | |
urllib.request.urlretrieve(GALLERY_DL_EXE_URL, "gallery-dl.exe") | |
else: | |
exit() | |
args = ["gallery-dl", "--write-metadata", "-c", "gallery-dl.conf"] | |
if user and password: | |
args += ["-u", user, "-p", password] | |
args += [url] | |
print("DOWNLOADING") | |
proc = subprocess.Popen(args, shell=False) | |
proc.communicate() | |
print("DONE") | |
if __name__ == '__main__': | |
if len(sys.argv) == 1 or (len(sys.argv) == 2 and sys.argv[1] == "-h"): | |
print(HELP) | |
elif len(sys.argv) == 2 and sys.argv[1] == "-p": | |
process_all() | |
elif len(sys.argv) == 2 and sys.argv[1] == "-ge": | |
extract_groups() | |
elif len(sys.argv) == 2 and sys.argv[1] == "-gc": | |
compile_groups() | |
elif len(sys.argv) == 4: | |
download(sys.argv[3], sys.argv[1], sys.argv[2]) | |
elif len(sys.argv) == 2: | |
download(sys.argv[1]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment