from typing import Callable, Optional import pytorch_lightning as pl import torch import torch.optim as optim from omegaconf import OmegaConf from sfu_compression.losses import RateDistortionLoss from sfu_compression.models import SFUDenoiseScalable from sfu_compression.utils import ( create_noise_model, git_branch_name, git_common_ancestor_hash, git_current_hash, ) class LitSFUDenoiseScalable(pl.LightningModule): def __init__( self, conf: Optional[OmegaConf] = None, **kwargs, ): super().__init__() self.save_hyperparameters(conf) self.save_hyperparameters(kwargs) self.model = SFUDenoiseScalable( N=self.hparams.architecture.num_channels, BASE_N=self.hparams.architecture.num_base_channels, ) self.criterion = RateDistortionLoss( lmbda=self.hparams.training.lmbda, w1d=self.hparams.training.w1d, w2d=self.hparams.training.w2d, w3d=self.hparams.training.w3d, w1r=self.hparams.training.w1r, w2r=self.hparams.training.w2r, w3r=self.hparams.training.w3r, ) self.noise_model = create_noise_model(self.hparams.noise_model) self.automatic_optimization = False def forward(self, x): # TODO compress, decompress? return self.model(x) def training_step(self, batch, batch_idx): x = batch x_noise = self.noise_model(x) optimizer, aux_optimizer = self.optimizers() optimizer.zero_grad() aux_optimizer.zero_grad() out_net = self.model(x_noise) out_criterion = self.criterion(out_net, {"x": x_noise, "x_denoise": x}) loss = out_criterion["loss"] self.manual_backward(loss) torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.hparams.training.clip_max_norm ) optimizer.step() aux_loss = self.model.aux_loss() self.manual_backward(aux_loss) aux_optimizer.step() log_dict = {**out_criterion, "aux_loss": aux_loss} log_dict = {f"train/{k}": v for k, v in log_dict.items()} self.log_dict(log_dict) def validation_step(self, batch, batch_idx): x = batch x_noise = self.noise_model(x) out_net = self.model(x_noise) out_criterion = self.criterion(out_net, {"x": x_noise, "x_denoise": x}) aux_loss = self.model.aux_loss() log_dict = {**out_criterion, "aux_loss": aux_loss} log_dict = {f"val/{k}": v for k, v in log_dict.items()} log_dict["val_loss"] = out_criterion["loss"] self.log_dict(log_dict) def validation_epoch_end(self, outputs): sch = self.lr_schedulers() if isinstance(sch, torch.optim.lr_scheduler.ReduceLROnPlateau): sch.step(self.trainer.callback_metrics["val/loss"]) else: raise Exception def test_step(self, batch, batch_idx): x = batch x_noise = self.noise_model(x) enc_dict = self.model.compress(x_noise) encoded = [x[0] for x in enc_dict["strings"]] result = self.model.decompress(**enc_dict) x_hat = result["x_hat"].numpy()[0] # TODO log metrics, etc; on_epoch, on_step def configure_optimizers(self): optimizer, aux_optimizer = configure_optimizers( self.model, self.hparams.training ) lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min") return ( { "optimizer": optimizer, "lr_scheduler": { "scheduler": lr_scheduler, "monitor": "val/loss", }, }, { "optimizer": aux_optimizer, }, ) def on_fit_start(self): params = { "git": { "branch_name": git_branch_name(), "hash": git_current_hash(), "master_hash": git_common_ancestor_hash(), }, **self.hparams, } metrics = {"hp/metric": -1} self.logger.log_hyperparams(params, metrics) def on_load_checkpoint(self, checkpoint): prefix = "model." checkpoint["state_dict"] = { f"{prefix}{k}": v for k, v in checkpoint["state_dict"].items() } def on_save_checkpoint(self, checkpoint): prefix_len = len("model.") checkpoint["state_dict"] = { k[prefix_len:]: v for k, v in checkpoint["state_dict"].items() } def configure_optimizers(net, args): """Separate parameters for the main optimizer and the auxiliary optimizer. Return two optimizers""" parameters = { n for n, p in net.named_parameters() if not n.endswith(".quantiles") and p.requires_grad } aux_parameters = { n for n, p in net.named_parameters() if n.endswith(".quantiles") and p.requires_grad } # Make sure we don't have an intersection of parameters params_dict = dict(net.named_parameters()) inter_params = parameters & aux_parameters union_params = parameters | aux_parameters assert len(inter_params) == 0 assert len(union_params) - len(params_dict.keys()) == 0 optimizer = optim.Adam( (params_dict[n] for n in sorted(parameters)), lr=args.learning_rate, ) aux_optimizer = optim.Adam( (params_dict[n] for n in sorted(aux_parameters)), lr=args.aux_learning_rate, ) return optimizer, aux_optimizer