Last active
November 11, 2022 04:30
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import math | |
import multiprocessing as mp | |
import typing | |
from einops import rearrange | |
import flax | |
import flax.linen as nn | |
import jax | |
from jax.experimental import maps | |
from jax.experimental import PartitionSpec | |
from jax.experimental.pjit import pjit | |
import jax.numpy as jnp | |
from jax.tree_util import Partial | |
from matplotlib import use | |
import lpips_jax | |
import numpy as np | |
import optax | |
from PIL import Image | |
from rich import print | |
from rich.traceback import install | |
import torch | |
from torch.utils import data | |
from torchvision import datasets, transforms | |
def n_params(tree): | |
return sum(x.size for x in jax.tree_util.tree_leaves(tree)) | |
def ema_update(params_ema, params, decay): | |
return jax.tree_map(lambda p_ema, p: p_ema * decay + p * (1 - decay), params_ema, params) | |
def inverse_decay_schedule(init_value, steps=1., power=1., warmup=0., final_lr=0.): | |
def schedule(count): | |
return init_value * (1 - warmup ** (1 + count)) * jnp.maximum(final_lr, (1 + count / steps) ** -power) | |
return schedule | |
# source: https://arxiv.org/pdf/1902.02603.pdf | |
def ive_fraction_approx_2(v, z, eps=1e-20): | |
def delta_a(a): | |
lamb = v + (a - 1.0) / 2.0 | |
return (v - 0.5) + lamb / (2 * jnp.sqrt(jnp.clip(lamb ** 2 + z ** 2, a_min=eps))) | |
delta_0 = delta_a(0.0) | |
delta_2 = delta_a(2.0) | |
B_0 = z / (delta_0 + jnp.clip(jnp.sqrt(delta_0 ** 2 + z ** 2), a_min=eps)) | |
B_2 = z / (delta_2 + jnp.clip(jnp.sqrt(delta_2 ** 2 + z ** 2), a_min=eps)) | |
return (B_0 + B_2) / 2 | |
def von_mises_fisher_kl_div_from_uniform_approx_der(log_kappa, p): | |
def fn(log_kappa): | |
kappa = jnp.exp(log_kappa) | |
return ive_fraction_approx_2(p / 2, kappa) | |
_, der = jax.jvp(fn, (log_kappa,), (jnp.ones_like(log_kappa),)) | |
der_log_kappa = jnp.exp(log_kappa) * der | |
return jnp.where(log_kappa <= jnp.log(10000) + jnp.log(p), der_log_kappa, (p - 1) / 2) | |
@jax.custom_jvp | |
@jnp.vectorize | |
def von_mises_fisher_kl_div_from_uniform_approx(log_kappa, p): | |
"""The KL divergence of a vMF distribution from the uniform prior.""" | |
x = jnp.linspace(-5, log_kappa, 101) | |
y = von_mises_fisher_kl_div_from_uniform_approx_der(x, p) | |
return jnp.trapz(y, x) | |
@von_mises_fisher_kl_div_from_uniform_approx.defjvp | |
def _vmf_kl_jvp(primals, tangents): | |
primal_out = von_mises_fisher_kl_div_from_uniform_approx(*primals) | |
tangent_out = von_mises_fisher_kl_div_from_uniform_approx_der(*primals) * tangents[0] | |
return primal_out, tangent_out | |
def log_c_der(log_kappa, p): | |
return -ive_fraction_approx_2(p / 2, jnp.exp(log_kappa)) | |
@jax.custom_jvp | |
def log_c_for_der(log_kappa, p): | |
"""A fake vMF log concentration parameter function that has an | |
approximation to its derivative.""" | |
return jnp.zeros_like(log_kappa) | |
@log_c_for_der.defjvp | |
def _log_c_for_der_jvp(primals, tangents): | |
primal_out = log_c_for_der(*primals) | |
tangent_out = log_c_der(*primals) * tangents[0] | |
return primal_out, tangent_out | |
def von_mises_fisher_logpdf_for_grad(x, mu, log_kappa): | |
"""A fake vMF log density function that has an approximation to its | |
derivative.""" | |
p = mu.shape[-1] | |
unnorm = jnp.exp(log_kappa) * jnp.sum(x * mu, axis=-1, keepdims=True) | |
c = log_c_for_der(log_kappa, p) | |
return unnorm + c | |
def projx(x): | |
"""Project x onto the manifold.""" | |
return x / jnp.linalg.norm(x, axis=-1, keepdims=True) | |
def proju(x, u): | |
"""Project u onto the manifold's tangent space at x.""" | |
return u - jnp.sum(x * u, axis=-1, keepdims=True) * x | |
def retr(x, u): | |
"""The manifold's retraction map.""" | |
return projx(x + u) | |
def scale_down(x): | |
n, h, w, c = x.shape | |
x = jax.image.resize(x, (n, h // 2, w // 2, c), jax.image.ResizeMethod.LINEAR) | |
return x | |
def scale_up(x): | |
n, h, w, c = x.shape | |
x = jax.image.resize(x, (n, h * 2, w * 2, c), jax.image.ResizeMethod.LINEAR) | |
return x | |
class ResConvBlock(nn.Module): | |
@nn.compact | |
def __call__(self, x): | |
y = x | |
y = nn.GroupNorm(16)(y) | |
y = nn.gelu(y) | |
y = nn.Conv(y.shape[-1], (3, 3))(y) | |
y = nn.gelu(y) | |
y = nn.Conv(y.shape[-1], (3, 3))(y) | |
return x + y | |
class Encoder(nn.Module): | |
features: int | |
depths: typing.Sequence[int] | |
widths: typing.Sequence[int] | |
@nn.compact | |
def __call__(self, x): | |
for i in range(len(self.depths)): | |
x = nn.Conv(self.widths[i], (1, 1), use_bias=False, kernel_init=nn.initializers.orthogonal())(x) | |
for j in range(self.depths[i]): | |
x = ResConvBlock()(x) | |
if i < len(self.depths) - 1: | |
x = scale_down(x) | |
x = nn.Conv(self.features + 1, (1, 1))(x) | |
mean, log_kappa = x[..., :-1], x[..., -1:] | |
return projx(mean), jnp.clip(log_kappa, -10, 15) | |
class Decoder(nn.Module): | |
features: int | |
depths: typing.Sequence[int] | |
widths: typing.Sequence[int] | |
@nn.compact | |
def __call__(self, x): | |
x = x * jnp.sqrt(x.shape[-1]) | |
for i in range(len(self.depths)): | |
x = nn.Conv(self.widths[i], (1, 1), use_bias=False, kernel_init=nn.initializers.orthogonal())(x) | |
if i > 0: | |
x = scale_up(x) | |
for j in range(self.depths[i]): | |
x = ResConvBlock()(x) | |
x = nn.Conv(self.features, (1, 1), kernel_init=nn.initializers.zeros)(x) | |
return x | |
class AutoencoderOutput(flax.struct.PyTreeNode): | |
rec: jax.Array | |
loss_kl: jax.Array | |
scale_l2: jax.Array | |
scale_p: jax.Array | |
scale_adv: jax.Array | |
class Autoencoder(nn.Module): | |
input_features: int | |
features: int | |
depths: typing.Sequence[int] | |
widths: typing.Sequence[int] | |
def setup(self): | |
self.encoder = Encoder(self.features, self.depths, self.widths) | |
self.decoder = Decoder(self.input_features, tuple(reversed(self.depths)), tuple(reversed(self.widths))) | |
self.scale_l2 = self.param('scale_l2', nn.initializers.zeros, ()) | |
self.scale_p = self.param('scale_p', nn.initializers.zeros, ()) | |
self.scale_adv = self.param('scale_adv', nn.initializers.zeros, ()) | |
def encode(self, x): | |
return self.encoder(x) | |
def decode(self, x): | |
return self.decoder(x) | |
def sample(self, key, mu, log_kappa): | |
eps = 1 | |
steps = 50 | |
def rld_update(x, key, mu, log_kappa, step): | |
"""Preconditioned Riemannian Langevin dynamics.""" | |
grad = jax.grad(lambda x: jnp.sum(von_mises_fisher_logpdf_for_grad(x, mu, log_kappa)))(x) | |
noise = jax.random.normal(key, x.shape) | |
cur_eps = eps / (1 + step) | |
precond = jnp.exp(-log_kappa) | |
e_step = precond * grad * 0.5 * cur_eps + jnp.sqrt(precond * cur_eps) * noise | |
r = proju(x, e_step) | |
x = retr(x, r) | |
return x | |
def scan_fn(carry, _): | |
x, key, step = carry | |
key, subkey = jax.random.split(key) | |
x = rld_update(x, subkey, mu, log_kappa, step) | |
return (x, key, step + 1), jnp.array([]) | |
return jax.lax.scan(scan_fn, (mu, key, jnp.array(0.)), jnp.zeros([steps]))[0][0] | |
def loss_kl(self, mu, log_kappa): | |
losses = von_mises_fisher_kl_div_from_uniform_approx(log_kappa, self.features) | |
return jnp.mean(losses) / 2 ** (2 * (len(self.depths) - 1)) | |
def scales(self): | |
return self.scale_l2 * 10, self.scale_p * 10, self.scale_adv * 10 | |
def __call__(self, x): | |
dist = self.encode(x) | |
loss_kl = self.loss_kl(*dist) | |
latent = self.sample(self.make_rng('sample'), *dist) | |
rec = self.decode(latent) | |
return AutoencoderOutput(rec, loss_kl, *self.scales()) | |
class Discriminator(nn.Module): | |
@nn.compact | |
def __call__(self, x): | |
x = nn.Conv(64, (3, 3))(x) | |
x = nn.gelu(x) | |
x = scale_down(x) | |
x = nn.Conv(128, (3, 3))(x) | |
x = nn.gelu(x) | |
x = scale_down(x) | |
x = nn.Conv(256, (3, 3))(x) | |
x = nn.gelu(x) | |
x = scale_down(x) | |
x = nn.Conv(256, (3, 3))(x) | |
x = nn.gelu(x) | |
x = scale_down(x) | |
x = nn.Conv(1, (3, 3), kernel_init=nn.initializers.zeros)(x) | |
return x | |
def loss_d(self, x, rec): | |
d_x, d_rec = jnp.split(self(jnp.concatenate([x, rec])), 2) | |
return jnp.mean(nn.softplus(d_rec)) + jnp.mean(nn.softplus(-d_x)) | |
def loss_gp(self, key, x, rec): | |
e = jax.random.uniform(key, [x.shape[0], 1, 1, 1]) | |
x_hat = e * x + (1 - e) * rec | |
grads = jax.grad(lambda x: jnp.sum(jnp.mean(self(x), axis=(1, 2, 3))))(x_hat) | |
gps = jnp.sum(jax.vmap(jnp.ravel)(grads) ** 2, axis=1) / 2 | |
return jnp.mean(gps) | |
def loss_g(self, x, rec): | |
d_rec = self(rec) | |
return jnp.mean(nn.softplus(-d_rec)) | |
class ToChannelsFirst: | |
def __call__(self, x): | |
return rearrange(x, 'c h w -> h w c') | |
class TrainState(flax.struct.PyTreeNode): | |
step: jax.Array | |
params: typing.Any | |
params_ema: typing.Any | |
opt_state: typing.Any | |
params_d: typing.Any | |
opt_state_d: typing.Any | |
class Checkpoint(flax.struct.PyTreeNode): | |
train_state: TrainState | |
key: jax.Array | |
def main(): | |
install() | |
try: | |
jax.distributed.initialize() | |
except ValueError: | |
pass | |
if jax.default_backend() == 'tpu': | |
mp.set_start_method('spawn') | |
elif jax.process_count() > 1: | |
mp.set_start_method('forkserver') | |
batch_size_per_device = 32 | |
batch_size_per_process = batch_size_per_device * jax.local_device_count() | |
batch_size = batch_size_per_device * jax.device_count() | |
if jax.process_index() == 0: | |
print('Processes:', jax.process_count()) | |
print('Devices:', jax.device_count()) | |
print('Batch size per device:', batch_size_per_device) | |
print('Batch size per process:', batch_size_per_process) | |
print('Batch size:', batch_size, flush=True) | |
mesh_shape = (jax.device_count(),) | |
devices = np.asarray(jax.devices()).reshape(*mesh_shape) | |
mesh = maps.Mesh(devices, ('n')) | |
batch_sharded_image = PartitionSpec('n', None, None, None) | |
key = jax.random.PRNGKey(9146) | |
key = jax.random.split(key, jax.process_count())[jax.process_index()] | |
key, subkey = jax.random.split(key) | |
torch.manual_seed(jax.random.randint(subkey, [], -2 ** 31, 2 ** 31 - 1).item()) | |
size = 128 | |
tf = transforms.Compose([ | |
transforms.Resize(256, transforms.InterpolationMode.BICUBIC), | |
transforms.RandomCrop(size), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5], [0.5]), | |
ToChannelsFirst(), | |
]) | |
# dataset = datasets.ImageFolder('/fsx/ilsvrc2012/train', transform=tf) | |
dataset = datasets.ImageFolder('/home/kat/datasets/ilsvrc2012/train', transform=tf) | |
dataloader = data.DataLoader(dataset, batch_size_per_process, shuffle=True, drop_last=True, num_workers=4, persistent_workers=True) | |
ch = 9 | |
model = Autoencoder(3, ch, (2, 2, 2, 2), (64, 128, 256, 256)) | |
d = Discriminator() | |
key, subkey = jax.random.split(key) | |
params = model.init({'params': subkey, 'sample': subkey}, jnp.zeros([1, size, size, 3]))['params'] | |
key, subkey = jax.random.split(key) | |
params_d = d.init(subkey, jnp.zeros([1, size, size, 3]))['params'] | |
if jax.process_index() == 0: | |
print(f'Parameters: {n_params(params)}', flush=True) | |
print(f'D parameters: {n_params(params_d)}', flush=True) | |
sched = inverse_decay_schedule(3e-4, steps=25000, warmup=0.99) | |
opt = optax.adamw(sched, b2=0.99, weight_decay=1e-4) | |
opt_state = opt.init(params) | |
sched_d = inverse_decay_schedule(3e-4, steps=25000, warmup=0.99) | |
opt_d = optax.adamw(sched_d, b1=0., b2=0.99, weight_decay=1e-4) | |
opt_state_d = opt_d.init(params_d) | |
lpips = lpips_jax.load(net='vgg16') | |
train_state = TrainState(jnp.array(0, jnp.int32), params, params, opt_state, params_d, opt_state_d) | |
d_start = 2000 | |
def kl_weight(step): | |
return jnp.sin(jnp.minimum(1., step / d_start) * jnp.pi / 2) ** 2 | |
def ema_decay(step): | |
return jnp.minimum(0.999, 1 - (1 + step) ** -0.75) | |
@Partial(pjit, in_axis_resources=(None, None, None, batch_sharded_image), out_axis_resources=None, donate_argnums=0) | |
def update(state, lpips, key, x): | |
def loss_fn(params): | |
out = model.apply({'params': params}, x, rngs={'sample': key}) | |
loss_l2 = 0.5 * (0.5 * jnp.mean(jnp.square(x - out.rec)) / jnp.exp(out.scale_l2) + 0.5 * out.scale_l2) | |
loss_p = 0.5 * (0.5 * jnp.mean(lpips.model.apply(lpips.params, x, out.rec)) / jnp.exp(out.scale_p) + 0.5 * out.scale_p) | |
loss_kl = out.loss_kl * kl_weight(state.step) | |
def adv_fn(): | |
loss_adv = d.apply({'params': state.params_d}, x, out.rec, method=d.loss_g) | |
return loss_adv / jnp.exp(out.scale_adv) + 0.5 * out.scale_adv | |
loss_adv = jax.lax.cond(state.step >= d_start, adv_fn, lambda: 0.) | |
loss = loss_l2 + loss_p + loss_kl + loss_adv | |
return loss, {'loss': loss, 'l2': loss_l2, 'p': loss_p, 'kl': loss_kl, 'adv': loss_adv} | |
(loss, losses), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params) | |
updates, opt_state = opt.update(grads, state.opt_state, state.params) | |
params = optax.apply_updates(state.params, updates) | |
params_ema = ema_update(state.params_ema, params, ema_decay(step)) | |
state = state.replace(step=state.step + 1, params=params, params_ema=params_ema, opt_state=opt_state) | |
return state, losses | |
@Partial(pjit, in_axis_resources=(None, None, batch_sharded_image), out_axis_resources=None, donate_argnums=0) | |
def update_d(state, key, x): | |
def loss_fn(params): | |
key_, subkey = jax.random.split(key) | |
rec = model.apply({'params': state.params}, x, rngs={'sample': key_}).rec | |
loss_d = d.apply({'params': params}, x, rec, method=d.loss_d) | |
loss_gp = d.apply({'params': params}, subkey, x, rec, method=d.loss_gp) * 10 | |
loss = loss_d + loss_gp | |
return loss, {'loss': loss, 'd': loss_d, 'gp': loss_gp} | |
(loss, losses), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params_d) | |
updates, opt_state_d = opt_d.update(grads, state.opt_state_d, state.params_d) | |
params_d = optax.apply_updates(state.params_d, updates) | |
state = state.replace(params_d=params_d, opt_state_d=opt_state_d) | |
return state, losses | |
@Partial(pjit, in_axis_resources=(None, None, batch_sharded_image), out_axis_resources=None) | |
def reconstruct(params, key, x): | |
return model.apply({'params': params}, x, rngs={'sample': key}).rec | |
# @Partial(pjit, in_axis_resources=(None, None, batch_sharded_image), out_axis_resources=None) | |
def _sample(params, key, x): | |
return model.apply({'params': params}, x, rngs={'sample': key}, method=model.decode) | |
def sample(params, key, shape): | |
keys = jax.random.split(key, 2) | |
lat = projx(jax.random.normal(keys[0], shape)) | |
return _sample(params, keys[1], lat) | |
@Partial(pjit, in_axis_resources=batch_sharded_image, out_axis_resources=None) | |
def gather(x): | |
return x | |
def demo(params, key, x): | |
keys = jax.random.split(key, 2) | |
n_side = min(8, math.floor(batch_size ** 0.5)) | |
n = n_side * n_side | |
n_per_process = jax.local_device_count() * math.ceil(n / jax.device_count()) | |
x = x[:n_per_process] | |
rec = reconstruct(params, keys[0], x)[:n] | |
x = gather(x)[:n] | |
grid = rearrange([x, rec], 't (nh nw) h w c -> (nh h) (nw t w) c', nh=n_side, nw=n_side) | |
# TODO: parallel sample | |
samples = jax.jit(sample, static_argnums=2)(params, keys[1], (n_side * 4, 16, 16, ch)) | |
sample_grid = rearrange(samples, '(nh nw) h w c -> (nh h) (nw w) c', nh=2, nw=n_side * 2) | |
grid = jnp.concatenate([grid, sample_grid], axis=0) | |
grid = np.array(jnp.round(jnp.clip((grid + 1) * 127.5, 0, 255)).astype(jnp.uint8)) | |
if jax.process_index() == 0: | |
Image.fromarray(grid).save(f'demo_{step:08}.png') | |
print('📸 Output demo grid!', flush=True) | |
def save(train_state, step, key): | |
ckpt = Checkpoint(train_state, key) | |
with open(f'ckpt_{step:08}.ckpt', 'wb') as file: | |
file.write(flax.serialization.to_bytes(ckpt)) | |
print('💾 Saved a checkpoint!') | |
def data_iterator(iterable): | |
while True: | |
for item in iterable: | |
yield jax.tree_map(lambda x: jnp.array(x), item) | |
step = train_state.step.item() | |
data_iter = data_iterator(dataloader) | |
try: | |
while True: | |
with mesh: | |
x = next(data_iter)[0] | |
if step % 250 == 0: | |
key, subkey = jax.random.split(key) | |
demo(train_state.params_ema, subkey, x) | |
if step > 0 and step % 10000 == 0: | |
if jax.process_index() == 0: | |
save(train_state, step, key) | |
key, subkey = jax.random.split(key) | |
train_state, losses = update(train_state, lpips, subkey, x) | |
if step >= d_start: | |
x = next(data_iter)[0] | |
key, subkey = jax.random.split(key) | |
train_state, losses_d = update_d(train_state, subkey, x) | |
if step % 25 == 0: | |
if jax.process_index() == 0: | |
if step >= d_start: | |
print(f"step: {step}, loss: {losses['loss']:g}, l2: {losses['l2']:g}, p: {losses['p']:g}, kl: {losses['kl']:g}, adv: {losses['adv']:g}, d: {losses_d['d']:g}, gp: {losses_d['gp']:g}", flush=True) | |
else: | |
print(f"step: {step}, loss: {losses['loss']:g}, l2: {losses['l2']:g}, p: {losses['p']:g}, kl: {losses['kl']:g}", flush=True) | |
step += 1 | |
except KeyboardInterrupt: | |
pass | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment