from typing import Callable, Optional

import pytorch_lightning as pl
import torch
import torch.optim as optim
from omegaconf import OmegaConf

from sfu_compression.losses import RateDistortionLoss
from sfu_compression.models import SFUDenoiseScalable
from sfu_compression.utils import (
    create_noise_model,
    git_branch_name,
    git_common_ancestor_hash,
    git_current_hash,
)


class LitSFUDenoiseScalable(pl.LightningModule):
    def __init__(
        self,
        conf: Optional[OmegaConf] = None,
        **kwargs,
    ):
        super().__init__()
        self.save_hyperparameters(conf)
        self.save_hyperparameters(kwargs)
        self.model = SFUDenoiseScalable(
            N=self.hparams.architecture.num_channels,
            BASE_N=self.hparams.architecture.num_base_channels,
        )
        self.criterion = RateDistortionLoss(
            lmbda=self.hparams.training.lmbda,
            w1d=self.hparams.training.w1d,
            w2d=self.hparams.training.w2d,
            w3d=self.hparams.training.w3d,
            w1r=self.hparams.training.w1r,
            w2r=self.hparams.training.w2r,
            w3r=self.hparams.training.w3r,
        )
        self.noise_model = create_noise_model(self.hparams.noise_model)
        self.automatic_optimization = False

    def forward(self, x):
        # TODO compress, decompress?
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x = batch
        x_noise = self.noise_model(x)

        optimizer, aux_optimizer = self.optimizers()
        optimizer.zero_grad()
        aux_optimizer.zero_grad()

        out_net = self.model(x_noise)

        out_criterion = self.criterion(out_net, {"x": x_noise, "x_denoise": x})
        loss = out_criterion["loss"]
        self.manual_backward(loss)
        torch.nn.utils.clip_grad_norm_(
            self.model.parameters(), self.hparams.training.clip_max_norm
        )
        optimizer.step()

        aux_loss = self.model.aux_loss()
        self.manual_backward(aux_loss)
        aux_optimizer.step()

        log_dict = {**out_criterion, "aux_loss": aux_loss}
        log_dict = {f"train/{k}": v for k, v in log_dict.items()}
        self.log_dict(log_dict)

    def validation_step(self, batch, batch_idx):
        x = batch
        x_noise = self.noise_model(x)

        out_net = self.model(x_noise)
        out_criterion = self.criterion(out_net, {"x": x_noise, "x_denoise": x})
        aux_loss = self.model.aux_loss()

        log_dict = {**out_criterion, "aux_loss": aux_loss}
        log_dict = {f"val/{k}": v for k, v in log_dict.items()}
        log_dict["val_loss"] = out_criterion["loss"]
        self.log_dict(log_dict)

    def validation_epoch_end(self, outputs):
        sch = self.lr_schedulers()

        if isinstance(sch, torch.optim.lr_scheduler.ReduceLROnPlateau):
            sch.step(self.trainer.callback_metrics["val/loss"])
        else:
            raise Exception

    def test_step(self, batch, batch_idx):
        x = batch
        x_noise = self.noise_model(x)

        enc_dict = self.model.compress(x_noise)
        encoded = [x[0] for x in enc_dict["strings"]]
        result = self.model.decompress(**enc_dict)
        x_hat = result["x_hat"].numpy()[0]

        # TODO log metrics, etc; on_epoch, on_step

    def configure_optimizers(self):
        optimizer, aux_optimizer = configure_optimizers(
            self.model, self.hparams.training
        )
        lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min")
        return (
            {
                "optimizer": optimizer,
                "lr_scheduler": {
                    "scheduler": lr_scheduler,
                    "monitor": "val/loss",
                },
            },
            {
                "optimizer": aux_optimizer,
            },
        )

    def on_fit_start(self):
        params = {
            "git": {
                "branch_name": git_branch_name(),
                "hash": git_current_hash(),
                "master_hash": git_common_ancestor_hash(),
            },
            **self.hparams,
        }
        metrics = {"hp/metric": -1}
        self.logger.log_hyperparams(params, metrics)

    def on_load_checkpoint(self, checkpoint):
        prefix = "model."
        checkpoint["state_dict"] = {
            f"{prefix}{k}": v for k, v in checkpoint["state_dict"].items()
        }

    def on_save_checkpoint(self, checkpoint):
        prefix_len = len("model.")
        checkpoint["state_dict"] = {
            k[prefix_len:]: v for k, v in checkpoint["state_dict"].items()
        }


def configure_optimizers(net, args):
    """Separate parameters for the main optimizer and the auxiliary optimizer.
    Return two optimizers"""

    parameters = {
        n
        for n, p in net.named_parameters()
        if not n.endswith(".quantiles") and p.requires_grad
    }
    aux_parameters = {
        n
        for n, p in net.named_parameters()
        if n.endswith(".quantiles") and p.requires_grad
    }

    # Make sure we don't have an intersection of parameters
    params_dict = dict(net.named_parameters())
    inter_params = parameters & aux_parameters
    union_params = parameters | aux_parameters

    assert len(inter_params) == 0
    assert len(union_params) - len(params_dict.keys()) == 0

    optimizer = optim.Adam(
        (params_dict[n] for n in sorted(parameters)),
        lr=args.learning_rate,
    )
    aux_optimizer = optim.Adam(
        (params_dict[n] for n in sorted(aux_parameters)),
        lr=args.aux_learning_rate,
    )
    return optimizer, aux_optimizer