Last active
February 10, 2025 14:58
-
-
Save ShigekiKarita/1496fdd5e5694d320c0f08eca7484bba to your computer and use it in GitHub Desktop.
Mini rectified flow (https://github.com/gnobitab/RectifiedFlow) on MNIST in JAX
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2025 Google LLC. | |
# SPDX-License-Identifier: Apache-2.0 | |
import tensorflow_datasets as tfds | |
from flax.training.train_state import TrainState | |
import flax.linen as nn | |
import jax | |
import jax.numpy as jnp | |
import matplotlib.pyplot as plt | |
import optax | |
from tqdm import tqdm | |
def resnet(ch: int, x: jax.Array, cond: jax.Array) -> jax.Array: | |
h = nn.Conv(ch, (3, 3))(nn.swish(nn.LayerNorm()(x))) | |
h += nn.Dense(ch)(cond)[:, None, None] | |
h = nn.Conv(ch, (3, 3))(nn.swish(nn.LayerNorm()(h))) | |
return h + nn.Dense(ch)(x) | |
def embed_pos(timesteps: jax.Array, dim: int, scale=10_000) -> jax.Array: | |
emb = jnp.log(scale) / (dim // 2 - 1) | |
emb = jnp.exp(jnp.arange(dim // 2, dtype=jnp.float32) * -emb) | |
emb = timesteps[:, None] * emb[None, :] | |
return jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], -1) # [batch, dim] | |
class Unet(nn.Module): | |
n_class: int = 10 | |
ch: int = 128 | |
num_block: int = 2 | |
ch_multi: list[int] = (1, 2, 4) | |
t_scale: float = 999 | |
@nn.compact | |
def __call__(self, x: jax.Array, t: jax.Array, y: jax.Array) -> jax.Array: | |
temb = embed_pos(t * self.t_scale, self.ch) | |
iemb = nn.Embed(self.n_class, self.ch)(y) | |
emb = jnp.concatenate([temb, iemb], -1) | |
emb = nn.Dense(self.ch)(nn.swish(nn.Dense(self.ch)(emb))) | |
hs = [x] | |
for i, ci in enumerate(self.ch_multi): | |
for _ in range(self.num_block): | |
hs.append(resnet(self.ch * ci, hs[-1], emb)) | |
if i != len(self.ch_multi) - 1: | |
hs.append(nn.Conv(hs[-1].shape[-1], (3, 3), (2, 2))(hs[-1])) # Downsample. | |
h = hs[-1] | |
for i, ci in reversed(list(enumerate(self.ch_multi))): | |
for _ in range(self.num_block + 1): | |
h = jnp.concatenate([h, hs.pop()], -1) | |
h = resnet(self.ch * ci, h, emb) | |
if i != 0: | |
h = jnp.repeat(jnp.repeat(h, 2, 2), 2, 1) # Upsample. | |
h = nn.Conv(h.shape[-1], (3, 3))(h) | |
return nn.Conv(x.shape[-1], (3, 3))(nn.swish(nn.LayerNorm()(h))) | |
@jax.jit | |
def step(rng, state: TrainState, x: jax.Array, y: jax.Array)-> tuple[TrainState, jax.Array]: | |
noise = jax.random.normal(rng, x.shape, dtype=x.dtype) | |
t = jax.random.uniform(rng, [x.shape[0]]) | |
tx = jnp.expand_dims(t, range(1, x.ndim)) | |
perturbed = tx * x + noise * (1 - tx) | |
def loss_fn(params): | |
predict = state.apply_fn({'params': params}, perturbed, t, y) | |
return jnp.mean(jnp.square(x - noise - predict)) | |
loss, grads = jax.value_and_grad(loss_fn)(state.params) | |
return state.apply_gradients(grads=grads), loss | |
@jax.jit | |
def sample(state: TrainState, x_init: jax.Array, y: jax.Array, n=1000, eps=1e-3) -> jax.Array: | |
def body_fn(i, x): | |
t = jnp.repeat(i / n * (1 - eps) + eps, x.shape[0], 0) | |
pred = state.apply_fn({'params': state.params}, x, t, y) | |
return x + pred / n | |
return jax.lax.fori_loop(0, n, body_fn, x_init) | |
def plot(state: TrainState, n_sample: int = 10, n_class: int = 10) -> None: | |
rng = jax.random.PRNGKey(0) # Fix seed to check improvements. | |
x = jax.random.normal(rng, [n_sample * n_class, 28, 28, 1], jnp.float32) | |
y = jnp.repeat(jnp.arange(n_class, dtype=jnp.int32), n_sample, 0) | |
out = sample(state, x, y) | |
out = jnp.reshape(out[:n_class * n_sample], [n_class, n_sample, 28, 28]) | |
plt.matshow(jnp.reshape(jnp.swapaxes(out, 1, 2), [n_class * 28, n_sample * 28])) | |
plt.show() | |
plt.close() | |
rng = jax.random.PRNGKey(0) | |
model = Unet() | |
tx = optax.adam(2e-4) | |
ds = tfds.load('mnist', split='train', batch_size=128, as_supervised=True, shuffle_files=True) | |
state = None | |
for epoch in range(10): | |
loss = 0 | |
with tqdm(tfds.as_numpy(ds)) as t: | |
for i, (x, y) in enumerate(t): | |
if state is None: | |
params = model.init(rng, x, jnp.zeros_like(y), y)['params'] | |
state = TrainState.create(apply_fn=model.apply, params=params, tx=tx) | |
rng, step_rng = jax.random.split(rng) | |
state, new_loss = step(step_rng, state, x/255*2-1, y.astype(jnp.int32)) | |
loss = (loss * i + float(jnp.mean(new_loss))) / (i + 1) | |
t.set_postfix(loss=loss) | |
plot(state) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment