Skip to content

Instantly share code, notes, and snippets.

@tigershen23
Created June 9, 2018 17:46
Show Gist options
  • Save tigershen23/a9a7810e9429be43fa205409956505ac to your computer and use it in GitHub Desktop.
Save tigershen23/a9a7810e9429be43fa205409956505ac to your computer and use it in GitHub Desktop.
def print_progress():
alphas = pyro.param("alphas")
betas = pyro.param("betas")
if torch.cuda.is_available():
alphas.cuda()
betas.cuda()
means = alphas / (alphas + betas)
normalized_means = means / torch.sum(means)
factors = betas / (alphas * (1.0 + alphas + betas))
stdevs = normalized_means * torch.sqrt(factors)
tiger_pays_string = "probability Tiger pays: {0:.3f} +/- {1:.2f}".format(normalized_means[0], stdevs[0])
jason_pays_string = "probability Jason pays: {0:.3f} +/- {1:.2f}".format(normalized_means[1], stdevs[1])
james_pays_string = "probability James pays: {0:.3f} +/- {1:.2f}".format(normalized_means[2], stdevs[2])
print("[", step, "|", tiger_pays_string, "|", jason_pays_string, "|", james_pays_string, "]")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment