Skip to content

Instantly share code, notes, and snippets.

@wiseodd
Created March 30, 2023 21:03
Show Gist options
  • Save wiseodd/62fadb452f77e488acc3716ed3822ac7 to your computer and use it in GitHub Desktop.
Save wiseodd/62fadb452f77e488acc3716ed3822ac7 to your computer and use it in GitHub Desktop.
import warnings
warnings.filterwarnings('ignore')
import torch
from laplace import Laplace
from helper.dataloaders import get_sinusoid_example
from laplace import Laplace
n_epochs = 1000
torch.manual_seed(711)
# create toy regression data
X_train, y_train, train_loader, X_test = get_sinusoid_example(sigma_noise=0.3)
class Test():
def __init__(self, train_loader, device='cpu'):
super().__init__()
self.nn = torch.nn.Sequential(
torch.nn.Linear(1, 50), torch.nn.Tanh(), torch.nn.Linear(50, 1)
).to(device)
self.train_loader = train_loader
self.nn.eval()
self.bnn = Laplace(
self.nn, 'regression',
subset_of_weights='all', hessian_structure='kron',
)
self.bnn.fit(self.train_loader)
self.bnn.optimize_prior_precision(n_steps=10)
def get(train_loader):
nn = torch.nn.Sequential(
torch.nn.Linear(1, 50), torch.nn.Tanh(), torch.nn.Linear(50, 1)
)
nn.eval()
bnn = Laplace(
nn, 'regression',
subset_of_weights='all', hessian_structure='kron',
)
bnn.fit(train_loader)
bnn.optimize_prior_precision(n_steps=10)
return nn, bnn
try:
print('Test case 1: Class init-predict-init-predict:')
model1 = Test(train_loader)
_, _ = model1.bnn(X_test)
model2 = Test(train_loader)
_, _ = model2.bnn(X_test)
except Exception:
print('Error!')
else:
print('Success!')
print()
try:
print('Test case 2: Class init-init-predict-predict:')
model1 = Test(train_loader)
model2 = Test(train_loader)
_, _ = model1.bnn(X_test)
_, _ = model2.bnn(X_test)
except Exception:
print('Error!')
else:
print('Success!')
print()
try:
print('Test case 3: Function init-predict-init-predict:')
model1, bnn1 = get(train_loader)
_, _ = bnn1(X_test)
model2, bnn2 = get(train_loader)
_, _ = bnn2(X_test)
except Exception:
print('Error!')
else:
print('Success!')
print()
try:
print('Test case 4: Function init-init-predict-predict:')
model1, bnn1 = get(train_loader)
model2, bnn2 = get(train_loader)
_, _ = bnn1(X_test)
_, _ = bnn2(X_test)
except Exception:
print('Error!')
else:
print('Success!')
print()
# --------------- SANITY CHECK ---------------------
try:
print('Sanity check: No class, no function init-init-predict-predict:')
# Laplace 1
nn1 = torch.nn.Sequential(
torch.nn.Linear(1, 50), torch.nn.Tanh(), torch.nn.Linear(50, 1)
)
nn1.eval()
bnn1 = Laplace(
nn1, 'regression',
subset_of_weights='all', hessian_structure='kron',
)
bnn1.fit(train_loader)
bnn1.optimize_prior_precision(n_steps=10)
# Laplace 2
nn2 = torch.nn.Sequential(
torch.nn.Linear(1, 50), torch.nn.Tanh(), torch.nn.Linear(50, 1)
)
nn2.eval()
bnn2 = Laplace(
nn2, 'regression',
subset_of_weights='all', hessian_structure='kron',
)
bnn2.fit(train_loader)
bnn2.optimize_prior_precision(n_steps=10)
# Predict
_, _ = bnn1(X_test)
_, _ = bnn1(X_test)
except Exception:
print('Error!')
else:
print('Success!')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment