Created
November 21, 2016 11:28
-
-
Save jiqiujia/0b59a0d9b6d6f6e8c839039ffc3c7f35 to your computer and use it in GitHub Desktop.
util functions of paper 'metric learning with adaptive density discrimination'
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from math import ceil | |
import numpy as np | |
from sklearn.cluster import KMeans | |
import theano | |
import theano.tensor as T | |
def compute_reps(extract_fn, X, chunk_size): | |
"""Compute representations for input in chunks.""" | |
chunks = int(ceil(float(X.shape[0]) / chunk_size)) | |
reps = [] | |
for i in range(chunks): | |
start = i * chunk_size | |
stop = min(start + chunk_size, X.shape[0]) | |
chunk_reps = extract_fn(X[start:stop]) | |
reps.append(chunk_reps) | |
return np.vstack(reps) | |
class ClusterBatchBuilder(object): | |
"""Sample minibatches for magnet loss.""" | |
def __init__(self, labels, k, m, d): | |
self.num_classes = np.unique(labels).shape[0] | |
self.labels = labels | |
self.k = k | |
self.m = m | |
self.d = d | |
self.centroids = None | |
self.assignments = np.zeros_like(labels, int) | |
self.cluster_assignments = {} | |
self.cluster_classes = np.repeat(range(self.num_classes), k) | |
self.example_losses = None | |
self.cluster_losses = None | |
self.has_loss = None | |
def update_clusters(self, rep_data, max_iter=20): | |
""" | |
Given an array of representations for the entire training set, | |
recompute clusters and store example cluster assignments in a | |
quickly sampleable form. | |
""" | |
# Lazily allocate array for centroids | |
if self.centroids is None: | |
self.centroids = np.zeros([self.num_classes * self.k, rep_data.shape[1]]) | |
for c in range(self.num_classes): | |
#cluster each class using KMeans | |
class_mask = self.labels == c | |
class_examples = rep_data[class_mask] | |
kmeans = KMeans(n_clusters=self.k, init='k-means++', n_init=1, max_iter=max_iter) | |
kmeans.fit(class_examples) | |
# Save cluster centroids for finding impostor clusters | |
start = self.get_cluster_ind(c, 0) | |
stop = self.get_cluster_ind(c, self.k) | |
self.centroids[start:stop] = kmeans.cluster_centers_ | |
# Update assignments with new global cluster indexes | |
self.assignments[class_mask] = self.get_cluster_ind(c, kmeans.predict(class_examples)) | |
# Construct a map from cluster to example indexes for fast batch creation | |
for cluster in range(self.k * self.num_classes): | |
cluster_mask = self.assignments == cluster | |
self.cluster_assignments[cluster] = np.flatnonzero(cluster_mask) | |
def update_losses(self, indexes, losses): | |
""" | |
Given a list of examples indexes and corresponding losses | |
store the new losses and update corresponding cluster losses. | |
""" | |
# Lazily allocate structures for losses | |
if self.example_losses is None: | |
self.example_losses = np.zeros_like(self.labels, float) | |
self.cluster_losses = np.zeros([self.k * self.num_classes], float) | |
self.has_loss = np.zeros_like(self.labels, bool) | |
# Update example losses | |
indexes = np.array(indexes) | |
self.example_losses[indexes] = losses | |
self.has_loss[indexes] = losses | |
# Find affected clusters and update the corresponding cluster losses | |
clusters = np.unique(self.assignments[indexes]) | |
for cluster in clusters: | |
cluster_inds = self.assignments == cluster | |
cluster_example_losses = self.example_losses[cluster_inds] | |
# Take the average closs in the cluster of examples for which we have measured a loss | |
self.cluster_losses[cluster] = np.mean(cluster_example_losses[self.has_loss[cluster_inds]]) | |
def gen_batch(self): | |
""" | |
Sample a batch by first sampling a seed cluster proportionally to | |
the mean loss of the clusters, then finding nearest neighbor | |
"impostor" clusters, then sampling d examples uniformly from each cluster. | |
The generated batch will consist of m clusters each with d consecutive | |
examples. | |
""" | |
# Sample seed cluster proportionally to cluster losses if available | |
if self.cluster_losses is not None: | |
p = self.cluster_losses / np.sum(self.cluster_losses) | |
seed_cluster = np.random.choice(self.num_classes * self.k, p=p) | |
else: | |
seed_cluster = np.random.choice(self.num_classes * self.k) | |
# Get imposter clusters by ranking centroids by distance | |
sq_dists = ((self.centroids[seed_cluster] - self.centroids) ** 2).sum(axis=1) | |
# Assure only clusters of different class from seed are chosen | |
sq_dists[self.get_class_ind(seed_cluster) == self.cluster_classes] = np.inf | |
# Get top impostor clusters and add seed | |
clusters = np.argpartition(sq_dists, self.m-1)[:self.m-1] | |
clusters = np.concatenate([[seed_cluster], clusters]) | |
# Sample examples uniformly from cluster | |
batch_indexes = np.empty([self.m * self.d], int) | |
for i, c in enumerate(clusters): | |
x = np.random.choice(self.cluster_assignments[c], self.d, replace=False) | |
start = i * self.d | |
stop = start + self.d | |
batch_indexes[start:stop] = x | |
# Translate class indexes to index for classes within the batch | |
class_inds = self.get_class_ind(clusters) | |
batch_class_inds = [] | |
inds_map = {} | |
class_count = 0 | |
for c in class_inds: | |
if c not in inds_map: | |
inds_map[c] = class_count | |
class_count += 1 | |
batch_class_inds.append(inds_map[c]) | |
return batch_indexes, np.repeat(batch_class_inds, self.d) | |
def get_cluster_ind(self, c, i): | |
""" | |
Given a class index and a cluster index within the class | |
return the global cluster index | |
""" | |
return c * self.k + i | |
def get_class_ind(self, c): | |
""" | |
Given a cluster index return the class index. | |
""" | |
return c / self.k | |
#def get_acc(rep_data, Y_test): | |
def magnet_loss(r, classes, clusters, cluster_classes, n_clusters, alpha=1.0): | |
"""Compute magnet loss. | |
Given a tensor of features `r`, the assigned class for each example, | |
the assigned cluster for each example, the assigned class for each | |
cluster, the total number of clusters, and separation hyperparameter, | |
compute the magnet loss according to equation (4) in | |
http://arxiv.org/pdf/1511.05939v2.pdf. | |
Note that cluster and class indexes should be sequential startined at 0. | |
Args: | |
r: A batch of features. | |
classes: Class labels for each example. | |
clusters: Cluster labels for each example. | |
cluster_classes: Class label for each cluster. | |
n_clusters: Total number of clusters. | |
alpha: The cluster separation gap hyperparameter. | |
Returns: | |
total_loss: The total magnet loss for the batch. | |
losses: The loss for each example in the batch. | |
""" | |
def comparison_mask(a_labels, b_labels): | |
return T.eq(a_labels.reshape((-1, 1)), b_labels.reshape((1, -1))) | |
#return tf.equal(tf.expand_dims(a_labels, 1), tf.expand_dims(b_labels, 0)) | |
# | |
N = r.shape[0] | |
# Take cluster means within the batch | |
# cluster_means = [] | |
# for i in np.arange(n_clusters): | |
# cluster_means.append(T.mean(r[clusters==i], 0)) | |
cluster_means, _ = theano.scan(lambda i, r, clusters: T.mean(r[T.eq(clusters, i).nonzero()], 0), | |
sequences=np.arange(n_clusters), non_sequences=[r, clusters]) | |
cluster_means = T.stack(cluster_means) #(8,64) | |
# Compute squared distance of each example to each cluster centroid | |
sample_costs = ((cluster_means - r.dimshuffle((0, 'x', 1)))**2).sum(2) #(64,8) | |
# Select distances of examples to their own centroid | |
intra_cluster_mask = comparison_mask(clusters, np.arange(n_clusters))#(64,8) | |
intra_cluster_costs = T.sum(intra_cluster_mask * sample_costs, 1) | |
# Compute variance of intra-cluster distances | |
variance = T.sum(intra_cluster_costs) / (N - 1.0) | |
var_normalizer = -1.0 / (2 * variance**2) | |
# Compute numerator | |
numerator = T.exp(var_normalizer * intra_cluster_costs - alpha) | |
# Compute denominator | |
diff_class_mask = T.neq(classes.reshape((-1, 1)), cluster_classes.reshape((1, -1)))#(64,8) | |
denom_sample_costs = T.exp(var_normalizer * sample_costs) | |
denominator = T.sum(diff_class_mask * denom_sample_costs, 1) | |
# | |
# # Compute example losses and total loss | |
epsilon = 1e-8 | |
losses = T.maximum((-T.log(numerator / (denominator + epsilon) + epsilon)), 0) | |
total_loss = T.mean(losses) | |
return total_loss | |
def unsupervised_clustering_accuracy(emb, labels): | |
from sklearn.utils import linear_assignment_ | |
from scipy.stats import itemfreq | |
k = np.unique(labels).size | |
kmeans = KMeans(n_clusters=k, max_iter=35, n_init=15, n_jobs=-1).fit(emb) | |
emb_labels = kmeans.labels_ | |
G = np.zeros((k,k)) | |
for i in range(k): | |
lbl = labels[emb_labels == i] | |
uc = itemfreq(lbl) | |
for uu, cc in uc: | |
G[i,uu] = -cc | |
A = linear_assignment_.linear_assignment(G) | |
acc = 0.0 | |
for (cluster, best) in A: | |
acc -= G[cluster,best] | |
return acc / float(len(labels)) | |
def plot_embedding(X, y, imgs=None, title=None): | |
import matplotlib.pyplot as plt | |
from matplotlib import offsetbox | |
# Adapted from http://scikit-learn.org/stable/auto_examples/manifold/plot_lle_digits.html | |
x_min, x_max = np.min(X, 0), np.max(X, 0) | |
X = (X - x_min) / (x_max - x_min) | |
# Plot colors numbers | |
plt.figure(figsize=(10,10)) | |
ax = plt.subplot(111) | |
for i in range(X.shape[0]): | |
# plot colored number | |
plt.text(X[i, 0], X[i, 1], str(y[i]), | |
color=plt.cm.Set1(y[i] / 10.), | |
fontdict={'weight': 'bold', 'size': 9}) | |
# Add image overlays | |
if imgs is not None and hasattr(offsetbox, 'AnnotationBbox'): | |
# only print thumbnails with matplotlib > 1.0 | |
shown_images = np.array([[1., 1.]]) # just something big | |
for i in range(X.shape[0]): | |
dist = np.sum((X[i] - shown_images) ** 2, 1) | |
if np.min(dist) < 4e-3: | |
# don't show points that are too close | |
continue | |
shown_images = np.r_[shown_images, [X[i]]] | |
imagebox = offsetbox.AnnotationBbox( | |
offsetbox.OffsetImage(imgs[i], cmap=plt.cm.gray_r), X[i]) | |
ax.add_artist(imagebox) | |
plt.xticks([]), plt.yticks([]) | |
if title is not None: | |
plt.title(title) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment