This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
"""Dumps a Caffe binary model to a pickle of NumPy arrays.""" | |
import argparse | |
from collections import OrderedDict | |
import os | |
import pickle | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Matrix square roots with backward passes. | |
Cleaned up from https://github.com/msubhransu/matrix-sqrt. | |
""" | |
import torch | |
def sqrtm_ns(a, num_iters=10): | |
if a.ndim < 2: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
"""Learns the parity function.""" | |
import torch | |
from torch import nn, optim | |
from tqdm import trange, tqdm | |
class GatedUnit(nn.Module): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Complex momentum SGD and Adam. See https://arxiv.org/abs/2102.08431.""" | |
import math | |
import torch | |
from torch import optim | |
class ComplexSGD(optim.Optimizer): | |
def __init__(self, params, lr=1e-2, momentum=0.9, angle=math.pi / 8, weight_decay=0.): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import geoopt | |
def spherical_avg(p, w=None, tol=1e-6): | |
sphere = geoopt.Sphere() | |
if w is None: | |
w = p.new_ones([p.shape[0]]) | |
assert p.ndim == 2 and w.ndim == 1 and len(p) == len(w) | |
w = w / w.sum() | |
p = sphere.projx(p) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
"""Generates images from saved embeddings with CLIP.""" | |
import argparse | |
from concurrent import futures | |
import sys | |
import torch | |
from torch import nn, optim |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Exponential moving average for PyTorch. Adapted from | |
https://www.zijianhu.com/post/pytorch/ema/. | |
""" | |
from copy import deepcopy | |
import torch | |
from torch import nn | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
class Binomial2Pool2d(nn.Module): | |
def __init__(self, ceil_mode=False): | |
super().__init__() | |
self.ceil_mode = ceil_mode | |
kernel = [[[[1/16, 1/8, 1/16], [1/8, 1/4, 1/8], [1/16, 1/8, 1/16]]]] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Exponential moving average for PyTorch. Adapted from | |
https://www.zijianhu.com/post/pytorch/ema/. | |
""" | |
from copy import deepcopy | |
import torch | |
from torch import nn | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import argparse | |
import copy | |
from functools import wraps | |
from hashlib import sha256 | |
from io import open | |
import json | |
import math | |
import logging |