Created
July 11, 2024 04:43
-
-
Save cloneofsimo/d53e102d1b2d4d9a5a00f507e330147b to your computer and use it in GitHub Desktop.
lol
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import torch | |
| import json | |
| from PIL import Image | |
| from torch.utils.data import Dataset | |
| from diffusers.models import AutoencoderKL | |
| from streaming import MDSWriter | |
| import logging | |
| import time | |
| import numpy as np | |
| from typing import Any | |
| import json | |
| from streaming.base.format.mds.encodings import Encoding, _encodings | |
| from tqdm import tqdm | |
| from torch.utils.data import DataLoader | |
| import webdataset as wds | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| import argparse | |
| from torchvision import transforms | |
| import glob | |
| import re | |
| # Initialize logging | |
| logging.basicConfig(level=logging.INFO) | |
| def modify_caption(caption: str) -> str: | |
| """ | |
| Removes common prefix substrings from CogVLM outputs. | |
| Args: | |
| caption (str): A string containing a cogvlm caption. | |
| Returns: | |
| str: The caption with the prefix substring removed | |
| or altered if it was present. | |
| """ | |
| try: | |
| base_words = ['showcases ', 'portrays ', 'appears to be ', 'is ', 'depicts ', 'features ', 'displays: '] | |
| prefix_substrings = [("The image " + s, '') for s in base_words] + [("This image " + s, '') for s in base_words] | |
| prefix_substrings += [("In this " + s, '') for s in ["picture, ", "depiction, ", "piece, ", "image, ", "scene, "]] | |
| prefix_substrings += [ | |
| ('In this artwork, ', 'Artwork of '), | |
| ('In this illustration, ', 'Illustration of '), | |
| ('In this art piece, ', 'Art of ') | |
| ] | |
| pattern = '|'.join([re.escape(opening) for opening, _ in prefix_substrings]) | |
| replacers = {opening: replacer for opening, replacer in prefix_substrings} | |
| def replace_fn(match): | |
| return replacers[match.group(0)] | |
| return re.sub(pattern, replace_fn, caption, count=1, flags=re.IGNORECASE).capitalize() | |
| except: | |
| return caption | |
| def modify_caption(caption: str) -> str: | |
| cap = caption.replace("This image displays", '').strip() | |
| if cap.startswith(':'): | |
| cap = cap[1:] | |
| return cap.strip() | |
| class uint8(Encoding): | |
| def encode(self, obj: Any) -> bytes: | |
| return obj.tobytes() | |
| def decode(self, data: bytes) -> Any: | |
| return np.frombuffer(data, np.uint8) | |
| class np16(Encoding): | |
| def encode(self, obj: Any) -> bytes: | |
| return obj.tobytes() | |
| def decode(self, data: bytes) -> Any: | |
| return np.frombuffer(data, np.float16) | |
| _encodings["np16"] = np16 | |
| _encodings["uint8"] = uint8 | |
| normalize = transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225], | |
| ) | |
| small_288 = transforms.Compose( | |
| [ | |
| transforms.Resize(288), | |
| transforms.CenterCrop(288), | |
| transforms.ToTensor(), | |
| normalize, | |
| ] | |
| ) | |
| def crop_to_center(image, new_size=768): | |
| width, height = image.size | |
| left = (width - new_size) / 2 | |
| top = (height - new_size) / 2 | |
| right = (width + new_size) / 2 | |
| bottom = (height + new_size) / 2 | |
| return image.crop((left, top, right, bottom)) | |
| def prepare_image(pil_image): | |
| arr = np.array(pil_image.convert("RGB")) | |
| arr = arr.astype(np.float32) / 127.5 - 1 | |
| arr = np.transpose(arr, [2, 0, 1]) | |
| image = torch.from_numpy(arr) | |
| return image | |
| def center_crop_to_nearest_multiple_of_64(image): | |
| w, h = image.size | |
| # if w, h is too large, resize by factor of n | |
| if w > 1024 or h > 1024: | |
| while w > 1024 or h > 1024: | |
| w = w // 2 | |
| h = h // 2 | |
| image = image.resize((w, h)) | |
| new_w = w - w % 64 | |
| new_h = h - h % 64 | |
| left = (w - new_w) // 2 | |
| top = (h - new_h) // 2 | |
| right = left + new_w | |
| bottom = top + new_h | |
| return image.crop((left, top, right, bottom)) | |
| def wds_preprocess(x): | |
| key, pil_image, _json = x | |
| pil_image = pil_image.convert("RGB") | |
| pil_image = center_crop_to_nearest_multiple_of_64(pil_image) | |
| image_for_vae = prepare_image(pil_image) | |
| image_for_sscd = small_288(pil_image) | |
| #print(_json) | |
| caption = _json["caption"] | |
| uid = _json.get("uid", None) | |
| if uid is None: | |
| uid = key | |
| watermark_class = _json.get("watermark_class_id", 1) | |
| if watermark_class is None: | |
| watermark_class = 0 | |
| est = _json.get("aesthetic_score", None) | |
| if est is None: | |
| est = 100 | |
| # print(_json) # {'uid': '95ff62922b3536189768bcc883598109', 'clip_b32_similarity_score': 0.299560546875, 'clip_l14_similarity_score': 0.3017578125, 'caption': 'Picture of Car seat 0+ cover Little Goose', 'url': 'https://dealers.little-dutch.com/content/images/thumbs/002/0023598_1000.jpeg', 'key': '000010028', 'status': 'success', 'error_message': None, 'width': None, 'height': None, 'original_width': None, 'original_height': None, 'exif': '{}', 'sha512': 'b1e6f78d70b10645f54682c5cb01a8ba9584f6e34b4f292b431350fa93e94060'} | |
| return ([image_for_vae], caption, image_for_sscd, uid, watermark_class, est) | |
| COLUMNS = { | |
| "key": "str", | |
| "caption": "str", | |
| "vae_latents": "np16", | |
| "t5_xl_embeddings": "uint8", | |
| "sscd_embeddings": "np16", | |
| "hps_score": "str", | |
| "width_height": "str", | |
| } | |
| @torch.no_grad() | |
| def convert_to_mds(dataset_paths, out_roots, device, is_test=False): | |
| logging.info(f"Processing on {device}") | |
| vae_model = AutoencoderKL.from_pretrained( | |
| "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16 | |
| ) | |
| vae_model = vae_model.to(device).eval() | |
| vae_model.to(memory_format=torch.channels_last) | |
| t5tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pile-t5-xl", use_fast=False) | |
| t5tokenizer.pad_token = t5tokenizer.bos_token | |
| t5model = AutoModelForSeq2SeqLM.from_pretrained("EleutherAI/pile-t5-xl", torch_dtype=torch.bfloat16) | |
| t5model = t5model.to(device).eval() | |
| #t5model.encoder = torch.compile(t5model.encoder, mode='reduce-overhead') | |
| sscd_model = torch.jit.load("sscd_disc_mixup.torchscript.pt").to("cuda") | |
| num_chunk = 1 | |
| dataset_bulks = [ | |
| dataset_paths[i : i + num_chunk] for i in range(0, len(dataset_paths), num_chunk) | |
| ] | |
| out_roots_bulks = [out_roots[i : i + num_chunk] for i in range(0, len(out_roots), num_chunk)] | |
| for dataset_paths, out_roots in zip(dataset_bulks, out_roots_bulks): | |
| for dataset_path in dataset_paths: | |
| if not os.path.exists(dataset_path): | |
| logging.info(f"Dataset not found: {dataset_path}") | |
| return | |
| out_root = out_roots[0] | |
| dataset = wds.DataPipeline( | |
| wds.SimpleShardList(dataset_paths), | |
| wds.split_by_worker, | |
| wds.tarfile_to_samples(handler=wds.warn_and_continue), | |
| wds.decode("pil", handler=wds.warn_and_continue), | |
| wds.to_tuple("__key__", "jpg;png", "json", handler=wds.warn_and_continue), | |
| wds.map(wds_preprocess), | |
| wds.batched(64), | |
| ) | |
| dataloader = DataLoader( | |
| dataset, | |
| batch_size=None, | |
| num_workers=2, | |
| shuffle=False, | |
| drop_last=False, | |
| ) | |
| t0 = time.time() | |
| localbatches_datasets = {} | |
| # iterate over dataset and for make bucket-dataset for each resolutions. | |
| print("Start Making Bucket Batching") | |
| for idx, batch in tqdm(enumerate(dataloader)): | |
| image_for_vae, captions, image_for_sscd, uids, watermark_class, est = batch | |
| for i in range(len(captions)): | |
| w, h = image_for_vae[i][0].shape[-2:] | |
| if w * h <= 512 * 512: | |
| continue | |
| if (w, h) not in localbatches_datasets: | |
| localbatches_datasets[(w, h)] = [] | |
| localbatches_datasets[(w, h)].append((image_for_vae[i][0], captions[i], image_for_sscd[i], uids[i], watermark_class[i], est[i])) | |
| print("End Making Bucket Batching") | |
| for k in localbatches_datasets.keys(): | |
| print(f"Resolution {k} has {len(localbatches_datasets[k])} samples") | |
| for (w, h), localbatches in localbatches_datasets.items(): | |
| # if its too small continue | |
| if len(localbatches) < 64: | |
| continue | |
| print(f"Processing resolution {w}x{h}") | |
| out_root_res = os.path.join(out_root, f"{w}_{h}") | |
| sub_data_root = os.path.join(out_root_res, f"data") | |
| if os.path.exists(sub_data_root): | |
| for file in os.listdir(sub_data_root): | |
| os.remove(os.path.join(sub_data_root, file)) | |
| os.makedirs(sub_data_root, exist_ok=True) | |
| with MDSWriter(out=sub_data_root, columns=COLUMNS) as out: | |
| BS = int(64 * 512 * 512 / (w * h)) | |
| local_dl = DataLoader(localbatches, batch_size=BS, num_workers=2, prefetch_factor=4, shuffle=False, drop_last=False) | |
| for image_for_vae, captions, image_for_sscd, uids, watermark_class, est in local_dl: | |
| ### SSCD | |
| image_for_sscd = image_for_sscd.to( | |
| "cuda", memory_format=torch.channels_last | |
| ) | |
| sscd_embeddings = sscd_model(image_for_sscd) | |
| sscd_embeddings = sscd_embeddings.cpu().numpy().astype(np.float16) | |
| ### VAE | |
| image_for_vae = image_for_vae.to(device).half() | |
| vae_latents = vae_model.encode(image_for_vae).latent_dist.sample() | |
| vae_outputs = vae_latents.cpu().numpy().astype(np.float16) | |
| ### T5 | |
| t5_inputs = t5tokenizer( | |
| captions, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=256, | |
| ) | |
| t5_inputs = {k: v.to(device) for k, v in t5_inputs.items()} | |
| t5_outputs = t5model.encoder(**t5_inputs)[0] | |
| mask = ( | |
| t5_inputs["attention_mask"].unsqueeze(-1).expand(t5_outputs.shape) | |
| ) | |
| t5_outputs = t5_outputs * mask | |
| t5_outputs = ( | |
| ((t5_outputs.clip(-0.25, 0.25) / 0.5 + 0.5) * 255.0) | |
| .to(torch.uint8) | |
| .cpu() | |
| .numpy() | |
| .astype(np.uint8) | |
| ) | |
| ### Write | |
| for i in range(len(captions)): | |
| sample = { | |
| "vae_latents": vae_outputs[i], | |
| "caption": str(captions[i]), | |
| "t5_xl_embeddings": t5_outputs[i], | |
| "sscd_embeddings": sscd_embeddings[i], | |
| "key": uids[i], | |
| "hps_score" : str(est), | |
| "width_height": f"{w}x{h}", | |
| } | |
| out.write(sample) | |
| def main(datasetinfos, out_roots, is_test=False, device_name="cuda"): | |
| device = torch.device(device_name if torch.cuda.is_available() else "cpu") | |
| print(f"Processing on {device}") | |
| convert_to_mds(datasetinfos, out_roots, device, is_test=is_test) | |
| logging.info("Finished processing images.") | |
| def detect_small_or_nonexistent_dirs(current_dir, start=0, end=18503, max_size=512): | |
| small_or_nonexistent_dirs = [] | |
| for i in range(start, end + 1): | |
| dir_name = f"{i:05d}" | |
| dir_path = os.path.join(current_dir, dir_name) | |
| if not os.path.exists(dir_path): | |
| if i%64 < 8: | |
| small_or_nonexistent_dirs.append(i) | |
| elif os.path.isdir(dir_path): | |
| total_size = 0 | |
| for dirpath, dirnames, filenames in os.walk(dir_path): | |
| for f in filenames: | |
| fp = os.path.join(dirpath, f) | |
| total_size += os.path.getsize(fp) | |
| if total_size < max_size: | |
| small_or_nonexistent_dirs.append(i) | |
| return small_or_nonexistent_dirs | |
| def save_to_json(data, filename): | |
| with open(filename, "w") as f: | |
| json.dump(data, f) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Convert images to MDS format.") | |
| parser.add_argument( | |
| "--device", | |
| type=str, | |
| default="cuda", | |
| help="Device to use for processing (cuda or cpu).", | |
| ) | |
| parser.add_argument( | |
| "--file_index", type=int, default=0, help="File index to process." | |
| ) | |
| parser.add_argument( | |
| "--is_test", action="store_true", help="Run in test mode with reduced dataset." | |
| ) | |
| parser.add_argument( | |
| "--outdir_basepath", | |
| type=str, | |
| default="/home/ubuntu/bucketmds", | |
| help="Output directory path.", | |
| ) | |
| parser.add_argument( | |
| "--tar_indir_basepath", | |
| type=str, | |
| default="/home/ubuntu/wds_deduped_0.3m_highres", | |
| help="Input directory path.", | |
| ) | |
| args = parser.parse_args() | |
| # reqsids = json.load(open("{outdir_basepath}/small_or_nonexistent_dirs.json")) | |
| # number of tars | |
| tars = glob.glob(f"{args.tar_indir_basepath}/*.tar") | |
| reqsids = range(len(tars)) | |
| #reqsids = detect_small_or_nonexistent_dirs(args.outdir_basepath, start=0, end=18503, max_size=512) | |
| # reqsids = [16516 - 64, 16516, 16580, 16644, 16708, 16772, 16836, 16900, 16964, 17028, 17092, 17156, 17220, 17284, 17348, 17412, 17476, 17540, 17604, 17668, 17732, 17796, 17860, 17924, 17988, 18052, 18116, 18180, 18244, 18308, 18372, 18436, 18500] | |
| #print(reqsids) | |
| out_roots, datasetinfos = [], [] | |
| for i, reqid in enumerate(reqsids): | |
| if i % 8 == args.file_index: | |
| out_root = f"{args.outdir_basepath}/{str(int(reqid)).zfill(5)}" | |
| dataset_path = f"{args.tar_indir_basepath}/{str(int(reqid)).zfill(5)}.tar" | |
| out_roots.append(out_root) | |
| datasetinfos.append(dataset_path) | |
| main(datasetinfos, out_roots, is_test=args.is_test, device_name=args.device) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| NUM_GPUS=8 | |
| START_INDEX=0 | |
| END_INDEX=7 | |
| NUM_CPUS=$(nproc) # Get the number of available CPUs | |
| CORES_PER_PROCESS=24 | |
| echo "Number of GPUs: $NUM_GPUS | Number of CPUs: $NUM_CPUS | Cores per process: $CORES_PER_PROCESS" | |
| for ((i=START_INDEX; i<=END_INDEX; i++)); do | |
| GPU_INDEX=$((i % NUM_GPUS)) | |
| CPU_START=$(( (i * CORES_PER_PROCESS) % NUM_CPUS )) | |
| CPU_END=$(( CPU_START + CORES_PER_PROCESS - 1 )) | |
| export CUDA_VISIBLE_DEVICES=$GPU_INDEX | |
| taskset -c $CPU_START-$CPU_END python /home/ubuntu/make_shards/run_featgen_hr_bucket.py --device cuda --file_index $i --tar_indir_basepath "/home/ubuntu/mjwds_deduped_0.3m_highres" --outdir_basepath "/home/ubuntu/mj_bucketmds" & | |
| echo "Started process $i on GPU $GPU_INDEX with CPU $CPU_START-$CPU_END" | |
| done | |
| wait # Wait for all background processes to finish |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment