Skip to content

Instantly share code, notes, and snippets.

@cloneofsimo
Created August 8, 2024 08:54
Show Gist options
  • Save cloneofsimo/5830ed223b94d55f1609270d129c623b to your computer and use it in GitHub Desktop.
Save cloneofsimo/5830ed223b94d55f1609270d129c623b to your computer and use it in GitHub Desktop.
bucket
import os
import torch
import json
from PIL import Image
from torch.utils.data import Dataset
from diffusers.models import AutoencoderKL
from streaming import MDSWriter
import logging
import time
import numpy as np
from typing import Any
import json
from streaming.base.format.mds.encodings import Encoding, _encodings
from tqdm import tqdm
from torch.utils.data import DataLoader
import webdataset as wds
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import argparse
from torchvision import transforms
import glob
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoProcessor, SiglipModel
import re
# Initialize logging
logging.basicConfig(level=logging.INFO)
def modify_caption(caption: str) -> str:
"""
Removes common prefix substrings from CogVLM outputs.
Args:
caption (str): A string containing a cogvlm caption.
Returns:
str: The caption with the prefix substring removed
or altered if it was present.
"""
try:
base_words = ['showcases ', 'portrays ', 'appears to be ', 'is ', 'depicts ', 'features ', 'displays: ']
prefix_substrings = [("The image " + s, '') for s in base_words] + [("This image " + s, '') for s in base_words]
prefix_substrings += [("In this " + s, '') for s in ["picture, ", "depiction, ", "piece, ", "image, ", "scene, "]]
prefix_substrings += [
('In this artwork, ', 'Artwork of '),
('In this illustration, ', 'Illustration of '),
('In this art piece, ', 'Art of ')
]
pattern = '|'.join([re.escape(opening) for opening, _ in prefix_substrings])
replacers = {opening: replacer for opening, replacer in prefix_substrings}
def replace_fn(match):
return replacers[match.group(0)]
return re.sub(pattern, replace_fn, caption, count=1, flags=re.IGNORECASE).capitalize()
except:
return caption
def modify_caption(caption: str) -> str:
cap = caption.replace("This image displays", '').strip()
if cap.startswith(':'):
cap = cap[1:]
return cap.strip()
class uint8(Encoding):
def encode(self, obj: Any) -> bytes:
return obj.tobytes()
def decode(self, data: bytes) -> Any:
return np.frombuffer(data, np.uint8)
class np16(Encoding):
def encode(self, obj: Any) -> bytes:
return obj.tobytes()
def decode(self, data: bytes) -> Any:
return np.frombuffer(data, np.float16)
_encodings["np16"] = np16
_encodings["uint8"] = uint8
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
)
small_288 = transforms.Compose(
[
transforms.Resize(288),
transforms.CenterCrop(288),
transforms.ToTensor(),
normalize,
]
)
def crop_to_center(image, new_size=768):
width, height = image.size
left = (width - new_size) / 2
top = (height - new_size) / 2
right = (width + new_size) / 2
bottom = (height + new_size) / 2
return image.crop((left, top, right, bottom))
def prepare_image(pil_image):
arr = np.array(pil_image.convert("RGB"))
arr = arr.astype(np.float32) / 127.5 - 1
arr = np.transpose(arr, [2, 0, 1])
image = torch.from_numpy(arr)
return image
def center_crop_to_nearest_multiple_of_64(image):
w, h = image.size
# if w, h is too large, resize by factor of n
if w > 1024 or h > 1024:
while w * h > 1024 * 1024 * 2:
w = w // 2
h = h // 2
image = image.resize((w, h))
new_w = w - w % 64
new_h = h - h % 64
left = (w - new_w) // 2
top = (h - new_h) // 2
right = left + new_w
bottom = top + new_h
return image.crop((left, top, right, bottom))
def wds_preprocess(x):
key, pil_image, _json = x
pil_image = pil_image.convert("RGB")
pil_image = center_crop_to_nearest_multiple_of_64(pil_image)
image_for_vae = prepare_image(pil_image)
image_for_sscd = small_288(pil_image)
#print(_json)
caption = _json["caption"]
uid = _json.get("uid", None)
if uid is None:
uid = key
watermark_class = _json.get("watermark_class_id", 1)
if watermark_class is None:
watermark_class = 0
est = _json.get("aesthetic_score", None)
if est is None:
est = 100
# print(_json) # {'uid': '95ff62922b3536189768bcc883598109', 'clip_b32_similarity_score': 0.299560546875, 'clip_l14_similarity_score': 0.3017578125, 'caption': 'Picture of Car seat 0+ cover Little Goose', 'url': 'https://dealers.little-dutch.com/content/images/thumbs/002/0023598_1000.jpeg', 'key': '000010028', 'status': 'success', 'error_message': None, 'width': None, 'height': None, 'original_width': None, 'original_height': None, 'exif': '{}', 'sha512': 'b1e6f78d70b10645f54682c5cb01a8ba9584f6e34b4f292b431350fa93e94060'}
return ([pil_image], caption, image_for_sscd, uid, watermark_class, est)
COLUMNS = {
"key": "str",
"caption": "str",
"vae_latents": "np16",
"t5_xl_embeddings": "uint8",
"sscd_embeddings": "np16",
"hps_score": "str",
"width_height": "str",
}
prompt_counter = {}
@torch.no_grad()
def convert_to_mds(dataset_paths, out_roots, device, is_test=False):
logging.info(f"Processing on {device}")
vae_model = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
)
vae_model = vae_model.to(device).eval()
vae_model.to(memory_format=torch.channels_last)
t5tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pile-t5-xl", use_fast=False)
t5tokenizer.pad_token = t5tokenizer.bos_token
t5model = AutoModelForSeq2SeqLM.from_pretrained("EleutherAI/pile-t5-xl", torch_dtype=torch.bfloat16)
t5model = t5model.to(device).eval()
siglip_model = SiglipModel.from_pretrained("google/siglip-large-patch16-256").to(device).eval()
siglip_processor = AutoProcessor.from_pretrained("google/siglip-large-patch16-256")
#t5model.encoder = torch.compile(t5model.encoder, mode='reduce-overhead')
sscd_model = torch.jit.load("sscd_disc_mixup.torchscript.pt").to("cuda")
num_chunk = 1
dataset_bulks = [
dataset_paths[i : i + num_chunk] for i in range(0, len(dataset_paths), num_chunk)
]
out_roots_bulks = [out_roots[i : i + num_chunk] for i in range(0, len(out_roots), num_chunk)]
for dataset_paths, out_roots in zip(dataset_bulks, out_roots_bulks):
for dataset_path in dataset_paths:
if not os.path.exists(dataset_path):
logging.info(f"Dataset not found: {dataset_path}")
return
out_root = out_roots[0]
def filter_cond(sample):
_, img, info = sample
prompt = info.get("caption")
is_cand = True
if prompt not in prompt_counter:
prompt_counter[prompt] = 1
else:
prompt_counter[prompt] += 1
if prompt_counter[prompt] > 5:
is_cand = False
return (img.size[0] * img.size[1] >= 1024 * 512) and is_cand
dataset = wds.DataPipeline(
wds.SimpleShardList(dataset_paths),
wds.split_by_worker,
wds.tarfile_to_samples(handler=wds.warn_and_continue),
wds.decode("pil", handler=wds.warn_and_continue),
wds.to_tuple("__key__", "jpg;png", "json", handler=wds.warn_and_continue),
wds.select(filter_cond),
wds.map(wds_preprocess),
wds.batched(64),
)
dataloader = DataLoader(
dataset,
batch_size=None,
num_workers=2,
shuffle=False,
drop_last=False,
)
t0 = time.time()
localbatches_datasets = {}
# iterate over dataset and for make bucket-dataset for each resolutions.
print("Start Making Bucket Batching")
for idx, batch in tqdm(enumerate(dataloader)):
pil_image, captions, image_for_sscd, uids, watermark_class, est = batch
for i in range(len(captions)):
w, h = pil_image[i][0].size
#print(w, h)
if w * h <= 512 * 512:
continue
if (w, h) not in localbatches_datasets:
localbatches_datasets[(w, h)] = []
localbatches_datasets[(w, h)].append((pil_image[i][0], captions[i], image_for_sscd[i], uids[i], watermark_class[i], est[i]))
print("End Making Bucket Batching")
for k in localbatches_datasets.keys():
print(f"Resolution {k} has {len(localbatches_datasets[k])} samples")
for (w, h), localbatches in localbatches_datasets.items():
# if its too small continue
if len(localbatches) < 32:
continue
print(f"Processing resolution {w}x{h}")
out_root_res = os.path.join(out_root, f"{w}_{h}")
sub_data_root = os.path.join(out_root_res, f"data")
if os.path.exists(sub_data_root):
for file in os.listdir(sub_data_root):
os.remove(os.path.join(sub_data_root, file))
os.makedirs(sub_data_root, exist_ok=True)
with MDSWriter(out=sub_data_root, columns=COLUMNS) as out:
BS = int(64 * 512 * 512 / (w * h))
def collate_fn(batch):
pil_image, captions, image_for_sscd, uids, watermark_class, est = zip(*batch)
return pil_image, captions, image_for_sscd, uids, watermark_class, est
local_dl = DataLoader(localbatches, batch_size=BS, num_workers=2, prefetch_factor=4, shuffle=False, drop_last=False, collate_fn=collate_fn)
for pil_images, captions, image_for_sscd, uids, watermark_class, est in local_dl:
image_for_vaes = [prepare_image(pil) for pil in pil_images]
image_for_vae = torch.stack(image_for_vaes).to(device).half()
### SSCD
image_for_sscd = torch.stack(image_for_sscd).to(device)
image_for_sscd = image_for_sscd.to(
"cuda", memory_format=torch.channels_last
)
sscd_embeddings = sscd_model(image_for_sscd)
sscd_embeddings = sscd_embeddings.cpu().numpy().astype(np.float16)
### VAE
image_for_vae = image_for_vae.to(device).half()
vae_latents = vae_model.encode(image_for_vae).latent_dist.sample()
vae_outputs = vae_latents.cpu().numpy().astype(np.float16)
### T5
t5_inputs = t5tokenizer(
captions,
return_tensors="pt",
padding=True,
truncation=True,
max_length=256,
)
t5_inputs = {k: v.to(device) for k, v in t5_inputs.items()}
t5_outputs = t5model.encoder(**t5_inputs)[0]
mask = (
t5_inputs["attention_mask"].unsqueeze(-1).expand(t5_outputs.shape)
)
t5_outputs = t5_outputs * mask
t5_outputs = (
((t5_outputs.clip(-0.25, 0.25) / 0.5 + 0.5) * 255.0)
.to(torch.uint8)
.cpu()
.numpy()
.astype(np.uint8)
)
# SigLIP
siglip_inputs = siglip_processor(
text=captions,
images=pil_images,
return_tensors="pt",
padding='max_length',
truncation=True,
).to(device)
siglip_outputs = siglip_model(**siglip_inputs)
siglip_text_embeddings = siglip_outputs.text_embeds.cpu().numpy().astype(np.float16)
siglip_image_embeddings = siglip_outputs.image_embeds.cpu().numpy().astype(np.float16)
# elementwise cos similarity
# normalize first
siglip_similarities = np.einsum("ij,ij->i", siglip_text_embeddings, siglip_image_embeddings)
### Write
for i in range(len(captions)):
if siglip_similarities[i] < 0.06:
continue
sample = {
"vae_latents": vae_outputs[i],
"caption": str(captions[i]),
"t5_xl_embeddings": t5_outputs[i],
"sscd_embeddings": sscd_embeddings[i],
"key": uids[i],
"hps_score" : str(est),
"width_height": f"{w}x{h}",
}
out.write(sample)
def main(datasetinfos, out_roots, is_test=False, device_name="cuda"):
device = torch.device(device_name if torch.cuda.is_available() else "cpu")
print(f"Processing on {device}")
convert_to_mds(datasetinfos, out_roots, device, is_test=is_test)
logging.info("Finished processing images.")
def detect_small_or_nonexistent_dirs(current_dir, start=0, end=18503, max_size=512):
small_or_nonexistent_dirs = []
for i in range(start, end + 1):
dir_name = f"{i:05d}"
dir_path = os.path.join(current_dir, dir_name)
if not os.path.exists(dir_path):
if i%64 < 8:
small_or_nonexistent_dirs.append(i)
elif os.path.isdir(dir_path):
total_size = 0
for dirpath, dirnames, filenames in os.walk(dir_path):
for f in filenames:
fp = os.path.join(dirpath, f)
total_size += os.path.getsize(fp)
if total_size < max_size:
small_or_nonexistent_dirs.append(i)
return small_or_nonexistent_dirs
def save_to_json(data, filename):
with open(filename, "w") as f:
json.dump(data, f)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert images to MDS format.")
parser.add_argument(
"--device",
type=str,
default="cuda",
help="Device to use for processing (cuda or cpu).",
)
parser.add_argument(
"--file_index", type=int, default=0, help="File index to process."
)
parser.add_argument(
"--is_test", action="store_true", help="Run in test mode with reduced dataset."
)
parser.add_argument(
"--outdir_basepath",
type=str,
default="/home/ubuntu/bucketmds",
help="Output directory path.",
)
parser.add_argument(
"--tar_indir_basepath",
type=str,
default="/home/ubuntu/wds_deduped_0.3m_highres",
help="Input directory path.",
)
args = parser.parse_args()
# reqsids = json.load(open("{outdir_basepath}/small_or_nonexistent_dirs.json"))
# number of tars
tars = glob.glob(f"{args.tar_indir_basepath}/*.tar")
reqsids = range(len(tars))
#reqsids = detect_small_or_nonexistent_dirs(args.outdir_basepath, start=0, end=18503, max_size=512)
# reqsids = [16516 - 64, 16516, 16580, 16644, 16708, 16772, 16836, 16900, 16964, 17028, 17092, 17156, 17220, 17284, 17348, 17412, 17476, 17540, 17604, 17668, 17732, 17796, 17860, 17924, 17988, 18052, 18116, 18180, 18244, 18308, 18372, 18436, 18500]
#print(reqsids)
out_roots, datasetinfos = [], []
for i, reqid in enumerate(reqsids):
if i % 8 == args.file_index:
out_root = f"{args.outdir_basepath}/{str(int(reqid)).zfill(5)}"
dataset_path = f"{args.tar_indir_basepath}/{str(int(reqid)).zfill(5)}.tar"
out_roots.append(out_root)
datasetinfos.append(dataset_path)
main(datasetinfos, out_roots, is_test=args.is_test, device_name=args.device)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment