Forked from daviesl/multinomial_score_function_estimator.py
Last active
July 2, 2024 00:57
-
-
Save danmackinlay/00a2ae88d2600a605db7e16144303532 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
An example of the score function gradient estimator, a.k.a. | |
[REINFORCE](http://stillbreeze.github.io/REINFORCE-vs-Reparameterization-trick/) | |
written to provide a concrete example to a blog post on | |
[Monte Carlo gradient estimation](https://danmackinlay.name/notebook/mc_grad.html). | |
This method can estimate very general estimands. | |
In this case we will try to find the parameters which minimises a difference | |
between the categorical distribution we sample from, and some target distrbution. | |
""" | |
import torch | |
# True target distribution probabilities | |
true_probs = torch.tensor([0.1, 0.6, 0.3]) | |
# Optimization parameters | |
n_batch = 1000 | |
n_iter = 3000 | |
lr = 0.01 | |
def f(x): | |
""" | |
The target function, a likelihood for a categorical distribution with | |
the given probabilities. | |
The minus sign is important, since this algorithm _minimises_ | |
""" | |
return -torch.distributions.Multinomial( | |
total_count=1, probs=true_probs).log_prob(x) | |
# Set the seed for reproducibility | |
torch.manual_seed(42) | |
# Initialize the parameter estimates | |
theta_hat = torch.nn.Parameter(torch.tensor([0., 0., 0.])) | |
optimizer = torch.optim.Adam([theta_hat], lr=lr) | |
for epoch in range(n_iter): | |
optimizer.zero_grad() | |
# Sample from the estimated distribution | |
x_sample = torch.distributions.Multinomial( | |
1, logits=theta_hat).sample((n_batch,)) | |
# exaluate log density at the sample points | |
log_p_theta_x = torch.distributions.Multinomial( | |
1, logits=theta_hat).log_prob(x_sample) | |
# Evaluate the target function at the sample points | |
f_hat = f(x_sample) | |
# Compute the gradient of the log density with respect to parameters | |
# The `grad_outputs` should multiply the `f_hat` with the gradient directly | |
grad_log_p_theta_x = torch.autograd.grad( | |
outputs=log_p_theta_x, inputs=theta_hat, | |
grad_outputs=torch.ones_like(log_p_theta_x), | |
create_graph=True)[0] | |
# The final gradients are weighted over the sample points | |
final_gradients = ( | |
f_hat.detach().unsqueeze(1) | |
* grad_log_p_theta_x | |
).mean(dim=0) | |
theta_hat.grad = final_gradients | |
optimizer.step() | |
if epoch % 100 == 0: | |
print(f"Epoch {epoch}, Estimated Probs:" | |
f"{torch.softmax(theta_hat, dim=0).detach().numpy()}") | |
# Display the final estimated probabilities | |
estimated_final_probs = torch.softmax(theta_hat, dim=0) | |
print("Final Estimated Probabilities: " | |
f" {estimated_final_probs.detach().numpy()}" | |
f" (True Probabilities: {true_probs.detach().numpy()}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment