Created
July 22, 2024 19:23
-
-
Save abhishek-ghose/7cf2cc8d9f6fb7f9b19855883f86c69c to your computer and use it in GitHub Desktop.
Sample codes for Optimal Transport
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np, itertools | |
import pandas as pd | |
import ot | |
from matplotlib import pyplot as plt | |
from matplotlib.collections import LineCollection | |
import seaborn as sns; sns.set() | |
from scipy.stats import multivariate_normal as mvn, norm | |
from sklearn.cluster import KMeans | |
import os | |
from datetime import datetime | |
OP_DIR = r'./generated/optimal_transport' | |
if not os.path.exists(OP_DIR) or not os.path.isdir(OP_DIR): | |
print(f"Output dir. {OP_DIR} doesn't exist! Will be created...") | |
os.makedirs(OP_DIR) | |
else: | |
print(f"The common output directory is {OP_DIR}.") | |
def get_plan_lines(opt_plan, X_source, X_target): | |
""" | |
Plotting helper function that takes in an optimal transport plan, the source and target point coordinates, and | |
produces a matplotlib segment that represent the source->target assignments as per the plan. A valid assignment is | |
a non-zero entry. | |
:param opt_plan: | |
:param X_source: | |
:param X_target: | |
:return: | |
""" | |
nz_idxs = list(zip(*np.where(opt_plan > 0))) | |
segments = [(X_source[i, :], X_target[j, :]) for i, j in nz_idxs] | |
return segments | |
def basic_2d(): | |
N = 10 | |
min_y, max_y = 0, 5 | |
x1, x2 = 2, 5 | |
fig = plt.figure(figsize=(24, 8)) | |
ax1, ax2, ax3 = [fig.add_subplot(130 + i) for i in range(1, 4)] | |
temp = sorted(min_y + np.random.random(N) * (max_y - min_y)) | |
X_source = np.concatenate((x1 * np.ones(N), temp)).reshape(2, -1).T | |
perturb = norm.rvs(0, 0.2, N) # add a small perturbation for the second set of points | |
X_target = np.concatenate((x2 * np.ones(N), temp + perturb)).reshape(2, -1).T | |
w_source, w_target = np.ones((N,)) / N, np.ones((N,)) / N # uniform distribution on samples | |
dist_mat = ot.dist(X_source, X_target, 'euclidean') | |
# plot the points | |
sns.scatterplot(x=X_source[:, 0], y=X_source[:, 1], ax=ax1).set(xlabel='x1', ylabel='x2') | |
sns.scatterplot(x=X_target[:, 0], y=X_target[:, 1], ax=ax1) | |
ax1.set_title("Points with optimal mapping (EMD).") | |
# plot the distance matrix | |
sns.heatmap(dist_mat, annot=False, ax=ax2).set(xlabel='X_source', ylabel='X_target') | |
ax2.invert_yaxis() | |
ax2.set_title("Distance Matrix") | |
# get the optimal plan and plot that matrix | |
opt_plan = ot.emd(w_source, w_target , dist_mat) | |
sns.heatmap(opt_plan, annot=False, ax=ax3).set(xlabel='X_source', ylabel='X_target') | |
ax3.invert_yaxis() | |
ax3.set_title("Optimal Transport Plan (EMD)") | |
# add the transport plan to the scatter plot | |
L = LineCollection(get_plan_lines(opt_plan, X_source, X_target), linewidths=0.7, colors='r', linestyles='--') | |
ax1.add_collection(L) | |
plt.savefig(f"{OP_DIR}/almost_bipartite.png", bbox_inches='tight') | |
def representative_points_2d(): | |
num_components = 5 # gaussian components | |
N_target = 50 | |
x_lim, y_lim = (0, 10), (0, 10) # limits of component means | |
points_per_comp = [int(1. * N_target/num_components)] * num_components | |
# assign leftovers to the last component if N_target can't be evenly divided | |
points_per_comp[-1] += N_target - sum(points_per_comp) | |
# components are Gaussian - get their mean and cov | |
mean = np.random.random(size=num_components * 2).reshape(num_components, 2) | |
# scale the means to be in the right bounding box | |
# data_x_lim, data_y_lim = (x_lim[0] + offset, x_lim[1] - offset), (y_lim[0] + offset, y_lim[1] - offset) | |
mean[:, 0] = (x_lim[1] - x_lim[0]) * mean[:, 0] + x_lim[0] | |
mean[:, 1] = (y_lim[1] - y_lim[0]) * mean[:, 1] + y_lim[0] | |
# we'll create isotropic Gaussians by setting the cross-terms in the cov. to 0 | |
covs = [np.array([[np.random.rand(), 0], [0, np.random.rand()]]) for _ in mean] | |
# create points | |
X_target = np.empty((0, 2)) | |
for i, n_comp, m_comp, cov_comp in zip(range(num_components), points_per_comp, mean, covs): | |
temp = mvn.rvs(mean=m_comp, cov=cov_comp, size=n_comp) | |
X_target = np.concatenate((X_target, temp)) | |
print(f"Created target points, of shape={np.shape(X_target)}.") | |
fig = plt.figure(figsize=(14, 8)) | |
point_size = 100 | |
axes = [fig.add_subplot(120 + i) for i in range(1, 3)] | |
for ax in axes: | |
sns.scatterplot(x=X_target[:, 0], y=X_target[:, 1], s=point_size, ax=ax) | |
# let's see what OT wrt cluster centers look like | |
X_source = KMeans(n_clusters=num_components).fit(X_target).cluster_centers_ | |
sns.scatterplot(x=X_source[:, 0], y=X_source[:, 1], c='r', s=point_size*1.5, ax=axes[0]) | |
dist_mat = ot.dist(X_source, X_target, 'euclidean') | |
# uniform distribution on samples | |
w_source, w_target = np.ones((len(X_source),)) / len(X_source), np.ones((len(X_target),)) / len(X_target) | |
opt_plan = ot.emd(w_source, w_target, dist_mat) | |
opt_cost = np.sum(opt_plan * dist_mat) | |
L = LineCollection(get_plan_lines(opt_plan, X_source, X_target), linewidths=0.7, color='red') | |
axes[0].add_collection(L) | |
axes[0].set_title(f"{num_components} Gauss. components, with {num_components} k-means centers.\nOpt cost={opt_cost:.2f}.") | |
# random samples | |
X_source = np.random.random((num_components, 2)) | |
x_min, x_max = min(X_target[:, 0]), max(X_target[:, 0]) | |
y_min, y_max = min(X_target[:, 1]), max(X_target[:, 1]) | |
X_source[:, 0] = (x_max - x_min) * X_source[:, 0] + x_min | |
X_source[:, 1] = (y_max - y_min) * X_source[:, 1] + y_min | |
dist_mat = ot.dist(X_source, X_target, 'euclidean') | |
sns.scatterplot(x=X_source[:, 0], y=X_source[:, 1], c='r', s=point_size * 1.5, ax=axes[1]) | |
w_source, w_target = np.ones((len(X_source),)) / len(X_source), np.ones((len(X_target),)) / len(X_target) | |
opt_plan = ot.emd(w_source, w_target, dist_mat) | |
opt_cost = np.sum(opt_plan * dist_mat) | |
L = LineCollection(get_plan_lines(opt_plan, X_source, X_target), linewidths=0.7, color='red') | |
axes[1].add_collection(L) | |
axes[1].set_title(f"{num_components} Gauss. components, with {num_components} random samples.\nOpt cost={opt_cost:.2f}.") | |
fig.suptitle("Optimal Transport demo") | |
plt.savefig(f"{OP_DIR}/demo_representative_points.png", bbox_inches='tight') | |
def compute_runtimes(): | |
dims = np.array([10, 50, 100, 200, 300, 400, 500], dtype=int) | |
num_points = np.array([100, 500, 1000, 1500, 2000], dtype=int) # this is per dist. | |
lambda_param = np.array([0, 0.01, 0.1, 1, 10]) | |
num_trials = 3 | |
res_file = f"{OP_DIR}/runtimes.csv" | |
res_df = pd.DataFrame(columns=['trial_idx', 'dims', 'num_points', 'lambda', 'ot_score', 'duration_sec']) | |
for trial_idx, curr_dim, curr_num_points, curr_lambda_param in itertools.product(range(num_trials), dims, | |
num_points, lambda_param): | |
print(f"\n{trial_idx, curr_dim, curr_num_points, curr_lambda_param}") | |
X_source = np.random.random((curr_num_points, curr_dim)) | |
X_target = np.random.random((curr_num_points, curr_dim)) | |
print(f"Created source points, shape={np.shape(X_source)}.") | |
print(f"Created target points, shape={np.shape(X_target)}.") | |
print(f"Calculating distances.") | |
dist_mat = ot.dist(X_source, X_target, 'euclidean') | |
w_source, w_target = (np.ones((curr_num_points,)) / curr_num_points, | |
np.ones((curr_num_points,)) / curr_num_points) | |
print(f"Calculating opt plan.") | |
time_start = datetime.now() | |
if curr_lambda_param == 0: | |
opt_plan = ot.emd(w_source, w_target, dist_mat) | |
else: | |
opt_plan = ot.sinkhorn(w_source, w_target, dist_mat, curr_lambda_param) | |
time_end = datetime.now() | |
duration_sec = (time_end - time_start).total_seconds() | |
print(f"Opt. transport finding took {duration_sec} sec.") | |
ot_score = np.sum(dist_mat * opt_plan) # TODO: can we directly return this score instead of computing? | |
temp_df = pd.DataFrame([[trial_idx, curr_dim, curr_num_points, curr_lambda_param, ot_score, duration_sec]], | |
columns=['trial_idx', 'dims', 'num_points', 'lambda', 'ot_score', 'duration_sec']) | |
res_df = pd.concat((res_df, temp_df), ignore_index=True) | |
res_df.to_csv(res_file, index=False) | |
def process_runtime_results(df): | |
# plot runtimes | |
fig = plt.figure(figsize=(20, 8)) | |
ax_data_size, ax_dims = fig.add_subplot(121), fig.add_subplot(122) | |
temp_df = df.groupby(by=['num_points', 'lambda'], as_index=False).agg({'duration_sec': 'mean'}) | |
sns.lineplot(data=temp_df, x='num_points', y='duration_sec', hue='lambda', palette=sns.color_palette("tab10"), | |
marker='o', ax=ax_data_size) | |
ax_data_size.set_title(f"Runtime wrt #points, averaged over #dims." | |
f"\nEMD is used for $\lambda=0$, the rest use Sinkhorn.") | |
temp_df = df.groupby(by=['dims', 'lambda'], as_index=False).agg({'duration_sec': 'mean'}) | |
sns.lineplot(data=temp_df, x='dims', y='duration_sec', hue='lambda', palette=sns.color_palette("tab10"), marker='o', | |
ax=ax_dims) | |
ax_dims.set_title(f"Runtime wrt #dims, averaged over #points." | |
f"\nEMD is used for $\lambda=0$, the rest use Sinkhorn.") | |
plt.savefig(f"{OP_DIR}/runtimes.png", bbox_inches='tight') | |
# plot approx. accuracy | |
fig = plt.figure() | |
ax = fig.add_subplot(111) | |
aggr_df = df.groupby(by=['dims', 'num_points', 'lambda'], as_index=False).agg({'ot_score': 'mean'}) | |
baseline_df = aggr_df[aggr_df['lambda'] == 0] | |
t = sorted(baseline_df['ot_score'].to_numpy().flatten()) | |
ax.plot(t, t, c='y', linestyle='--', label='no approx.') | |
other_lambdas = set(df['lambda']) - {0} | |
for lambda_param in other_lambdas: | |
print(f"Analyzing data for lambda={lambda_param}.") | |
temp_df = aggr_df[aggr_df['lambda'] == lambda_param] | |
# join on identical settings | |
joined_df = pd.merge(baseline_df, temp_df, on=['dims', 'num_points'], how='inner') | |
plot_data = joined_df[['ot_score_x', 'ot_score_y']].sort_values(by='ot_score_x').to_numpy() | |
ax.plot(plot_data[:, 0], plot_data[:, 1], label=f"{lambda_param}") | |
ax.set_xlabel('OT score baseline (EMD)') | |
ax.set_ylabel('OT score approx. (Sinkhorn)') | |
plt.legend() | |
plt.savefig(f"{OP_DIR}/sinkhorn_approx_quality.png", bbox_inches='tight') | |
if __name__ == "__main__": | |
pass | |
# basic_2d() | |
# representative_points_2d() | |
# compute_runtimes() | |
resfile = f"{OP_DIR}/runtimes.csv" | |
df = pd.read_csv(resfile) | |
process_runtime_results(df) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment