Skip to content

Instantly share code, notes, and snippets.

#!/usr/bin/env python3
"""Dumps a Caffe binary model to a pickle of NumPy arrays."""
import argparse
from collections import OrderedDict
import os
import pickle
"""Matrix square roots with backward passes.
Cleaned up from https://github.com/msubhransu/matrix-sqrt.
"""
import torch
def sqrtm_ns(a, num_iters=10):
if a.ndim < 2:
#!/usr/bin/env python3
"""Learns the parity function."""
import torch
from torch import nn, optim
from tqdm import trange, tqdm
class GatedUnit(nn.Module):
@crowsonkb
crowsonkb / complex_optim.py
Created June 4, 2021 16:56
Complex momentum SGD and Adam. See https://arxiv.org/abs/2102.08431.
"""Complex momentum SGD and Adam. See https://arxiv.org/abs/2102.08431."""
import math
import torch
from torch import optim
class ComplexSGD(optim.Optimizer):
def __init__(self, params, lr=1e-2, momentum=0.9, angle=math.pi / 8, weight_decay=0.):
@crowsonkb
crowsonkb / spherical_avg.py
Created April 26, 2021 22:51
Spherical weighted average
import geoopt
def spherical_avg(p, w=None, tol=1e-6):
sphere = geoopt.Sphere()
if w is None:
w = p.new_ones([p.shape[0]])
assert p.ndim == 2 and w.ndim == 1 and len(p) == len(w)
w = w / w.sum()
p = sphere.projx(p)
#!/usr/bin/env python3
"""Generates images from saved embeddings with CLIP."""
import argparse
from concurrent import futures
import sys
import torch
from torch import nn, optim
@crowsonkb
crowsonkb / ema_biased.py
Created February 26, 2021 19:09
Biased EMA for PyTorch
"""Exponential moving average for PyTorch. Adapted from
https://www.zijianhu.com/post/pytorch/ema/.
"""
from copy import deepcopy
import torch
from torch import nn
@crowsonkb
crowsonkb / binomial_pool.py
Created February 22, 2021 18:08
Binomial2Pool2d
import torch
from torch import nn
from torch.nn import functional as F
class Binomial2Pool2d(nn.Module):
def __init__(self, ceil_mode=False):
super().__init__()
self.ceil_mode = ceil_mode
kernel = [[[[1/16, 1/8, 1/16], [1/8, 1/4, 1/8], [1/16, 1/8, 1/16]]]]
@crowsonkb
crowsonkb / ema.py
Last active February 28, 2021 22:45
Parameter averaging for PyTorch
"""Exponential moving average for PyTorch. Adapted from
https://www.zijianhu.com/post/pytorch/ema/.
"""
from copy import deepcopy
import torch
from torch import nn
#!/usr/bin/env python3
import argparse
import copy
from functools import wraps
from hashlib import sha256
from io import open
import json
import math
import logging