Created
March 27, 2026 18:50
-
-
Save cheery/137e01b9da8259a0bf96f54f8a6e428a to your computer and use it in GitHub Desktop.
claude's version of SEDD
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Score Entropy Discrete Diffusion (SEDD) | |
| ======================================== | |
| PyTorch implementation of the algorithms from: | |
| "Discrete Diffusion Modeling by Estimating the Ratios of the Data Distribution" | |
| (Lou, Meng, Ermon — ICML 2024) | |
| Implements: | |
| - Algorithm 1: Score Entropy Training (DWDSE loss) | |
| - Algorithm 2: Unconditional Sampling (Euler & Tweedie) | |
| - Algorithm 3: Conditional Sampling (infilling / prompting) | |
| The score network s_θ : X × R → R^{d×n} learns the concrete score, | |
| i.e. the ratios p_t(y)/p_t(x) for Hamming-distance-1 neighbours. | |
| Two transition matrices are supported: Q^uniform and Q^absorb (Eqs 15–16). | |
| """ | |
| from pathlib import Path | |
| from torch.utils.data import Dataset, DataLoader | |
| import torch | |
| directory = Path(__file__).parent | |
| def load_kalevala(): | |
| filename = (directory / "../../../data/kalevala.plain.txt").resolve() | |
| with filename.open("r", encoding="utf-8") as fd: | |
| text = fd.read().replace("\n", " ") | |
| return text | |
| def create_dataloader(text, | |
| encoder=None, | |
| batch_size=4, | |
| length=256, | |
| stride=128, | |
| shuffle=True, | |
| drop_last=False, | |
| num_workers=0): | |
| dataset = KalevalaDataset(text, encoder or text_to_tensor, length, stride) | |
| dataloader = DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| shuffle = shuffle, | |
| drop_last = drop_last, | |
| num_workers = num_workers) | |
| return dataloader | |
| class KalevalaDataset(Dataset): | |
| def __init__(self, text, encoder, length, stride): | |
| self.inputs = [] | |
| self.targets = [] | |
| data = encoder(text) | |
| for i in range(0, len(data) - length - 1, stride): | |
| input_chunk = data[i:i + length] | |
| target_chunk = data[i+1:i+1+length] | |
| self.inputs.append(input_chunk) | |
| self.targets.append(target_chunk) | |
| def __len__(self): | |
| return len(self.inputs) | |
| def __getitem__(self, index): | |
| return self.inputs[index], self.targets[index] | |
| def text_to_tensor(text): | |
| data = text.encode("utf-8") | |
| raw = torch.frombuffer(bytearray(data), dtype=torch.uint8) | |
| return raw.type(torch.long) | |
| def as_text(p: torch.Tensor) -> str: | |
| return p.cpu().to(torch.uint8).numpy().tobytes().decode("utf-8", errors="replace") | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import Optional, Literal | |
| # ============================================================ | |
| # Noise Schedules (Appendix C.1) | |
| # ============================================================ | |
| # σ̄(t) = cumulative noise = ∫₀ᵗ σ(s)ds | |
| # σ(t) = instantaneous rate = dσ̄/dt | |
| # t ∈ [0, 1]. At t=0 little noise; at t=1 base distribution. | |
| class GeometricSchedule: | |
| r""" | |
| σ̄(t) = σ_min^{1-t} · σ_max^t (Appendix C.1) | |
| σ(t) = σ̄(t) · ln(σ_max / σ_min) | |
| """ | |
| def __init__(self, sigma_min: float = 1e-5, sigma_max: float = 20.0): | |
| self.sigma_min = sigma_min | |
| self.sigma_max = sigma_max | |
| self.log_ratio = math.log(sigma_max / sigma_min) | |
| def sigma_bar(self, t: torch.Tensor) -> torch.Tensor: | |
| return self.sigma_min ** (1.0 - t) * self.sigma_max ** t | |
| def sigma(self, t: torch.Tensor) -> torch.Tensor: | |
| return self.sigma_bar(t) * self.log_ratio | |
| class LogLinearSchedule: | |
| r""" | |
| σ̄(t) = -log(1 - (1-ε)t) (Appendix C.1) | |
| σ(t) = (1-ε) / (1 - (1-ε)t) | |
| """ | |
| def __init__(self, eps: float = 1e-3): | |
| self.eps = eps | |
| def sigma_bar(self, t: torch.Tensor) -> torch.Tensor: | |
| return -torch.log1p(-(1.0 - self.eps) * t) | |
| def sigma(self, t: torch.Tensor) -> torch.Tensor: | |
| return (1.0 - self.eps) / (1.0 - (1.0 - self.eps) * t) | |
| # ============================================================ | |
| # Forward Transition p_{t|0}(·|x₀) (Section 3.3, Algorithm 1) | |
| # ============================================================ | |
| # Each token is perturbed independently via | |
| # x_t^i ~ p_{t|0}^{tok}(·|x₀^i) = exp(σ̄(t) Q^{tok})_{x₀^i} (Eq 14) | |
| # Closed forms for the two Q matrices: | |
| def _forward_probs_absorb(sigma_bar: torch.Tensor, x0: torch.Tensor, n: int): | |
| """ | |
| Absorbing diffusion (Eq 16). MASK = token index n-1. | |
| p_{t|0}(y | x₀) = e^{-σ̄}·δ(y, x₀) + (1 - e^{-σ̄})·δ(y, MASK) | |
| Args: | |
| sigma_bar: (B,) | |
| x0: (B, d) tokens in {0, …, n-1} | |
| n: vocab size (last token = MASK) | |
| Returns: | |
| (B, d, n) transition probabilities | |
| """ | |
| B, d = x0.shape | |
| sb = sigma_bar[:, None, None] # (B,1,1) | |
| stay = torch.exp(-sb) # prob of no change | |
| probs = torch.zeros(B, d, n, device=x0.device, dtype=sb.dtype) | |
| probs.scatter_(2, x0.unsqueeze(-1), stay.expand(B, d, 1)) | |
| probs[:, :, n - 1] += (1.0 - stay).squeeze(-1) # go to MASK | |
| return probs | |
| def _forward_probs_uniform(sigma_bar: torch.Tensor, x0: torch.Tensor, n: int): | |
| """ | |
| Uniform diffusion (Eq 15). | |
| p_{t|0}(y | x₀) = (1 - e^{-σ̄})/n ∀y | |
| + e^{-σ̄} additionally when y = x₀ | |
| (Algorithm 1, "else if Q is Uniform" branch.) | |
| """ | |
| B, d = x0.shape | |
| sb = sigma_bar[:, None, None] | |
| unif = (1.0 - torch.exp(-sb)) / n # (B,1,1) | |
| probs = unif.expand(B, d, n).clone() | |
| probs.scatter_add_(2, x0.unsqueeze(-1), | |
| torch.exp(-sb).expand(B, d, 1)) | |
| return probs | |
| def forward_transition(sigma_bar, x0, n, mode): | |
| if mode == "absorb": | |
| return _forward_probs_absorb(sigma_bar, x0, n) | |
| return _forward_probs_uniform(sigma_bar, x0, n) | |
| def sample_xt(x0, sigma_bar, n, mode): | |
| """Sample x_t ~ p_{t|0}(·|x₀) per token (Eq 13).""" | |
| probs = forward_transition(sigma_bar, x0, n, mode) | |
| flat = probs.reshape(-1, n) | |
| xt = torch.multinomial(flat, 1).reshape(x0.shape) | |
| return xt, probs | |
| # ============================================================ | |
| # Algorithm 1 — Score Entropy Training Loss (L̂_DWDSE) | |
| # ============================================================ | |
| # L̂ = σ(t) Σ_i Σ_{y≠x_t^i} [ s_θ(x_t,t)_{i,y} | |
| # - p_{t|0}(y|x₀^i) / p_{t|0}(x_t^i|x₀^i) · log s_θ(x_t,t)_{i,y} ] | |
| # | |
| # The K(·) normalising constant (Eq 5) is θ-independent and dropped. | |
| def score_entropy_loss( | |
| score_net: nn.Module, | |
| x0: torch.Tensor, | |
| n: int, | |
| schedule, | |
| mode: Literal["absorb", "uniform"] = "absorb", | |
| t: Optional[torch.Tensor] = None, | |
| ): | |
| """ | |
| One training step of Algorithm 1. | |
| Args: | |
| score_net: (x_t, σ̄) → (B, d, n) positive scores | |
| x0: (B, d) clean data | |
| n: vocab size | |
| schedule: noise schedule object | |
| mode: "absorb" | "uniform" | |
| t: (B,) optional pre-sampled times; else U[0,1] | |
| Returns: | |
| scalar loss (mean over batch) | |
| """ | |
| B, d = x0.shape | |
| device = x0.device | |
| # --- sample t ∼ U([0,1]) --- | |
| if t is None: | |
| t = torch.rand(B, device=device) | |
| sb = schedule.sigma_bar(t) # σ̄(t) (B,) | |
| st = schedule.sigma(t) # σ(t) (B,) | |
| # --- construct x_t (forward noising) --- | |
| xt, probs = sample_xt(x0, sb, n, mode) # xt (B,d), probs (B,d,n) | |
| # --- score network forward pass --- | |
| scores = score_net(xt, sb) # (B, d, n) positive | |
| # --- target ratios r_y = p_{t|0}(y|x₀) / p_{t|0}(x_t|x₀) --- | |
| p_xt = probs.gather(2, xt.unsqueeze(-1)) # (B, d, 1) | |
| ratios = probs / p_xt.clamp(min=1e-20) # (B, d, n) | |
| # --- mask diagonal y ≠ x_t --- | |
| mask = torch.ones_like(scores) | |
| mask.scatter_(2, xt.unsqueeze(-1), 0.0) | |
| # --- per-entry loss: s_y − r_y · log(s_y) --- | |
| log_s = torch.log(scores.clamp(min=1e-20)) | |
| loss_entries = (scores - ratios * log_s) * mask # (B, d, n) | |
| # --- weight by σ(t) and average --- | |
| loss = loss_entries.sum(dim=(-1, -2)) * st # (B,) | |
| return loss.mean() | |
| # ============================================================ | |
| # ELBO evaluation (Theorem 3.6, Appendix C.6) | |
| # ============================================================ | |
| # −log p₀^θ(x₀) ≤ L_DWDSE(x₀) + D_KL(p_{T|0}(·|x₀) ‖ π) | |
| # Monte-Carlo over 1000 random timesteps as in the paper. | |
| @torch.no_grad() | |
| def estimate_elbo( | |
| score_net: nn.Module, | |
| x0: torch.Tensor, | |
| n: int, | |
| schedule, | |
| mode: Literal["absorb", "uniform"] = "absorb", | |
| num_t: int = 1000, | |
| ): | |
| """ | |
| Estimate −ELBO (upper bound on NLL) for a batch of data. | |
| Uses K(a) = a(log a − 1) to get the full (non-gradient) loss, | |
| then adds the prior KL. Returns per-sample values in nats. | |
| """ | |
| B, d = x0.shape | |
| device = x0.device | |
| # Monte-Carlo estimate of the integral ∫₀¹ L(t) dt | |
| ts = torch.rand(num_t, device=device) | |
| total = torch.zeros(B, device=device) | |
| for ti in ts: | |
| t_batch = ti.expand(B) | |
| sb = schedule.sigma_bar(t_batch) | |
| st = schedule.sigma(t_batch) | |
| xt, probs = sample_xt(x0, sb, n, mode) | |
| scores = score_net(xt, sb) | |
| p_xt = probs.gather(2, xt.unsqueeze(-1)) | |
| ratios = probs / p_xt.clamp(min=1e-20) | |
| mask = torch.ones_like(scores) | |
| mask.scatter_(2, xt.unsqueeze(-1), 0.0) | |
| log_s = torch.log(scores.clamp(min=1e-20)) | |
| # full score entropy includes K(r) = r(log r − 1) | |
| log_r = torch.log(ratios.clamp(min=1e-20)) | |
| K_r = ratios * (log_r - 1.0) | |
| entry = (scores - ratios * log_s + K_r) * mask | |
| total += entry.sum(dim=(-1, -2)) * st # weighted by σ(t) | |
| integral = total / num_t # Monte-Carlo mean (B,) | |
| # --- prior KL: D_KL(p_{T|0}(·|x₀) ‖ π) --- | |
| sb_T = schedule.sigma_bar(torch.ones(B, device=device)) | |
| probs_T = forward_transition(sb_T, x0, n, mode) # (B,d,n) | |
| if mode == "absorb": | |
| # π = MASK everywhere | |
| log_pi = torch.zeros(n, device=device) | |
| log_pi[n - 1] = 0.0 | |
| log_pi[:n - 1] = -float('inf') | |
| else: | |
| # π = uniform 1/n | |
| log_pi = torch.full((n,), -math.log(n), device=device) | |
| kl = (probs_T * (torch.log(probs_T.clamp(min=1e-20)) - log_pi)).sum(-1).sum(-1) | |
| return integral + kl # (B,) upper bound on −log p₀(x₀) | |
| # ============================================================ | |
| # Algorithm 2 — Unconditional Sampling | |
| # ============================================================ | |
| # Two strategies: Euler (Eq 17) and Tweedie τ-leaping (Eq 19). | |
| # Both reverse from t=1 (base) to t=0 (data). | |
| def _sample_base(B, d, n, mode, device): | |
| """x_T ∼ p_base. Absorb → all MASK; Uniform → random.""" | |
| if mode == "absorb": | |
| return torch.full((B, d), n - 1, dtype=torch.long, device=device) | |
| return torch.randint(0, n, (B, d), device=device) | |
| def _euler_transition(xt, scores, sigma_t, dt, n, mode): | |
| """ | |
| Euler reverse step (Eq 17). | |
| p^i(y | x_t^i) = δ(y, x_t^i) | |
| + Δt · σ(t) · Q^{tok}(x_t^i, y) · s_θ(x_t, t)_{i,y} | |
| Q(x,y) is row-x col-y of the forward rate matrix, i.e. the | |
| forward rate of probability flowing *from state y to state x*. | |
| The reverse rate from x to y is s_θ_y · Q(x,y) · σ(t). | |
| But here y is the destination in reverse, so the off-diagonal | |
| reverse-transition prob is Δt · σ(t) · Q(x_t, y) · s_θ_y. | |
| Uniform: Q(x,y)=1 for x≠y → rate = σ·s_y for every y≠x. | |
| Absorb: Q(MASK,y)=1 for y<MASK; all other off-diag = 0 | |
| → only MASK tokens unmask; non-MASK tokens stay. | |
| """ | |
| B, d, _ = scores.shape | |
| probs = F.one_hot(xt, n).float() # (B,d,n) | |
| rate = sigma_t[:, None, None] * scores * dt # (B,d,n) | |
| if mode == "uniform": | |
| rate.scatter_(2, xt.unsqueeze(-1), 0.0) # zero diagonal | |
| probs = probs + rate | |
| elif mode == "absorb": | |
| # Only MASK tokens get reverse transitions | |
| is_mask = (xt == n - 1).unsqueeze(-1).float() # (B,d,1) | |
| rate[:, :, n - 1] = 0.0 # no MASK→MASK | |
| probs = probs + rate * is_mask | |
| return probs | |
| def _tweedie_transition(xt, scores, alpha, n, mode): | |
| """ | |
| Tweedie τ-leaping step (Eq 19, Theorem 4.2). | |
| p^i(y | x_t^i) = | |
| [ exp(-α Q) · s_θ ]_y × exp(α Q)_{x_t^i, y} | |
| where α = σ_t^{Δt} = σ̄(t) − σ̄(t−Δt) > 0. | |
| Closed forms for exp(±α Q): | |
| Uniform eigenvalues 0, −1 (after absorbing n into σ̄): | |
| exp(αQ)_{x,y} = e^{-α}δ_{xy} + (1−e^{-α})/n | |
| exp(−αQ)_{x,y} = e^{α}δ_{xy} + (1−e^{α})/n | |
| Absorb (MASK = n−1): | |
| exp(αQ): col y<M → row y: e^{-α}, row M: 1−e^{-α} | |
| col M → row M: 1 | |
| exp(−αQ): same structure with −α. | |
| """ | |
| B, d, _ = scores.shape | |
| a = alpha[:, None, None] # (B,1,1) | |
| ea = torch.exp(a) # e^α | |
| ema = torch.exp(-a) # e^{-α} | |
| if mode == "uniform": | |
| # --- left factor: [exp(−αQ) · s]_y --- | |
| s_sum = scores.sum(dim=-1, keepdim=True) # (B,d,1) | |
| left = ea * scores + (1.0 - ea) / n * s_sum # (B,d,n) | |
| # --- right factor: exp(αQ)_{x_t, y} --- | |
| right = ((1.0 - ema) / n).expand(B, d, n).clone() | |
| right.scatter_add_(2, xt.unsqueeze(-1), | |
| ema.expand(B, d, 1)) | |
| probs = left * right | |
| elif mode == "absorb": | |
| # Non-MASK tokens always stay (derivation in text). | |
| # MASK tokens: | |
| # y < n-1: e^α · s_y · (1 − e^{-α}) | |
| # y = MASK: (1−e^α)·Σ_{z<M} s_z + s_{MASK} | |
| trans_mask = torch.zeros_like(scores) | |
| trans_mask[:, :, :n - 1] = ( | |
| ea * scores[:, :, :n - 1] * (1.0 - ema) | |
| ) | |
| s_nonmask_sum = scores[:, :, :n - 1].sum(-1, keepdim=True) | |
| trans_mask[:, :, n - 1:] = ( | |
| (1.0 - ea) * s_nonmask_sum + scores[:, :, n - 1:] | |
| ) | |
| trans_stay = F.one_hot(xt, n).float() | |
| is_mask = (xt == n - 1).unsqueeze(-1).float() | |
| probs = is_mask * trans_mask + (1.0 - is_mask) * trans_stay | |
| return probs | |
| #from catsample import sample_categorical | |
| #return sample_categorical(probs) | |
| def _clamp_and_sample(probs): | |
| """Clamp negatives, normalise, sample (Algorithm 2 post-processing).""" | |
| probs = probs.clamp(min=0.0) | |
| probs = probs / (probs.sum(-1, keepdim=True) + 1e-20) | |
| return torch.multinomial(probs.reshape(-1, probs.shape[-1]), 1 | |
| ).reshape(probs.shape[:-1]) | |
| @torch.no_grad() | |
| def sample( | |
| score_net: nn.Module, | |
| batch: int, | |
| d: int, | |
| n: int, | |
| schedule, | |
| num_steps: int = 256, | |
| method: Literal["euler", "tweedie"] = "tweedie", | |
| mode: Literal["absorb", "uniform"] = "absorb", | |
| device: torch.device = torch.device("cpu"), | |
| ): | |
| """ | |
| Algorithm 2: unconditional sampling. | |
| Reverses from t=1 → t=0 in `num_steps` uniform steps. | |
| """ | |
| dt = 1.0 / num_steps | |
| xt = _sample_base(batch, d, n, mode, device) | |
| for step in range(num_steps): | |
| t_val = 1.0 - step * dt | |
| t_vec = torch.full((batch,), t_val, device=device) | |
| sb = schedule.sigma_bar(t_vec) | |
| scores = score_net(xt, sb) | |
| if method == "euler": | |
| st = schedule.sigma(t_vec) | |
| probs = _euler_transition(xt, scores, st, dt, n, mode) | |
| else: | |
| t_prev = torch.full((batch,), max(t_val - dt, 0.0), device=device) | |
| alpha = sb - schedule.sigma_bar(t_prev) # σ̄(t)−σ̄(t−Δt) | |
| probs = _tweedie_transition(xt, scores, alpha, n, mode) | |
| xt = _clamp_and_sample(probs) | |
| return xt | |
| # ============================================================ | |
| # Algorithm 3 — Conditional Sampling (infilling / prompting) | |
| # ============================================================ | |
| # By Bayes' rule (Eq 22) the conditional and unconditional scores | |
| # coincide when we only modify tokens at unfilled positions Ω. | |
| # So we run normal reverse diffusion but freeze prompt positions. | |
| @torch.no_grad() | |
| def sample_conditional( | |
| score_net: nn.Module, | |
| batch: int, | |
| d: int, | |
| n: int, | |
| schedule, | |
| prompt_indices: torch.Tensor, | |
| prompt_tokens: torch.Tensor, | |
| num_steps: int = 256, | |
| method: Literal["euler", "tweedie"] = "tweedie", | |
| mode: Literal["absorb", "uniform"] = "absorb", | |
| device: torch.device = torch.device("cpu"), | |
| ): | |
| """ | |
| Algorithm 3: conditional sampling. | |
| Positions listed in `prompt_indices` are clamped to `prompt_tokens`. | |
| All other positions are generated via the reverse process. | |
| Args: | |
| prompt_indices: (P,) int — which positions are given | |
| prompt_tokens: (P,) int — token values at those positions | |
| Returns: | |
| (B, d) generated sequences with prompts inserted | |
| """ | |
| dt = 1.0 / num_steps | |
| xt = _sample_base(batch, d, n, mode, device) | |
| # Ω = prompt positions (fixed); Ω̄ = free positions | |
| xt[:, prompt_indices] = prompt_tokens.unsqueeze(0).expand(batch, -1) | |
| free = torch.ones(d, dtype=torch.bool, device=device) | |
| free[prompt_indices] = False | |
| for step in range(num_steps): | |
| t_val = 1.0 - step * dt | |
| t_vec = torch.full((batch,), t_val, device=device) | |
| sb = schedule.sigma_bar(t_vec) | |
| scores = score_net(xt, sb) | |
| if method == "euler": | |
| st = schedule.sigma(t_vec) | |
| probs = _euler_transition(xt, scores, st, dt, n, mode) | |
| else: | |
| t_prev = torch.full((batch,), max(t_val - dt, 0.0), device=device) | |
| alpha = sb - schedule.sigma_bar(t_prev) | |
| probs = _tweedie_transition(xt, scores, alpha, n, mode) | |
| new_xt = _clamp_and_sample(probs) | |
| # Only update free (non-prompt) positions | |
| xt[:, free] = new_xt[:, free] | |
| return xt | |
| # ============================================================ | |
| # Example score network (placeholder — not the paper's DiT) | |
| # ============================================================ | |
| # The actual architecture is a DiT-style encoder-only transformer | |
| # (Peebles & Xie 2023) with: | |
| # - adaLN-zero time conditioning on σ̄(t) (not t itself) | |
| # - rotary positional embeddings | |
| # - separate input embedding and output projection matrices | |
| # - output exponentiated for positivity; scaled by (e^σ̄ − 1) | |
| # for absorb (Appendix C.2) | |
| class ExampleScoreNetwork(nn.Module): | |
| """ | |
| Minimal feed-forward score network for testing / illustration. | |
| Replace with a proper transformer for real experiments. | |
| """ | |
| def __init__(self, n: int, d: int, hidden: int = 256, | |
| mode: Literal["absorb", "uniform"] = "absorb"): | |
| super().__init__() | |
| self.n = n | |
| self.mode = mode | |
| self.tok_embed = nn.Embedding(n, hidden) | |
| self.time_mlp = nn.Sequential( | |
| nn.Linear(1, hidden), nn.SiLU(), nn.Linear(hidden, hidden)) | |
| self.body = nn.Sequential( | |
| nn.Linear(hidden, hidden), nn.SiLU(), | |
| nn.Linear(hidden, hidden), nn.SiLU(), | |
| nn.Linear(hidden, n)) | |
| def forward(self, xt: torch.Tensor, sigma_bar: torch.Tensor): | |
| """ | |
| Args: | |
| xt: (B, d) token indices | |
| sigma_bar: (B,) cumulative noise σ̄(t) | |
| Returns: | |
| (B, d, n) positive score estimates | |
| """ | |
| h = self.tok_embed(xt) # (B,d,H) | |
| h = h + self.time_mlp(sigma_bar[:, None]).unsqueeze(1) | |
| logits = self.body(h) # (B,d,n) | |
| # Exponentiate for positivity (Appendix C.2) | |
| scores = torch.exp(logits) | |
| # Absorb scaling: multiply by (e^σ̄ − 1) | |
| if self.mode == "absorb": | |
| scores = scores * (torch.exp(sigma_bar) - 1.0)[:, None, None] | |
| return scores | |
| # ============================================================ | |
| # Quick smoke test | |
| # ============================================================ | |
| if __name__ == "__main__": | |
| torch.manual_seed(42) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| device = torch.device("cpu") | |
| n, d, B = 256, 32, 4 # vocab, seq len, batch | |
| mode = "absorb" | |
| dataloader = create_dataloader( | |
| load_kalevala(), | |
| text_to_tensor, | |
| batch_size=B, | |
| length=d) | |
| net = ExampleScoreNetwork(n, d, hidden=128, mode=mode).to(device) | |
| schedule = GeometricSchedule(sigma_min=1e-4, sigma_max=20.0) | |
| # --- training step --- | |
| for k, (x0, _) in enumerate(dataloader): | |
| net.train() | |
| x0 = x0.to(device) | |
| #x0 = torch.randint(0, n - 1, (B, d), device=device) # avoid MASK in data | |
| optimizer = torch.optim.Adam(net.parameters(), lr=3e-4) | |
| loss = score_entropy_loss(net, x0, n, schedule, mode=mode) | |
| loss.backward() | |
| optimizer.step() | |
| print(f"training loss: {loss.item():.4f}") | |
| if k % 8 == 0: | |
| # --- unconditional sampling --- | |
| net.eval() | |
| samples = sample(net, batch=2, d=d, n=n, schedule=schedule, | |
| num_steps=64, method="tweedie", mode=mode, device=device) | |
| print(repr(as_text(samples[0]))) | |
| print(repr(as_text(samples[1]))) | |
| print(f"samples shape: {samples.shape} range: [{samples.min()}, {samples.max()}]") | |
| # --- conditional sampling (infill positions 0..7) --- | |
| prompt_idx = torch.arange(8, device=device) | |
| prompt_tok = torch.randint(0, n - 1, (8,), device=device) | |
| infilled = sample_conditional( | |
| net, batch=2, d=d, n=n, schedule=schedule, | |
| prompt_indices=prompt_idx, prompt_tokens=prompt_tok, | |
| num_steps=64, method="tweedie", mode=mode, device=device) | |
| assert (infilled[:, :8] == prompt_tok).all() | |
| print(f"conditional: prompt preserved ✓ shape: {infilled.shape}") | |
| print("smoke test passed") | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment