Last active
October 25, 2024 17:34
-
-
Save trougnouf/9b5b5fe94341394ab29b977d5b69e65f to your computer and use it in GitHub Desktop.
Denoising architecture used in RawNIND (for both linear RGB and Bayer input images)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class UtNet2(Denoiser): | |
def __init__( | |
self, | |
in_channels: int, | |
funit: int = 32, | |
activation: str = "LeakyReLU", | |
preupsample: bool = False, | |
): | |
super().__init__(in_channels=in_channels) | |
assert (in_channels == 3 and not preupsample) or in_channels == 4 | |
activation_fun, activation_params = get_activation_class_params(activation) | |
# self.pad = nn.ReflectionPad2d(2) | |
if preupsample: | |
self.preprocess = torch.nn.Upsample( | |
scale_factor=2, mode="bilinear", align_corners=False | |
) | |
else: | |
self.preprocess = torch.nn.Identity() | |
self.convs1 = nn.Sequential( | |
nn.Conv2d(in_channels, funit, 3, padding=1), | |
activation_fun(**activation_params), | |
nn.Conv2d(funit, funit, 3, padding=1), | |
activation_fun(**activation_params), | |
) | |
self.maxpool = nn.MaxPool2d(2) | |
self.convs2 = nn.Sequential( | |
nn.Conv2d(funit, 2 * funit, 3, padding=1), | |
activation_fun(**activation_params), | |
nn.Conv2d(2 * funit, 2 * funit, 3, padding=1), | |
activation_fun(**activation_params), | |
) | |
self.convs3 = nn.Sequential( | |
nn.Conv2d(2 * funit, 4 * funit, 3, padding=1), | |
activation_fun(**activation_params), | |
nn.Conv2d(4 * funit, 4 * funit, 3, padding=1), | |
activation_fun(**activation_params), | |
) | |
self.convs4 = nn.Sequential( | |
nn.Conv2d(4 * funit, 8 * funit, 3, padding=1), | |
activation_fun(**activation_params), | |
nn.Conv2d(8 * funit, 8 * funit, 3, padding=1), | |
activation_fun(**activation_params), | |
) | |
self.bottom = nn.Sequential( | |
nn.Conv2d(8 * funit, 16 * funit, 3, padding=1), | |
activation_fun(**activation_params), | |
nn.Conv2d(16 * funit, 16 * funit, 3, padding=1), | |
activation_fun(**activation_params), | |
) | |
self.up1 = nn.ConvTranspose2d(16 * funit, 8 * funit, 2, stride=2) | |
self.tconvs1 = nn.Sequential( | |
nn.Conv2d(16 * funit, 8 * funit, 3, padding=1), | |
activation_fun(**activation_params), | |
nn.Conv2d(8 * funit, 8 * funit, 3, padding=1), | |
activation_fun(**activation_params), | |
) | |
self.up2 = nn.ConvTranspose2d(8 * funit, 4 * funit, 2, stride=2) | |
self.tconvs2 = nn.Sequential( | |
nn.Conv2d(8 * funit, 4 * funit, 3, padding=1), | |
activation_fun(**activation_params), | |
nn.Conv2d(4 * funit, 4 * funit, 3, padding=1), | |
activation_fun(**activation_params), | |
) | |
self.up3 = nn.ConvTranspose2d(4 * funit, 2 * funit, 2, stride=2) | |
self.tconvs3 = nn.Sequential( | |
nn.Conv2d(4 * funit, 2 * funit, 3, padding=1), | |
activation_fun(**activation_params), | |
nn.Conv2d(2 * funit, 2 * funit, 3, padding=1), | |
activation_fun(**activation_params), | |
) | |
self.up4 = nn.ConvTranspose2d(2 * funit, funit, 2, stride=2) | |
self.tconvs4 = nn.Sequential( | |
nn.Conv2d(2 * funit, funit, 3, padding=1), | |
activation_fun(**activation_params), | |
nn.Conv2d(funit, funit, 3, padding=1), | |
activation_fun(**activation_params), | |
) | |
if in_channels == 3 or preupsample: | |
self.output_module = nn.Sequential(nn.Conv2d(funit, 3, 1)) | |
elif in_channels == 4: | |
self.output_module = nn.Sequential( | |
nn.Conv2d(funit, 4 * 3, 1), nn.PixelShuffle(2) | |
) | |
else: | |
raise NotImplementedError(f"{in_channels=}") | |
# self.unpad = nn.ZeroPad2d(-2) | |
for m in self.modules(): | |
if isinstance(m, nn.Conv2d): | |
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") | |
if isinstance(m, nn.ConvTranspose2d): | |
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") | |
# TODO try xavier_normal_ ? | |
def forward(self, l): | |
l1 = self.preprocess(l) | |
# l = self.pad(l) | |
l1 = self.convs1(l1) | |
l2 = self.convs2(self.maxpool(l1)) | |
l3 = self.convs3(self.maxpool(l2)) | |
l4 = self.convs4(self.maxpool(l3)) | |
l = torch.cat([self.up1(self.bottom(self.maxpool(l4))), l4], dim=1) | |
l = torch.cat([self.up2(self.tconvs1(l)), l3], dim=1) | |
l = torch.cat([self.up3(self.tconvs2(l)), l2], dim=1) | |
l = torch.cat([self.up4(self.tconvs3(l)), l1], dim=1) | |
l = self.tconvs4(l) | |
# l = self.unpad(l) | |
return self.output_module(l) | |
def get_activation_class_params(activation: str) -> tuple: | |
if activation == "PReLU": | |
return nn.PReLU, {} | |
elif activation == "ELU": | |
return nn.ELU, {"inplace": True} | |
elif activation == "Hardswish": | |
return nn.Hardswish, {"inplace": True} | |
elif activation == "LeakyReLU": | |
return nn.LeakyReLU, {"inplace": True, "negative_slope": 0.2} | |
# negative_slope from # per https://github.com/lavi135246/pytorch-Learning-to-See-in-the-Dark/blob/master/model.py | |
else: | |
exit(f"get_activation_class: unknown activation function: {activation}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment