Skip to content

Instantly share code, notes, and snippets.

@cloneofsimo
Created July 11, 2024 04:43
Show Gist options
  • Select an option

  • Save cloneofsimo/d53e102d1b2d4d9a5a00f507e330147b to your computer and use it in GitHub Desktop.

Select an option

Save cloneofsimo/d53e102d1b2d4d9a5a00f507e330147b to your computer and use it in GitHub Desktop.
lol
import os
import torch
import json
from PIL import Image
from torch.utils.data import Dataset
from diffusers.models import AutoencoderKL
from streaming import MDSWriter
import logging
import time
import numpy as np
from typing import Any
import json
from streaming.base.format.mds.encodings import Encoding, _encodings
from tqdm import tqdm
from torch.utils.data import DataLoader
import webdataset as wds
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import argparse
from torchvision import transforms
import glob
import re
# Initialize logging
logging.basicConfig(level=logging.INFO)
def modify_caption(caption: str) -> str:
"""
Removes common prefix substrings from CogVLM outputs.
Args:
caption (str): A string containing a cogvlm caption.
Returns:
str: The caption with the prefix substring removed
or altered if it was present.
"""
try:
base_words = ['showcases ', 'portrays ', 'appears to be ', 'is ', 'depicts ', 'features ', 'displays: ']
prefix_substrings = [("The image " + s, '') for s in base_words] + [("This image " + s, '') for s in base_words]
prefix_substrings += [("In this " + s, '') for s in ["picture, ", "depiction, ", "piece, ", "image, ", "scene, "]]
prefix_substrings += [
('In this artwork, ', 'Artwork of '),
('In this illustration, ', 'Illustration of '),
('In this art piece, ', 'Art of ')
]
pattern = '|'.join([re.escape(opening) for opening, _ in prefix_substrings])
replacers = {opening: replacer for opening, replacer in prefix_substrings}
def replace_fn(match):
return replacers[match.group(0)]
return re.sub(pattern, replace_fn, caption, count=1, flags=re.IGNORECASE).capitalize()
except:
return caption
def modify_caption(caption: str) -> str:
cap = caption.replace("This image displays", '').strip()
if cap.startswith(':'):
cap = cap[1:]
return cap.strip()
class uint8(Encoding):
def encode(self, obj: Any) -> bytes:
return obj.tobytes()
def decode(self, data: bytes) -> Any:
return np.frombuffer(data, np.uint8)
class np16(Encoding):
def encode(self, obj: Any) -> bytes:
return obj.tobytes()
def decode(self, data: bytes) -> Any:
return np.frombuffer(data, np.float16)
_encodings["np16"] = np16
_encodings["uint8"] = uint8
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
)
small_288 = transforms.Compose(
[
transforms.Resize(288),
transforms.CenterCrop(288),
transforms.ToTensor(),
normalize,
]
)
def crop_to_center(image, new_size=768):
width, height = image.size
left = (width - new_size) / 2
top = (height - new_size) / 2
right = (width + new_size) / 2
bottom = (height + new_size) / 2
return image.crop((left, top, right, bottom))
def prepare_image(pil_image):
arr = np.array(pil_image.convert("RGB"))
arr = arr.astype(np.float32) / 127.5 - 1
arr = np.transpose(arr, [2, 0, 1])
image = torch.from_numpy(arr)
return image
def center_crop_to_nearest_multiple_of_64(image):
w, h = image.size
# if w, h is too large, resize by factor of n
if w > 1024 or h > 1024:
while w > 1024 or h > 1024:
w = w // 2
h = h // 2
image = image.resize((w, h))
new_w = w - w % 64
new_h = h - h % 64
left = (w - new_w) // 2
top = (h - new_h) // 2
right = left + new_w
bottom = top + new_h
return image.crop((left, top, right, bottom))
def wds_preprocess(x):
key, pil_image, _json = x
pil_image = pil_image.convert("RGB")
pil_image = center_crop_to_nearest_multiple_of_64(pil_image)
image_for_vae = prepare_image(pil_image)
image_for_sscd = small_288(pil_image)
#print(_json)
caption = _json["caption"]
uid = _json.get("uid", None)
if uid is None:
uid = key
watermark_class = _json.get("watermark_class_id", 1)
if watermark_class is None:
watermark_class = 0
est = _json.get("aesthetic_score", None)
if est is None:
est = 100
# print(_json) # {'uid': '95ff62922b3536189768bcc883598109', 'clip_b32_similarity_score': 0.299560546875, 'clip_l14_similarity_score': 0.3017578125, 'caption': 'Picture of Car seat 0+ cover Little Goose', 'url': 'https://dealers.little-dutch.com/content/images/thumbs/002/0023598_1000.jpeg', 'key': '000010028', 'status': 'success', 'error_message': None, 'width': None, 'height': None, 'original_width': None, 'original_height': None, 'exif': '{}', 'sha512': 'b1e6f78d70b10645f54682c5cb01a8ba9584f6e34b4f292b431350fa93e94060'}
return ([image_for_vae], caption, image_for_sscd, uid, watermark_class, est)
COLUMNS = {
"key": "str",
"caption": "str",
"vae_latents": "np16",
"t5_xl_embeddings": "uint8",
"sscd_embeddings": "np16",
"hps_score": "str",
"width_height": "str",
}
@torch.no_grad()
def convert_to_mds(dataset_paths, out_roots, device, is_test=False):
logging.info(f"Processing on {device}")
vae_model = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
)
vae_model = vae_model.to(device).eval()
vae_model.to(memory_format=torch.channels_last)
t5tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pile-t5-xl", use_fast=False)
t5tokenizer.pad_token = t5tokenizer.bos_token
t5model = AutoModelForSeq2SeqLM.from_pretrained("EleutherAI/pile-t5-xl", torch_dtype=torch.bfloat16)
t5model = t5model.to(device).eval()
#t5model.encoder = torch.compile(t5model.encoder, mode='reduce-overhead')
sscd_model = torch.jit.load("sscd_disc_mixup.torchscript.pt").to("cuda")
num_chunk = 1
dataset_bulks = [
dataset_paths[i : i + num_chunk] for i in range(0, len(dataset_paths), num_chunk)
]
out_roots_bulks = [out_roots[i : i + num_chunk] for i in range(0, len(out_roots), num_chunk)]
for dataset_paths, out_roots in zip(dataset_bulks, out_roots_bulks):
for dataset_path in dataset_paths:
if not os.path.exists(dataset_path):
logging.info(f"Dataset not found: {dataset_path}")
return
out_root = out_roots[0]
dataset = wds.DataPipeline(
wds.SimpleShardList(dataset_paths),
wds.split_by_worker,
wds.tarfile_to_samples(handler=wds.warn_and_continue),
wds.decode("pil", handler=wds.warn_and_continue),
wds.to_tuple("__key__", "jpg;png", "json", handler=wds.warn_and_continue),
wds.map(wds_preprocess),
wds.batched(64),
)
dataloader = DataLoader(
dataset,
batch_size=None,
num_workers=2,
shuffle=False,
drop_last=False,
)
t0 = time.time()
localbatches_datasets = {}
# iterate over dataset and for make bucket-dataset for each resolutions.
print("Start Making Bucket Batching")
for idx, batch in tqdm(enumerate(dataloader)):
image_for_vae, captions, image_for_sscd, uids, watermark_class, est = batch
for i in range(len(captions)):
w, h = image_for_vae[i][0].shape[-2:]
if w * h <= 512 * 512:
continue
if (w, h) not in localbatches_datasets:
localbatches_datasets[(w, h)] = []
localbatches_datasets[(w, h)].append((image_for_vae[i][0], captions[i], image_for_sscd[i], uids[i], watermark_class[i], est[i]))
print("End Making Bucket Batching")
for k in localbatches_datasets.keys():
print(f"Resolution {k} has {len(localbatches_datasets[k])} samples")
for (w, h), localbatches in localbatches_datasets.items():
# if its too small continue
if len(localbatches) < 64:
continue
print(f"Processing resolution {w}x{h}")
out_root_res = os.path.join(out_root, f"{w}_{h}")
sub_data_root = os.path.join(out_root_res, f"data")
if os.path.exists(sub_data_root):
for file in os.listdir(sub_data_root):
os.remove(os.path.join(sub_data_root, file))
os.makedirs(sub_data_root, exist_ok=True)
with MDSWriter(out=sub_data_root, columns=COLUMNS) as out:
BS = int(64 * 512 * 512 / (w * h))
local_dl = DataLoader(localbatches, batch_size=BS, num_workers=2, prefetch_factor=4, shuffle=False, drop_last=False)
for image_for_vae, captions, image_for_sscd, uids, watermark_class, est in local_dl:
### SSCD
image_for_sscd = image_for_sscd.to(
"cuda", memory_format=torch.channels_last
)
sscd_embeddings = sscd_model(image_for_sscd)
sscd_embeddings = sscd_embeddings.cpu().numpy().astype(np.float16)
### VAE
image_for_vae = image_for_vae.to(device).half()
vae_latents = vae_model.encode(image_for_vae).latent_dist.sample()
vae_outputs = vae_latents.cpu().numpy().astype(np.float16)
### T5
t5_inputs = t5tokenizer(
captions,
return_tensors="pt",
padding=True,
truncation=True,
max_length=256,
)
t5_inputs = {k: v.to(device) for k, v in t5_inputs.items()}
t5_outputs = t5model.encoder(**t5_inputs)[0]
mask = (
t5_inputs["attention_mask"].unsqueeze(-1).expand(t5_outputs.shape)
)
t5_outputs = t5_outputs * mask
t5_outputs = (
((t5_outputs.clip(-0.25, 0.25) / 0.5 + 0.5) * 255.0)
.to(torch.uint8)
.cpu()
.numpy()
.astype(np.uint8)
)
### Write
for i in range(len(captions)):
sample = {
"vae_latents": vae_outputs[i],
"caption": str(captions[i]),
"t5_xl_embeddings": t5_outputs[i],
"sscd_embeddings": sscd_embeddings[i],
"key": uids[i],
"hps_score" : str(est),
"width_height": f"{w}x{h}",
}
out.write(sample)
def main(datasetinfos, out_roots, is_test=False, device_name="cuda"):
device = torch.device(device_name if torch.cuda.is_available() else "cpu")
print(f"Processing on {device}")
convert_to_mds(datasetinfos, out_roots, device, is_test=is_test)
logging.info("Finished processing images.")
def detect_small_or_nonexistent_dirs(current_dir, start=0, end=18503, max_size=512):
small_or_nonexistent_dirs = []
for i in range(start, end + 1):
dir_name = f"{i:05d}"
dir_path = os.path.join(current_dir, dir_name)
if not os.path.exists(dir_path):
if i%64 < 8:
small_or_nonexistent_dirs.append(i)
elif os.path.isdir(dir_path):
total_size = 0
for dirpath, dirnames, filenames in os.walk(dir_path):
for f in filenames:
fp = os.path.join(dirpath, f)
total_size += os.path.getsize(fp)
if total_size < max_size:
small_or_nonexistent_dirs.append(i)
return small_or_nonexistent_dirs
def save_to_json(data, filename):
with open(filename, "w") as f:
json.dump(data, f)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert images to MDS format.")
parser.add_argument(
"--device",
type=str,
default="cuda",
help="Device to use for processing (cuda or cpu).",
)
parser.add_argument(
"--file_index", type=int, default=0, help="File index to process."
)
parser.add_argument(
"--is_test", action="store_true", help="Run in test mode with reduced dataset."
)
parser.add_argument(
"--outdir_basepath",
type=str,
default="/home/ubuntu/bucketmds",
help="Output directory path.",
)
parser.add_argument(
"--tar_indir_basepath",
type=str,
default="/home/ubuntu/wds_deduped_0.3m_highres",
help="Input directory path.",
)
args = parser.parse_args()
# reqsids = json.load(open("{outdir_basepath}/small_or_nonexistent_dirs.json"))
# number of tars
tars = glob.glob(f"{args.tar_indir_basepath}/*.tar")
reqsids = range(len(tars))
#reqsids = detect_small_or_nonexistent_dirs(args.outdir_basepath, start=0, end=18503, max_size=512)
# reqsids = [16516 - 64, 16516, 16580, 16644, 16708, 16772, 16836, 16900, 16964, 17028, 17092, 17156, 17220, 17284, 17348, 17412, 17476, 17540, 17604, 17668, 17732, 17796, 17860, 17924, 17988, 18052, 18116, 18180, 18244, 18308, 18372, 18436, 18500]
#print(reqsids)
out_roots, datasetinfos = [], []
for i, reqid in enumerate(reqsids):
if i % 8 == args.file_index:
out_root = f"{args.outdir_basepath}/{str(int(reqid)).zfill(5)}"
dataset_path = f"{args.tar_indir_basepath}/{str(int(reqid)).zfill(5)}.tar"
out_roots.append(out_root)
datasetinfos.append(dataset_path)
main(datasetinfos, out_roots, is_test=args.is_test, device_name=args.device)
NUM_GPUS=8
START_INDEX=0
END_INDEX=7
NUM_CPUS=$(nproc) # Get the number of available CPUs
CORES_PER_PROCESS=24
echo "Number of GPUs: $NUM_GPUS | Number of CPUs: $NUM_CPUS | Cores per process: $CORES_PER_PROCESS"
for ((i=START_INDEX; i<=END_INDEX; i++)); do
GPU_INDEX=$((i % NUM_GPUS))
CPU_START=$(( (i * CORES_PER_PROCESS) % NUM_CPUS ))
CPU_END=$(( CPU_START + CORES_PER_PROCESS - 1 ))
export CUDA_VISIBLE_DEVICES=$GPU_INDEX
taskset -c $CPU_START-$CPU_END python /home/ubuntu/make_shards/run_featgen_hr_bucket.py --device cuda --file_index $i --tar_indir_basepath "/home/ubuntu/mjwds_deduped_0.3m_highres" --outdir_basepath "/home/ubuntu/mj_bucketmds" &
echo "Started process $i on GPU $GPU_INDEX with CPU $CPU_START-$CPU_END"
done
wait # Wait for all background processes to finish
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment