Skip to content

Instantly share code, notes, and snippets.

@piEsposito
Created March 30, 2020 18:37
Show Gist options
  • Save piEsposito/521b0a27dc2c2b2338c9cb30d702edeb to your computer and use it in GitHub Desktop.
Save piEsposito/521b0a27dc2c2b2338c9cb30d702edeb to your computer and use it in GitHub Desktop.
@variational_estimator
class BayesianRegressor(nn.Module):
def __init__(self,):
super().__init__()
#self.linear = nn.Linear(input_dim, output_dim)
self.blinear1 = BayesianLinear(input_dim, 512)
self.blinear2 = BayesianLinear(512, output_dim)
def forward(self, x):
x_ = self.blinear1(x)
return self.blinear2(x_)
reg = BayesianRegressor()
preds = reg(dataset)
fit_loss = criterion(preds, labels)
complexity_loss = reg.nn_kl_divergence()
true_loss = fit_loss + complecity_loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment