Skip to content

Instantly share code, notes, and snippets.

"""Vectorial TV loss using higher accuracy order finite difference operators."""
import torch
FINITE_DIFFERENCE_COEFFS = {
1: torch.tensor([-1, 1]),
2: torch.tensor([-3 / 2, 2, -1 / 2]),
3: torch.tensor([-11 / 6, 3, -3 / 2, 1 / 3]),
4: torch.tensor([-25 / 12, 4, -3, 4 / 3, -1 / 4]),
#!/usr/bin/env python3
"""Pads an image for Bluesky."""
import argparse
import math
from pathlib import Path
from PIL import Image
@crowsonkb
crowsonkb / jax_wavelet.py
Last active February 10, 2023 17:56
JAX implementation of the 2D DWT and IDWT.
"""JAX implementation of the 2D DWT and IDWT."""
from einops import rearrange
import jax
import jax.numpy as jnp
import pywt
def get_filter_bank(wavelet, dtype=jnp.float32):
"""Get the filter bank for a given pywavelets wavelet name."""
@crowsonkb
crowsonkb / jax_local_cluster.py
Last active March 7, 2024 11:37
A simple JAX process launcher for multiple devices on a single host.
#!/usr/bin/env python3
"""A simple JAX process launcher for multiple devices on a single host.
You must import jax_local_cluster somewhere inside the script you are launching.
"""
import argparse
from functools import partial
import os
@crowsonkb
crowsonkb / LICENSE
Last active December 25, 2023 04:28
Adam Langevin Dynamics for optax, by Katherine Crowson
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
#!/usr/bin/env python3
import math
import multiprocessing as mp
import typing
from einops import rearrange
import flax
import flax.linen as nn
import jax
@crowsonkb
crowsonkb / biggan.py
Last active November 1, 2022 06:11
BigGAN + CLIP, Langevin dynamics method
import copy
from functools import wraps
from hashlib import sha256
from io import open
import json
import math
import logging
import os
from pathlib import Path
import shutil
import math
import torch
from torch import optim
class AdamWFinetune(optim.Optimizer):
r"""Implements AdamW algorithm with optional weight decay toward the starting value, to
prevent overfitting to the new dataset during fine-tuning.
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
"""Learning rate and EMA warmup schedulers for PyTorch."""
import warnings
from torch import optim
class InverseLR(optim.lr_scheduler._LRScheduler):
"""Implements an inverse decay learning rate schedule with an optional exponential
warmup. When last_epoch=-1, sets initial lr as lr.
#!/usr/bin/env python3
"""Computes the channel-wise means, standard deviations, and covariance
matrix of a dataset of images."""
import argparse
import torch
from torch.utils import data
from torchvision import datasets, transforms as T