Skip to content

Instantly share code, notes, and snippets.

@nielsrolf
Last active February 7, 2023 11:32
Show Gist options
  • Save nielsrolf/eb9cb6aa093ce1ded7a34d35a05c8736 to your computer and use it in GitHub Desktop.
Save nielsrolf/eb9cb6aa093ce1ded7a34d35a05c8736 to your computer and use it in GitHub Desktop.
Image deduplication
import torch
import clip
from PIL import Image
import requests
from io import BytesIO
import click
import os
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
def get_image(img_url):
if img_url.startswith("http"):
response = requests.get(img_url)
img = Image.open(BytesIO(response.content))
else:
img = Image.open(img_url)
return img
def deduplicate(img_urls, cutoff_sim=0.975):
"""img_urls: either actual urls or local paths"""
features = []
for img_url in img_urls:
img = get_image(img_url)
image = preprocess(img).unsqueeze(0).to(device)
with torch.no_grad():
image_features = model.encode_image(image)
features.append(image_features)
features = torch.cat(features) # (n, 512)
# normalize features
features = features / features.norm(dim=-1, keepdim=True) # (n, 512)
sim = torch.mm(features, features.T) # (n, n)
# get indices of images that are not duplicates and print the sets of duplicates
deduped = []
duplicates = []
for i in range(len(img_urls)):
if img_urls[i] in duplicates:
continue
for j in range(i+1, len(img_urls)):
if sim[i, j] > cutoff_sim:
print(f"Duplicates: \n- {img_urls[i]} \n- {img_urls[j]}")
duplicates.append(img_urls[j])
deduped.append(img_urls[i])
return deduped
@click.command()
@click.argument("img_urls", nargs=-1)
@click.option("--cutoff-sim", default=0.975)
@click.option("--target_dir", default="deduped")
def main(img_urls, cutoff_sim, target_dir):
os.makedirs(target_dir, exist_ok=True)
deduped = deduplicate(img_urls, cutoff_sim)
print("Found duplicates:", "\n - ".join(list(set(img_urls) - set(deduped))))
for img_url in deduped:
if img_url.startswith("http"):
response = requests.get(img_url)
img = Image.open(BytesIO(response.content))
img.save(os.path.join(target_dir, os.path.basename(img_url)))
else:
os.system(f"cp {img_url} {target_dir}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment