Created
March 30, 2023 21:03
-
-
Save wiseodd/62fadb452f77e488acc3716ed3822ac7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import warnings | |
warnings.filterwarnings('ignore') | |
import torch | |
from laplace import Laplace | |
from helper.dataloaders import get_sinusoid_example | |
from laplace import Laplace | |
n_epochs = 1000 | |
torch.manual_seed(711) | |
# create toy regression data | |
X_train, y_train, train_loader, X_test = get_sinusoid_example(sigma_noise=0.3) | |
class Test(): | |
def __init__(self, train_loader, device='cpu'): | |
super().__init__() | |
self.nn = torch.nn.Sequential( | |
torch.nn.Linear(1, 50), torch.nn.Tanh(), torch.nn.Linear(50, 1) | |
).to(device) | |
self.train_loader = train_loader | |
self.nn.eval() | |
self.bnn = Laplace( | |
self.nn, 'regression', | |
subset_of_weights='all', hessian_structure='kron', | |
) | |
self.bnn.fit(self.train_loader) | |
self.bnn.optimize_prior_precision(n_steps=10) | |
def get(train_loader): | |
nn = torch.nn.Sequential( | |
torch.nn.Linear(1, 50), torch.nn.Tanh(), torch.nn.Linear(50, 1) | |
) | |
nn.eval() | |
bnn = Laplace( | |
nn, 'regression', | |
subset_of_weights='all', hessian_structure='kron', | |
) | |
bnn.fit(train_loader) | |
bnn.optimize_prior_precision(n_steps=10) | |
return nn, bnn | |
try: | |
print('Test case 1: Class init-predict-init-predict:') | |
model1 = Test(train_loader) | |
_, _ = model1.bnn(X_test) | |
model2 = Test(train_loader) | |
_, _ = model2.bnn(X_test) | |
except Exception: | |
print('Error!') | |
else: | |
print('Success!') | |
print() | |
try: | |
print('Test case 2: Class init-init-predict-predict:') | |
model1 = Test(train_loader) | |
model2 = Test(train_loader) | |
_, _ = model1.bnn(X_test) | |
_, _ = model2.bnn(X_test) | |
except Exception: | |
print('Error!') | |
else: | |
print('Success!') | |
print() | |
try: | |
print('Test case 3: Function init-predict-init-predict:') | |
model1, bnn1 = get(train_loader) | |
_, _ = bnn1(X_test) | |
model2, bnn2 = get(train_loader) | |
_, _ = bnn2(X_test) | |
except Exception: | |
print('Error!') | |
else: | |
print('Success!') | |
print() | |
try: | |
print('Test case 4: Function init-init-predict-predict:') | |
model1, bnn1 = get(train_loader) | |
model2, bnn2 = get(train_loader) | |
_, _ = bnn1(X_test) | |
_, _ = bnn2(X_test) | |
except Exception: | |
print('Error!') | |
else: | |
print('Success!') | |
print() | |
# --------------- SANITY CHECK --------------------- | |
try: | |
print('Sanity check: No class, no function init-init-predict-predict:') | |
# Laplace 1 | |
nn1 = torch.nn.Sequential( | |
torch.nn.Linear(1, 50), torch.nn.Tanh(), torch.nn.Linear(50, 1) | |
) | |
nn1.eval() | |
bnn1 = Laplace( | |
nn1, 'regression', | |
subset_of_weights='all', hessian_structure='kron', | |
) | |
bnn1.fit(train_loader) | |
bnn1.optimize_prior_precision(n_steps=10) | |
# Laplace 2 | |
nn2 = torch.nn.Sequential( | |
torch.nn.Linear(1, 50), torch.nn.Tanh(), torch.nn.Linear(50, 1) | |
) | |
nn2.eval() | |
bnn2 = Laplace( | |
nn2, 'regression', | |
subset_of_weights='all', hessian_structure='kron', | |
) | |
bnn2.fit(train_loader) | |
bnn2.optimize_prior_precision(n_steps=10) | |
# Predict | |
_, _ = bnn1(X_test) | |
_, _ = bnn1(X_test) | |
except Exception: | |
print('Error!') | |
else: | |
print('Success!') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment