This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | def main(): | |
| loss_fn = torch.nn.BCEWithLogitsLoss() | |
| scaler = torch.cuda.amp.GradScaler() | |
| model = UNET(3, 64, 1, padding=0, downhill=4).to(DEVICE) | |
| optim = Adam(model.parameters(), lr=LEARNING_RATE) | |
| if CHECKPOINT: | |
| load_model_checkpoint(CHECKPOINT, model) | |
| load_optim_checkpoint(CHECKPOINT, optim) | |
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | def __getitem__(self, idx): | |
| img = np.array(Image.open(os.path.join(self.db_root_dir, self.img_list[idx])).convert("RGB"), dtype=np.float32) | |
| gt = np.array(Image.open(os.path.join(self.db_root_dir, self.labels[idx])).convert("L"), dtype=np.float32) | |
| gt = ((gt/np.max([gt.max(), 1e-8])) > 0.5).astype(np.float32) | |
| #gt = gt.astype(np.bool).astype(np.float32) | |
| if self.transform is not None: | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | class DAVIS2017(Dataset): | |
| """DAVIS 2017 dataset constructed using the PyTorch built-in functionalities""" | |
| def __init__(self, train=True, | |
| db_root_dir=ROOT_DIR, | |
| transform=None, | |
| seq_name=None, | |
| pad_mirroring=None): | |
| """Loads image to label pairs for tool pose estimation | |
| db_root_dir: dataset directory with subfolders "JPEGImages" and "Annotations" | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | class UNET(nn.Module): | |
| def __init__(self, | |
| in_channels, | |
| first_out_channels, | |
| exit_channels, | |
| downhill, | |
| padding=0 | |
| ): | |
| super(UNET, self).__init__() | |
| self.encoder = Encoder(in_channels, first_out_channels, padding=padding, downhill=downhill) | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | class Decoder(nn.Module): | |
| """ | |
| Parameters: | |
| in_channels (int): number of in_channels of the first ConvTranspose2d | |
| out_channels (int): number of out_channels of the first ConvTranspose2d | |
| padding (int): padding applied in each convolution | |
| uphill (int): number times a ConvTranspose2d + CNNBlocks it's applied. | |
| """ | |
| def __init__(self, | |
| in_channels, | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | class Encoder(nn.Module): | |
| """ | |
| Parameters: | |
| in_channels (int): number of in_channels of the first CNNBlocks | |
| out_channels (int): number of out_channels of the first CNNBlocks | |
| padding (int): padding applied in each convolution | |
| downhill (int): number times a CNNBlocks + MaxPool2D it's applied. | |
| """ | |
| def __init__(self, | |
| in_channels, | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | class CNNBlocks(nn.Module): | |
| """ | |
| Parameters: | |
| n_conv (int): creates a block of n_conv convolutions | |
| in_channels (int): number of in_channels of the first block's convolution | |
| out_channels (int): number of out_channels of the first block's convolution | |
| expand (bool) : if True after the first convolution of a blocl the number of channels doubles | |
| """ | |
| def __init__(self, | |
| n_conv, | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | class CNNBlock(nn.Module): | |
| def __init__(self, | |
| in_channels, | |
| out_channels, | |
| kernel_size=3, | |
| stride=1, | |
| padding=0): | |
| super(CNNBlock, self).__init__() | |
| self.seq_block = nn.Sequential( | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | def softmax(self, logits): | |
| if self.library == "tf": | |
| exp = tf.exp(logits) | |
| denom = tf.math.reduce_sum(exp, 1, keepdims=True) | |
| else: | |
| exp = torch.exp(logits) | |
| denom = torch.sum(exp, dim=1, keepdim=True) | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | def accuracy(self, y_hat, Y): | |
| if self.library == "tf": | |
| # calculate argmax | |
| argmax = tf.cast(tf.argmax(y_hat, axis=1), Y.dtype) | |
| acc = tf.math.reduce_sum(tf.cast(argmax == Y, tf.int32)) / Y.shape[0] | |
| else: | |
| argmax = torch.argmax(y_hat, dim=1) | |
| acc = torch.sum(torch.eq(argmax, Y)) / Y.shape[0] | |
| return acc |