This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# The name of the Pygments (syntax highlighting) style to use. | |
sys.path.append('.') | |
import pygments.styles | |
import base16_spacemacs_kat | |
# pygments.styles.base16_spacemacs_kat = base16_spacemacs_kat | |
sys.modules['pygments.styles.' + 'base16_spacemacs_kat'] = base16_spacemacs_kat | |
pygments.styles.STYLE_MAP['base16-spacemacs-kat'] = 'base16_spacemacs_kat::Base16SpacemacsStyle' | |
pygments_style = 'base16-spacemacs-kat' # 'friendly' |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from prettyprinter import register_pretty, pretty_call | |
import pyparsing as pp | |
@register_pretty(pp.ParseResults) | |
def pretty_parse_results(value, ctx): | |
return pretty_call(ctx, pp.ParseResults, value.asList(), value.asDict()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch import nn | |
class Lambda(nn.Module): | |
"""Wraps a callable in an :class:`nn.Module` without registering it.""" | |
def __init__(self, func): | |
super().__init__() | |
object.__setattr__(self, 'forward', func) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.nn import functional as F | |
from torchvision import transforms | |
def random_shift(input, max_shift, mode='replicate', value=0): | |
padded = F.pad(input, max_shift, mode=mode, value=value) | |
return transforms.RandomCrop(input.shape[2:])(padded) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Computes the covariance matrix in PyTorch.""" | |
def cov_mean(input, unbiased=True, keepdims=False): | |
n = input.shape[-1] - unbiased | |
mean = input.mean(dim=-1, keepdims=True) | |
dev = input - mean | |
mean = mean if keepdims else mean[..., 0] | |
return dev @ dev.transpose(-1, -2) / n, mean |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn, optim | |
class Constraint(nn.Module): | |
def __init__(self, fn, maximum, damping=1e-2): | |
super().__init__() | |
self.fn = fn | |
self.register_buffer('maximum', torch.as_tensor(maximum)) | |
self.register_buffer('damping', torch.as_tensor(damping)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import abc | |
import torch | |
from torch import nn, optim | |
class Constraint(nn.Module, metaclass=abc.ABCMeta): | |
def __init__(self, fn, damping): | |
super().__init__() | |
self.fn = fn |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from math import ceil | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
class Downsample2d(nn.Module): | |
kernels = { | |
'binomial2': [0.25, 0.5, 0.25], |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Good differentiable image resampling for PyTorch.""" | |
from functools import update_wrapper | |
import math | |
import torch | |
from torch.nn import functional as F | |
def sinc(x): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Trains IMLE on the MNIST dataset.""" | |
import torch | |
from torch import optim, nn | |
from torch.nn import functional as F | |
from torch.utils import data | |
from torchvision import datasets, transforms, utils | |
from torchvision.transforms import functional as TF | |
from tqdm import tqdm |