Skip to content

Instantly share code, notes, and snippets.

@crowsonkb
crowsonkb / mnist_imle_vgg.py
Last active February 5, 2021 21:13
Trains IMLE on the MNIST dataset.
"""Trains IMLE on the MNIST dataset."""
import torch
from torch import optim, nn
from torch.utils import data
from torchvision import datasets, transforms, utils
from torchvision.transforms import functional as TF
from tqdm import tqdm
from vgg_loss import vgg_loss
@crowsonkb
crowsonkb / softpool.py
Created February 7, 2021 12:21
Applies a 2D soft pooling over an input signal composed of several input planes. See https://arxiv.org/abs/2101.00440
from torch import nn
from torch.nn import functional as F
class SoftPool2d(nn.Module):
"""Applies a 2D soft pooling over an input signal composed of several
input planes. See https://arxiv.org/abs/2101.00440"""
def __init__(self, kernel_size, ceil_mode=False, temperature=1.):
super().__init__()
@crowsonkb
crowsonkb / pseudo_huber.py
Created February 7, 2021 12:33
The Pseudo-Huber loss
import torch
from torch import nn
class PseudoHuberLoss(nn.Module):
"""The Pseudo-Huber loss."""
reductions = {'mean': torch.mean, 'sum': torch.sum, 'none': lambda x: x}
def __init__(self, beta=1, reduction='mean'):
#!/usr/bin/env python3
import argparse
from collections import defaultdict
import csv
import math
import torch
from torch import nn, optim
from torch.nn import functional as F
import argparse
import csv
from pathlib import Path
import torch
from torch import optim, nn
from torch.nn import functional as F
from torch.utils import data
from torchvision import datasets, transforms, utils
from torchvision.transforms import functional as TF
#!/usr/bin/env python3
import argparse
import copy
from functools import wraps
from hashlib import sha256
from io import open
import json
import math
import logging
@crowsonkb
crowsonkb / ema.py
Last active February 28, 2021 22:45
Parameter averaging for PyTorch
"""Exponential moving average for PyTorch. Adapted from
https://www.zijianhu.com/post/pytorch/ema/.
"""
from copy import deepcopy
import torch
from torch import nn
@crowsonkb
crowsonkb / binomial_pool.py
Created February 22, 2021 18:08
Binomial2Pool2d
import torch
from torch import nn
from torch.nn import functional as F
class Binomial2Pool2d(nn.Module):
def __init__(self, ceil_mode=False):
super().__init__()
self.ceil_mode = ceil_mode
kernel = [[[[1/16, 1/8, 1/16], [1/8, 1/4, 1/8], [1/16, 1/8, 1/16]]]]
@crowsonkb
crowsonkb / ema_biased.py
Created February 26, 2021 19:09
Biased EMA for PyTorch
"""Exponential moving average for PyTorch. Adapted from
https://www.zijianhu.com/post/pytorch/ema/.
"""
from copy import deepcopy
import torch
from torch import nn
#!/usr/bin/env python3
"""Generates images from saved embeddings with CLIP."""
import argparse
from concurrent import futures
import sys
import torch
from torch import nn, optim