Last active
March 8, 2022 10:18
-
-
Save biggzlar/39199226dfb3d0923aaf72acb76e6978 to your computer and use it in GitHub Desktop.
Argmax vs. softargmax (play with beta)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
def softmax(x, beta=1.0): | |
return np.exp(beta * x) / np.sum(np.exp(beta * x)) | |
def softargmax(x, beta=1.0): | |
return np.sum((np.exp(beta * x) / np.sum(np.exp(beta * x))) * np.arange(len(x))) | |
dim = 8 | |
samples = 32 | |
beta = 1 | |
x = np.random.random(dim) | |
softmax_a = [] | |
softargmax_a = [] | |
for i in range(10000): | |
b = np.random.random(dim) | |
if np.abs(np.argmax(softmax(x, beta)) - np.argmax(softmax(b, beta))) < 0.005: | |
softmax_a += [b] | |
if np.abs(softargmax(x, beta) - softargmax(b, beta)) < 0.005: | |
softargmax_a += [b] | |
softmax_a = np.array(softmax_a).squeeze() | |
softargmax_a = np.array(softargmax_a).squeeze() | |
fig, axs = plt.subplots(2) | |
fig.suptitle('Illustration of inputs with equal outputs @' + r'$\beta=' + f'{beta}$') | |
axs[0].set_title('Inputs with equal argmax') | |
axs[0].imshow(softmax_a[:samples].T) | |
axs[1].set_title('Inputs with equal ' + r'$\bf{soft}$' + 'argmax') | |
axs[1].imshow(softargmax_a[:samples].T) | |
plt.tight_layout() | |
plt.savefig("soft_arg_max.png", dpi=300) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment