Created
August 13, 2025 06:37
-
-
Save brianlow/6045fc636bd5504ebfbd4910c903f843 to your computer and use it in GitHub Desktop.
dino-embedding.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from PIL import Image | |
import torchvision.transforms as T | |
import numpy as np | |
from pathlib import Path | |
import itertools | |
image_dir = Path('../images') # ./imaages/{part-num}/t.jpg | |
output_file = Path('embeddings/lego-classify-10-1200.dino.npz') | |
# Load the dino_vitb16 model and move to device | |
device = torch.device("mps") | |
model = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16') | |
model.to(device) | |
model.eval() | |
# Define the image transformations | |
transform = T.Compose([ | |
T.Resize(256, interpolation=T.InterpolationMode.BICUBIC), | |
T.CenterCrop(224), | |
T.ToTensor(), | |
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), | |
]) | |
# 3. Image Processing Loop | |
# Get all split directories (e.g., 'train', 'val') | |
split_dirs = sorted([d for d in image_dir.iterdir() if d.is_dir()]) | |
image_paths_to_process = [] | |
# Get all part directories within the split directory | |
part_dirs = sorted([d for d in image_dir.iterdir() if d.is_dir()]) | |
for part_dir in part_dirs: | |
# Get all images in the part directory | |
image_paths = sorted(itertools.chain( | |
part_dir.glob('*.jpg'), | |
part_dir.glob('*.jpeg'), | |
part_dir.glob('*.png') | |
)) | |
for image_path in image_paths: | |
image_paths_to_process.append(image_path) | |
for image_path in image_paths_to_process: | |
# Extract label (part directory name) and base filename | |
label = image_path.parent.name | |
img = Image.open(image_path).convert('RGB') | |
img_tensor = transform(img).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
embedding = model(img_tensor) | |
# before adding to an array the script calls: embedding.cpu().numpy().squeeze() | |
print(f"Processed {image_path} with label '{label}' -> embedding shape: {embedding.shape}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment