import numpy as np, itertools
import pandas as pd
import ot
from matplotlib import pyplot as plt
from matplotlib.collections import LineCollection
import seaborn as sns; sns.set()
from scipy.stats import multivariate_normal as mvn, norm
from sklearn.cluster import KMeans
import os
from datetime import datetime
OP_DIR = r'./generated/optimal_transport'
if not os.path.exists(OP_DIR) or not os.path.isdir(OP_DIR):
print(f"Output dir. {OP_DIR} doesn't exist! Will be created...")
print(f"The common output directory is {OP_DIR}.")
def get_plan_lines(opt_plan, X_source, X_target):
Plotting helper function that takes in an optimal transport plan, the source and target point coordinates, and
produces a matplotlib segment that represent the source->target assignments as per the plan. A valid assignment is
a non-zero entry.
:param opt_plan:
:param X_source:
:param X_target:
nz_idxs = list(zip(*np.where(opt_plan > 0)))
segments = [(X_source[i, :], X_target[j, :]) for i, j in nz_idxs]
return segments
def basic_2d():
N = 10
min_y, max_y = 0, 5
x1, x2 = 2, 5
fig = plt.figure(figsize=(24, 8))
ax1, ax2, ax3 = [fig.add_subplot(130 + i) for i in range(1, 4)]
temp = sorted(min_y + np.random.random(N) * (max_y - min_y))
X_source = np.concatenate((x1 * np.ones(N), temp)).reshape(2, -1).T
perturb = norm.rvs(0, 0.2, N) # add a small perturbation for the second set of points
X_target = np.concatenate((x2 * np.ones(N), temp + perturb)).reshape(2, -1).T
w_source, w_target = np.ones((N,)) / N, np.ones((N,)) / N # uniform distribution on samples
dist_mat = ot.dist(X_source, X_target, 'euclidean')
# plot the points
sns.scatterplot(x=X_source[:, 0], y=X_source[:, 1], ax=ax1).set(xlabel='x1', ylabel='x2')
sns.scatterplot(x=X_target[:, 0], y=X_target[:, 1], ax=ax1)
ax1.set_title("Points with optimal mapping (EMD).")
# plot the distance matrix
sns.heatmap(dist_mat, annot=False, ax=ax2).set(xlabel='X_source', ylabel='X_target')
ax2.set_title("Distance Matrix")
# get the optimal plan and plot that matrix
opt_plan = ot.emd(w_source, w_target , dist_mat)
sns.heatmap(opt_plan, annot=False, ax=ax3).set(xlabel='X_source', ylabel='X_target')
ax3.set_title("Optimal Transport Plan (EMD)")
# add the transport plan to the scatter plot
L = LineCollection(get_plan_lines(opt_plan, X_source, X_target), linewidths=0.7, colors='r', linestyles='--')
plt.savefig(f"{OP_DIR}/almost_bipartite.png", bbox_inches='tight')
def representative_points_2d():
num_components = 5 # gaussian components
N_target = 50
x_lim, y_lim = (0, 10), (0, 10) # limits of component means
points_per_comp = [int(1. * N_target/num_components)] * num_components
# assign leftovers to the last component if N_target can't be evenly divided
points_per_comp[-1] += N_target - sum(points_per_comp)
# components are Gaussian - get their mean and cov
mean = np.random.random(size=num_components * 2).reshape(num_components, 2)
# scale the means to be in the right bounding box
# data_x_lim, data_y_lim = (x_lim[0] + offset, x_lim[1] - offset), (y_lim[0] + offset, y_lim[1] - offset)
mean[:, 0] = (x_lim[1] - x_lim[0]) * mean[:, 0] + x_lim[0]
mean[:, 1] = (y_lim[1] - y_lim[0]) * mean[:, 1] + y_lim[0]
# we'll create isotropic Gaussians by setting the cross-terms in the cov. to 0
covs = [np.array([[np.random.rand(), 0], [0, np.random.rand()]]) for _ in mean]
# create points
X_target = np.empty((0, 2))
for i, n_comp, m_comp, cov_comp in zip(range(num_components), points_per_comp, mean, covs):
temp = mvn.rvs(mean=m_comp, cov=cov_comp, size=n_comp)
X_target = np.concatenate((X_target, temp))
print(f"Created target points, of shape={np.shape(X_target)}.")
fig = plt.figure(figsize=(14, 8))
point_size = 100
axes = [fig.add_subplot(120 + i) for i in range(1, 3)]
for ax in axes:
sns.scatterplot(x=X_target[:, 0], y=X_target[:, 1], s=point_size, ax=ax)
# let's see what OT wrt cluster centers look like
X_source = KMeans(n_clusters=num_components).fit(X_target).cluster_centers_
sns.scatterplot(x=X_source[:, 0], y=X_source[:, 1], c='r', s=point_size*1.5, ax=axes[0])
dist_mat = ot.dist(X_source, X_target, 'euclidean')
# uniform distribution on samples
w_source, w_target = np.ones((len(X_source),)) / len(X_source), np.ones((len(X_target),)) / len(X_target)
opt_plan = ot.emd(w_source, w_target, dist_mat)
opt_cost = np.sum(opt_plan * dist_mat)
L = LineCollection(get_plan_lines(opt_plan, X_source, X_target), linewidths=0.7, color='red')
axes[0].set_title(f"{num_components} Gauss. components, with {num_components} k-means centers.\nOpt cost={opt_cost:.2f}.")
# random samples
X_source = np.random.random((num_components, 2))
x_min, x_max = min(X_target[:, 0]), max(X_target[:, 0])
y_min, y_max = min(X_target[:, 1]), max(X_target[:, 1])
X_source[:, 0] = (x_max - x_min) * X_source[:, 0] + x_min
X_source[:, 1] = (y_max - y_min) * X_source[:, 1] + y_min
dist_mat = ot.dist(X_source, X_target, 'euclidean')
sns.scatterplot(x=X_source[:, 0], y=X_source[:, 1], c='r', s=point_size * 1.5, ax=axes[1])
w_source, w_target = np.ones((len(X_source),)) / len(X_source), np.ones((len(X_target),)) / len(X_target)
opt_plan = ot.emd(w_source, w_target, dist_mat)
opt_cost = np.sum(opt_plan * dist_mat)
L = LineCollection(get_plan_lines(opt_plan, X_source, X_target), linewidths=0.7, color='red')
axes[1].set_title(f"{num_components} Gauss. components, with {num_components} random samples.\nOpt cost={opt_cost:.2f}.")
fig.suptitle("Optimal Transport demo")
plt.savefig(f"{OP_DIR}/demo_representative_points.png", bbox_inches='tight')
def compute_runtimes():
dims = np.array([10, 50, 100, 200, 300, 400, 500], dtype=int)
num_points = np.array([100, 500, 1000, 1500, 2000], dtype=int) # this is per dist.
lambda_param = np.array([0, 0.01, 0.1, 1, 10])
num_trials = 3
res_file = f"{OP_DIR}/runtimes.csv"
res_df = pd.DataFrame(columns=['trial_idx', 'dims', 'num_points', 'lambda', 'ot_score', 'duration_sec'])
for trial_idx, curr_dim, curr_num_points, curr_lambda_param in itertools.product(range(num_trials), dims,
num_points, lambda_param):
print(f"\n{trial_idx, curr_dim, curr_num_points, curr_lambda_param}")
X_source = np.random.random((curr_num_points, curr_dim))
X_target = np.random.random((curr_num_points, curr_dim))
print(f"Created source points, shape={np.shape(X_source)}.")
print(f"Created target points, shape={np.shape(X_target)}.")
print(f"Calculating distances.")
dist_mat = ot.dist(X_source, X_target, 'euclidean')
w_source, w_target = (np.ones((curr_num_points,)) / curr_num_points,
np.ones((curr_num_points,)) / curr_num_points)
print(f"Calculating opt plan.")
time_start =
if curr_lambda_param == 0:
opt_plan = ot.emd(w_source, w_target, dist_mat)
opt_plan = ot.sinkhorn(w_source, w_target, dist_mat, curr_lambda_param)
time_end =
duration_sec = (time_end - time_start).total_seconds()
print(f"Opt. transport finding took {duration_sec} sec.")
ot_score = np.sum(dist_mat * opt_plan) # TODO: can we directly return this score instead of computing?
temp_df = pd.DataFrame([[trial_idx, curr_dim, curr_num_points, curr_lambda_param, ot_score, duration_sec]],
columns=['trial_idx', 'dims', 'num_points', 'lambda', 'ot_score', 'duration_sec'])
res_df = pd.concat((res_df, temp_df), ignore_index=True)
res_df.to_csv(res_file, index=False)
def process_runtime_results(df):
# plot runtimes
fig = plt.figure(figsize=(20, 8))
ax_data_size, ax_dims = fig.add_subplot(121), fig.add_subplot(122)
temp_df = df.groupby(by=['num_points', 'lambda'], as_index=False).agg({'duration_sec': 'mean'})
sns.lineplot(data=temp_df, x='num_points', y='duration_sec', hue='lambda', palette=sns.color_palette("tab10"),
marker='o', ax=ax_data_size)
ax_data_size.set_title(f"Runtime wrt #points, averaged over #dims."
f"\nEMD is used for $\lambda=0$, the rest use Sinkhorn.")
temp_df = df.groupby(by=['dims', 'lambda'], as_index=False).agg({'duration_sec': 'mean'})
sns.lineplot(data=temp_df, x='dims', y='duration_sec', hue='lambda', palette=sns.color_palette("tab10"), marker='o',
ax_dims.set_title(f"Runtime wrt #dims, averaged over #points."
f"\nEMD is used for $\lambda=0$, the rest use Sinkhorn.")
plt.savefig(f"{OP_DIR}/runtimes.png", bbox_inches='tight')
# plot approx. accuracy
fig = plt.figure()
ax = fig.add_subplot(111)
aggr_df = df.groupby(by=['dims', 'num_points', 'lambda'], as_index=False).agg({'ot_score': 'mean'})
baseline_df = aggr_df[aggr_df['lambda'] == 0]
t = sorted(baseline_df['ot_score'].to_numpy().flatten())
ax.plot(t, t, c='y', linestyle='--', label='no approx.')
other_lambdas = set(df['lambda']) - {0}
for lambda_param in other_lambdas:
print(f"Analyzing data for lambda={lambda_param}.")
temp_df = aggr_df[aggr_df['lambda'] == lambda_param]
# join on identical settings
joined_df = pd.merge(baseline_df, temp_df, on=['dims', 'num_points'], how='inner')
plot_data = joined_df[['ot_score_x', 'ot_score_y']].sort_values(by='ot_score_x').to_numpy()
ax.plot(plot_data[:, 0], plot_data[:, 1], label=f"{lambda_param}")
ax.set_xlabel('OT score baseline (EMD)')
ax.set_ylabel('OT score approx. (Sinkhorn)')
plt.savefig(f"{OP_DIR}/sinkhorn_approx_quality.png", bbox_inches='tight')
if __name__ == "__main__":
# basic_2d()
# representative_points_2d()
# compute_runtimes()
resfile = f"{OP_DIR}/runtimes.csv"
df = pd.read_csv(resfile)
