Created
July 7, 2021 14:53
-
-
Save j20232/e6bed077b993556273541d87ad8ae0d9 to your computer and use it in GitHub Desktop.
Tiny snippets of Learning Aggregation Functions [Pellegrini et al. IJCAI2021] https://arxiv.org/abs/2012.08482
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import random | |
from tqdm import tqdm | |
import numpy as np | |
import torch | |
def seed_everything(seed=1116): | |
random.seed(seed) | |
os.environ['PYTHONHASHSEED'] = str(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
torch.backends.cudnn.deterministic = True | |
class TinyLAF(torch.nn.Module): | |
def __init__(self): | |
super(TinyLAF, self).__init__() | |
self.alpha = torch.nn.Parameter(torch.ones(1)) | |
self.beta = torch.nn.Parameter(torch.ones(1)) | |
self.gamma = torch.nn.Parameter(torch.ones(1)) | |
self.delta = torch.nn.Parameter(torch.ones(1)) | |
self.a = torch.nn.Parameter(torch.ones(1)) | |
self.b = torch.nn.Parameter(torch.ones(1)) | |
self.c = torch.nn.Parameter(torch.ones(1)) | |
self.d = torch.nn.Parameter(torch.ones(1)) | |
self.e = torch.nn.Parameter(torch.ones(1)) | |
self.f = torch.nn.Parameter(torch.ones(1)) | |
self.g = torch.nn.Parameter(torch.ones(1)) | |
self.h = torch.nn.Parameter(torch.ones(1)) | |
def get_coeff(self): | |
return { | |
"alpha": self.alpha.detach().numpy()[0], | |
"beta": self.beta.detach().numpy()[0], | |
"gamma": self.gamma.detach().numpy()[0], | |
"delta": self.delta.detach().numpy()[0], | |
} | |
def get_exp(self): | |
return { | |
"a": self.a.detach().numpy()[0], | |
"b": self.b.detach().numpy()[0], | |
"c": self.c.detach().numpy()[0], | |
"d": self.d.detach().numpy()[0], | |
"e": self.e.detach().numpy()[0], | |
"f": self.f.detach().numpy()[0], | |
"g": self.g.detach().numpy()[0], | |
"h": self.h.detach().numpy()[0], | |
} | |
def forward(self, x): | |
def L(coeff, x, a, b): | |
return coeff * torch.pow(torch.sum(torch.pow(x, b)), a) | |
first_denom = L(self.alpha, x, self.a, self.b) | |
second_denom = L(self.beta, 1 - x, self.c, self.d) | |
first_num = L(self.gamma, x, self.e, self.f) | |
second_num = L(self.delta, 1 - x, self.g, self.h) | |
return (first_denom + second_denom) / (first_num + second_num) | |
def optimize(x, gt_val, label, epochs=5000, lr=1e-2, max_esr=1000): | |
ten_x = torch.from_numpy(x.astype(np.float32)) | |
gt = torch.tensor(gt_val.astype(np.float32)) | |
model = TinyLAF() | |
loss_fn = torch.nn.MSELoss() | |
optimizer = torch.optim.Adam(model.parameters(), lr=lr) | |
esr = 0 | |
best_loss = 99999 | |
best_round = -1 | |
with tqdm(range(epochs)) as t: | |
for e in t: | |
optimizer.zero_grad() | |
loss = loss_fn(gt, model(ten_x)[0]) | |
loss.backward() | |
optimizer.step() | |
val = loss.detach().numpy() | |
if best_loss > val: | |
best_round = e | |
esr = 0 | |
best_loss = val | |
else: | |
esr += 1 | |
if esr > max_esr: | |
break | |
with torch.no_grad(): | |
ten_x = torch.from_numpy(x.astype(np.float32)) | |
out = model(ten_x) | |
print(f"----- {label}-----") | |
print("Out: ", out.detach().numpy()[0]) | |
print("GT: ", gt_val) | |
print("Best round: ", best_round) | |
print("Coeff: ", model.get_coeff()) | |
print("Exp coeff: ", model.get_exp()) | |
if __name__ == '__main__': | |
seed_everything() | |
n_samples = 50 | |
x = np.random.rand(n_samples) | |
optimize(x, np.max(x), "max") | |
optimize(x, np.min(x), "min") | |
optimize(x, np.mean(x), "mean") | |
optimize(x, np.array(np.nonzero(x)[0].shape[0]), "nonzero count") | |
optimize(x, np.min(x) / np.max(x), "min/max") | |
optimize(x, np.max(x) / np.min(x), "max/min") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment