Skip to content

Instantly share code, notes, and snippets.

@cobanov
Created March 12, 2023 08:25
Show Gist options
  • Save cobanov/4c69032a20e534a19024b5a5b8cd9463 to your computer and use it in GitHub Desktop.
Save cobanov/4c69032a20e534a19024b5a5b8cd9463 to your computer and use it in GitHub Desktop.
import timm
import torch
import shutil
import itertools
import torchvision
import numpy as np
from PIL import Image
from tqdm import tqdm
from pathlib import Path
from torchvision import transforms
from scipy.cluster.vq import kmeans2
from sklearn.decomposition import PCA
from torch.utils.data import DataLoader
def prep_dataset(folder_path):
# Transform Functions
data_transform = transforms.Compose(
[transforms.Resize(size=(224, 224)), transforms.ToTensor()]
)
# Read from directory as a ImageFolder object
image_dataset = torchvision.datasets.ImageFolder(
folder_path, transform=data_transform
)
# Create DataLoader objects
image_dataloader = DataLoader(image_dataset, batch_size=8)
return image_dataloader, image_dataset
def calculate_embeddings(model, dataloader):
feature_embeddings = []
with torch.no_grad():
for image, label in tqdm(dataloader):
embedding = model(image)
feature_embeddings.extend(embedding.numpy())
np_embeddings = np.vstack(feature_embeddings)
print(np_embeddings.shape)
return np_embeddings
def calculate_pca(embeddings):
pca = PCA(n_components=16)
return pca.fit_transform(embeddings)
def save_embeddings(output_path, embeddings, filelist):
np.savez(output_path, embeddings=embeddings, filelist=filelist)
def find_clusters(pca_embeddings, k=16):
centroid, labels = kmeans2(pca_embeddings, k=k, minit="points")
print(np.bincount(labels))
return centroid, labels
def copy_to_clusters(labels, filelist):
for label_number in range(len(np.bincount(labels))):
label_mask = labels == label_number
clustered_images = list(itertools.compress(filelist, label_mask))
for img_path in clustered_images:
Path(f"./output/{label_number}").mkdir(parents=True, exist_ok=True)
shutil.copy2(img_path, f"./output/{label_number}/{img_path.split('/')[-1]}")
def main():
CLUSTER_RANGE = 16
model = timm.create_model("resnet34", pretrained=True)
dataloader, image_dataset = prep_dataset("./root")
filelist = [path for path, label in image_dataset.imgs]
feature_embeddings = calculate_embeddings(model, dataloader)
save_embeddings("./embeddings.npz", feature_embeddings, filelist)
pca_embeddings = calculate_pca(feature_embeddings)
centroid, labels = find_clusters(pca_embeddings, k=CLUSTER_RANGE)
copy_to_clusters(labels, filelist)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment