Skip to content

Instantly share code, notes, and snippets.

@piercus
Last active January 18, 2024 14:24
Show Gist options
  • Save piercus/fa6c7852199c4c852e4e9b5dd83a8953 to your computer and use it in GitHub Desktop.
Save piercus/fa6c7852199c4c852e4e9b5dd83a8953 to your computer and use it in GitHub Desktop.
from datasets import load_dataset, DownloadManager, Image
from sklearn.neighbors import NearestNeighbors
import pandas as pd
from PIL import Image
import numpy as np
from sklearn.metrics import ndcg_score
import time
from joblib import Parallel, delayed
def distance(a, b):
return np.linalg.norm(a - b)
def calculate_std_dev(clusters, points, centroids):
distances_list = []
for i in range(len(centroids)):
cluster_points = points[np.where(clusters == i)]
distances = [distance(p, centroids[i]) for p in cluster_points]
distances_list.extend(distances)
return np.std(distances_list)
SAMPLING_SIZE = 1000
IMAGE_SIZE = (256, 256)
# file is coming from https://huggingface.co/datasets/1aurent/unsplash-lite-palette/tree/main/data
df = pd.read_parquet('data/train-00000-of-00001.parquet')
dl_manager = DownloadManager()
def palette_metrics(palette, points):
num = len(palette)
centroids = np.stack(palette)
nn = NearestNeighbors(n_neighbors=num)
nn.fit(centroids)
distances, indices = nn.kneighbors(points)
indices = indices[:, 0]
counts = np.bincount(indices)
counts = np.pad(counts, (0, num - len(counts)), 'constant')
ordered_centroids = np.argsort(counts)[::-1]
y_true_ranking = list(range(num, 0, -1))
if num > 1:
ndcg = ndcg_score([y_true_ranking], [counts])
else:
ndcg = 1.0
std_dev = calculate_std_dev(indices, points, centroids)
return (std_dev, ndcg)
summary = {}
for i in range(1, 9):
summary[i] = []
def random_palette(num):
return np.random.randint(0, 256, size=(num, 3))
for index, data in df.iloc[0:1000].iterrows():
start_time = time.time()
palettes = data['palettes']
filename = dl_manager.download(data['url'])
with Image.open(filename) as img:
img = img.convert("RGB")
resized_img = img.resize(IMAGE_SIZE)
points = np.array(resized_img.getdata())
sample_points = points[np.random.choice(len(points), SAMPLING_SIZE)]
for palette in palettes.values():
num = len(palette)
results = palette_metrics(palette, sample_points)
random_palette_results = palette_metrics(random_palette(num), sample_points)
summary[num].append(results+random_palette_results)
end_time = time.time()
iteration_time = end_time - start_time
if index > 0 and index % 10 == 0:
print(f"[{index}] Time spent for iteration: {iteration_time} seconds")
print(f"|Palette Size | Std Dev (db) | Std Dev (random) | NDCG (db)| NDCG (random)|")
print(f"|-|---|---|---|---|")
for i in range(1, 9):
std_devs = np.mean([x[0] for x in summary[i]])
std_devs_rdm = np.mean([x[2] for x in summary[i]])
ndcg = np.mean([x[1] for x in summary[i]])
ndcg_rdm = np.mean([x[3] for x in summary[i]])
print(f"|{i}|{round(std_devs, 2)}|{round(std_devs_rdm, 2)}|{round(ndcg, 2)}|{round(ndcg_rdm, 2)}|")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment