Last active
November 23, 2023 18:04
-
-
Save alper111/ed57eae08ca6e0822ebe313eda0b5e2b to your computer and use it in GitHub Desktop.
This snippet compares the sigmoid function's response and derivative with the Gumbel-sigmoid's.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import matplotlib.pyplot as plt | |
def sample_gumbel_diff(*shape): | |
eps = 1e-20 | |
u1 = torch.rand(shape) | |
u2 = torch.rand(shape) | |
diff = torch.log(torch.log(u2+eps)/torch.log(u1+eps)+eps) | |
return diff | |
def gumbel_sigmoid(logits, T=1.0, hard=False): | |
g = sample_gumbel_diff(*logits.shape) | |
g = g.to(logits.device) | |
y = (g + logits) / T | |
s = torch.sigmoid(y) | |
if hard: | |
s_hard = s.round() | |
s = (s_hard - s).detach() + s | |
return s | |
fig, ax = plt.subplots(1, 3, figsize=(12, 4)) | |
# Sigmoid | |
x = torch.linspace(-8, 8, 100) # input x \in [-8, 8] | |
x.requires_grad = True # turn on gradient logging | |
y_sig = torch.sigmoid(x) # sigmoid function | |
y_sig.backward(torch.ones_like(x)) # backpropagate gradients | |
ax[0].plot(x.detach().numpy(), | |
y_sig.detach().numpy(), | |
label="sigmoid(x)", c="r") | |
ax[1].plot(x.detach().numpy(), | |
x.grad.detach().numpy(), | |
label="sigmoid'(x)", c="r") | |
# Sigmoid with increased slope | |
x = torch.linspace(-8, 8, 100) # input x \in [-8, 8] | |
x.requires_grad = True # turn on gradient logging | |
y_sig = torch.sigmoid(x*10) # sigmoid function with increased slope | |
y_sig.backward(torch.ones_like(x)) # backpropagate gradients | |
ax[2].plot(x.detach().numpy(), | |
y_sig.detach().numpy(), | |
label="sigmoid(x*10)", c="g") | |
ax[2].plot(x.detach().numpy(), | |
x.grad.detach().numpy(), | |
label="sigmoid'(x*10)", c="g", linestyle="--") | |
# Gumbel-sigmoid | |
x = torch.linspace(-8, 8, 100).repeat(200, 1) # repeat each segment 200 times to | |
# visualize the stochasticity | |
x.requires_grad = True | |
yg = gumbel_sigmoid(x, T=1, hard=False) # use the default temperature | |
yg.backward(torch.ones_like(yg)) # backpropagate gradients | |
ax[0].scatter(x.reshape(-1).detach(), | |
yg.reshape(-1).detach(), | |
label="gumbel(x)", c="b", alpha=0.05) | |
ax[1].scatter(x.reshape(-1).detach(), | |
x.grad.data.reshape(-1), | |
label="gumbel'(x)", c="b", alpha=0.05) | |
ax[0].legend() | |
ax[1].legend() | |
ax[2].legend() | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment