This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import csv | |
from pathlib import Path | |
import torch | |
from torch import optim, nn | |
from torch.nn import functional as F | |
from torch.utils import data | |
from torchvision import datasets, transforms, utils | |
from torchvision.transforms import functional as TF |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import argparse | |
from collections import defaultdict | |
import csv | |
import math | |
import torch | |
from torch import nn, optim | |
from torch.nn import functional as F |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
class PseudoHuberLoss(nn.Module): | |
"""The Pseudo-Huber loss.""" | |
reductions = {'mean': torch.mean, 'sum': torch.sum, 'none': lambda x: x} | |
def __init__(self, beta=1, reduction='mean'): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch import nn | |
from torch.nn import functional as F | |
class SoftPool2d(nn.Module): | |
"""Applies a 2D soft pooling over an input signal composed of several | |
input planes. See https://arxiv.org/abs/2101.00440""" | |
def __init__(self, kernel_size, ceil_mode=False, temperature=1.): | |
super().__init__() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Trains IMLE on the MNIST dataset.""" | |
import torch | |
from torch import optim, nn | |
from torch.utils import data | |
from torchvision import datasets, transforms, utils | |
from torchvision.transforms import functional as TF | |
from tqdm import tqdm | |
from vgg_loss import vgg_loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Trains IMLE on the MNIST dataset.""" | |
import torch | |
from torch import optim, nn | |
from torch.nn import functional as F | |
from torch.utils import data | |
from torchvision import datasets, transforms, utils | |
from torchvision.transforms import functional as TF | |
from tqdm import tqdm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Good differentiable image resampling for PyTorch.""" | |
from functools import update_wrapper | |
import math | |
import torch | |
from torch.nn import functional as F | |
def sinc(x): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from math import ceil | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
class Downsample2d(nn.Module): | |
kernels = { | |
'binomial2': [0.25, 0.5, 0.25], |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import abc | |
import torch | |
from torch import nn, optim | |
class Constraint(nn.Module, metaclass=abc.ABCMeta): | |
def __init__(self, fn, damping): | |
super().__init__() | |
self.fn = fn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn, optim | |
class Constraint(nn.Module): | |
def __init__(self, fn, maximum, damping=1e-2): | |
super().__init__() | |
self.fn = fn | |
self.register_buffer('maximum', torch.as_tensor(maximum)) | |
self.register_buffer('damping', torch.as_tensor(damping)) |