Last active
January 18, 2024 14:24
-
-
Save piercus/fa6c7852199c4c852e4e9b5dd83a8953 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from datasets import load_dataset, DownloadManager, Image | |
from sklearn.neighbors import NearestNeighbors | |
import pandas as pd | |
from PIL import Image | |
import numpy as np | |
from sklearn.metrics import ndcg_score | |
import time | |
from joblib import Parallel, delayed | |
def distance(a, b): | |
return np.linalg.norm(a - b) | |
def calculate_std_dev(clusters, points, centroids): | |
distances_list = [] | |
for i in range(len(centroids)): | |
cluster_points = points[np.where(clusters == i)] | |
distances = [distance(p, centroids[i]) for p in cluster_points] | |
distances_list.extend(distances) | |
return np.std(distances_list) | |
SAMPLING_SIZE = 1000 | |
IMAGE_SIZE = (256, 256) | |
# file is coming from https://huggingface.co/datasets/1aurent/unsplash-lite-palette/tree/main/data | |
df = pd.read_parquet('data/train-00000-of-00001.parquet') | |
dl_manager = DownloadManager() | |
def palette_metrics(palette, points): | |
num = len(palette) | |
centroids = np.stack(palette) | |
nn = NearestNeighbors(n_neighbors=num) | |
nn.fit(centroids) | |
distances, indices = nn.kneighbors(points) | |
indices = indices[:, 0] | |
counts = np.bincount(indices) | |
counts = np.pad(counts, (0, num - len(counts)), 'constant') | |
ordered_centroids = np.argsort(counts)[::-1] | |
y_true_ranking = list(range(num, 0, -1)) | |
if num > 1: | |
ndcg = ndcg_score([y_true_ranking], [counts]) | |
else: | |
ndcg = 1.0 | |
std_dev = calculate_std_dev(indices, points, centroids) | |
return (std_dev, ndcg) | |
summary = {} | |
for i in range(1, 9): | |
summary[i] = [] | |
def random_palette(num): | |
return np.random.randint(0, 256, size=(num, 3)) | |
for index, data in df.iloc[0:1000].iterrows(): | |
start_time = time.time() | |
palettes = data['palettes'] | |
filename = dl_manager.download(data['url']) | |
with Image.open(filename) as img: | |
img = img.convert("RGB") | |
resized_img = img.resize(IMAGE_SIZE) | |
points = np.array(resized_img.getdata()) | |
sample_points = points[np.random.choice(len(points), SAMPLING_SIZE)] | |
for palette in palettes.values(): | |
num = len(palette) | |
results = palette_metrics(palette, sample_points) | |
random_palette_results = palette_metrics(random_palette(num), sample_points) | |
summary[num].append(results+random_palette_results) | |
end_time = time.time() | |
iteration_time = end_time - start_time | |
if index > 0 and index % 10 == 0: | |
print(f"[{index}] Time spent for iteration: {iteration_time} seconds") | |
print(f"|Palette Size | Std Dev (db) | Std Dev (random) | NDCG (db)| NDCG (random)|") | |
print(f"|-|---|---|---|---|") | |
for i in range(1, 9): | |
std_devs = np.mean([x[0] for x in summary[i]]) | |
std_devs_rdm = np.mean([x[2] for x in summary[i]]) | |
ndcg = np.mean([x[1] for x in summary[i]]) | |
ndcg_rdm = np.mean([x[3] for x in summary[i]]) | |
print(f"|{i}|{round(std_devs, 2)}|{round(std_devs_rdm, 2)}|{round(ndcg, 2)}|{round(ndcg_rdm, 2)}|") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment