This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Decoder(nn.Module): | |
""" | |
Parameters: | |
in_channels (int): number of in_channels of the first ConvTranspose2d | |
out_channels (int): number of out_channels of the first ConvTranspose2d | |
padding (int): padding applied in each convolution | |
uphill (int): number times a ConvTranspose2d + CNNBlocks it's applied. | |
""" | |
def __init__(self, | |
in_channels, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class UNET(nn.Module): | |
def __init__(self, | |
in_channels, | |
first_out_channels, | |
exit_channels, | |
downhill, | |
padding=0 | |
): | |
super(UNET, self).__init__() | |
self.encoder = Encoder(in_channels, first_out_channels, padding=padding, downhill=downhill) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class DAVIS2017(Dataset): | |
"""DAVIS 2017 dataset constructed using the PyTorch built-in functionalities""" | |
def __init__(self, train=True, | |
db_root_dir=ROOT_DIR, | |
transform=None, | |
seq_name=None, | |
pad_mirroring=None): | |
"""Loads image to label pairs for tool pose estimation | |
db_root_dir: dataset directory with subfolders "JPEGImages" and "Annotations" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def __getitem__(self, idx): | |
img = np.array(Image.open(os.path.join(self.db_root_dir, self.img_list[idx])).convert("RGB"), dtype=np.float32) | |
gt = np.array(Image.open(os.path.join(self.db_root_dir, self.labels[idx])).convert("L"), dtype=np.float32) | |
gt = ((gt/np.max([gt.max(), 1e-8])) > 0.5).astype(np.float32) | |
#gt = gt.astype(np.bool).astype(np.float32) | |
if self.transform is not None: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def main(): | |
loss_fn = torch.nn.BCEWithLogitsLoss() | |
scaler = torch.cuda.amp.GradScaler() | |
model = UNET(3, 64, 1, padding=0, downhill=4).to(DEVICE) | |
optim = Adam(model.parameters(), lr=LEARNING_RATE) | |
if CHECKPOINT: | |
load_model_checkpoint(CHECKPOINT, model) | |
load_optim_checkpoint(CHECKPOINT, optim) | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class CBL(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size, stride, padding): | |
super(CBL, self).__init__() | |
conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False) | |
bn = nn.BatchNorm2d(out_channels, eps=1e-3, momentum=0.03) | |
self.cbl = nn.Sequential( | |
conv, | |
bn, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Bottleneck(nn.Module): | |
""" | |
Parameters: | |
in_channels (int): number of channel of the input tensor | |
out_channels (int): number of channel of the output tensor | |
width_multiple (float): it controls the number of channels (and weights) | |
of all the convolutions beside the | |
first and last one. If closer to 0, | |
the simpler the modelIf closer to 1, | |
the model becomes more complex |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class C3(nn.Module): | |
""" | |
Parameters: | |
in_channels (int): number of channel of the input tensor | |
out_channels (int): number of channel of the output tensor | |
width_multiple (float): it controls the number of channels (and weights) | |
of all the convolutions beside the | |
first and last one. If closer to 0, | |
the simpler the modelIf closer to 1, | |
the model becomes more complex |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class SPPF(nn.Module): | |
def __init__(self, in_channels, out_channels): | |
super(SPPF, self).__init__() | |
c_ = int(in_channels//2) | |
self.c1 = CBL(in_channels, c_, 1, 1, 0) | |
self.pool = nn.MaxPool2d(kernel_size=5, stride=1, padding=2) | |
self.c_out = CBL(c_ * 4, out_channels, 1, 1, 0) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
self.backbone += [ | |
CBL(in_channels=3, out_channels=first_out, kernel_size=6, stride=2, padding=2), | |
CBL(in_channels=first_out, out_channels=first_out*2, kernel_size=3, stride=2, padding=1), | |
C3(in_channels=first_out*2, out_channels=first_out*2, width_multiple=0.5, depth=2), | |
CBL(in_channels=first_out*2, out_channels=first_out*4, kernel_size=3, stride=2, padding=1), | |
C3(in_channels=first_out*4, out_channels=first_out*4, width_multiple=0.5, depth=4), | |
CBL(in_channels=first_out*4, out_channels=first_out*8, kernel_size=3, stride=2, padding=1), | |
C3(in_channels=first_out*8, out_channels=first_out*8, width_multiple=0.5, depth=6), | |
CBL(in_channels=first_out*8, out_channels=first_out*16, kernel_size=3, stride=2, padding=1), | |
C3(in_channels=first_out*16, out_channels=first_out*16, width_multiple=0.5, depth=2), |