Skip to content

Instantly share code, notes, and snippets.

@jayelm
Created April 26, 2021 03:53
Show Gist options
  • Save jayelm/303f4ec6d885ac8be9d39bd45b1dc3ea to your computer and use it in GitHub Desktop.
Save jayelm/303f4ec6d885ac8be9d39bd45b1dc3ea to your computer and use it in GitHub Desktop.
eps_greedy_gumbel_softmax.py
import numpy as np
import torch
import torch.nn.functional as F
from torch.distributions import Gumbel
B = 1000
logits = torch.tensor([np.log(.1), np.log(.2), np.log(.7)]).unsqueeze(0).expand(B, -1)
# Standard sample of gumbel softmax
standard_samples = F.gumbel_softmax(logits, tau=0.01, hard=True)
print(standard_samples.mean(0))
random_noise = Gumbel(torch.zeros((B, 3)), torch.full((B, 3), 10000.)).sample()
eps_samples = F.gumbel_softmax(logits + random_noise, tau=0.01, hard=True)
print(eps_samples.mean(0))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment