Last active
February 22, 2020 05:58
-
-
Save tigershen23/8aab1ffc8bdef002f1515cb38269ce4d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def model(data): | |
alpha = torch.tensor(6.0) | |
beta = torch.tensor(10.0) | |
pay_probs = pyro.sample('pay_probs', dist.Beta(alpha, beta).expand(3).independent(1)) | |
normalized_pay_probs = pay_probs / torch.sum(pay_probs) | |
with pyro.iarange('data_loop', len(data)): | |
pyro.sample('obs', dist.Categorical(probs=normalized_pay_probs), obs=data) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Found this useful as part of your Pyro example blogpost.
The code from your blog was close to running out of the box, but with two small changes:
.expand
now seems not take anint
but look for a tensor thus.expand((3,))
runsfrom torch.distributions import constraints
needed for the guideWanted to let you know in case you wanted to revise the blogpost to be current.