Created
March 12, 2023 08:25
-
-
Save cobanov/4c69032a20e534a19024b5a5b8cd9463 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import timm | |
import torch | |
import shutil | |
import itertools | |
import torchvision | |
import numpy as np | |
from PIL import Image | |
from tqdm import tqdm | |
from pathlib import Path | |
from torchvision import transforms | |
from scipy.cluster.vq import kmeans2 | |
from sklearn.decomposition import PCA | |
from torch.utils.data import DataLoader | |
def prep_dataset(folder_path): | |
# Transform Functions | |
data_transform = transforms.Compose( | |
[transforms.Resize(size=(224, 224)), transforms.ToTensor()] | |
) | |
# Read from directory as a ImageFolder object | |
image_dataset = torchvision.datasets.ImageFolder( | |
folder_path, transform=data_transform | |
) | |
# Create DataLoader objects | |
image_dataloader = DataLoader(image_dataset, batch_size=8) | |
return image_dataloader, image_dataset | |
def calculate_embeddings(model, dataloader): | |
feature_embeddings = [] | |
with torch.no_grad(): | |
for image, label in tqdm(dataloader): | |
embedding = model(image) | |
feature_embeddings.extend(embedding.numpy()) | |
np_embeddings = np.vstack(feature_embeddings) | |
print(np_embeddings.shape) | |
return np_embeddings | |
def calculate_pca(embeddings): | |
pca = PCA(n_components=16) | |
return pca.fit_transform(embeddings) | |
def save_embeddings(output_path, embeddings, filelist): | |
np.savez(output_path, embeddings=embeddings, filelist=filelist) | |
def find_clusters(pca_embeddings, k=16): | |
centroid, labels = kmeans2(pca_embeddings, k=k, minit="points") | |
print(np.bincount(labels)) | |
return centroid, labels | |
def copy_to_clusters(labels, filelist): | |
for label_number in range(len(np.bincount(labels))): | |
label_mask = labels == label_number | |
clustered_images = list(itertools.compress(filelist, label_mask)) | |
for img_path in clustered_images: | |
Path(f"./output/{label_number}").mkdir(parents=True, exist_ok=True) | |
shutil.copy2(img_path, f"./output/{label_number}/{img_path.split('/')[-1]}") | |
def main(): | |
CLUSTER_RANGE = 16 | |
model = timm.create_model("resnet34", pretrained=True) | |
dataloader, image_dataset = prep_dataset("./root") | |
filelist = [path for path, label in image_dataset.imgs] | |
feature_embeddings = calculate_embeddings(model, dataloader) | |
save_embeddings("./embeddings.npz", feature_embeddings, filelist) | |
pca_embeddings = calculate_pca(feature_embeddings) | |
centroid, labels = find_clusters(pca_embeddings, k=CLUSTER_RANGE) | |
copy_to_clusters(labels, filelist) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment