Skip to content

Instantly share code, notes, and snippets.

@crowsonkb
Last active November 11, 2022 04:30
#!/usr/bin/env python3
import math
import multiprocessing as mp
import typing
from einops import rearrange
import flax
import flax.linen as nn
import jax
from jax.experimental import maps
from jax.experimental import PartitionSpec
from jax.experimental.pjit import pjit
import jax.numpy as jnp
from jax.tree_util import Partial
from matplotlib import use
import lpips_jax
import numpy as np
import optax
from PIL import Image
from rich import print
from rich.traceback import install
import torch
from torch.utils import data
from torchvision import datasets, transforms
def n_params(tree):
return sum(x.size for x in jax.tree_util.tree_leaves(tree))
def ema_update(params_ema, params, decay):
return jax.tree_map(lambda p_ema, p: p_ema * decay + p * (1 - decay), params_ema, params)
def inverse_decay_schedule(init_value, steps=1., power=1., warmup=0., final_lr=0.):
def schedule(count):
return init_value * (1 - warmup ** (1 + count)) * jnp.maximum(final_lr, (1 + count / steps) ** -power)
return schedule
# source: https://arxiv.org/pdf/1902.02603.pdf
def ive_fraction_approx_2(v, z, eps=1e-20):
def delta_a(a):
lamb = v + (a - 1.0) / 2.0
return (v - 0.5) + lamb / (2 * jnp.sqrt(jnp.clip(lamb ** 2 + z ** 2, a_min=eps)))
delta_0 = delta_a(0.0)
delta_2 = delta_a(2.0)
B_0 = z / (delta_0 + jnp.clip(jnp.sqrt(delta_0 ** 2 + z ** 2), a_min=eps))
B_2 = z / (delta_2 + jnp.clip(jnp.sqrt(delta_2 ** 2 + z ** 2), a_min=eps))
return (B_0 + B_2) / 2
def von_mises_fisher_kl_div_from_uniform_approx_der(log_kappa, p):
def fn(log_kappa):
kappa = jnp.exp(log_kappa)
return ive_fraction_approx_2(p / 2, kappa)
_, der = jax.jvp(fn, (log_kappa,), (jnp.ones_like(log_kappa),))
der_log_kappa = jnp.exp(log_kappa) * der
return jnp.where(log_kappa <= jnp.log(10000) + jnp.log(p), der_log_kappa, (p - 1) / 2)
@jax.custom_jvp
@jnp.vectorize
def von_mises_fisher_kl_div_from_uniform_approx(log_kappa, p):
"""The KL divergence of a vMF distribution from the uniform prior."""
x = jnp.linspace(-5, log_kappa, 101)
y = von_mises_fisher_kl_div_from_uniform_approx_der(x, p)
return jnp.trapz(y, x)
@von_mises_fisher_kl_div_from_uniform_approx.defjvp
def _vmf_kl_jvp(primals, tangents):
primal_out = von_mises_fisher_kl_div_from_uniform_approx(*primals)
tangent_out = von_mises_fisher_kl_div_from_uniform_approx_der(*primals) * tangents[0]
return primal_out, tangent_out
def log_c_der(log_kappa, p):
return -ive_fraction_approx_2(p / 2, jnp.exp(log_kappa))
@jax.custom_jvp
def log_c_for_der(log_kappa, p):
"""A fake vMF log concentration parameter function that has an
approximation to its derivative."""
return jnp.zeros_like(log_kappa)
@log_c_for_der.defjvp
def _log_c_for_der_jvp(primals, tangents):
primal_out = log_c_for_der(*primals)
tangent_out = log_c_der(*primals) * tangents[0]
return primal_out, tangent_out
def von_mises_fisher_logpdf_for_grad(x, mu, log_kappa):
"""A fake vMF log density function that has an approximation to its
derivative."""
p = mu.shape[-1]
unnorm = jnp.exp(log_kappa) * jnp.sum(x * mu, axis=-1, keepdims=True)
c = log_c_for_der(log_kappa, p)
return unnorm + c
def projx(x):
"""Project x onto the manifold."""
return x / jnp.linalg.norm(x, axis=-1, keepdims=True)
def proju(x, u):
"""Project u onto the manifold's tangent space at x."""
return u - jnp.sum(x * u, axis=-1, keepdims=True) * x
def retr(x, u):
"""The manifold's retraction map."""
return projx(x + u)
def scale_down(x):
n, h, w, c = x.shape
x = jax.image.resize(x, (n, h // 2, w // 2, c), jax.image.ResizeMethod.LINEAR)
return x
def scale_up(x):
n, h, w, c = x.shape
x = jax.image.resize(x, (n, h * 2, w * 2, c), jax.image.ResizeMethod.LINEAR)
return x
class ResConvBlock(nn.Module):
@nn.compact
def __call__(self, x):
y = x
y = nn.GroupNorm(16)(y)
y = nn.gelu(y)
y = nn.Conv(y.shape[-1], (3, 3))(y)
y = nn.gelu(y)
y = nn.Conv(y.shape[-1], (3, 3))(y)
return x + y
class Encoder(nn.Module):
features: int
depths: typing.Sequence[int]
widths: typing.Sequence[int]
@nn.compact
def __call__(self, x):
for i in range(len(self.depths)):
x = nn.Conv(self.widths[i], (1, 1), use_bias=False, kernel_init=nn.initializers.orthogonal())(x)
for j in range(self.depths[i]):
x = ResConvBlock()(x)
if i < len(self.depths) - 1:
x = scale_down(x)
x = nn.Conv(self.features + 1, (1, 1))(x)
mean, log_kappa = x[..., :-1], x[..., -1:]
return projx(mean), jnp.clip(log_kappa, -10, 15)
class Decoder(nn.Module):
features: int
depths: typing.Sequence[int]
widths: typing.Sequence[int]
@nn.compact
def __call__(self, x):
x = x * jnp.sqrt(x.shape[-1])
for i in range(len(self.depths)):
x = nn.Conv(self.widths[i], (1, 1), use_bias=False, kernel_init=nn.initializers.orthogonal())(x)
if i > 0:
x = scale_up(x)
for j in range(self.depths[i]):
x = ResConvBlock()(x)
x = nn.Conv(self.features, (1, 1), kernel_init=nn.initializers.zeros)(x)
return x
class AutoencoderOutput(flax.struct.PyTreeNode):
rec: jax.Array
loss_kl: jax.Array
scale_l2: jax.Array
scale_p: jax.Array
scale_adv: jax.Array
class Autoencoder(nn.Module):
input_features: int
features: int
depths: typing.Sequence[int]
widths: typing.Sequence[int]
def setup(self):
self.encoder = Encoder(self.features, self.depths, self.widths)
self.decoder = Decoder(self.input_features, tuple(reversed(self.depths)), tuple(reversed(self.widths)))
self.scale_l2 = self.param('scale_l2', nn.initializers.zeros, ())
self.scale_p = self.param('scale_p', nn.initializers.zeros, ())
self.scale_adv = self.param('scale_adv', nn.initializers.zeros, ())
def encode(self, x):
return self.encoder(x)
def decode(self, x):
return self.decoder(x)
def sample(self, key, mu, log_kappa):
eps = 1
steps = 50
def rld_update(x, key, mu, log_kappa, step):
"""Preconditioned Riemannian Langevin dynamics."""
grad = jax.grad(lambda x: jnp.sum(von_mises_fisher_logpdf_for_grad(x, mu, log_kappa)))(x)
noise = jax.random.normal(key, x.shape)
cur_eps = eps / (1 + step)
precond = jnp.exp(-log_kappa)
e_step = precond * grad * 0.5 * cur_eps + jnp.sqrt(precond * cur_eps) * noise
r = proju(x, e_step)
x = retr(x, r)
return x
def scan_fn(carry, _):
x, key, step = carry
key, subkey = jax.random.split(key)
x = rld_update(x, subkey, mu, log_kappa, step)
return (x, key, step + 1), jnp.array([])
return jax.lax.scan(scan_fn, (mu, key, jnp.array(0.)), jnp.zeros([steps]))[0][0]
def loss_kl(self, mu, log_kappa):
losses = von_mises_fisher_kl_div_from_uniform_approx(log_kappa, self.features)
return jnp.mean(losses) / 2 ** (2 * (len(self.depths) - 1))
def scales(self):
return self.scale_l2 * 10, self.scale_p * 10, self.scale_adv * 10
def __call__(self, x):
dist = self.encode(x)
loss_kl = self.loss_kl(*dist)
latent = self.sample(self.make_rng('sample'), *dist)
rec = self.decode(latent)
return AutoencoderOutput(rec, loss_kl, *self.scales())
class Discriminator(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(64, (3, 3))(x)
x = nn.gelu(x)
x = scale_down(x)
x = nn.Conv(128, (3, 3))(x)
x = nn.gelu(x)
x = scale_down(x)
x = nn.Conv(256, (3, 3))(x)
x = nn.gelu(x)
x = scale_down(x)
x = nn.Conv(256, (3, 3))(x)
x = nn.gelu(x)
x = scale_down(x)
x = nn.Conv(1, (3, 3), kernel_init=nn.initializers.zeros)(x)
return x
def loss_d(self, x, rec):
d_x, d_rec = jnp.split(self(jnp.concatenate([x, rec])), 2)
return jnp.mean(nn.softplus(d_rec)) + jnp.mean(nn.softplus(-d_x))
def loss_gp(self, key, x, rec):
e = jax.random.uniform(key, [x.shape[0], 1, 1, 1])
x_hat = e * x + (1 - e) * rec
grads = jax.grad(lambda x: jnp.sum(jnp.mean(self(x), axis=(1, 2, 3))))(x_hat)
gps = jnp.sum(jax.vmap(jnp.ravel)(grads) ** 2, axis=1) / 2
return jnp.mean(gps)
def loss_g(self, x, rec):
d_rec = self(rec)
return jnp.mean(nn.softplus(-d_rec))
class ToChannelsFirst:
def __call__(self, x):
return rearrange(x, 'c h w -> h w c')
class TrainState(flax.struct.PyTreeNode):
step: jax.Array
params: typing.Any
params_ema: typing.Any
opt_state: typing.Any
params_d: typing.Any
opt_state_d: typing.Any
class Checkpoint(flax.struct.PyTreeNode):
train_state: TrainState
key: jax.Array
def main():
install()
try:
jax.distributed.initialize()
except ValueError:
pass
if jax.default_backend() == 'tpu':
mp.set_start_method('spawn')
elif jax.process_count() > 1:
mp.set_start_method('forkserver')
batch_size_per_device = 32
batch_size_per_process = batch_size_per_device * jax.local_device_count()
batch_size = batch_size_per_device * jax.device_count()
if jax.process_index() == 0:
print('Processes:', jax.process_count())
print('Devices:', jax.device_count())
print('Batch size per device:', batch_size_per_device)
print('Batch size per process:', batch_size_per_process)
print('Batch size:', batch_size, flush=True)
mesh_shape = (jax.device_count(),)
devices = np.asarray(jax.devices()).reshape(*mesh_shape)
mesh = maps.Mesh(devices, ('n'))
batch_sharded_image = PartitionSpec('n', None, None, None)
key = jax.random.PRNGKey(9146)
key = jax.random.split(key, jax.process_count())[jax.process_index()]
key, subkey = jax.random.split(key)
torch.manual_seed(jax.random.randint(subkey, [], -2 ** 31, 2 ** 31 - 1).item())
size = 128
tf = transforms.Compose([
transforms.Resize(256, transforms.InterpolationMode.BICUBIC),
transforms.RandomCrop(size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
ToChannelsFirst(),
])
# dataset = datasets.ImageFolder('/fsx/ilsvrc2012/train', transform=tf)
dataset = datasets.ImageFolder('/home/kat/datasets/ilsvrc2012/train', transform=tf)
dataloader = data.DataLoader(dataset, batch_size_per_process, shuffle=True, drop_last=True, num_workers=4, persistent_workers=True)
ch = 9
model = Autoencoder(3, ch, (2, 2, 2, 2), (64, 128, 256, 256))
d = Discriminator()
key, subkey = jax.random.split(key)
params = model.init({'params': subkey, 'sample': subkey}, jnp.zeros([1, size, size, 3]))['params']
key, subkey = jax.random.split(key)
params_d = d.init(subkey, jnp.zeros([1, size, size, 3]))['params']
if jax.process_index() == 0:
print(f'Parameters: {n_params(params)}', flush=True)
print(f'D parameters: {n_params(params_d)}', flush=True)
sched = inverse_decay_schedule(3e-4, steps=25000, warmup=0.99)
opt = optax.adamw(sched, b2=0.99, weight_decay=1e-4)
opt_state = opt.init(params)
sched_d = inverse_decay_schedule(3e-4, steps=25000, warmup=0.99)
opt_d = optax.adamw(sched_d, b1=0., b2=0.99, weight_decay=1e-4)
opt_state_d = opt_d.init(params_d)
lpips = lpips_jax.load(net='vgg16')
train_state = TrainState(jnp.array(0, jnp.int32), params, params, opt_state, params_d, opt_state_d)
d_start = 2000
def kl_weight(step):
return jnp.sin(jnp.minimum(1., step / d_start) * jnp.pi / 2) ** 2
def ema_decay(step):
return jnp.minimum(0.999, 1 - (1 + step) ** -0.75)
@Partial(pjit, in_axis_resources=(None, None, None, batch_sharded_image), out_axis_resources=None, donate_argnums=0)
def update(state, lpips, key, x):
def loss_fn(params):
out = model.apply({'params': params}, x, rngs={'sample': key})
loss_l2 = 0.5 * (0.5 * jnp.mean(jnp.square(x - out.rec)) / jnp.exp(out.scale_l2) + 0.5 * out.scale_l2)
loss_p = 0.5 * (0.5 * jnp.mean(lpips.model.apply(lpips.params, x, out.rec)) / jnp.exp(out.scale_p) + 0.5 * out.scale_p)
loss_kl = out.loss_kl * kl_weight(state.step)
def adv_fn():
loss_adv = d.apply({'params': state.params_d}, x, out.rec, method=d.loss_g)
return loss_adv / jnp.exp(out.scale_adv) + 0.5 * out.scale_adv
loss_adv = jax.lax.cond(state.step >= d_start, adv_fn, lambda: 0.)
loss = loss_l2 + loss_p + loss_kl + loss_adv
return loss, {'loss': loss, 'l2': loss_l2, 'p': loss_p, 'kl': loss_kl, 'adv': loss_adv}
(loss, losses), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
updates, opt_state = opt.update(grads, state.opt_state, state.params)
params = optax.apply_updates(state.params, updates)
params_ema = ema_update(state.params_ema, params, ema_decay(step))
state = state.replace(step=state.step + 1, params=params, params_ema=params_ema, opt_state=opt_state)
return state, losses
@Partial(pjit, in_axis_resources=(None, None, batch_sharded_image), out_axis_resources=None, donate_argnums=0)
def update_d(state, key, x):
def loss_fn(params):
key_, subkey = jax.random.split(key)
rec = model.apply({'params': state.params}, x, rngs={'sample': key_}).rec
loss_d = d.apply({'params': params}, x, rec, method=d.loss_d)
loss_gp = d.apply({'params': params}, subkey, x, rec, method=d.loss_gp) * 10
loss = loss_d + loss_gp
return loss, {'loss': loss, 'd': loss_d, 'gp': loss_gp}
(loss, losses), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params_d)
updates, opt_state_d = opt_d.update(grads, state.opt_state_d, state.params_d)
params_d = optax.apply_updates(state.params_d, updates)
state = state.replace(params_d=params_d, opt_state_d=opt_state_d)
return state, losses
@Partial(pjit, in_axis_resources=(None, None, batch_sharded_image), out_axis_resources=None)
def reconstruct(params, key, x):
return model.apply({'params': params}, x, rngs={'sample': key}).rec
# @Partial(pjit, in_axis_resources=(None, None, batch_sharded_image), out_axis_resources=None)
def _sample(params, key, x):
return model.apply({'params': params}, x, rngs={'sample': key}, method=model.decode)
def sample(params, key, shape):
keys = jax.random.split(key, 2)
lat = projx(jax.random.normal(keys[0], shape))
return _sample(params, keys[1], lat)
@Partial(pjit, in_axis_resources=batch_sharded_image, out_axis_resources=None)
def gather(x):
return x
def demo(params, key, x):
keys = jax.random.split(key, 2)
n_side = min(8, math.floor(batch_size ** 0.5))
n = n_side * n_side
n_per_process = jax.local_device_count() * math.ceil(n / jax.device_count())
x = x[:n_per_process]
rec = reconstruct(params, keys[0], x)[:n]
x = gather(x)[:n]
grid = rearrange([x, rec], 't (nh nw) h w c -> (nh h) (nw t w) c', nh=n_side, nw=n_side)
# TODO: parallel sample
samples = jax.jit(sample, static_argnums=2)(params, keys[1], (n_side * 4, 16, 16, ch))
sample_grid = rearrange(samples, '(nh nw) h w c -> (nh h) (nw w) c', nh=2, nw=n_side * 2)
grid = jnp.concatenate([grid, sample_grid], axis=0)
grid = np.array(jnp.round(jnp.clip((grid + 1) * 127.5, 0, 255)).astype(jnp.uint8))
if jax.process_index() == 0:
Image.fromarray(grid).save(f'demo_{step:08}.png')
print('📸 Output demo grid!', flush=True)
def save(train_state, step, key):
ckpt = Checkpoint(train_state, key)
with open(f'ckpt_{step:08}.ckpt', 'wb') as file:
file.write(flax.serialization.to_bytes(ckpt))
print('💾 Saved a checkpoint!')
def data_iterator(iterable):
while True:
for item in iterable:
yield jax.tree_map(lambda x: jnp.array(x), item)
step = train_state.step.item()
data_iter = data_iterator(dataloader)
try:
while True:
with mesh:
x = next(data_iter)[0]
if step % 250 == 0:
key, subkey = jax.random.split(key)
demo(train_state.params_ema, subkey, x)
if step > 0 and step % 10000 == 0:
if jax.process_index() == 0:
save(train_state, step, key)
key, subkey = jax.random.split(key)
train_state, losses = update(train_state, lpips, subkey, x)
if step >= d_start:
x = next(data_iter)[0]
key, subkey = jax.random.split(key)
train_state, losses_d = update_d(train_state, subkey, x)
if step % 25 == 0:
if jax.process_index() == 0:
if step >= d_start:
print(f"step: {step}, loss: {losses['loss']:g}, l2: {losses['l2']:g}, p: {losses['p']:g}, kl: {losses['kl']:g}, adv: {losses['adv']:g}, d: {losses_d['d']:g}, gp: {losses_d['gp']:g}", flush=True)
else:
print(f"step: {step}, loss: {losses['loss']:g}, l2: {losses['l2']:g}, p: {losses['p']:g}, kl: {losses['kl']:g}", flush=True)
step += 1
except KeyboardInterrupt:
pass
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment