Last active
August 30, 2023 14:36
-
-
Save Mikubill/ac42360ebb7b808cd8d8284ec76aed8d to your computer and use it in GitHub Desktop.
Batch inference for CLIP Aesthetic Score
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import clip | |
import pandas as pd | |
import hashlib | |
import numpy | |
import cv2 | |
from tqdm.auto import tqdm | |
from pathlib import Path | |
from PIL import Image | |
from torch.utils.data import Dataset, DataLoader | |
# if you changed the MLP architecture during training, change it also here: | |
class MLP(nn.Module): | |
def __init__(self, input_size, xcol="emb", ycol="avg_rating"): | |
super().__init__() | |
self.input_size = input_size | |
self.xcol = xcol | |
self.ycol = ycol | |
self.layers = nn.Sequential( | |
nn.Linear(self.input_size, 1024), | |
# nn.ReLU(), | |
nn.Dropout(0.2), | |
nn.Linear(1024, 128), | |
# nn.ReLU(), | |
nn.Dropout(0.2), | |
nn.Linear(128, 64), | |
# nn.ReLU(), | |
nn.Dropout(0.1), | |
nn.Linear(64, 16), | |
# nn.ReLU(), | |
nn.Linear(16, 1), | |
) | |
def forward(self, x): | |
return self.layers(x) | |
def training_step(self, batch, batch_idx): | |
x = batch[self.xcol] | |
y = batch[self.ycol].reshape(-1, 1) | |
x_hat = self.layers(x) | |
loss = F.mse_loss(x_hat, y) | |
return loss | |
def validation_step(self, batch, batch_idx): | |
x = batch[self.xcol] | |
y = batch[self.ycol].reshape(-1, 1) | |
x_hat = self.layers(x) | |
loss = F.mse_loss(x_hat, y) | |
return loss | |
def configure_optimizers(self): | |
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) | |
return optimizer | |
def _binary_array_to_hex(arr): | |
bit_string = ''.join(str(b) for b in 1 * arr.flatten()) | |
width = int(numpy.ceil(len(bit_string) / 4)) | |
return '{:0>{width}x}'.format(int(bit_string, 2), width=width) | |
def phashstr(image, hash_size=8, highfreq_factor=4): | |
# type: (Image.Image, int, int) -> str | |
if hash_size < 2: | |
raise ValueError('Hash size must be greater than or equal to 2') | |
import scipy.fftpack | |
img_size = hash_size * highfreq_factor | |
image = image.convert('L').resize((img_size, img_size), Image.Resampling.LANCZOS) | |
pixels = numpy.asarray(image) | |
dct = scipy.fftpack.dct(scipy.fftpack.dct(pixels, axis=0), axis=1) | |
dctlowfreq = dct[:hash_size, :hash_size] | |
med = numpy.median(dctlowfreq) | |
diff = dctlowfreq > med | |
return _binary_array_to_hex(diff.flatten()) | |
def normalized(a, axis=-1, order=2): | |
import numpy as np # pylint: disable=import-outside-toplevel | |
l2 = np.atleast_1d(np.linalg.norm(a, order, axis)) | |
l2[l2 == 0] = 1 | |
return a / np.expand_dims(l2, axis) | |
class ImageDataset(Dataset): | |
def __init__(self, imgdir, transform): | |
self.img_list = [ | |
str(p) | |
for ext in ["*.jpg", "*.png", "*.jpeg", "*.webp"] | |
for p in Path(imgdir).rglob(ext) | |
] | |
self.transform = transform | |
def __len__(self): | |
return len(self.img_list) | |
def __getitem__(self, idx): | |
img_path = self.img_list[idx] | |
with open(img_path, "rb") as file: | |
data = file.read() | |
md5 = hashlib.md5(data).hexdigest() | |
sha1 = hashlib.sha1(data).hexdigest() | |
image = Image.open(img_path) | |
phash = phashstr(image) | |
image_np = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) | |
laplacian_variance = cv2.Laplacian(image_np, cv2.CV_64F).var() | |
image = self.transform(image) | |
metrics = (phash, md5, sha1, laplacian_variance) | |
return image, img_path, metrics | |
if __name__ == "__main__": | |
print(f"Loading model sac+logos+ava1-l14-linearMSE.pth") | |
model = MLP(768) # CLIP embedding dim is 768 for CLIP ViT L 14 | |
pthpath = "https://huggingface.co/trojblue/clip-aesthetics/resolve/main/sac%2Blogos%2Bava1-l14-linearMSE.pth" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
if device == "cpu": | |
print("Using CPU, this might be slow.") | |
model.load_state_dict(torch.hub.load_state_dict_from_url(pthpath, map_location=device)) | |
model.to(device).eval() | |
model2, preprocess = clip.load("ViT-L/14", device=device) # RN50x64 | |
imgdir = sys.argv[1] | |
dataset = ImageDataset(imgdir, preprocess) | |
dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=4) | |
data = [] # Create an empty list to store data | |
bar = tqdm(total=len(dataset), desc="Processing", position=0) | |
clip_avg = 0.0 | |
for inputs, img_paths, metrics in dataloader: | |
inputs = inputs.to(device) | |
with torch.no_grad(): | |
img_emb = model2.encode_image(inputs) | |
img_emb = normalized(img_emb.cpu().detach().numpy()) | |
predictions = model(torch.from_numpy(img_emb).to(device).type(torch.cuda.FloatTensor)) | |
for idx, (prediction, img_path) in enumerate(zip(predictions, img_paths)): | |
phashes, md5s, sha1s, laplacian_variances = [metric[idx] for metric in metrics] | |
laplacian_variances = laplacian_variances.float().numpy().item() | |
data.append([img_path, prediction.item(), phashes, md5s, sha1s, laplacian_variances]) | |
clip_avg += prediction.item() | |
bar.update(len(inputs)) | |
bar.set_postfix(avg=clip_avg / bar.n) | |
df = pd.DataFrame( | |
data, | |
columns=[ | |
"filename", | |
"clip_aesthetic", | |
"phash", | |
"md5", | |
"sha1", | |
"laplacian_variance", | |
], | |
) | |
df.to_parquet("metrics.parquet") | |
df.to_csv("metrics.csv") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment