Skip to content

Instantly share code, notes, and snippets.

@smsharma
Created February 10, 2022 20:07
Show Gist options
  • Save smsharma/802aa1a27a21b28e8a431349d99f9e9e to your computer and use it in GitHub Desktop.
Save smsharma/802aa1a27a21b28e8a431349d99f9e9e to your computer and use it in GitHub Desktop.
import sys
sys.path.append("../")
from itertools import permutations
import torch
from torch import nn
from torch.distributions import Normal
from torch.nn import CrossEntropyLoss
import pytorch_lightning as pl
from einops import rearrange, repeat
from models.flows import build_maf, build_mlp
class SleepFlows(pl.LightningModule):
def __init__(self, inference_net, n_out,
optimizer=torch.optim.AdamW, optimizer_kwargs={"weight_decay":1e-5}, lr=3e-4,
scheduler=torch.optim.lr_scheduler.ReduceLROnPlateau,
scheduler_kwargs = {"patience":5},
n_ps_max=5):
super().__init__()
self.save_hyperparameters("n_ps_max", "lr")
self.n_ps_max = n_ps_max
n_param_per_ps = 3
self.inference_net = inference_net # Encoder
self.optimizer = optimizer
self.optimizer_kwargs = optimizer_kwargs
self.scheduler = scheduler
self.scheduler_kwargs = scheduler_kwargs
self.lr = lr
# Number of features conditioning each flow
self.n_context_features = int(n_out / (self.n_ps_max + 1))
# Instantiate normalizing flows---one for each n_source
self.flows = [build_maf(dim=int(n_param_per_ps * i), context_features=self.n_context_features, num_transforms=8, hidden_features=128) for i in range(1, self.n_ps_max + 1)]
self.flows = nn.ModuleList(self.flows)
# MLP for predicting p(n_sources)
self.probs_mlp = build_mlp(input_dim=self.n_context_features, hidden_dim=256, output_dim=self.n_ps_max + 1, layers=3)
self.perms_list = [list(permutations(torch.arange(i))) for i in range(1, self.n_ps_max + 1)]
self.n_perms = [len(perms) for perms in self.perms_list]
def forward(self, x):
x = self.inference_net(x)
return x
def configure_optimizers(self):
optimizer = self.optimizer(self.parameters(), lr=self.lr, **self.optimizer_kwargs)
return {"optimizer": optimizer,
"lr_scheduler": {
"scheduler": self.scheduler(optimizer, **self.scheduler_kwargs),
"interval": "epoch",
"monitor": "val_loss",
"frequency": 1}
}
def loss(self, z_n, z_x, z_c, out):
n_batch = out.shape[0]
# From output of of encoder, get probability logits (after feeding through MLP) and conditioning context from flow
probs = self.probs_mlp(out[:,:self.n_context_features]) # Categorical probability logits
x = out[:,self.n_context_features:].chunk(self.n_ps_max, dim=-1) # The rest of the output is used as conditioning context for flow
# Loss on number of sources
n_source_log_probs = probs.softmax(dim=1).view(n_batch, self.n_ps_max + 1)
cross_entropy = CrossEntropyLoss(reduction="none").requires_grad_(False)
counter_loss = cross_entropy(n_source_log_probs, z_n.long())
# Combine x-y positions and flux
y = torch.cat([z_x, z_c.unsqueeze(2)], dim=2)
z_n = z_n.int()
# Position/flux losses---computed by taking the largest log_prob from all possible permutations
# For speed, each of the n_max_sources flows is evaluated only once by combining all elements in a batch that contain a given number of sources
log_prob = 0.
for i in range(1, self.n_ps_max + 1):
n_elems = (z_n == i).sum()
y_perm = rearrange(y[torch.where(z_n == i)[0]][:, self.perms_list[i - 1], :], 'ne np nps npps -> (ne np) (nps npps)', ne=n_elems, np=self.n_perms[i - 1], nps=i, npps=3)
x_perm = rearrange(repeat(x[i - 1][z_n == i], 'ne ns -> ne np ns', np=self.n_perms[i - 1]), 'ne np ns -> (ne np) ns', ne=n_elems)
log_probs = self.flows[i - 1].log_prob(y_perm, x_perm)
log_prob += (torch.max(rearrange(log_probs, '(ne np) -> ne np', ne=n_elems), dim=1)[0]).sum()
return (- log_prob / n_batch) + counter_loss.mean()
def training_step(self, batch, batch_idx):
z_n, z_x, z_c, x = batch
out = self(x)
loss = self.loss(z_n, z_x, z_c, out)
self.log('train_loss', loss, on_epoch=True)
return loss
def validation_step(self, batch, batch_idx):
z_n, z_x, z_c, x = batch
out = self(x)
loss = self.loss(z_n, z_x, z_c, out)
self.log('val_loss', loss, on_epoch=True)
return loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment