Created
August 8, 2024 08:54
-
-
Save cloneofsimo/5830ed223b94d55f1609270d129c623b to your computer and use it in GitHub Desktop.
bucket
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import torch | |
import json | |
from PIL import Image | |
from torch.utils.data import Dataset | |
from diffusers.models import AutoencoderKL | |
from streaming import MDSWriter | |
import logging | |
import time | |
import numpy as np | |
from typing import Any | |
import json | |
from streaming.base.format.mds.encodings import Encoding, _encodings | |
from tqdm import tqdm | |
from torch.utils.data import DataLoader | |
import webdataset as wds | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import argparse | |
from torchvision import transforms | |
import glob | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoProcessor, SiglipModel | |
import re | |
# Initialize logging | |
logging.basicConfig(level=logging.INFO) | |
def modify_caption(caption: str) -> str: | |
""" | |
Removes common prefix substrings from CogVLM outputs. | |
Args: | |
caption (str): A string containing a cogvlm caption. | |
Returns: | |
str: The caption with the prefix substring removed | |
or altered if it was present. | |
""" | |
try: | |
base_words = ['showcases ', 'portrays ', 'appears to be ', 'is ', 'depicts ', 'features ', 'displays: '] | |
prefix_substrings = [("The image " + s, '') for s in base_words] + [("This image " + s, '') for s in base_words] | |
prefix_substrings += [("In this " + s, '') for s in ["picture, ", "depiction, ", "piece, ", "image, ", "scene, "]] | |
prefix_substrings += [ | |
('In this artwork, ', 'Artwork of '), | |
('In this illustration, ', 'Illustration of '), | |
('In this art piece, ', 'Art of ') | |
] | |
pattern = '|'.join([re.escape(opening) for opening, _ in prefix_substrings]) | |
replacers = {opening: replacer for opening, replacer in prefix_substrings} | |
def replace_fn(match): | |
return replacers[match.group(0)] | |
return re.sub(pattern, replace_fn, caption, count=1, flags=re.IGNORECASE).capitalize() | |
except: | |
return caption | |
def modify_caption(caption: str) -> str: | |
cap = caption.replace("This image displays", '').strip() | |
if cap.startswith(':'): | |
cap = cap[1:] | |
return cap.strip() | |
class uint8(Encoding): | |
def encode(self, obj: Any) -> bytes: | |
return obj.tobytes() | |
def decode(self, data: bytes) -> Any: | |
return np.frombuffer(data, np.uint8) | |
class np16(Encoding): | |
def encode(self, obj: Any) -> bytes: | |
return obj.tobytes() | |
def decode(self, data: bytes) -> Any: | |
return np.frombuffer(data, np.float16) | |
_encodings["np16"] = np16 | |
_encodings["uint8"] = uint8 | |
normalize = transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225], | |
) | |
small_288 = transforms.Compose( | |
[ | |
transforms.Resize(288), | |
transforms.CenterCrop(288), | |
transforms.ToTensor(), | |
normalize, | |
] | |
) | |
def crop_to_center(image, new_size=768): | |
width, height = image.size | |
left = (width - new_size) / 2 | |
top = (height - new_size) / 2 | |
right = (width + new_size) / 2 | |
bottom = (height + new_size) / 2 | |
return image.crop((left, top, right, bottom)) | |
def prepare_image(pil_image): | |
arr = np.array(pil_image.convert("RGB")) | |
arr = arr.astype(np.float32) / 127.5 - 1 | |
arr = np.transpose(arr, [2, 0, 1]) | |
image = torch.from_numpy(arr) | |
return image | |
def center_crop_to_nearest_multiple_of_64(image): | |
w, h = image.size | |
# if w, h is too large, resize by factor of n | |
if w > 1024 or h > 1024: | |
while w * h > 1024 * 1024 * 2: | |
w = w // 2 | |
h = h // 2 | |
image = image.resize((w, h)) | |
new_w = w - w % 64 | |
new_h = h - h % 64 | |
left = (w - new_w) // 2 | |
top = (h - new_h) // 2 | |
right = left + new_w | |
bottom = top + new_h | |
return image.crop((left, top, right, bottom)) | |
def wds_preprocess(x): | |
key, pil_image, _json = x | |
pil_image = pil_image.convert("RGB") | |
pil_image = center_crop_to_nearest_multiple_of_64(pil_image) | |
image_for_vae = prepare_image(pil_image) | |
image_for_sscd = small_288(pil_image) | |
#print(_json) | |
caption = _json["caption"] | |
uid = _json.get("uid", None) | |
if uid is None: | |
uid = key | |
watermark_class = _json.get("watermark_class_id", 1) | |
if watermark_class is None: | |
watermark_class = 0 | |
est = _json.get("aesthetic_score", None) | |
if est is None: | |
est = 100 | |
# print(_json) # {'uid': '95ff62922b3536189768bcc883598109', 'clip_b32_similarity_score': 0.299560546875, 'clip_l14_similarity_score': 0.3017578125, 'caption': 'Picture of Car seat 0+ cover Little Goose', 'url': 'https://dealers.little-dutch.com/content/images/thumbs/002/0023598_1000.jpeg', 'key': '000010028', 'status': 'success', 'error_message': None, 'width': None, 'height': None, 'original_width': None, 'original_height': None, 'exif': '{}', 'sha512': 'b1e6f78d70b10645f54682c5cb01a8ba9584f6e34b4f292b431350fa93e94060'} | |
return ([pil_image], caption, image_for_sscd, uid, watermark_class, est) | |
COLUMNS = { | |
"key": "str", | |
"caption": "str", | |
"vae_latents": "np16", | |
"t5_xl_embeddings": "uint8", | |
"sscd_embeddings": "np16", | |
"hps_score": "str", | |
"width_height": "str", | |
} | |
prompt_counter = {} | |
@torch.no_grad() | |
def convert_to_mds(dataset_paths, out_roots, device, is_test=False): | |
logging.info(f"Processing on {device}") | |
vae_model = AutoencoderKL.from_pretrained( | |
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16 | |
) | |
vae_model = vae_model.to(device).eval() | |
vae_model.to(memory_format=torch.channels_last) | |
t5tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pile-t5-xl", use_fast=False) | |
t5tokenizer.pad_token = t5tokenizer.bos_token | |
t5model = AutoModelForSeq2SeqLM.from_pretrained("EleutherAI/pile-t5-xl", torch_dtype=torch.bfloat16) | |
t5model = t5model.to(device).eval() | |
siglip_model = SiglipModel.from_pretrained("google/siglip-large-patch16-256").to(device).eval() | |
siglip_processor = AutoProcessor.from_pretrained("google/siglip-large-patch16-256") | |
#t5model.encoder = torch.compile(t5model.encoder, mode='reduce-overhead') | |
sscd_model = torch.jit.load("sscd_disc_mixup.torchscript.pt").to("cuda") | |
num_chunk = 1 | |
dataset_bulks = [ | |
dataset_paths[i : i + num_chunk] for i in range(0, len(dataset_paths), num_chunk) | |
] | |
out_roots_bulks = [out_roots[i : i + num_chunk] for i in range(0, len(out_roots), num_chunk)] | |
for dataset_paths, out_roots in zip(dataset_bulks, out_roots_bulks): | |
for dataset_path in dataset_paths: | |
if not os.path.exists(dataset_path): | |
logging.info(f"Dataset not found: {dataset_path}") | |
return | |
out_root = out_roots[0] | |
def filter_cond(sample): | |
_, img, info = sample | |
prompt = info.get("caption") | |
is_cand = True | |
if prompt not in prompt_counter: | |
prompt_counter[prompt] = 1 | |
else: | |
prompt_counter[prompt] += 1 | |
if prompt_counter[prompt] > 5: | |
is_cand = False | |
return (img.size[0] * img.size[1] >= 1024 * 512) and is_cand | |
dataset = wds.DataPipeline( | |
wds.SimpleShardList(dataset_paths), | |
wds.split_by_worker, | |
wds.tarfile_to_samples(handler=wds.warn_and_continue), | |
wds.decode("pil", handler=wds.warn_and_continue), | |
wds.to_tuple("__key__", "jpg;png", "json", handler=wds.warn_and_continue), | |
wds.select(filter_cond), | |
wds.map(wds_preprocess), | |
wds.batched(64), | |
) | |
dataloader = DataLoader( | |
dataset, | |
batch_size=None, | |
num_workers=2, | |
shuffle=False, | |
drop_last=False, | |
) | |
t0 = time.time() | |
localbatches_datasets = {} | |
# iterate over dataset and for make bucket-dataset for each resolutions. | |
print("Start Making Bucket Batching") | |
for idx, batch in tqdm(enumerate(dataloader)): | |
pil_image, captions, image_for_sscd, uids, watermark_class, est = batch | |
for i in range(len(captions)): | |
w, h = pil_image[i][0].size | |
#print(w, h) | |
if w * h <= 512 * 512: | |
continue | |
if (w, h) not in localbatches_datasets: | |
localbatches_datasets[(w, h)] = [] | |
localbatches_datasets[(w, h)].append((pil_image[i][0], captions[i], image_for_sscd[i], uids[i], watermark_class[i], est[i])) | |
print("End Making Bucket Batching") | |
for k in localbatches_datasets.keys(): | |
print(f"Resolution {k} has {len(localbatches_datasets[k])} samples") | |
for (w, h), localbatches in localbatches_datasets.items(): | |
# if its too small continue | |
if len(localbatches) < 32: | |
continue | |
print(f"Processing resolution {w}x{h}") | |
out_root_res = os.path.join(out_root, f"{w}_{h}") | |
sub_data_root = os.path.join(out_root_res, f"data") | |
if os.path.exists(sub_data_root): | |
for file in os.listdir(sub_data_root): | |
os.remove(os.path.join(sub_data_root, file)) | |
os.makedirs(sub_data_root, exist_ok=True) | |
with MDSWriter(out=sub_data_root, columns=COLUMNS) as out: | |
BS = int(64 * 512 * 512 / (w * h)) | |
def collate_fn(batch): | |
pil_image, captions, image_for_sscd, uids, watermark_class, est = zip(*batch) | |
return pil_image, captions, image_for_sscd, uids, watermark_class, est | |
local_dl = DataLoader(localbatches, batch_size=BS, num_workers=2, prefetch_factor=4, shuffle=False, drop_last=False, collate_fn=collate_fn) | |
for pil_images, captions, image_for_sscd, uids, watermark_class, est in local_dl: | |
image_for_vaes = [prepare_image(pil) for pil in pil_images] | |
image_for_vae = torch.stack(image_for_vaes).to(device).half() | |
### SSCD | |
image_for_sscd = torch.stack(image_for_sscd).to(device) | |
image_for_sscd = image_for_sscd.to( | |
"cuda", memory_format=torch.channels_last | |
) | |
sscd_embeddings = sscd_model(image_for_sscd) | |
sscd_embeddings = sscd_embeddings.cpu().numpy().astype(np.float16) | |
### VAE | |
image_for_vae = image_for_vae.to(device).half() | |
vae_latents = vae_model.encode(image_for_vae).latent_dist.sample() | |
vae_outputs = vae_latents.cpu().numpy().astype(np.float16) | |
### T5 | |
t5_inputs = t5tokenizer( | |
captions, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=256, | |
) | |
t5_inputs = {k: v.to(device) for k, v in t5_inputs.items()} | |
t5_outputs = t5model.encoder(**t5_inputs)[0] | |
mask = ( | |
t5_inputs["attention_mask"].unsqueeze(-1).expand(t5_outputs.shape) | |
) | |
t5_outputs = t5_outputs * mask | |
t5_outputs = ( | |
((t5_outputs.clip(-0.25, 0.25) / 0.5 + 0.5) * 255.0) | |
.to(torch.uint8) | |
.cpu() | |
.numpy() | |
.astype(np.uint8) | |
) | |
# SigLIP | |
siglip_inputs = siglip_processor( | |
text=captions, | |
images=pil_images, | |
return_tensors="pt", | |
padding='max_length', | |
truncation=True, | |
).to(device) | |
siglip_outputs = siglip_model(**siglip_inputs) | |
siglip_text_embeddings = siglip_outputs.text_embeds.cpu().numpy().astype(np.float16) | |
siglip_image_embeddings = siglip_outputs.image_embeds.cpu().numpy().astype(np.float16) | |
# elementwise cos similarity | |
# normalize first | |
siglip_similarities = np.einsum("ij,ij->i", siglip_text_embeddings, siglip_image_embeddings) | |
### Write | |
for i in range(len(captions)): | |
if siglip_similarities[i] < 0.06: | |
continue | |
sample = { | |
"vae_latents": vae_outputs[i], | |
"caption": str(captions[i]), | |
"t5_xl_embeddings": t5_outputs[i], | |
"sscd_embeddings": sscd_embeddings[i], | |
"key": uids[i], | |
"hps_score" : str(est), | |
"width_height": f"{w}x{h}", | |
} | |
out.write(sample) | |
def main(datasetinfos, out_roots, is_test=False, device_name="cuda"): | |
device = torch.device(device_name if torch.cuda.is_available() else "cpu") | |
print(f"Processing on {device}") | |
convert_to_mds(datasetinfos, out_roots, device, is_test=is_test) | |
logging.info("Finished processing images.") | |
def detect_small_or_nonexistent_dirs(current_dir, start=0, end=18503, max_size=512): | |
small_or_nonexistent_dirs = [] | |
for i in range(start, end + 1): | |
dir_name = f"{i:05d}" | |
dir_path = os.path.join(current_dir, dir_name) | |
if not os.path.exists(dir_path): | |
if i%64 < 8: | |
small_or_nonexistent_dirs.append(i) | |
elif os.path.isdir(dir_path): | |
total_size = 0 | |
for dirpath, dirnames, filenames in os.walk(dir_path): | |
for f in filenames: | |
fp = os.path.join(dirpath, f) | |
total_size += os.path.getsize(fp) | |
if total_size < max_size: | |
small_or_nonexistent_dirs.append(i) | |
return small_or_nonexistent_dirs | |
def save_to_json(data, filename): | |
with open(filename, "w") as f: | |
json.dump(data, f) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Convert images to MDS format.") | |
parser.add_argument( | |
"--device", | |
type=str, | |
default="cuda", | |
help="Device to use for processing (cuda or cpu).", | |
) | |
parser.add_argument( | |
"--file_index", type=int, default=0, help="File index to process." | |
) | |
parser.add_argument( | |
"--is_test", action="store_true", help="Run in test mode with reduced dataset." | |
) | |
parser.add_argument( | |
"--outdir_basepath", | |
type=str, | |
default="/home/ubuntu/bucketmds", | |
help="Output directory path.", | |
) | |
parser.add_argument( | |
"--tar_indir_basepath", | |
type=str, | |
default="/home/ubuntu/wds_deduped_0.3m_highres", | |
help="Input directory path.", | |
) | |
args = parser.parse_args() | |
# reqsids = json.load(open("{outdir_basepath}/small_or_nonexistent_dirs.json")) | |
# number of tars | |
tars = glob.glob(f"{args.tar_indir_basepath}/*.tar") | |
reqsids = range(len(tars)) | |
#reqsids = detect_small_or_nonexistent_dirs(args.outdir_basepath, start=0, end=18503, max_size=512) | |
# reqsids = [16516 - 64, 16516, 16580, 16644, 16708, 16772, 16836, 16900, 16964, 17028, 17092, 17156, 17220, 17284, 17348, 17412, 17476, 17540, 17604, 17668, 17732, 17796, 17860, 17924, 17988, 18052, 18116, 18180, 18244, 18308, 18372, 18436, 18500] | |
#print(reqsids) | |
out_roots, datasetinfos = [], [] | |
for i, reqid in enumerate(reqsids): | |
if i % 8 == args.file_index: | |
out_root = f"{args.outdir_basepath}/{str(int(reqid)).zfill(5)}" | |
dataset_path = f"{args.tar_indir_basepath}/{str(int(reqid)).zfill(5)}.tar" | |
out_roots.append(out_root) | |
datasetinfos.append(dataset_path) | |
main(datasetinfos, out_roots, is_test=args.is_test, device_name=args.device) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment