Skip to content

Instantly share code, notes, and snippets.

@tigershen23
Last active February 22, 2020 05:58
Show Gist options
  • Save tigershen23/8aab1ffc8bdef002f1515cb38269ce4d to your computer and use it in GitHub Desktop.
Save tigershen23/8aab1ffc8bdef002f1515cb38269ce4d to your computer and use it in GitHub Desktop.
def model(data):
alpha = torch.tensor(6.0)
beta = torch.tensor(10.0)
pay_probs = pyro.sample('pay_probs', dist.Beta(alpha, beta).expand(3).independent(1))
normalized_pay_probs = pay_probs / torch.sum(pay_probs)
with pyro.iarange('data_loop', len(data)):
pyro.sample('obs', dist.Categorical(probs=normalized_pay_probs), obs=data)
@willzeng
Copy link

Found this useful as part of your Pyro example blogpost.

The code from your blog was close to running out of the box, but with two small changes:

  • .expand now seems not take an int but look for a tensor thus .expand((3,)) runs
  • there's a missing import for from torch.distributions import constraints needed for the guide

Wanted to let you know in case you wanted to revise the blogpost to be current.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment