This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Learning rate and EMA warmup schedulers for PyTorch.""" | |
import warnings | |
from torch import optim | |
class InverseLR(optim.lr_scheduler._LRScheduler): | |
"""Implements an inverse decay learning rate schedule with an optional exponential | |
warmup. When last_epoch=-1, sets initial lr as lr. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
"""Computes the channel-wise means, standard deviations, and covariance | |
matrix of a dataset of images.""" | |
import argparse | |
import torch | |
from torch.utils import data | |
from torchvision import datasets, transforms as T |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
"""Dumps a Caffe binary model to a pickle of NumPy arrays.""" | |
import argparse | |
from collections import OrderedDict | |
import os | |
import pickle | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Matrix square roots with backward passes. | |
Cleaned up from https://github.com/msubhransu/matrix-sqrt. | |
""" | |
import torch | |
def sqrtm_ns(a, num_iters=10): | |
if a.ndim < 2: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
"""Learns the parity function.""" | |
import torch | |
from torch import nn, optim | |
from tqdm import trange, tqdm | |
class GatedUnit(nn.Module): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Complex momentum SGD and Adam. See https://arxiv.org/abs/2102.08431.""" | |
import math | |
import torch | |
from torch import optim | |
class ComplexSGD(optim.Optimizer): | |
def __init__(self, params, lr=1e-2, momentum=0.9, angle=math.pi / 8, weight_decay=0.): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import geoopt | |
def spherical_avg(p, w=None, tol=1e-6): | |
sphere = geoopt.Sphere() | |
if w is None: | |
w = p.new_ones([p.shape[0]]) | |
assert p.ndim == 2 and w.ndim == 1 and len(p) == len(w) | |
w = w / w.sum() | |
p = sphere.projx(p) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
"""Generates images from saved embeddings with CLIP.""" | |
import argparse | |
from concurrent import futures | |
import sys | |
import torch | |
from torch import nn, optim |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Exponential moving average for PyTorch. Adapted from | |
https://www.zijianhu.com/post/pytorch/ema/. | |
""" | |
from copy import deepcopy | |
import torch | |
from torch import nn | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
class Binomial2Pool2d(nn.Module): | |
def __init__(self, ceil_mode=False): | |
super().__init__() | |
self.ceil_mode = ceil_mode | |
kernel = [[[[1/16, 1/8, 1/16], [1/8, 1/4, 1/8], [1/16, 1/8, 1/16]]]] |