This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Trains IMLE on the MNIST dataset.""" | |
import torch | |
from torch import optim, nn | |
from torch.utils import data | |
from torchvision import datasets, transforms, utils | |
from torchvision.transforms import functional as TF | |
from tqdm import tqdm | |
from vgg_loss import vgg_loss |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch import nn | |
from torch.nn import functional as F | |
class SoftPool2d(nn.Module): | |
"""Applies a 2D soft pooling over an input signal composed of several | |
input planes. See https://arxiv.org/abs/2101.00440""" | |
def __init__(self, kernel_size, ceil_mode=False, temperature=1.): | |
super().__init__() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
class PseudoHuberLoss(nn.Module): | |
"""The Pseudo-Huber loss.""" | |
reductions = {'mean': torch.mean, 'sum': torch.sum, 'none': lambda x: x} | |
def __init__(self, beta=1, reduction='mean'): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import argparse | |
from collections import defaultdict | |
import csv | |
import math | |
import torch | |
from torch import nn, optim | |
from torch.nn import functional as F |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import csv | |
from pathlib import Path | |
import torch | |
from torch import optim, nn | |
from torch.nn import functional as F | |
from torch.utils import data | |
from torchvision import datasets, transforms, utils | |
from torchvision.transforms import functional as TF |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import argparse | |
import copy | |
from functools import wraps | |
from hashlib import sha256 | |
from io import open | |
import json | |
import math | |
import logging |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Exponential moving average for PyTorch. Adapted from | |
https://www.zijianhu.com/post/pytorch/ema/. | |
""" | |
from copy import deepcopy | |
import torch | |
from torch import nn | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
class Binomial2Pool2d(nn.Module): | |
def __init__(self, ceil_mode=False): | |
super().__init__() | |
self.ceil_mode = ceil_mode | |
kernel = [[[[1/16, 1/8, 1/16], [1/8, 1/4, 1/8], [1/16, 1/8, 1/16]]]] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Exponential moving average for PyTorch. Adapted from | |
https://www.zijianhu.com/post/pytorch/ema/. | |
""" | |
from copy import deepcopy | |
import torch | |
from torch import nn | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
"""Generates images from saved embeddings with CLIP.""" | |
import argparse | |
from concurrent import futures | |
import sys | |
import torch | |
from torch import nn, optim |