Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save danmackinlay/00a2ae88d2600a605db7e16144303532 to your computer and use it in GitHub Desktop.
Save danmackinlay/00a2ae88d2600a605db7e16144303532 to your computer and use it in GitHub Desktop.
"""
An example of the score function gradient estimator, a.k.a.
[REINFORCE](http://stillbreeze.github.io/REINFORCE-vs-Reparameterization-trick/)
written to provide a concrete example to a blog post on
[Monte Carlo gradient estimation](https://danmackinlay.name/notebook/mc_grad.html).
This method can estimate very general estimands.
In this case we will try to find the parameters which minimises a difference
between the categorical distribution we sample from, and some target distrbution.
"""
import torch
# True target distribution probabilities
true_probs = torch.tensor([0.1, 0.6, 0.3])
# Optimization parameters
n_batch = 1000
n_iter = 3000
lr = 0.01
def f(x):
"""
The target function, a likelihood for a categorical distribution with
the given probabilities.
The minus sign is important, since this algorithm _minimises_
"""
return -torch.distributions.Multinomial(
total_count=1, probs=true_probs).log_prob(x)
# Set the seed for reproducibility
torch.manual_seed(42)
# Initialize the parameter estimates
theta_hat = torch.nn.Parameter(torch.tensor([0., 0., 0.]))
optimizer = torch.optim.Adam([theta_hat], lr=lr)
for epoch in range(n_iter):
optimizer.zero_grad()
# Sample from the estimated distribution
x_sample = torch.distributions.Multinomial(
1, logits=theta_hat).sample((n_batch,))
# exaluate log density at the sample points
log_p_theta_x = torch.distributions.Multinomial(
1, logits=theta_hat).log_prob(x_sample)
# Evaluate the target function at the sample points
f_hat = f(x_sample)
# Compute the gradient of the log density with respect to parameters
# The `grad_outputs` should multiply the `f_hat` with the gradient directly
grad_log_p_theta_x = torch.autograd.grad(
outputs=log_p_theta_x, inputs=theta_hat,
grad_outputs=torch.ones_like(log_p_theta_x),
create_graph=True)[0]
# The final gradients are weighted over the sample points
final_gradients = (
f_hat.detach().unsqueeze(1)
* grad_log_p_theta_x
).mean(dim=0)
theta_hat.grad = final_gradients
optimizer.step()
if epoch % 100 == 0:
print(f"Epoch {epoch}, Estimated Probs:"
f"{torch.softmax(theta_hat, dim=0).detach().numpy()}")
# Display the final estimated probabilities
estimated_final_probs = torch.softmax(theta_hat, dim=0)
print("Final Estimated Probabilities: "
f" {estimated_final_probs.detach().numpy()}"
f" (True Probabilities: {true_probs.detach().numpy()}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment