Skip to content

Instantly share code, notes, and snippets.

@cloneofsimo
Created August 8, 2024 05:11
Show Gist options
  • Save cloneofsimo/e41d5718905a5023df7bab494cede051 to your computer and use it in GitHub Desktop.
Save cloneofsimo/e41d5718905a5023df7bab494cede051 to your computer and use it in GitHub Desktop.
vae_preprocess
import os
import torch
import json
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from diffusers.models import AutoencoderKL
from streaming import MDSWriter
import logging
import time
import numpy as np
from typing import Any
import json
from streaming.base.format.mds.encodings import Encoding, _encodings
from tqdm import tqdm
import webdataset as wds
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoProcessor, SiglipModel
import argparse
from torchvision import transforms
import re
# Initialize logging
logging.basicConfig(level=logging.INFO)
def modify_caption(caption: str) -> str:
cap = caption.replace("This image displays", '').strip()
if cap.startswith(':'):
cap = cap[1:]
return cap.strip()
class uint8(Encoding):
def encode(self, obj: Any) -> bytes:
return obj.tobytes()
def decode(self, data: bytes) -> Any:
return np.frombuffer(data, np.uint8)
class np16(Encoding):
def encode(self, obj: Any) -> bytes:
return obj.tobytes()
def decode(self, data: bytes) -> Any:
return np.frombuffer(data, np.float16)
_encodings["np16"] = np16
_encodings["uint8"] = uint8
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
)
small_288 = transforms.Compose(
[
transforms.Resize(288),
transforms.CenterCrop(288),
transforms.ToTensor(),
normalize,
]
)
def crop_to_center(image, new_size=768):
width, height = image.size
left = (width - new_size) / 2
top = (height - new_size) / 2
right = (width + new_size) / 2
bottom = (height + new_size) / 2
return image.crop((left, top, right, bottom))
def prepare_image(pil_image):
arr = np.array(pil_image.convert("RGB"))
arr = arr.astype(np.float32) / 127.5 - 1
arr = np.transpose(arr, [2, 0, 1])
image = torch.from_numpy(arr)
return image
def wds_preprocess(x):
key, pil_image, _json = x
pil_image = pil_image.convert("RGB")
if pil_image.size[0] > pil_image.size[1]:
pil_image = pil_image.resize((int(pil_image.size[0] * 512 / pil_image.size[1]), 512))
else:
pil_image = pil_image.resize((512, int(pil_image.size[1] * 512 / pil_image.size[0])))
pil_image = crop_to_center(pil_image, new_size=512)
image_for_vae = prepare_image(pil_image)
image_for_sscd = small_288(pil_image)
caption = _json["caption"] or ""
uid = _json.get("uid", key)
watermark_class = _json.get("watermark_class_id", 1)
est = _json.get("aesthetic_score", 100)
return (image_for_vae, caption, image_for_sscd, uid, watermark_class, est, [pil_image])
COLUMNS = {
"key": "str",
"caption": "str",
"vae_512x512_latents": "np16",
"t5_xl_embeddings": "uint8",
"sscd_embeddings": "np16",
"hps_score": "str",
"siglip_text_vec": "np16",
"siglip_image_vec": "np16",
"siglip_sim": "float32",
}
@torch.no_grad()
def convert_to_mds(dataset_paths, out_roots, device, is_test=False):
logging.info(f"Processing on {device}")
vae_model = AutoencoderKL.from_pretrained(
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.float16, subfolder='vae'
).to(device).eval()
vae_model.to(memory_format=torch.channels_last)
t5tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pile-t5-xl", use_fast=False)
t5tokenizer.pad_token = t5tokenizer.bos_token
t5model = AutoModelForSeq2SeqLM.from_pretrained("EleutherAI/pile-t5-xl", torch_dtype=torch.bfloat16).to(device).eval()
sscd_model = torch.jit.load("sscd_disc_mixup.torchscript.pt").to(device)
siglip_model = SiglipModel.from_pretrained("google/siglip-large-patch16-256").to(device).eval()
siglip_processor = AutoProcessor.from_pretrained("google/siglip-large-patch16-256")
num_chunk = 8
dataset_bulks = [dataset_paths[i:i + num_chunk] for i in range(0, len(dataset_paths), num_chunk)]
out_roots_bulks = [out_roots[i:i + num_chunk] for i in range(0, len(out_roots), num_chunk)]
for dataset_paths, out_roots in zip(dataset_bulks, out_roots_bulks):
for dataset_path in dataset_paths:
if not os.path.exists(dataset_path):
logging.info(f"Dataset not found: {dataset_path}")
return
out_root = out_roots[0]
dataset = wds.DataPipeline(
wds.SimpleShardList(dataset_paths),
wds.split_by_worker,
wds.tarfile_to_samples(handler=wds.warn_and_continue),
wds.decode("pil", handler=wds.warn_and_continue),
wds.to_tuple("__key__", "jpg;png", "json", handler=wds.warn_and_continue),
wds.map(wds_preprocess),
wds.batched(64),
)
dataloader = DataLoader(
dataset,
batch_size=None,
num_workers=16,
prefetch_factor=4,
shuffle=False,
drop_last=False,
)
t0 = time.time()
sub_data_root = os.path.join(out_root, "data")
if os.path.exists(sub_data_root):
for file in os.listdir(sub_data_root):
os.remove(os.path.join(sub_data_root, file))
os.makedirs(sub_data_root, exist_ok=True)
inference_latencies = []
keys = []
with MDSWriter(out=sub_data_root, columns=COLUMNS) as out:
for idx, batch in tqdm(enumerate(dataloader)):
if is_test and idx > 0:
break
start_time = time.time()
image_for_vae, captions, image_for_sscd, uids, watermark_class, est, pil_images = batch
pil_images = [pil_images[i][0] for i in range(len(pil_images))]
est_idx = np.where(np.array(est) > 3)[0]
if len(est_idx) == 0:
continue
watermark_class = np.array(watermark_class).astype(int)
image_for_vae = image_for_vae[est_idx]
captions = [modify_caption(captions[i]) for i in est_idx]
watermark_class = watermark_class[est_idx]
for i in range(len(captions)):
if watermark_class[i] == 0:
captions[i] = "watermarked image " + captions[i]
image_for_sscd = image_for_sscd[est_idx]
uids = [uids[i] for i in est_idx]
est = np.array(est)[est_idx]
# SSCD
image_for_sscd = image_for_sscd.to(device, memory_format=torch.channels_last)
sscd_embeddings = sscd_model(image_for_sscd)
sscd_embeddings = sscd_embeddings.cpu().numpy().astype(np.float16)
# VAE
image_for_vae = image_for_vae.to(device).half()
vae_latents = vae_model.encode(image_for_vae).latent_dist.sample()
vae_outputs = vae_latents.cpu().numpy().astype(np.float16)
# T5
t5_inputs = t5tokenizer(
captions,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512,
)
t5_inputs = {k: v.to(device) for k, v in t5_inputs.items()}
t5_outputs = t5model.encoder(**t5_inputs)[0]
mask = t5_inputs["attention_mask"].unsqueeze(-1).expand(t5_outputs.shape)
t5_outputs = t5_outputs * mask
t5_outputs = ((t5_outputs.clip(-0.25, 0.25) / 0.5 + 0.5) * 255.0).to(torch.uint8).cpu().numpy().astype(np.uint8)
# t5 outputs as list of 2d
t5_output_as_list = [t5_outputs[i] for i in range(t5_outputs.shape[0])]
# where there is mask, cut it off
mask = mask.cpu().numpy().astype(np.uint8)
t5_masksize = [np.where(mask[i] == 0)[0][0] if 0 in mask[i] else mask[i].shape[0] for i in range(mask.shape[0])]
t5_outputs = [t5_output_as_list[i][:t5_masksize[i]] for i in range(len(t5_output_as_list))]
# SigLIP
siglip_inputs = siglip_processor(
text=captions,
images=pil_images,
return_tensors="pt",
padding='max_length',
truncation=True,
).to(device)
siglip_outputs = siglip_model(**siglip_inputs)
siglip_text_embeddings = siglip_outputs.text_embeds.cpu().numpy().astype(np.float16)
siglip_image_embeddings = siglip_outputs.image_embeds.cpu().numpy().astype(np.float16)
# elementwise cos similarity
# normalize first
siglip_similarities = np.einsum("ij,ij->i", siglip_text_embeddings, siglip_image_embeddings)
# Write
for i in range(len(captions)):
if siglip_similarities[i] < 0.05:
print("Oh no, not similar!")
# # write the image and caption as json, img at ./local_bad_images
# os.makedirs("./local_bad_images", exist_ok=True)
# pil_images[i].save(f"./local_bad_images/{uids[i]}.jpg")
# with open(f"./local_bad_images/{uids[i]}.json", "w") as f:
# json.dump({"caption": captions[i], "similarity": float(siglip_similarities[i])}, f)
continue
sample = {
"vae_512x512_latents": vae_outputs[i],
"caption": str(captions[i]),
"t5_xl_embeddings": t5_outputs[i],
"sscd_embeddings": sscd_embeddings[i],
"key": uids[i],
"hps_score": str(est[i]),
"siglip_text_vec": siglip_text_embeddings[i],
"siglip_image_vec": siglip_image_embeddings[i],
"siglip_sim": float(siglip_similarities[i]),
}
out.write(sample)
inference_latencies.append(time.time() - start_time)
keys.extend(uids)
logging.info(f"Average Inference Latency on {device}: {np.mean(inference_latencies)} seconds")
logging.info(f"Total Inference Time on {device}: {time.time() - t0} seconds")
save_to_json(keys, os.path.join(out_root, "keys.json"))
def main(datasetinfos, out_roots, is_test=False, device_name="cuda"):
device = torch.device(device_name if torch.cuda.is_available() else "cpu")
print(f"Processing on {device}")
convert_to_mds(datasetinfos, out_roots, device, is_test=is_test)
logging.info("Finished processing images.")
def detect_small_or_nonexistent_dirs(current_dir, start=0, end=18503, max_size=512):
small_or_nonexistent_dirs = []
for i in range(start, end + 1):
dir_name = f"{i:05d}"
dir_path = os.path.join(current_dir, dir_name)
if not os.path.exists(dir_path):
if i % 64 < 8:
small_or_nonexistent_dirs.append(i)
elif os.path.isdir(dir_path):
total_size = 0
for dirpath, dirnames, filenames in os.walk(dir_path):
for f in filenames:
fp = os.path.join(dirpath, f)
total_size += os.path.getsize(fp)
if total_size < max_size:
small_or_nonexistent_dirs.append(i)
return small_or_nonexistent_dirs
def save_to_json(data, filename):
with open(filename, "w") as f:
json.dump(data, f)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert images to MDS format.")
parser.add_argument("--device", type=str, default="cuda", help="Device to use for processing (cuda or cpu).")
parser.add_argument("--file_index", type=int, default=0, help="File index to process.")
parser.add_argument("--is_test", action="store_true", help="Run in test mode with reduced dataset.")
parser.add_argument("--outdir_basepath", type=str, default="/jfs/mds_pp512_fvae_8.5M", help="Output directory path.")
parser.add_argument("--tar_indir_basepath", type=str, default="/home/ubuntu/pprowds", help="Input directory path.")
args = parser.parse_args()
reqsids = list(range(2000))
out_roots, datasetinfos = [], []
for i, reqid in enumerate(reqsids):
if i % 8 == args.file_index:
out_root = f"{args.outdir_basepath}/{str(int(reqid)).zfill(5)}"
dataset_path = f"{args.tar_indir_basepath}/{str(int(reqid)).zfill(5)}.tar"
out_roots.append(out_root)
datasetinfos.append(dataset_path)
main(datasetinfos, out_roots, is_test=args.is_test, device_name=args.device)
@cloneofsimo
Copy link
Author

cloneofsimo commented Aug 8, 2024

wget https://dl.fbaipublicfiles.com/sscd-copy-detection/sscd_disc_mixup.torchscript.pt

NUM_GPUS=8
START_INDEX=0
END_INDEX=7
NUM_CPUS=$(nproc)  # Get the number of available CPUs
CORES_PER_PROCESS=24
echo "Number of GPUs: $NUM_GPUS | Number of CPUs: $NUM_CPUS | Cores per process: $CORES_PER_PROCESS"

for ((i=START_INDEX; i<=END_INDEX; i++)); do

    GPU_INDEX=$((i % NUM_GPUS))
    CPU_START=$(( (i * CORES_PER_PROCESS) % NUM_CPUS ))
    CPU_END=$(( CPU_START + CORES_PER_PROCESS - 1 ))

    export CUDA_VISIBLE_DEVICES=$GPU_INDEX
    taskset -c $CPU_START-$CPU_END python /home/ubuntu/make_shards/run_featgen_siglip_fvae.py --device cuda --file_index $i &
    echo "Started process $i on GPU $GPU_INDEX with CPU $CPU_START-$CPU_END"
done

wait  # Wait for all background processes to finish

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment