This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import itertools | |
import random | |
class WeightedSampler: | |
"""Samples k elements from a stream of weighted items without replacement. | |
See Weighted Random Sampling (Efraimidis, Spirakis 2005). | |
""" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Stochastic beam search. | |
Implements "Stochastic Beams and Where to Find Them: The Gumbel-Top-k Trick for | |
Sampling Sequences Without Replacement" (https://arxiv.org/abs/1903.06059)""" | |
import math | |
import torch | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Vectorial TV loss using higher accuracy order finite difference operators.""" | |
import torch | |
FINITE_DIFFERENCE_COEFFS = { | |
1: torch.tensor([-1, 1]), | |
2: torch.tensor([-3 / 2, 2, -1 / 2]), | |
3: torch.tensor([-11 / 6, 3, -3 / 2, 1 / 3]), | |
4: torch.tensor([-25 / 12, 4, -3, 4 / 3, -1 / 4]), |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
"""Pads an image for Bluesky.""" | |
import argparse | |
import math | |
from pathlib import Path | |
from PIL import Image |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""JAX implementation of the 2D DWT and IDWT.""" | |
from einops import rearrange | |
import jax | |
import jax.numpy as jnp | |
import pywt | |
def get_filter_bank(wavelet, dtype=jnp.float32): | |
"""Get the filter bank for a given pywavelets wavelet name.""" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
"""A simple JAX process launcher for multiple devices on a single host. | |
You must import jax_local_cluster somewhere inside the script you are launching. | |
""" | |
import argparse | |
from functools import partial | |
import os |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Apache License | |
Version 2.0, January 2004 | |
http://www.apache.org/licenses/ | |
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION | |
1. Definitions. | |
"License" shall mean the terms and conditions for use, reproduction, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import math | |
import multiprocessing as mp | |
import typing | |
from einops import rearrange | |
import flax | |
import flax.linen as nn | |
import jax |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import copy | |
from functools import wraps | |
from hashlib import sha256 | |
from io import open | |
import json | |
import math | |
import logging | |
import os | |
from pathlib import Path | |
import shutil |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import torch | |
from torch import optim | |
class AdamWFinetune(optim.Optimizer): | |
r"""Implements AdamW algorithm with optional weight decay toward the starting value, to | |
prevent overfitting to the new dataset during fine-tuning. | |
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. |