import torch
from torch import nn

__all__ = ['FCDenseNet', 'fcdensenet_tiny', 'fcdensenet56_nodrop',
           'fcdensenet56', 'fcdensenet67', 'fcdensenet103',
           'fcdensenet103_nodrop']


class DenseBlock(nn.Module):

    def __init__(self, nIn, growth_rate, depth, drop_rate=0, only_new=False,
                 bottle_neck=False):
        super(DenseBlock, self).__init__()
        self.only_new = only_new
        self.depth = depth
        self.growth_rate = growth_rate
        self.layers = nn.ModuleList([self.get_transform(
            nIn + i * growth_rate, growth_rate, bottle_neck,
            drop_rate) for i in range(depth)])

    def forward(self, x):
        if self.only_new:
            outputs = []
            for i in range(self.depth):
                tx = self.layers[i](x)
                x = torch.cat((x, tx), 1)
                outputs.append(tx)
            return torch.cat(outputs, 1)
        else:
            for i in range(self.depth):
                x = torch.cat((x, self.layers[i](x)), 1)
            return x

    def get_transform(self, nIn, nOut, bottle_neck=None, drop_rate=0):
        if not bottle_neck or nIn <= nOut * bottle_neck:
            return nn.Sequential(
                nn.BatchNorm2d(nIn),
                nn.ReLU(True),
                nn.Conv2d(nIn, nOut, 3, stride=1, padding=1, bias=True),
                nn.Dropout(drop_rate),
            )
        else:
            nBottle = nOut * bottle_neck
            return nn.Sequential(
                nn.BatchNorm2d(nIn),
                nn.ReLU(True),
                nn.Conv2d(nIn, nBottle, 1, stride=1, padding=0, bias=True),
                nn.BatchNorm2d(nBottle),
                nn.ReLU(True),
                nn.Conv2d(nBottle, nOut, 3, stride=1, padding=1, bias=True),
                nn.Dropout(drop_rate),
            )


class FCDenseNet(nn.Module):

    def __init__(self, depths, growth_rates, n_scales=5, n_channel_start=48,
                 n_classes=12, drop_rate=0, bottle_neck=False):
        super(FCDenseNet, self).__init__()
        self.n_scales = n_scales
        self.n_classes = n_classes
        self.n_channel_start = n_channel_start
        self.depths = [depths] * \
            (2 * n_scales + 1) if type(depths) == int else depths
        self.growth_rates = [growth_rates] * (2 * n_scales + 1) if \
            type(growth_rates) == int else growth_rates
        self.drop_rate = drop_rate
        assert len(self.depths) == len(self.growth_rates) == 2 * n_scales + 1
        self.conv_first = nn.Conv2d(
            3, n_channel_start, 3, stride=1, padding=1, bias=True)
        self.dense_blocks = nn.ModuleList([])
        self.transition_downs = nn.ModuleList([])
        self.transition_ups = nn.ModuleList([])

        nskip = []
        nIn = self.n_channel_start
        for i in range(n_scales):
            self.dense_blocks.append(
                DenseBlock(nIn, self.growth_rates[i], self.depths[i],
                           drop_rate=drop_rate, bottle_neck=bottle_neck))
            nIn += self.growth_rates[i] * self.depths[i]
            nskip.append(nIn)
            self.transition_downs.append(self.get_TD(nIn, drop_rate))

        self.dense_blocks.append(
            DenseBlock(nIn, self.growth_rates[n_scales], self.depths[n_scales],
                       only_new=True, drop_rate=drop_rate,
                       bottle_neck=bottle_neck))
        nIn = self.growth_rates[n_scales] * self.depths[n_scales]

        for i in range(n_scales-1):
            self.transition_ups.append(nn.ConvTranspose2d(
                nIn, nIn, 3, stride=2, padding=1, bias=True))
            nIn += nskip.pop()
            self.dense_blocks.append(
                DenseBlock(nIn, self.growth_rates[n_scales + 1 + i],
                           self.depths[n_scales + 1 + i],
                           only_new=True, drop_rate=drop_rate,
                           bottle_neck=bottle_neck))
            nIn = self.growth_rates[n_scales + 1 + i] * \
                self.depths[n_scales + 1 + i]
        # last dense block
        self.transition_ups.append(nn.ConvTranspose2d(
            nIn, nIn, 3, stride=2, padding=1, bias=True))
        nIn += nskip.pop()
        self.dense_blocks.append(
            DenseBlock(nIn, self.growth_rates[2 * n_scales],
                       self.depths[2 * n_scales], drop_rate=drop_rate,
                       bottle_neck=bottle_neck))
        nIn += self.growth_rates[2 * n_scales] * \
            self.depths[2 * n_scales]
        self.conv_last = nn.Conv2d(nIn, n_classes, 1, bias=True)
        self.logsoftmax = nn.LogSoftmax()

    def forward(self, x):
        x = self.conv_first(x)
        skip_connects = []
        # down sample
        for i in range(self.n_scales):
            x = self.dense_blocks[i](x)
            skip_connects.append(x)
            x = self.transition_downs[i](x)
        # bottle neck
        x = self.dense_blocks[self.n_scales](x)
        # up sample
        for i in range(self.n_scales):
            skip = skip_connects.pop()
            TU = self.transition_ups[i]
            # adjust padding
            TU.padding = (((x.size(2) - 1) * TU.stride[0] - skip.size(2)
                           + TU.kernel_size[0] + 1) // 2,
                          ((x.size(3) - 1) * TU.stride[1] - skip.size(3)
                              + TU.kernel_size[1] + 1) // 2)
            x = TU(x, output_size=skip.size())
            x = torch.cat((skip, x), 1)
            x = self.dense_blocks[self.n_scales + 1 + i](x)
        x = self.conv_last(x)
        return self.logsoftmax(x)

    def get_TD(self, nIn, drop_rate):
        layers = [nn.BatchNorm2d(nIn), nn.ReLU(
            True), nn.Conv2d(nIn, nIn, 1, bias=True)]
        if drop_rate > 0:
            layers.append(nn.Dropout(drop_rate))
        layers.append(nn.MaxPool2d(2))
        return nn.Sequential(*layers)


def fcdensenet_tiny(drop_rate=0):
    return FCDenseNet(2, 6, drop_rate=drop_rate)


def fcdensenet56_nodrop():
    return FCDenseNet(4, 12, drop_rate=0)


def fcdensenet56(drop_rate=0.2):
    return FCDenseNet(4, 12, drop_rate=drop_rate)


def fcdensenet67(drop_rate=0.2):
    return FCDenseNet(5, 16, drop_rate=drop_rate)


def fcdensenet103(drop_rate=0.2):
    return FCDenseNet([4, 5, 7, 10, 12, 15, 12, 10, 7, 5, 4], 16,
                      drop_rate=drop_rate)


def fcdensenet103_nodrop(drop_rate=0):
    return FCDenseNet([4, 5, 7, 10, 12, 15, 12, 10, 7, 5, 4], 16,
                      drop_rate=drop_rate)