This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Vectorial TV loss using higher accuracy order finite difference operators.""" | |
import torch | |
FINITE_DIFFERENCE_COEFFS = { | |
1: torch.tensor([-1, 1]), | |
2: torch.tensor([-3 / 2, 2, -1 / 2]), | |
3: torch.tensor([-11 / 6, 3, -3 / 2, 1 / 3]), | |
4: torch.tensor([-25 / 12, 4, -3, 4 / 3, -1 / 4]), |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
"""Pads an image for Bluesky.""" | |
import argparse | |
import math | |
from pathlib import Path | |
from PIL import Image |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""JAX implementation of the 2D DWT and IDWT.""" | |
from einops import rearrange | |
import jax | |
import jax.numpy as jnp | |
import pywt | |
def get_filter_bank(wavelet, dtype=jnp.float32): | |
"""Get the filter bank for a given pywavelets wavelet name.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
"""A simple JAX process launcher for multiple devices on a single host. | |
You must import jax_local_cluster somewhere inside the script you are launching. | |
""" | |
import argparse | |
from functools import partial | |
import os |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Apache License | |
Version 2.0, January 2004 | |
http://www.apache.org/licenses/ | |
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION | |
1. Definitions. | |
"License" shall mean the terms and conditions for use, reproduction, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import math | |
import multiprocessing as mp | |
import typing | |
from einops import rearrange | |
import flax | |
import flax.linen as nn | |
import jax |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import copy | |
from functools import wraps | |
from hashlib import sha256 | |
from io import open | |
import json | |
import math | |
import logging | |
import os | |
from pathlib import Path | |
import shutil |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import torch | |
from torch import optim | |
class AdamWFinetune(optim.Optimizer): | |
r"""Implements AdamW algorithm with optional weight decay toward the starting value, to | |
prevent overfitting to the new dataset during fine-tuning. | |
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Learning rate and EMA warmup schedulers for PyTorch.""" | |
import warnings | |
from torch import optim | |
class InverseLR(optim.lr_scheduler._LRScheduler): | |
"""Implements an inverse decay learning rate schedule with an optional exponential | |
warmup. When last_epoch=-1, sets initial lr as lr. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
"""Computes the channel-wise means, standard deviations, and covariance | |
matrix of a dataset of images.""" | |
import argparse | |
import torch | |
from torch.utils import data | |
from torchvision import datasets, transforms as T |