Skip to content

Instantly share code, notes, and snippets.

@Mikubill
Last active August 30, 2023 14:36
Show Gist options
  • Save Mikubill/ac42360ebb7b808cd8d8284ec76aed8d to your computer and use it in GitHub Desktop.
Save Mikubill/ac42360ebb7b808cd8d8284ec76aed8d to your computer and use it in GitHub Desktop.
Batch inference for CLIP Aesthetic Score
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import clip
import pandas as pd
import hashlib
import numpy
import cv2
from tqdm.auto import tqdm
from pathlib import Path
from PIL import Image
from torch.utils.data import Dataset, DataLoader
# if you changed the MLP architecture during training, change it also here:
class MLP(nn.Module):
def __init__(self, input_size, xcol="emb", ycol="avg_rating"):
super().__init__()
self.input_size = input_size
self.xcol = xcol
self.ycol = ycol
self.layers = nn.Sequential(
nn.Linear(self.input_size, 1024),
# nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(1024, 128),
# nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(128, 64),
# nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(64, 16),
# nn.ReLU(),
nn.Linear(16, 1),
)
def forward(self, x):
return self.layers(x)
def training_step(self, batch, batch_idx):
x = batch[self.xcol]
y = batch[self.ycol].reshape(-1, 1)
x_hat = self.layers(x)
loss = F.mse_loss(x_hat, y)
return loss
def validation_step(self, batch, batch_idx):
x = batch[self.xcol]
y = batch[self.ycol].reshape(-1, 1)
x_hat = self.layers(x)
loss = F.mse_loss(x_hat, y)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
def _binary_array_to_hex(arr):
bit_string = ''.join(str(b) for b in 1 * arr.flatten())
width = int(numpy.ceil(len(bit_string) / 4))
return '{:0>{width}x}'.format(int(bit_string, 2), width=width)
def phashstr(image, hash_size=8, highfreq_factor=4):
# type: (Image.Image, int, int) -> str
if hash_size < 2:
raise ValueError('Hash size must be greater than or equal to 2')
import scipy.fftpack
img_size = hash_size * highfreq_factor
image = image.convert('L').resize((img_size, img_size), Image.Resampling.LANCZOS)
pixels = numpy.asarray(image)
dct = scipy.fftpack.dct(scipy.fftpack.dct(pixels, axis=0), axis=1)
dctlowfreq = dct[:hash_size, :hash_size]
med = numpy.median(dctlowfreq)
diff = dctlowfreq > med
return _binary_array_to_hex(diff.flatten())
def normalized(a, axis=-1, order=2):
import numpy as np # pylint: disable=import-outside-toplevel
l2 = np.atleast_1d(np.linalg.norm(a, order, axis))
l2[l2 == 0] = 1
return a / np.expand_dims(l2, axis)
class ImageDataset(Dataset):
def __init__(self, imgdir, transform):
self.img_list = [
str(p)
for ext in ["*.jpg", "*.png", "*.jpeg", "*.webp"]
for p in Path(imgdir).rglob(ext)
]
self.transform = transform
def __len__(self):
return len(self.img_list)
def __getitem__(self, idx):
img_path = self.img_list[idx]
with open(img_path, "rb") as file:
data = file.read()
md5 = hashlib.md5(data).hexdigest()
sha1 = hashlib.sha1(data).hexdigest()
image = Image.open(img_path)
phash = phashstr(image)
image_np = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
laplacian_variance = cv2.Laplacian(image_np, cv2.CV_64F).var()
image = self.transform(image)
metrics = (phash, md5, sha1, laplacian_variance)
return image, img_path, metrics
if __name__ == "__main__":
print(f"Loading model sac+logos+ava1-l14-linearMSE.pth")
model = MLP(768) # CLIP embedding dim is 768 for CLIP ViT L 14
pthpath = "https://huggingface.co/trojblue/clip-aesthetics/resolve/main/sac%2Blogos%2Bava1-l14-linearMSE.pth"
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cpu":
print("Using CPU, this might be slow.")
model.load_state_dict(torch.hub.load_state_dict_from_url(pthpath, map_location=device))
model.to(device).eval()
model2, preprocess = clip.load("ViT-L/14", device=device) # RN50x64
imgdir = sys.argv[1]
dataset = ImageDataset(imgdir, preprocess)
dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=4)
data = [] # Create an empty list to store data
bar = tqdm(total=len(dataset), desc="Processing", position=0)
clip_avg = 0.0
for inputs, img_paths, metrics in dataloader:
inputs = inputs.to(device)
with torch.no_grad():
img_emb = model2.encode_image(inputs)
img_emb = normalized(img_emb.cpu().detach().numpy())
predictions = model(torch.from_numpy(img_emb).to(device).type(torch.cuda.FloatTensor))
for idx, (prediction, img_path) in enumerate(zip(predictions, img_paths)):
phashes, md5s, sha1s, laplacian_variances = [metric[idx] for metric in metrics]
laplacian_variances = laplacian_variances.float().numpy().item()
data.append([img_path, prediction.item(), phashes, md5s, sha1s, laplacian_variances])
clip_avg += prediction.item()
bar.update(len(inputs))
bar.set_postfix(avg=clip_avg / bar.n)
df = pd.DataFrame(
data,
columns=[
"filename",
"clip_aesthetic",
"phash",
"md5",
"sha1",
"laplacian_variance",
],
)
df.to_parquet("metrics.parquet")
df.to_csv("metrics.csv")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment