Mini rectified flow ( on MNIST in JAX
# Copyright 2025 Google LLC.
# SPDX-License-Identifier: Apache-2.0
import tensorflow_datasets as tfds
from import TrainState
import flax.linen as nn
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax
from tqdm import tqdm
def resnet(ch: int, x: jax.Array, cond: jax.Array) -> jax.Array:
h = nn.Conv(ch, (3, 3))(nn.swish(nn.LayerNorm()(x)))
h += nn.Dense(ch)(cond)[:, None, None]
h = nn.Conv(ch, (3, 3))(nn.swish(nn.LayerNorm()(h)))
return h + nn.Dense(ch)(x)
def embed_pos(timesteps: jax.Array, dim: int, scale=10_000) -> jax.Array:
emb = jnp.log(scale) / (dim // 2 - 1)
emb = jnp.exp(jnp.arange(dim // 2, dtype=jnp.float32) * -emb)
emb = timesteps[:, None] * emb[None, :]
return jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], -1) # [batch, dim]
class Unet(nn.Module):
n_class: int = 10
ch: int = 128
num_block: int = 2
ch_multi: list[int] = (1, 2, 4)
t_scale: float = 999
def __call__(self, x: jax.Array, t: jax.Array, y: jax.Array) -> jax.Array:
temb = embed_pos(t * self.t_scale,
iemb = nn.Embed(self.n_class,
emb = jnp.concatenate([temb, iemb], -1)
emb = nn.Dense(
hs = [x]
for i, ci in enumerate(self.ch_multi):
for _ in range(self.num_block):
hs.append(resnet( * ci, hs[-1], emb))
if i != len(self.ch_multi) - 1:
hs.append(nn.Conv(hs[-1].shape[-1], (3, 3), (2, 2))(hs[-1])) # Downsample.
h = hs[-1]
for i, ci in reversed(list(enumerate(self.ch_multi))):
for _ in range(self.num_block + 1):
h = jnp.concatenate([h, hs.pop()], -1)
h = resnet( * ci, h, emb)
if i != 0:
h = jnp.repeat(jnp.repeat(h, 2, 2), 2, 1) # Upsample.
h = nn.Conv(h.shape[-1], (3, 3))(h)
return nn.Conv(x.shape[-1], (3, 3))(nn.swish(nn.LayerNorm()(h)))
def step(rng, state: TrainState, x: jax.Array, y: jax.Array)-> tuple[TrainState, jax.Array]:
noise = jax.random.normal(rng, x.shape, dtype=x.dtype)
t = jax.random.uniform(rng, [x.shape[0]])
tx = jnp.expand_dims(t, range(1, x.ndim))
perturbed = tx * x + noise * (1 - tx)
def loss_fn(params):
predict = state.apply_fn({'params': params}, perturbed, t, y)
return jnp.mean(jnp.square(x - noise - predict))
loss, grads = jax.value_and_grad(loss_fn)(state.params)
return state.apply_gradients(grads=grads), loss
def sample(state: TrainState, x_init: jax.Array, y: jax.Array, n=1000, eps=1e-3) -> jax.Array:
def body_fn(i, x):
t = jnp.repeat(i / n * (1 - eps) + eps, x.shape[0], 0)
pred = state.apply_fn({'params': state.params}, x, t, y)
return x + pred / n
return jax.lax.fori_loop(0, n, body_fn, x_init)
def plot(state: TrainState, n_sample: int = 10, n_class: int = 10) -> None:
rng = jax.random.PRNGKey(0) # Fix seed to check improvements.
x = jax.random.normal(rng, [n_sample * n_class, 28, 28, 1], jnp.float32)
y = jnp.repeat(jnp.arange(n_class, dtype=jnp.int32), n_sample, 0)
out = sample(state, x, y)
out = jnp.reshape(out[:n_class * n_sample], [n_class, n_sample, 28, 28])
plt.matshow(jnp.reshape(jnp.swapaxes(out, 1, 2), [n_class * 28, n_sample * 28]))
rng = jax.random.PRNGKey(0)
model = Unet()
tx = optax.adam(2e-4)
ds = tfds.load('mnist', split='train', batch_size=128, as_supervised=True, shuffle_files=True)
state = None
for epoch in range(10):
loss = 0
with tqdm(tfds.as_numpy(ds)) as t:
for i, (x, y) in enumerate(t):
if state is None:
params = model.init(rng, x, jnp.zeros_like(y), y)['params']
state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)
rng, step_rng = jax.random.split(rng)
state, new_loss = step(step_rng, state, x/255*2-1, y.astype(jnp.int32))
loss = (loss * i + float(jnp.mean(new_loss))) / (i + 1)
