This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| DeepLabv3+ model (https://arxiv.org/abs/1802.02611) | |
| Author: Jacob Reinhold ([email protected]) | |
| """ | |
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from torchvision.models import resnet101 | |
| from torchvision.models._utils import IntermediateLayerGetter |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| U-Net architecture in PyTorch (https://arxiv.org/abs/1505.04597) | |
| Author: Jacob Reinhold ([email protected]) | |
| """ | |
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def classification_uncertainty(logits:torch.Tensor, sigmas:torch.Tensor, eps:float=1e-6) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ calculate epistemic, entropy, and aleatory uncertainty quantities """ | |
| probits = torch.sigmoid(logits) | |
| epistemic = probits.var(dim=0, unbiased=True) | |
| probit = probits.mean(dim=0) | |
| entropy = -1 * (probit * (probit + eps).log2() + ((1 - probit) * (1 - probit + eps).log2())) | |
| aleatory = torch.exp(sigmas).mean(dim=0) | |
| return epistemic, entropy, aleatory |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def regression_uncertainty(yhat:torch.Tensor, s:torch.Tensor, mse:bool=True) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ calculate epistemic and aleatory uncertainty quantities based on whether MSE or L1 loss used """ | |
| # variance over samples (dim=0), mean over channels (dim=1, after reduction by variance calculation) | |
| epistemic = torch.mean(yhat.var(dim=0, unbiased=True), dim=1, keepdim=True) | |
| aleatory = torch.mean(torch.exp(s), dim=0) if mse else torch.mean(2*torch.exp(s)**2, dim=0) | |
| return epistemic, aleatory |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class ExtendedBCELoss(nn.Module): | |
| """ modified BCE loss for variance fitting """ | |
| def forward(self, out:torch.Tensor, y:torch.Tensor, n_samp:int=10) -> torch.Tensor: | |
| logit, sigma = out | |
| dist = torch.distributions.Normal(logit, torch.exp(sigma)) | |
| mc_logs = dist.rsample((n_samp,)) | |
| loss = 0. | |
| for mc_log in mc_logs: | |
| loss += F.binary_cross_entropy_with_logits(mc_log, y) | |
| loss /= n_samp |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class ExtendedL1Loss(nn.Module): | |
| """ modified L1 loss for scale param. fitting """ | |
| def forward(self, out:torch.Tensor, y:torch.Tensor) -> torch.Tensor: | |
| yhat, s = out | |
| loss = torch.mean((torch.exp(-s) * F.l1_loss(yhat, y, reduction='none')) + s) | |
| return loss |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class ExtendedMSELoss(nn.Module): | |
| """ modified MSE loss for variance fitting """ | |
| def forward(self, out:torch.Tensor, y:torch.Tensor) -> torch.Tensor: | |
| yhat, s = out | |
| loss = torch.mean(0.5 * (torch.exp(-s) * F.mse_loss(yhat, y, reduction='none') + s)) | |
| return loss |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python | |
| # -*- coding: utf-8 -*- | |
| """ | |
| tif_to_nii | |
| command line executable to convert a directory of tif images | |
| (from one image) to a nifti image stacked along a user-specified axis | |
| call as: python tif_to_nii.py /path/to/tif/ /path/to/nifti | |
| (append optional arguments to the call as desired) |