Created
February 27, 2026 20:35
-
-
Save nboyd/84b188afe76f74a49e88a0c25cffcc98 to your computer and use it in GitHub Desktop.
Boltzgen finetuning loss
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Finetuning BoltzGen's structure module on hallucinated binder designs. | |
| Simplified gist. | |
| """ | |
| import equinox as eqx | |
| import jax | |
| import jax.numpy as jnp | |
| from joltzgen import AtomDiffusion, weighted_rigid_align | |
| from mosaic.models.boltzgen import Sampler | |
| # ---------- Core diffusion loss ---------- | |
| # Noise a ground-truth structure, denoise it, compute MSE + LDDT. | |
| def compute_diffusion_loss(sm: AtomDiffusion, sampler: Sampler, coords, atom_mask, *, key): | |
| # random noise level from log-normal | |
| sigma = sm.sigma_data * jnp.exp(sm.P_mean + sm.P_std * jax.random.normal(key=key)) | |
| noise = sigma * jax.random.normal(key, coords.shape) | |
| # denoise | |
| denoised = sm.preconditioned_network_forward( | |
| (coords + noise)[None], | |
| sigma, | |
| network_condition_kwargs=dict( | |
| s_trunk=sampler.trunk_s, | |
| s_inputs=sampler.s_inputs, | |
| feats=sampler.feats, | |
| multiplicity=1, | |
| diffusion_conditioning={ | |
| "q": sampler.q, "c": sampler.c, "to_keys": sampler.to_keys, | |
| "atom_enc_bias": sampler.atom_enc_bias, | |
| "atom_dec_bias": sampler.atom_dec_bias, | |
| "token_trans_bias": sampler.token_trans_bias, | |
| }, | |
| ), | |
| key=key, | |
| )[0] | |
| # align ground truth to prediction | |
| aligned = weighted_rigid_align(coords[None], denoised[None], atom_mask[None], atom_mask[None])[0] | |
| # weighted MSE (Karras et al. weighting) | |
| w = (sigma**2 + sm.sigma_data**2) / (sigma * sm.sigma_data) ** 2 | |
| mse = (((denoised - aligned) ** 2).sum(-1) * atom_mask).sum() / (3 * atom_mask.sum()) | |
| # smooth LDDT loss (Algorithm 27 from AF3) | |
| true_dists = jnp.sqrt(((aligned[:, None] - aligned[None, :]) ** 2).sum(-1) + 1e-8) | |
| pred_dists = jnp.sqrt(((denoised[:, None] - denoised[None, :]) ** 2).sum(-1) + 1e-8) | |
| pair_mask = (true_dists < 15.0) & atom_mask[:, None] & atom_mask[None, :] & ~jnp.eye(len(aligned), dtype=bool) | |
| dist_err = jnp.abs(true_dists - pred_dists) | |
| lddt = ( | |
| (jax.nn.sigmoid(0.5 - dist_err) + jax.nn.sigmoid(1.0 - dist_err) | |
| + jax.nn.sigmoid(2.0 - dist_err) + jax.nn.sigmoid(4.0 - dist_err)) / 4.0 | |
| * pair_mask | |
| ).sum() / (pair_mask.sum() + 1e-5) | |
| return w * mse + (1.0 - lddt) | |
| @eqx.filter_value_and_grad | |
| def diffusion_loss_and_grad(sm, *, coords, sampler, key, n_samples=4): | |
| return jax.vmap( | |
| lambda x, k: jax.vmap( | |
| lambda k2: compute_diffusion_loss( | |
| sm, sampler, x, sampler.feats["atom_resolved_mask"][0], key=k2, | |
| ) | |
| )(jax.random.split(k, n_samples)).mean() | |
| )(coords, jax.random.split(key, coords.shape[0])).mean() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment