#!/usr/bin/env python3

import argparse

from PIL import Image
import torch
from torch import nn
from torch.nn import functional as F
from torchvision.transforms import functional as TF

import mdmm_2 as mdmm


class TVLoss(nn.Module):
    def forward(self, input):
        input = F.pad(input, (0, 1, 0, 1), 'replicate')
        x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]
        y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]
        diff = x_diff**2 + y_diff**2 + 1e-8
        return diff.sum(dim=1).sqrt().sum()


def main():
    p = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    p.add_argument('input_image', type=str,
                   help='the input image')
    p.add_argument('output_image', type=str, nargs='?', default='out.png',
                   help='the output image')
    p.add_argument('--max-tv', type=float, default=0.02,
                   help='the maximum allowable total variation per sample')
    p.add_argument('--damping', type=float, default=1e-2,
                   help='the constraint damping strength')
    p.add_argument('--lr', type=float, default=2e-3,
                   help='the learning rate')
    args = p.parse_args()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print('Using device:', device)

    pil_image = Image.open(args.input_image).resize((128, 128), Image.LANCZOS)
    target = TF.to_tensor(pil_image)[None].to(device)
    input = target.clone().requires_grad_()
    # torch.manual_seed(0)
    # target += torch.randn_like(target) / 10
    # target.clamp_(0, 1)

    crit_l2 = nn.MSELoss(reduction='sum')
    crit_tv = TVLoss()
    max_tv = args.max_tv * input.numel()

    mdmm_mod = mdmm.MDMM([mdmm.MaxConstraint(lambda: crit_tv(input), max_tv, args.damping)])
    opt = mdmm_mod.make_optimizer([input], lr=args.lr)

    try:
        i = 0
        while True:
            i += 1
            loss = crit_l2(input, target)
            lagrangian, losses = mdmm_mod(loss)
            msg = '{} l2={:g}, tv={:g}'
            print(msg.format(i,
                             loss.item() / input.numel(),
                             losses[0].item() / input.numel()))
            if not lagrangian.isfinite():
                break
            opt.zero_grad()
            lagrangian.backward()
            opt.step()
    except KeyboardInterrupt:
        pass

    TF.to_pil_image(input[0].clamp(0, 1)).save(args.output_image)


if __name__ == '__main__':
    main()