Skip to content

Instantly share code, notes, and snippets.

@ovuruska
Last active April 25, 2022 13:03
Show Gist options
  • Select an option

  • Save ovuruska/dbe3c60f3469d4ec7b6762d96b932a73 to your computer and use it in GitHub Desktop.

Select an option

Save ovuruska/dbe3c60f3469d4ec7b6762d96b932a73 to your computer and use it in GitHub Desktop.
Find UMAP of a folder dataset.
import numpy as np
import glob
import umap
from tqdm import tqdm
import cv2
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import os
def scatter_text(x, y, text_column, data, title, xlabel, ylabel):
"""Scatter plot with country codes on the x y coordinates
Based on this answer: https://stackoverflow.com/a/54789170/2641825"""
# Create the scatter plot
p1 = sns.scatterplot(x, y,hue=text_column, data=data, size = 8)
# Add text besides each point
# Set title and axis labels
plt.title(title)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
return p1
if __name__ == "__main__":
df = pd.DataFrame({})
images = glob.glob("**/*.png",recursive=True)
total_images = len(images)
arrr = np.zeros((total_images,448,448))
arrr_labels = []
reducer = umap.UMAP(random_state=42)
labels = {}
for ind,image_path in enumerate(tqdm(images)):
dirname = os.path.dirname(image_path)
label = os.path.basename(dirname)
if labels.get(label,None) is None:
try:
labels[label] = max(labels.values()) + 1
except ValueError:
labels[label] = 0
image = cv2.imread(image_path)
out_image = cv2.cvtColor(image,cv2.COLOR_RGB2GRAY)
arrr[ind] = out_image
arrr_labels.append(label)
result = reducer.fit_transform(arrr.reshape(total_images,-1))
np.save("result.npy",result)
# result = np.load("result.npy")
df["x"] = result[:,0].tolist()
df["y"] = result[:,1].tolist()
df["labels"] = arrr_labels
scatter_text(
"x","y","labels",data=df,title="UMAP of posture dataset",xlabel="X",ylabel="Y")
plt.savefig("out.png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment