Skip to content

Instantly share code, notes, and snippets.

@nboyd
Created February 27, 2026 20:35
Show Gist options
  • Select an option

  • Save nboyd/84b188afe76f74a49e88a0c25cffcc98 to your computer and use it in GitHub Desktop.

Select an option

Save nboyd/84b188afe76f74a49e88a0c25cffcc98 to your computer and use it in GitHub Desktop.
Boltzgen finetuning loss
"""
Finetuning BoltzGen's structure module on hallucinated binder designs.
Simplified gist.
"""
import equinox as eqx
import jax
import jax.numpy as jnp
from joltzgen import AtomDiffusion, weighted_rigid_align
from mosaic.models.boltzgen import Sampler
# ---------- Core diffusion loss ----------
# Noise a ground-truth structure, denoise it, compute MSE + LDDT.
def compute_diffusion_loss(sm: AtomDiffusion, sampler: Sampler, coords, atom_mask, *, key):
# random noise level from log-normal
sigma = sm.sigma_data * jnp.exp(sm.P_mean + sm.P_std * jax.random.normal(key=key))
noise = sigma * jax.random.normal(key, coords.shape)
# denoise
denoised = sm.preconditioned_network_forward(
(coords + noise)[None],
sigma,
network_condition_kwargs=dict(
s_trunk=sampler.trunk_s,
s_inputs=sampler.s_inputs,
feats=sampler.feats,
multiplicity=1,
diffusion_conditioning={
"q": sampler.q, "c": sampler.c, "to_keys": sampler.to_keys,
"atom_enc_bias": sampler.atom_enc_bias,
"atom_dec_bias": sampler.atom_dec_bias,
"token_trans_bias": sampler.token_trans_bias,
},
),
key=key,
)[0]
# align ground truth to prediction
aligned = weighted_rigid_align(coords[None], denoised[None], atom_mask[None], atom_mask[None])[0]
# weighted MSE (Karras et al. weighting)
w = (sigma**2 + sm.sigma_data**2) / (sigma * sm.sigma_data) ** 2
mse = (((denoised - aligned) ** 2).sum(-1) * atom_mask).sum() / (3 * atom_mask.sum())
# smooth LDDT loss (Algorithm 27 from AF3)
true_dists = jnp.sqrt(((aligned[:, None] - aligned[None, :]) ** 2).sum(-1) + 1e-8)
pred_dists = jnp.sqrt(((denoised[:, None] - denoised[None, :]) ** 2).sum(-1) + 1e-8)
pair_mask = (true_dists < 15.0) & atom_mask[:, None] & atom_mask[None, :] & ~jnp.eye(len(aligned), dtype=bool)
dist_err = jnp.abs(true_dists - pred_dists)
lddt = (
(jax.nn.sigmoid(0.5 - dist_err) + jax.nn.sigmoid(1.0 - dist_err)
+ jax.nn.sigmoid(2.0 - dist_err) + jax.nn.sigmoid(4.0 - dist_err)) / 4.0
* pair_mask
).sum() / (pair_mask.sum() + 1e-5)
return w * mse + (1.0 - lddt)
@eqx.filter_value_and_grad
def diffusion_loss_and_grad(sm, *, coords, sampler, key, n_samples=4):
return jax.vmap(
lambda x, k: jax.vmap(
lambda k2: compute_diffusion_loss(
sm, sampler, x, sampler.feats["atom_resolved_mask"][0], key=k2,
)
)(jax.random.split(k, n_samples)).mean()
)(coords, jax.random.split(key, coords.shape[0])).mean()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment