Skip to content

Instantly share code, notes, and snippets.

@piEsposito
Created March 30, 2020 18:06
Show Gist options
  • Save piEsposito/e7c07b379f257eff0ab20fdeb30e998e to your computer and use it in GitHub Desktop.
Save piEsposito/e7c07b379f257eff0ab20fdeb30e998e to your computer and use it in GitHub Desktop.
import torch
from torch import nn
from blitz.modules import BayesianLinear
class BayesianRegressor(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.linear = nn.Linear(input_dim, output_dim)
self.blinear1 = BayesianLinear(input_dim, 64)
self.blinear2 = BayesianLinear(64, output_dim)
def forward(self, x):
x_ = self.linear(x)
x_ = self.blinear1(x_)
return self.blinear2(x_)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment