Last active
February 7, 2023 11:32
-
-
Save nielsrolf/eb9cb6aa093ce1ded7a34d35a05c8736 to your computer and use it in GitHub Desktop.
Image deduplication
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import clip | |
from PIL import Image | |
import requests | |
from io import BytesIO | |
import click | |
import os | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model, preprocess = clip.load("ViT-B/32", device=device) | |
def get_image(img_url): | |
if img_url.startswith("http"): | |
response = requests.get(img_url) | |
img = Image.open(BytesIO(response.content)) | |
else: | |
img = Image.open(img_url) | |
return img | |
def deduplicate(img_urls, cutoff_sim=0.975): | |
"""img_urls: either actual urls or local paths""" | |
features = [] | |
for img_url in img_urls: | |
img = get_image(img_url) | |
image = preprocess(img).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
image_features = model.encode_image(image) | |
features.append(image_features) | |
features = torch.cat(features) # (n, 512) | |
# normalize features | |
features = features / features.norm(dim=-1, keepdim=True) # (n, 512) | |
sim = torch.mm(features, features.T) # (n, n) | |
# get indices of images that are not duplicates and print the sets of duplicates | |
deduped = [] | |
duplicates = [] | |
for i in range(len(img_urls)): | |
if img_urls[i] in duplicates: | |
continue | |
for j in range(i+1, len(img_urls)): | |
if sim[i, j] > cutoff_sim: | |
print(f"Duplicates: \n- {img_urls[i]} \n- {img_urls[j]}") | |
duplicates.append(img_urls[j]) | |
deduped.append(img_urls[i]) | |
return deduped | |
@click.command() | |
@click.argument("img_urls", nargs=-1) | |
@click.option("--cutoff-sim", default=0.975) | |
@click.option("--target_dir", default="deduped") | |
def main(img_urls, cutoff_sim, target_dir): | |
os.makedirs(target_dir, exist_ok=True) | |
deduped = deduplicate(img_urls, cutoff_sim) | |
print("Found duplicates:", "\n - ".join(list(set(img_urls) - set(deduped)))) | |
for img_url in deduped: | |
if img_url.startswith("http"): | |
response = requests.get(img_url) | |
img = Image.open(BytesIO(response.content)) | |
img.save(os.path.join(target_dir, os.path.basename(img_url))) | |
else: | |
os.system(f"cp {img_url} {target_dir}") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment