Skip to content

Instantly share code, notes, and snippets.

@nboyd
Created February 18, 2026 19:44
Show Gist options
  • Select an option

  • Save nboyd/8e4f3243243fab4f4b0467ebc5d63447 to your computer and use it in GitHub Desktop.

Select an option

Save nboyd/8e4f3243243fab4f4b0467ebc5d63447 to your computer and use it in GitHub Desktop.
hallucination + ranking with Protenix v1.0
from math import sqrt
import equinox as eqx
import gemmi
import jax
import jax.numpy as jnp
import mosaic.losses.structure_prediction as sp
from mosaic.common import TOKENS
from mosaic.losses.protein_mpnn import InverseFoldingSequenceRecovery
from mosaic.losses.transformations import NoCys
from mosaic.models.protenix import Protenix2025
from mosaic.optimizers import simplex_APGM
from mosaic.proteinmpnn.mpnn import load_mpnn_sol
from mosaic.structure_prediction import TargetChain
# Load target
st = gemmi.read_structure("targets/PDL1.cif")
st.remove_ligands_and_waters()
template_chain = st[0]["A"]
target_seq = gemmi.one_letter_code([r.name for r in template_chain.get_polymer()])
BINDER_LENGTH = 200
INIT_SCALE = 2.0
KEY = jax.random.key(0)
folder = Protenix2025()
mpnn = load_mpnn_sol(0.05)
# Build features for binder + target complex
features, _ = folder.binder_features(
binder_length=BINDER_LENGTH,
chains=[
TargetChain(sequence=target_seq, use_msa=True, template_chain=template_chain)
],
)
# Multi-term design loss: contacts + inverse folding + PAE + iPTM + pLDDT
loss = (
1.0 * sp.BinderTargetContact()
+ 1.0 * sp.WithinBinderContact()
+ 10.0 * InverseFoldingSequenceRecovery(mpnn, temp=0.001)
+ 0.05 * sp.TargetBinderPAE()
+ 0.05 * sp.BinderTargetPAE()
+ 0.025 * sp.IPTMLoss()
+ 0.4 * sp.WithinBinderPAE()
+ 0.025 * sp.pTMEnergy()
+ 0.1 * sp.PLDDTLoss()
)
design_loss = NoCys(
loss=folder.build_multisample_loss(
loss=loss, features=features, recycling_steps=6, num_samples=4
)
)
# 3-phase simplex optimization of the binder PSSM (19 cols, Cys excluded)
pssm = INIT_SCALE * jax.random.gumbel(KEY, shape=(BINDER_LENGTH, 19))
_, pssm = simplex_APGM(
loss_function=design_loss,
x=jax.nn.softmax(pssm),
n_steps=100,
stepsize=0.15 * sqrt(BINDER_LENGTH),
logspace=False,
)
pssm, _ = simplex_APGM(
loss_function=design_loss,
x=jnp.log(pssm + 1e-5),
n_steps=50,
stepsize=0.5 * sqrt(BINDER_LENGTH),
logspace=True,
)
pssm, _ = simplex_APGM(
loss_function=design_loss,
x=jnp.log(pssm + 1e-5),
n_steps=15,
stepsize=0.5 * sqrt(BINDER_LENGTH),
logspace=True,
)
# Extract sequence (reinsert Cys column, then argmax over 20 AAs)
full_pssm = NoCys.sequence(pssm)
seq_indices = full_pssm.argmax(-1)
seq_str = "".join(TOKENS[int(j)] for j in seq_indices)
# Rank: re-predict with higher recycling, score by iPTM + IPSAE
seq_oh = jax.nn.one_hot(seq_indices, 20)
rank_features, writer = folder.target_only_features(
chains=[
TargetChain(sequence=seq_str, use_msa=False),
TargetChain(sequence=target_seq, use_msa=True),
],
)
ranking_loss = folder.build_multisample_loss(
loss=1.0 * sp.IPTMLoss()
+ 0.5 * sp.TargetBinderIPSAE()
+ 0.5 * sp.BinderTargetIPSAE(),
features=rank_features,
recycling_steps=10,
num_samples=6,
)
@eqx.filter_jit
def eval_ranking(loss, pssm, key):
return loss(pssm, key=key)
rank_score, _ = eval_ranking(ranking_loss, seq_oh, key=jax.random.key(0))
print(f"Sequence: {seq_str}")
print(f"Rank score: {float(rank_score):.4f}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment