Created
February 10, 2022 20:07
-
-
Save smsharma/802aa1a27a21b28e8a431349d99f9e9e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
sys.path.append("../") | |
from itertools import permutations | |
import torch | |
from torch import nn | |
from torch.distributions import Normal | |
from torch.nn import CrossEntropyLoss | |
import pytorch_lightning as pl | |
from einops import rearrange, repeat | |
from models.flows import build_maf, build_mlp | |
class SleepFlows(pl.LightningModule): | |
def __init__(self, inference_net, n_out, | |
optimizer=torch.optim.AdamW, optimizer_kwargs={"weight_decay":1e-5}, lr=3e-4, | |
scheduler=torch.optim.lr_scheduler.ReduceLROnPlateau, | |
scheduler_kwargs = {"patience":5}, | |
n_ps_max=5): | |
super().__init__() | |
self.save_hyperparameters("n_ps_max", "lr") | |
self.n_ps_max = n_ps_max | |
n_param_per_ps = 3 | |
self.inference_net = inference_net # Encoder | |
self.optimizer = optimizer | |
self.optimizer_kwargs = optimizer_kwargs | |
self.scheduler = scheduler | |
self.scheduler_kwargs = scheduler_kwargs | |
self.lr = lr | |
# Number of features conditioning each flow | |
self.n_context_features = int(n_out / (self.n_ps_max + 1)) | |
# Instantiate normalizing flows---one for each n_source | |
self.flows = [build_maf(dim=int(n_param_per_ps * i), context_features=self.n_context_features, num_transforms=8, hidden_features=128) for i in range(1, self.n_ps_max + 1)] | |
self.flows = nn.ModuleList(self.flows) | |
# MLP for predicting p(n_sources) | |
self.probs_mlp = build_mlp(input_dim=self.n_context_features, hidden_dim=256, output_dim=self.n_ps_max + 1, layers=3) | |
self.perms_list = [list(permutations(torch.arange(i))) for i in range(1, self.n_ps_max + 1)] | |
self.n_perms = [len(perms) for perms in self.perms_list] | |
def forward(self, x): | |
x = self.inference_net(x) | |
return x | |
def configure_optimizers(self): | |
optimizer = self.optimizer(self.parameters(), lr=self.lr, **self.optimizer_kwargs) | |
return {"optimizer": optimizer, | |
"lr_scheduler": { | |
"scheduler": self.scheduler(optimizer, **self.scheduler_kwargs), | |
"interval": "epoch", | |
"monitor": "val_loss", | |
"frequency": 1} | |
} | |
def loss(self, z_n, z_x, z_c, out): | |
n_batch = out.shape[0] | |
# From output of of encoder, get probability logits (after feeding through MLP) and conditioning context from flow | |
probs = self.probs_mlp(out[:,:self.n_context_features]) # Categorical probability logits | |
x = out[:,self.n_context_features:].chunk(self.n_ps_max, dim=-1) # The rest of the output is used as conditioning context for flow | |
# Loss on number of sources | |
n_source_log_probs = probs.softmax(dim=1).view(n_batch, self.n_ps_max + 1) | |
cross_entropy = CrossEntropyLoss(reduction="none").requires_grad_(False) | |
counter_loss = cross_entropy(n_source_log_probs, z_n.long()) | |
# Combine x-y positions and flux | |
y = torch.cat([z_x, z_c.unsqueeze(2)], dim=2) | |
z_n = z_n.int() | |
# Position/flux losses---computed by taking the largest log_prob from all possible permutations | |
# For speed, each of the n_max_sources flows is evaluated only once by combining all elements in a batch that contain a given number of sources | |
log_prob = 0. | |
for i in range(1, self.n_ps_max + 1): | |
n_elems = (z_n == i).sum() | |
y_perm = rearrange(y[torch.where(z_n == i)[0]][:, self.perms_list[i - 1], :], 'ne np nps npps -> (ne np) (nps npps)', ne=n_elems, np=self.n_perms[i - 1], nps=i, npps=3) | |
x_perm = rearrange(repeat(x[i - 1][z_n == i], 'ne ns -> ne np ns', np=self.n_perms[i - 1]), 'ne np ns -> (ne np) ns', ne=n_elems) | |
log_probs = self.flows[i - 1].log_prob(y_perm, x_perm) | |
log_prob += (torch.max(rearrange(log_probs, '(ne np) -> ne np', ne=n_elems), dim=1)[0]).sum() | |
return (- log_prob / n_batch) + counter_loss.mean() | |
def training_step(self, batch, batch_idx): | |
z_n, z_x, z_c, x = batch | |
out = self(x) | |
loss = self.loss(z_n, z_x, z_c, out) | |
self.log('train_loss', loss, on_epoch=True) | |
return loss | |
def validation_step(self, batch, batch_idx): | |
z_n, z_x, z_c, x = batch | |
out = self(x) | |
loss = self.loss(z_n, z_x, z_c, out) | |
self.log('val_loss', loss, on_epoch=True) | |
return loss | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment