Skip to content

Instantly share code, notes, and snippets.

@cheery
Created March 27, 2026 18:50
Show Gist options
  • Select an option

  • Save cheery/137e01b9da8259a0bf96f54f8a6e428a to your computer and use it in GitHub Desktop.

Select an option

Save cheery/137e01b9da8259a0bf96f54f8a6e428a to your computer and use it in GitHub Desktop.
claude's version of SEDD
"""
Score Entropy Discrete Diffusion (SEDD)
========================================
PyTorch implementation of the algorithms from:
"Discrete Diffusion Modeling by Estimating the Ratios of the Data Distribution"
(Lou, Meng, Ermon — ICML 2024)
Implements:
- Algorithm 1: Score Entropy Training (DWDSE loss)
- Algorithm 2: Unconditional Sampling (Euler & Tweedie)
- Algorithm 3: Conditional Sampling (infilling / prompting)
The score network s_θ : X × R → R^{d×n} learns the concrete score,
i.e. the ratios p_t(y)/p_t(x) for Hamming-distance-1 neighbours.
Two transition matrices are supported: Q^uniform and Q^absorb (Eqs 15–16).
"""
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
import torch
directory = Path(__file__).parent
def load_kalevala():
filename = (directory / "../../../data/kalevala.plain.txt").resolve()
with filename.open("r", encoding="utf-8") as fd:
text = fd.read().replace("\n", " ")
return text
def create_dataloader(text,
encoder=None,
batch_size=4,
length=256,
stride=128,
shuffle=True,
drop_last=False,
num_workers=0):
dataset = KalevalaDataset(text, encoder or text_to_tensor, length, stride)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle = shuffle,
drop_last = drop_last,
num_workers = num_workers)
return dataloader
class KalevalaDataset(Dataset):
def __init__(self, text, encoder, length, stride):
self.inputs = []
self.targets = []
data = encoder(text)
for i in range(0, len(data) - length - 1, stride):
input_chunk = data[i:i + length]
target_chunk = data[i+1:i+1+length]
self.inputs.append(input_chunk)
self.targets.append(target_chunk)
def __len__(self):
return len(self.inputs)
def __getitem__(self, index):
return self.inputs[index], self.targets[index]
def text_to_tensor(text):
data = text.encode("utf-8")
raw = torch.frombuffer(bytearray(data), dtype=torch.uint8)
return raw.type(torch.long)
def as_text(p: torch.Tensor) -> str:
return p.cpu().to(torch.uint8).numpy().tobytes().decode("utf-8", errors="replace")
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Literal
# ============================================================
# Noise Schedules (Appendix C.1)
# ============================================================
# σ̄(t) = cumulative noise = ∫₀ᵗ σ(s)ds
# σ(t) = instantaneous rate = dσ̄/dt
# t ∈ [0, 1]. At t=0 little noise; at t=1 base distribution.
class GeometricSchedule:
r"""
σ̄(t) = σ_min^{1-t} · σ_max^t (Appendix C.1)
σ(t) = σ̄(t) · ln(σ_max / σ_min)
"""
def __init__(self, sigma_min: float = 1e-5, sigma_max: float = 20.0):
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.log_ratio = math.log(sigma_max / sigma_min)
def sigma_bar(self, t: torch.Tensor) -> torch.Tensor:
return self.sigma_min ** (1.0 - t) * self.sigma_max ** t
def sigma(self, t: torch.Tensor) -> torch.Tensor:
return self.sigma_bar(t) * self.log_ratio
class LogLinearSchedule:
r"""
σ̄(t) = -log(1 - (1-ε)t) (Appendix C.1)
σ(t) = (1-ε) / (1 - (1-ε)t)
"""
def __init__(self, eps: float = 1e-3):
self.eps = eps
def sigma_bar(self, t: torch.Tensor) -> torch.Tensor:
return -torch.log1p(-(1.0 - self.eps) * t)
def sigma(self, t: torch.Tensor) -> torch.Tensor:
return (1.0 - self.eps) / (1.0 - (1.0 - self.eps) * t)
# ============================================================
# Forward Transition p_{t|0}(·|x₀) (Section 3.3, Algorithm 1)
# ============================================================
# Each token is perturbed independently via
# x_t^i ~ p_{t|0}^{tok}(·|x₀^i) = exp(σ̄(t) Q^{tok})_{x₀^i} (Eq 14)
# Closed forms for the two Q matrices:
def _forward_probs_absorb(sigma_bar: torch.Tensor, x0: torch.Tensor, n: int):
"""
Absorbing diffusion (Eq 16). MASK = token index n-1.
p_{t|0}(y | x₀) = e^{-σ̄}·δ(y, x₀) + (1 - e^{-σ̄})·δ(y, MASK)
Args:
sigma_bar: (B,)
x0: (B, d) tokens in {0, …, n-1}
n: vocab size (last token = MASK)
Returns:
(B, d, n) transition probabilities
"""
B, d = x0.shape
sb = sigma_bar[:, None, None] # (B,1,1)
stay = torch.exp(-sb) # prob of no change
probs = torch.zeros(B, d, n, device=x0.device, dtype=sb.dtype)
probs.scatter_(2, x0.unsqueeze(-1), stay.expand(B, d, 1))
probs[:, :, n - 1] += (1.0 - stay).squeeze(-1) # go to MASK
return probs
def _forward_probs_uniform(sigma_bar: torch.Tensor, x0: torch.Tensor, n: int):
"""
Uniform diffusion (Eq 15).
p_{t|0}(y | x₀) = (1 - e^{-σ̄})/n ∀y
+ e^{-σ̄} additionally when y = x₀
(Algorithm 1, "else if Q is Uniform" branch.)
"""
B, d = x0.shape
sb = sigma_bar[:, None, None]
unif = (1.0 - torch.exp(-sb)) / n # (B,1,1)
probs = unif.expand(B, d, n).clone()
probs.scatter_add_(2, x0.unsqueeze(-1),
torch.exp(-sb).expand(B, d, 1))
return probs
def forward_transition(sigma_bar, x0, n, mode):
if mode == "absorb":
return _forward_probs_absorb(sigma_bar, x0, n)
return _forward_probs_uniform(sigma_bar, x0, n)
def sample_xt(x0, sigma_bar, n, mode):
"""Sample x_t ~ p_{t|0}(·|x₀) per token (Eq 13)."""
probs = forward_transition(sigma_bar, x0, n, mode)
flat = probs.reshape(-1, n)
xt = torch.multinomial(flat, 1).reshape(x0.shape)
return xt, probs
# ============================================================
# Algorithm 1 — Score Entropy Training Loss (L̂_DWDSE)
# ============================================================
# L̂ = σ(t) Σ_i Σ_{y≠x_t^i} [ s_θ(x_t,t)_{i,y}
# - p_{t|0}(y|x₀^i) / p_{t|0}(x_t^i|x₀^i) · log s_θ(x_t,t)_{i,y} ]
#
# The K(·) normalising constant (Eq 5) is θ-independent and dropped.
def score_entropy_loss(
score_net: nn.Module,
x0: torch.Tensor,
n: int,
schedule,
mode: Literal["absorb", "uniform"] = "absorb",
t: Optional[torch.Tensor] = None,
):
"""
One training step of Algorithm 1.
Args:
score_net: (x_t, σ̄) → (B, d, n) positive scores
x0: (B, d) clean data
n: vocab size
schedule: noise schedule object
mode: "absorb" | "uniform"
t: (B,) optional pre-sampled times; else U[0,1]
Returns:
scalar loss (mean over batch)
"""
B, d = x0.shape
device = x0.device
# --- sample t ∼ U([0,1]) ---
if t is None:
t = torch.rand(B, device=device)
sb = schedule.sigma_bar(t) # σ̄(t) (B,)
st = schedule.sigma(t) # σ(t) (B,)
# --- construct x_t (forward noising) ---
xt, probs = sample_xt(x0, sb, n, mode) # xt (B,d), probs (B,d,n)
# --- score network forward pass ---
scores = score_net(xt, sb) # (B, d, n) positive
# --- target ratios r_y = p_{t|0}(y|x₀) / p_{t|0}(x_t|x₀) ---
p_xt = probs.gather(2, xt.unsqueeze(-1)) # (B, d, 1)
ratios = probs / p_xt.clamp(min=1e-20) # (B, d, n)
# --- mask diagonal y ≠ x_t ---
mask = torch.ones_like(scores)
mask.scatter_(2, xt.unsqueeze(-1), 0.0)
# --- per-entry loss: s_y − r_y · log(s_y) ---
log_s = torch.log(scores.clamp(min=1e-20))
loss_entries = (scores - ratios * log_s) * mask # (B, d, n)
# --- weight by σ(t) and average ---
loss = loss_entries.sum(dim=(-1, -2)) * st # (B,)
return loss.mean()
# ============================================================
# ELBO evaluation (Theorem 3.6, Appendix C.6)
# ============================================================
# −log p₀^θ(x₀) ≤ L_DWDSE(x₀) + D_KL(p_{T|0}(·|x₀) ‖ π)
# Monte-Carlo over 1000 random timesteps as in the paper.
@torch.no_grad()
def estimate_elbo(
score_net: nn.Module,
x0: torch.Tensor,
n: int,
schedule,
mode: Literal["absorb", "uniform"] = "absorb",
num_t: int = 1000,
):
"""
Estimate −ELBO (upper bound on NLL) for a batch of data.
Uses K(a) = a(log a − 1) to get the full (non-gradient) loss,
then adds the prior KL. Returns per-sample values in nats.
"""
B, d = x0.shape
device = x0.device
# Monte-Carlo estimate of the integral ∫₀¹ L(t) dt
ts = torch.rand(num_t, device=device)
total = torch.zeros(B, device=device)
for ti in ts:
t_batch = ti.expand(B)
sb = schedule.sigma_bar(t_batch)
st = schedule.sigma(t_batch)
xt, probs = sample_xt(x0, sb, n, mode)
scores = score_net(xt, sb)
p_xt = probs.gather(2, xt.unsqueeze(-1))
ratios = probs / p_xt.clamp(min=1e-20)
mask = torch.ones_like(scores)
mask.scatter_(2, xt.unsqueeze(-1), 0.0)
log_s = torch.log(scores.clamp(min=1e-20))
# full score entropy includes K(r) = r(log r − 1)
log_r = torch.log(ratios.clamp(min=1e-20))
K_r = ratios * (log_r - 1.0)
entry = (scores - ratios * log_s + K_r) * mask
total += entry.sum(dim=(-1, -2)) * st # weighted by σ(t)
integral = total / num_t # Monte-Carlo mean (B,)
# --- prior KL: D_KL(p_{T|0}(·|x₀) ‖ π) ---
sb_T = schedule.sigma_bar(torch.ones(B, device=device))
probs_T = forward_transition(sb_T, x0, n, mode) # (B,d,n)
if mode == "absorb":
# π = MASK everywhere
log_pi = torch.zeros(n, device=device)
log_pi[n - 1] = 0.0
log_pi[:n - 1] = -float('inf')
else:
# π = uniform 1/n
log_pi = torch.full((n,), -math.log(n), device=device)
kl = (probs_T * (torch.log(probs_T.clamp(min=1e-20)) - log_pi)).sum(-1).sum(-1)
return integral + kl # (B,) upper bound on −log p₀(x₀)
# ============================================================
# Algorithm 2 — Unconditional Sampling
# ============================================================
# Two strategies: Euler (Eq 17) and Tweedie τ-leaping (Eq 19).
# Both reverse from t=1 (base) to t=0 (data).
def _sample_base(B, d, n, mode, device):
"""x_T ∼ p_base. Absorb → all MASK; Uniform → random."""
if mode == "absorb":
return torch.full((B, d), n - 1, dtype=torch.long, device=device)
return torch.randint(0, n, (B, d), device=device)
def _euler_transition(xt, scores, sigma_t, dt, n, mode):
"""
Euler reverse step (Eq 17).
p^i(y | x_t^i) = δ(y, x_t^i)
+ Δt · σ(t) · Q^{tok}(x_t^i, y) · s_θ(x_t, t)_{i,y}
Q(x,y) is row-x col-y of the forward rate matrix, i.e. the
forward rate of probability flowing *from state y to state x*.
The reverse rate from x to y is s_θ_y · Q(x,y) · σ(t).
But here y is the destination in reverse, so the off-diagonal
reverse-transition prob is Δt · σ(t) · Q(x_t, y) · s_θ_y.
Uniform: Q(x,y)=1 for x≠y → rate = σ·s_y for every y≠x.
Absorb: Q(MASK,y)=1 for y<MASK; all other off-diag = 0
→ only MASK tokens unmask; non-MASK tokens stay.
"""
B, d, _ = scores.shape
probs = F.one_hot(xt, n).float() # (B,d,n)
rate = sigma_t[:, None, None] * scores * dt # (B,d,n)
if mode == "uniform":
rate.scatter_(2, xt.unsqueeze(-1), 0.0) # zero diagonal
probs = probs + rate
elif mode == "absorb":
# Only MASK tokens get reverse transitions
is_mask = (xt == n - 1).unsqueeze(-1).float() # (B,d,1)
rate[:, :, n - 1] = 0.0 # no MASK→MASK
probs = probs + rate * is_mask
return probs
def _tweedie_transition(xt, scores, alpha, n, mode):
"""
Tweedie τ-leaping step (Eq 19, Theorem 4.2).
p^i(y | x_t^i) =
[ exp(-α Q) · s_θ ]_y × exp(α Q)_{x_t^i, y}
where α = σ_t^{Δt} = σ̄(t) − σ̄(t−Δt) > 0.
Closed forms for exp(±α Q):
Uniform eigenvalues 0, −1 (after absorbing n into σ̄):
exp(αQ)_{x,y} = e^{-α}δ_{xy} + (1−e^{-α})/n
exp(−αQ)_{x,y} = e^{α}δ_{xy} + (1−e^{α})/n
Absorb (MASK = n−1):
exp(αQ): col y<M → row y: e^{-α}, row M: 1−e^{-α}
col M → row M: 1
exp(−αQ): same structure with −α.
"""
B, d, _ = scores.shape
a = alpha[:, None, None] # (B,1,1)
ea = torch.exp(a) # e^α
ema = torch.exp(-a) # e^{-α}
if mode == "uniform":
# --- left factor: [exp(−αQ) · s]_y ---
s_sum = scores.sum(dim=-1, keepdim=True) # (B,d,1)
left = ea * scores + (1.0 - ea) / n * s_sum # (B,d,n)
# --- right factor: exp(αQ)_{x_t, y} ---
right = ((1.0 - ema) / n).expand(B, d, n).clone()
right.scatter_add_(2, xt.unsqueeze(-1),
ema.expand(B, d, 1))
probs = left * right
elif mode == "absorb":
# Non-MASK tokens always stay (derivation in text).
# MASK tokens:
# y < n-1: e^α · s_y · (1 − e^{-α})
# y = MASK: (1−e^α)·Σ_{z<M} s_z + s_{MASK}
trans_mask = torch.zeros_like(scores)
trans_mask[:, :, :n - 1] = (
ea * scores[:, :, :n - 1] * (1.0 - ema)
)
s_nonmask_sum = scores[:, :, :n - 1].sum(-1, keepdim=True)
trans_mask[:, :, n - 1:] = (
(1.0 - ea) * s_nonmask_sum + scores[:, :, n - 1:]
)
trans_stay = F.one_hot(xt, n).float()
is_mask = (xt == n - 1).unsqueeze(-1).float()
probs = is_mask * trans_mask + (1.0 - is_mask) * trans_stay
return probs
#from catsample import sample_categorical
#return sample_categorical(probs)
def _clamp_and_sample(probs):
"""Clamp negatives, normalise, sample (Algorithm 2 post-processing)."""
probs = probs.clamp(min=0.0)
probs = probs / (probs.sum(-1, keepdim=True) + 1e-20)
return torch.multinomial(probs.reshape(-1, probs.shape[-1]), 1
).reshape(probs.shape[:-1])
@torch.no_grad()
def sample(
score_net: nn.Module,
batch: int,
d: int,
n: int,
schedule,
num_steps: int = 256,
method: Literal["euler", "tweedie"] = "tweedie",
mode: Literal["absorb", "uniform"] = "absorb",
device: torch.device = torch.device("cpu"),
):
"""
Algorithm 2: unconditional sampling.
Reverses from t=1 → t=0 in `num_steps` uniform steps.
"""
dt = 1.0 / num_steps
xt = _sample_base(batch, d, n, mode, device)
for step in range(num_steps):
t_val = 1.0 - step * dt
t_vec = torch.full((batch,), t_val, device=device)
sb = schedule.sigma_bar(t_vec)
scores = score_net(xt, sb)
if method == "euler":
st = schedule.sigma(t_vec)
probs = _euler_transition(xt, scores, st, dt, n, mode)
else:
t_prev = torch.full((batch,), max(t_val - dt, 0.0), device=device)
alpha = sb - schedule.sigma_bar(t_prev) # σ̄(t)−σ̄(t−Δt)
probs = _tweedie_transition(xt, scores, alpha, n, mode)
xt = _clamp_and_sample(probs)
return xt
# ============================================================
# Algorithm 3 — Conditional Sampling (infilling / prompting)
# ============================================================
# By Bayes' rule (Eq 22) the conditional and unconditional scores
# coincide when we only modify tokens at unfilled positions Ω.
# So we run normal reverse diffusion but freeze prompt positions.
@torch.no_grad()
def sample_conditional(
score_net: nn.Module,
batch: int,
d: int,
n: int,
schedule,
prompt_indices: torch.Tensor,
prompt_tokens: torch.Tensor,
num_steps: int = 256,
method: Literal["euler", "tweedie"] = "tweedie",
mode: Literal["absorb", "uniform"] = "absorb",
device: torch.device = torch.device("cpu"),
):
"""
Algorithm 3: conditional sampling.
Positions listed in `prompt_indices` are clamped to `prompt_tokens`.
All other positions are generated via the reverse process.
Args:
prompt_indices: (P,) int — which positions are given
prompt_tokens: (P,) int — token values at those positions
Returns:
(B, d) generated sequences with prompts inserted
"""
dt = 1.0 / num_steps
xt = _sample_base(batch, d, n, mode, device)
# Ω = prompt positions (fixed); Ω̄ = free positions
xt[:, prompt_indices] = prompt_tokens.unsqueeze(0).expand(batch, -1)
free = torch.ones(d, dtype=torch.bool, device=device)
free[prompt_indices] = False
for step in range(num_steps):
t_val = 1.0 - step * dt
t_vec = torch.full((batch,), t_val, device=device)
sb = schedule.sigma_bar(t_vec)
scores = score_net(xt, sb)
if method == "euler":
st = schedule.sigma(t_vec)
probs = _euler_transition(xt, scores, st, dt, n, mode)
else:
t_prev = torch.full((batch,), max(t_val - dt, 0.0), device=device)
alpha = sb - schedule.sigma_bar(t_prev)
probs = _tweedie_transition(xt, scores, alpha, n, mode)
new_xt = _clamp_and_sample(probs)
# Only update free (non-prompt) positions
xt[:, free] = new_xt[:, free]
return xt
# ============================================================
# Example score network (placeholder — not the paper's DiT)
# ============================================================
# The actual architecture is a DiT-style encoder-only transformer
# (Peebles & Xie 2023) with:
# - adaLN-zero time conditioning on σ̄(t) (not t itself)
# - rotary positional embeddings
# - separate input embedding and output projection matrices
# - output exponentiated for positivity; scaled by (e^σ̄ − 1)
# for absorb (Appendix C.2)
class ExampleScoreNetwork(nn.Module):
"""
Minimal feed-forward score network for testing / illustration.
Replace with a proper transformer for real experiments.
"""
def __init__(self, n: int, d: int, hidden: int = 256,
mode: Literal["absorb", "uniform"] = "absorb"):
super().__init__()
self.n = n
self.mode = mode
self.tok_embed = nn.Embedding(n, hidden)
self.time_mlp = nn.Sequential(
nn.Linear(1, hidden), nn.SiLU(), nn.Linear(hidden, hidden))
self.body = nn.Sequential(
nn.Linear(hidden, hidden), nn.SiLU(),
nn.Linear(hidden, hidden), nn.SiLU(),
nn.Linear(hidden, n))
def forward(self, xt: torch.Tensor, sigma_bar: torch.Tensor):
"""
Args:
xt: (B, d) token indices
sigma_bar: (B,) cumulative noise σ̄(t)
Returns:
(B, d, n) positive score estimates
"""
h = self.tok_embed(xt) # (B,d,H)
h = h + self.time_mlp(sigma_bar[:, None]).unsqueeze(1)
logits = self.body(h) # (B,d,n)
# Exponentiate for positivity (Appendix C.2)
scores = torch.exp(logits)
# Absorb scaling: multiply by (e^σ̄ − 1)
if self.mode == "absorb":
scores = scores * (torch.exp(sigma_bar) - 1.0)[:, None, None]
return scores
# ============================================================
# Quick smoke test
# ============================================================
if __name__ == "__main__":
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
n, d, B = 256, 32, 4 # vocab, seq len, batch
mode = "absorb"
dataloader = create_dataloader(
load_kalevala(),
text_to_tensor,
batch_size=B,
length=d)
net = ExampleScoreNetwork(n, d, hidden=128, mode=mode).to(device)
schedule = GeometricSchedule(sigma_min=1e-4, sigma_max=20.0)
# --- training step ---
for k, (x0, _) in enumerate(dataloader):
net.train()
x0 = x0.to(device)
#x0 = torch.randint(0, n - 1, (B, d), device=device) # avoid MASK in data
optimizer = torch.optim.Adam(net.parameters(), lr=3e-4)
loss = score_entropy_loss(net, x0, n, schedule, mode=mode)
loss.backward()
optimizer.step()
print(f"training loss: {loss.item():.4f}")
if k % 8 == 0:
# --- unconditional sampling ---
net.eval()
samples = sample(net, batch=2, d=d, n=n, schedule=schedule,
num_steps=64, method="tweedie", mode=mode, device=device)
print(repr(as_text(samples[0])))
print(repr(as_text(samples[1])))
print(f"samples shape: {samples.shape} range: [{samples.min()}, {samples.max()}]")
# --- conditional sampling (infill positions 0..7) ---
prompt_idx = torch.arange(8, device=device)
prompt_tok = torch.randint(0, n - 1, (8,), device=device)
infilled = sample_conditional(
net, batch=2, d=d, n=n, schedule=schedule,
prompt_indices=prompt_idx, prompt_tokens=prompt_tok,
num_steps=64, method="tweedie", mode=mode, device=device)
assert (infilled[:, :8] == prompt_tok).all()
print(f"conditional: prompt preserved ✓ shape: {infilled.shape}")
print("smoke test passed")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment