Skip to content

Instantly share code, notes, and snippets.

@turicas
Created November 25, 2024 19:30
Show Gist options
  • Save turicas/b36fb1876b40888d92f5b2eefa2e9779 to your computer and use it in GitHub Desktop.
Save turicas/b36fb1876b40888d92f5b2eefa2e9779 to your computer and use it in GitHub Desktop.
Image embedding using timm and Dinov2
from pathlib import Path
import timm
import torch
from PIL import Image
class FeatureExtractor:
"""Extract embeddings from images using timm's Dinov2 models"""
model_names = (
"timm/vit_small_patch14_dinov2.lvd142m",
"timm/vit_base_patch14_dinov2.lvd142m",
"timm/vit_large_patch14_dinov2.lvd142m",
"timm/vit_giant_patch14_dinov2.lvd142m",
"timm/vit_small_patch14_reg4_dinov2.lvd142m",
"timm/vit_base_patch14_reg4_dinov2.lvd142m",
"timm/vit_large_patch14_reg4_dinov2.lvd142m",
"timm/vit_giant_patch14_reg4_dinov2.lvd142m",
)
def __init__(self, model_name, device="cpu"):
if model_name not in self.model_names:
raise ValueError(f"Unknown model name: {repr(model_name)}")
self.device = device
self.model = timm.create_model(model_name, pretrained=True, num_classes=0).to(self.device)
self.model.eval()
data_config = timm.data.resolve_model_data_config(self.model)
self.transforms = timm.data.create_transform(**data_config, is_training=False)
def load_normalize(self, image: str | Path | Image.Image):
if isinstance(image, (str, Path)):
img = Image.open(image)
elif isinstance(image, Image.Image):
img = image
return self.transforms(img).to(self.device).unsqueeze(0)
@torch.no_grad
def extract(self, image: str | Path | Image.Image):
input_tensor = self.load_normalize(image)
output = self.model(input_tensor).squeeze(0)
normalized = torch.nn.functional.normalize(output, dim=-1)
return normalized
@torch.no_grad
def extract_batch(self, image_filenames):
"""Run a batch embedding extraction using `image_filenames`"""
input_tensor = torch.cat([self.load_normalize(image_filename) for image_filename in image_filenames])
output = self.model(input_tensor)
normalized = torch.nn.functional.normalize(output, dim=-1)
return normalized
@torch.no_grad
def extract_many(self, image_filenames, batch_size=16):
"""Extract embeddings from `image_filenames` in batches, using `batch_size`"""
batch = []
for filename in image_filenames:
batch.append(filename)
if len(batch) == batch_size:
yield from zip(batch, self.extract_batch(batch))
batch = []
if batch:
yield from zip(batch, self.extract_batch(batch))
if __name__ == "__main__":
import argparse
import traceback
parser = argparse.ArgumentParser()
parser.add_argument("--batch-size", "-b", type=int, default=16)
parser.add_argument("--device", "-d", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
parser.add_argument("--registers", "-r", action="store_true", help="Use model with registers")
parser.add_argument("model_size", choices=["small", "base", "large", "giant"], help="Size of Dinov2 model")
parser.add_argument("image_filename", type=Path, nargs="+")
args = parser.parse_args()
device = args.device
batch_size = args.batch_size
model_size = args.model_size
use_registers = args.registers
image_filenames = args.image_filename
model_name = f"timm/vit_{model_size}_patch14_{'reg4_' if use_registers else ''}dinov2.lvd142m"
extractor = FeatureExtractor(model_name, device=device)
for image_filename, embedding in extractor.extract_many(image_filenames, batch_size=batch_size):
print(f"{image_filename.absolute()}: {embedding.tolist()}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment