Created
December 20, 2024 01:17
-
-
Save jakelevi1996/7d021992dcf47b4c214f49e53e856d56 to your computer and use it in GitHub Desktop.
I spent way too long doing this
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from jutility import plotting | |
def d(k: int) -> np.ndarray: | |
n = 1 << k | |
a_n11 = np.arange(n).reshape(n, 1, 1) | |
p_11k = np.arange(k).reshape(1, 1, k) | |
b_n1k = (a_n11 >> p_11k) & 1 | |
b_1nk = b_n1k.reshape(1, n, k) | |
s_nnk = np.where(b_n1k == b_1nk, 0, 1 + p_11k) | |
d_nn = s_nnk.max(axis=-1) | |
return d_nn | |
def d_perm(n: int, d_nn: np.ndarray, pi_n: np.ndarray) -> np.ndarray: | |
assert sorted(pi_n.tolist()) == list(range(n)) | |
pi_n1 = pi_n.reshape(n, 1) | |
pi_1n = pi_n.reshape(1, n) | |
dpi_nn = np.empty_like(d_nn) | |
dpi_nn[pi_n1, pi_1n] = d_nn | |
return dpi_nn | |
def d_perm_set(n: int, d_nn: np.ndarray, *pi_n: np.ndarray) -> np.ndarray: | |
dpi_mnn = np.stack([d_perm(n, d_nn, pi) for pi in pi_n], axis=-3) | |
dpi_nn = dpi_mnn.min(axis=-3) | |
return dpi_nn | |
def plot_d(k: int, d_nn: np.ndarray) -> plotting.Subplot: | |
return plotting.Subplot(plotting.ImShow(d_nn, vmin=0, vmax=k)) | |
def plot_perm_set(k: int, d_nn: np.ndarray, *pi_n: np.ndarray) -> plotting.MultiPlot: | |
n = 1 << k | |
ds_nn = d_perm_set(n, d_nn, *pi_n) | |
return plotting.MultiPlot( | |
plotting.MultiPlot( | |
*[plot_d(k, d_perm(n, d_nn, pi)) for pi in pi_n], | |
title="$\\{D(\\pi_i)\\}_{i=1}^m$", | |
), | |
plotting.MultiPlot( | |
plot_d(k, ds_nn), | |
title="$D(\\pi_{1:m})$", | |
), | |
figsize=[10, 5], | |
space=0.2, | |
) | |
def best_perm_set(k: int) -> np.ndarray: | |
n = 1 << k | |
a_n = np.arange(n) | |
pi_mn = np.zeros([n-1, n], dtype=int) | |
pi_mn[0, 1] = 1 | |
for s in range(1, k): | |
t = 1 << s | |
t2 = t << 1 | |
tm1 = t - 1 | |
t2m1 = t2 - 1 | |
pi_mn[:tm1, t:t2] = pi_mn[:tm1, :t] + t | |
a_t = a_n[:t] | |
a_1t = a_t.reshape(1, t) | |
a_t1 = a_t.reshape(t, 1) | |
pi_mn[tm1:t2m1, 0:t2:2] = a_1t | |
pi_mn[tm1:t2m1, 1:t2:2] = ((a_1t + a_t1) % t) + t | |
return pi_mn | |
k = 4 | |
d_nn = d(k) | |
print(d_nn) | |
plotting.plot(plotting.ImShow(d_nn), figsize=[6, 6]) | |
n = 1 << k | |
pi_n = np.arange(n) | |
pi_n[[1, 7, 12]] = [7, 12, 1] | |
dpi_nn = d_perm(n, d_nn, pi_n) | |
ticks = {"xticks": [0, 6, 13], "yticks": [7, 12, 1]} | |
labels = {"xlabel": "i", "ylabel": "j"} | |
plotting.MultiPlot( | |
plotting.Subplot( | |
plotting.ImShow(d_nn, axis_off=False), | |
plotting.Scatter([0, 6, 13], [1, 7, 12], c="r", m="x", s=15), | |
**ticks, | |
**labels, | |
title="$d(i, j)$", | |
), | |
plotting.Subplot( | |
plotting.ImShow(dpi_nn, axis_off=False), | |
plotting.Scatter([0, 6, 13], [7, 12, 1], c="r", m="x", s=15), | |
**ticks, | |
**labels, | |
title="$d_\\pi(i, j)=d(\\pi^{-1}(i), \\pi^{-1}(j))$", | |
), | |
title="$\\pi(0:15)$ = %s" % pi_n.tolist(), | |
title_font_size=12, | |
figsize=[6, 3], | |
).save() | |
a_n = np.arange(n) | |
pi2 = np.empty_like(pi_n) | |
pi2[0::2] = np.arange(n >> 1) | |
pi2[1::2] = np.arange(n >> 1) + (n >> 1) | |
pi3 = np.empty_like(pi_n) | |
pi3[0::2] = np.arange(n >> 1) | |
pi3[1::2] = n - 1 - np.arange(n >> 1) | |
pi_list = [a_n, pi_n, pi2, pi3, pi3[pi3], pi3[pi3[pi3]]] | |
plot_perm_set(k, d_nn, *pi_list).save() | |
plotting.MultiPlot( | |
plotting.Subplot(plotting.ImShow(best_perm_set(4)), title="k=%i" % 4), | |
plotting.Subplot(plotting.ImShow(best_perm_set(6)), title="k=%i" % 6), | |
plotting.Subplot(plotting.ImShow(best_perm_set(8)), title="k=%i" % 8), | |
figsize=[9, 3], | |
num_rows=1, | |
).save("im1") | |
k = 4 | |
n = 1 << k | |
d_nn = d(k) | |
pi_mn = best_perm_set(k) | |
pi_list = np.split(pi_mn, n-1, axis=0) | |
pi_list = [pi_n.reshape(n) for pi_n in pi_list] | |
plot_perm_set(k, d_nn, *pi_list).save("im2") | |
pi_mn = best_perm_set(10) | |
plotting.MultiPlot( | |
plotting.Subplot(plotting.ImShow(pi_mn[:, 1::2])), | |
plotting.Subplot(plotting.ImShow(pi_mn[:, 0::2])), | |
figsize=[10, 10], | |
pad=0, | |
).save("im3") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment