Created
November 25, 2024 19:30
-
-
Save turicas/b36fb1876b40888d92f5b2eefa2e9779 to your computer and use it in GitHub Desktop.
Image embedding using timm and Dinov2
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pathlib import Path | |
import timm | |
import torch | |
from PIL import Image | |
class FeatureExtractor: | |
"""Extract embeddings from images using timm's Dinov2 models""" | |
model_names = ( | |
"timm/vit_small_patch14_dinov2.lvd142m", | |
"timm/vit_base_patch14_dinov2.lvd142m", | |
"timm/vit_large_patch14_dinov2.lvd142m", | |
"timm/vit_giant_patch14_dinov2.lvd142m", | |
"timm/vit_small_patch14_reg4_dinov2.lvd142m", | |
"timm/vit_base_patch14_reg4_dinov2.lvd142m", | |
"timm/vit_large_patch14_reg4_dinov2.lvd142m", | |
"timm/vit_giant_patch14_reg4_dinov2.lvd142m", | |
) | |
def __init__(self, model_name, device="cpu"): | |
if model_name not in self.model_names: | |
raise ValueError(f"Unknown model name: {repr(model_name)}") | |
self.device = device | |
self.model = timm.create_model(model_name, pretrained=True, num_classes=0).to(self.device) | |
self.model.eval() | |
data_config = timm.data.resolve_model_data_config(self.model) | |
self.transforms = timm.data.create_transform(**data_config, is_training=False) | |
def load_normalize(self, image: str | Path | Image.Image): | |
if isinstance(image, (str, Path)): | |
img = Image.open(image) | |
elif isinstance(image, Image.Image): | |
img = image | |
return self.transforms(img).to(self.device).unsqueeze(0) | |
@torch.no_grad | |
def extract(self, image: str | Path | Image.Image): | |
input_tensor = self.load_normalize(image) | |
output = self.model(input_tensor).squeeze(0) | |
normalized = torch.nn.functional.normalize(output, dim=-1) | |
return normalized | |
@torch.no_grad | |
def extract_batch(self, image_filenames): | |
"""Run a batch embedding extraction using `image_filenames`""" | |
input_tensor = torch.cat([self.load_normalize(image_filename) for image_filename in image_filenames]) | |
output = self.model(input_tensor) | |
normalized = torch.nn.functional.normalize(output, dim=-1) | |
return normalized | |
@torch.no_grad | |
def extract_many(self, image_filenames, batch_size=16): | |
"""Extract embeddings from `image_filenames` in batches, using `batch_size`""" | |
batch = [] | |
for filename in image_filenames: | |
batch.append(filename) | |
if len(batch) == batch_size: | |
yield from zip(batch, self.extract_batch(batch)) | |
batch = [] | |
if batch: | |
yield from zip(batch, self.extract_batch(batch)) | |
if __name__ == "__main__": | |
import argparse | |
import traceback | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--batch-size", "-b", type=int, default=16) | |
parser.add_argument("--device", "-d", type=str, default="cuda" if torch.cuda.is_available() else "cpu") | |
parser.add_argument("--registers", "-r", action="store_true", help="Use model with registers") | |
parser.add_argument("model_size", choices=["small", "base", "large", "giant"], help="Size of Dinov2 model") | |
parser.add_argument("image_filename", type=Path, nargs="+") | |
args = parser.parse_args() | |
device = args.device | |
batch_size = args.batch_size | |
model_size = args.model_size | |
use_registers = args.registers | |
image_filenames = args.image_filename | |
model_name = f"timm/vit_{model_size}_patch14_{'reg4_' if use_registers else ''}dinov2.lvd142m" | |
extractor = FeatureExtractor(model_name, device=device) | |
for image_filename, embedding in extractor.extract_many(image_filenames, batch_size=batch_size): | |
print(f"{image_filename.absolute()}: {embedding.tolist()}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment