This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def main(): | |
loss_fn = torch.nn.BCEWithLogitsLoss() | |
scaler = torch.cuda.amp.GradScaler() | |
model = UNET(3, 64, 1, padding=0, downhill=4).to(DEVICE) | |
optim = Adam(model.parameters(), lr=LEARNING_RATE) | |
if CHECKPOINT: | |
load_model_checkpoint(CHECKPOINT, model) | |
load_optim_checkpoint(CHECKPOINT, optim) | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def __getitem__(self, idx): | |
img = np.array(Image.open(os.path.join(self.db_root_dir, self.img_list[idx])).convert("RGB"), dtype=np.float32) | |
gt = np.array(Image.open(os.path.join(self.db_root_dir, self.labels[idx])).convert("L"), dtype=np.float32) | |
gt = ((gt/np.max([gt.max(), 1e-8])) > 0.5).astype(np.float32) | |
#gt = gt.astype(np.bool).astype(np.float32) | |
if self.transform is not None: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class DAVIS2017(Dataset): | |
"""DAVIS 2017 dataset constructed using the PyTorch built-in functionalities""" | |
def __init__(self, train=True, | |
db_root_dir=ROOT_DIR, | |
transform=None, | |
seq_name=None, | |
pad_mirroring=None): | |
"""Loads image to label pairs for tool pose estimation | |
db_root_dir: dataset directory with subfolders "JPEGImages" and "Annotations" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class UNET(nn.Module): | |
def __init__(self, | |
in_channels, | |
first_out_channels, | |
exit_channels, | |
downhill, | |
padding=0 | |
): | |
super(UNET, self).__init__() | |
self.encoder = Encoder(in_channels, first_out_channels, padding=padding, downhill=downhill) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Decoder(nn.Module): | |
""" | |
Parameters: | |
in_channels (int): number of in_channels of the first ConvTranspose2d | |
out_channels (int): number of out_channels of the first ConvTranspose2d | |
padding (int): padding applied in each convolution | |
uphill (int): number times a ConvTranspose2d + CNNBlocks it's applied. | |
""" | |
def __init__(self, | |
in_channels, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Encoder(nn.Module): | |
""" | |
Parameters: | |
in_channels (int): number of in_channels of the first CNNBlocks | |
out_channels (int): number of out_channels of the first CNNBlocks | |
padding (int): padding applied in each convolution | |
downhill (int): number times a CNNBlocks + MaxPool2D it's applied. | |
""" | |
def __init__(self, | |
in_channels, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class CNNBlocks(nn.Module): | |
""" | |
Parameters: | |
n_conv (int): creates a block of n_conv convolutions | |
in_channels (int): number of in_channels of the first block's convolution | |
out_channels (int): number of out_channels of the first block's convolution | |
expand (bool) : if True after the first convolution of a blocl the number of channels doubles | |
""" | |
def __init__(self, | |
n_conv, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class CNNBlock(nn.Module): | |
def __init__(self, | |
in_channels, | |
out_channels, | |
kernel_size=3, | |
stride=1, | |
padding=0): | |
super(CNNBlock, self).__init__() | |
self.seq_block = nn.Sequential( |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def softmax(self, logits): | |
if self.library == "tf": | |
exp = tf.exp(logits) | |
denom = tf.math.reduce_sum(exp, 1, keepdims=True) | |
else: | |
exp = torch.exp(logits) | |
denom = torch.sum(exp, dim=1, keepdim=True) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def accuracy(self, y_hat, Y): | |
if self.library == "tf": | |
# calculate argmax | |
argmax = tf.cast(tf.argmax(y_hat, axis=1), Y.dtype) | |
acc = tf.math.reduce_sum(tf.cast(argmax == Y, tf.int32)) / Y.shape[0] | |
else: | |
argmax = torch.argmax(y_hat, dim=1) | |
acc = torch.sum(torch.eq(argmax, Y)) / Y.shape[0] | |
return acc |