Skip to content

Instantly share code, notes, and snippets.

@nboyd
Created March 2, 2026 01:14
Show Gist options
  • Select an option

  • Save nboyd/1ca6acdb5a6f2676f1e1a0cf6e3459c2 to your computer and use it in GitHub Desktop.

Select an option

Save nboyd/1ca6acdb5a6f2676f1e1a0cf6e3459c2 to your computer and use it in GitHub Desktop.
Hallucination with protenix + mosaic.
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "mosaic @ git+https://github.com/escalante-bio/mosaic.git",
# "gemmi>=0.6.5",
# "jax[cuda12]",
# "numpy",
# ]
#
# [tool.uv]
# override-dependencies = ["scipy>=1.15.3", "numpy>=2.1"]
# ///
"""
Minimal protein binder hallucination with Protenix + ProteinMPNN.
Optimizer hyperparameters are randomized per design for diversity.
"""
import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.95"
import gemmi
import jax
import jax.numpy as jnp
import numpy as np
from mosaic.common import TOKENS
from mosaic.losses.protein_mpnn import InverseFoldingSequenceRecovery
from mosaic.losses.transformations import NoCys
import mosaic.losses.structure_prediction as sp
from mosaic.models.protenix import Protenix2025
from mosaic.optimizers import simplex_APGM
from mosaic.proteinmpnn.mpnn import load_mpnn_sol
from mosaic.structure_prediction import TargetChain
from dataclasses import dataclass
# ── Hyperparameter distributions ──────────────────────────────────────────
@dataclass(frozen=True)
class Uniform:
lo: float; hi: float
def sample(self, rng): return float(rng.uniform(self.lo, self.hi))
@dataclass(frozen=True)
class UniformInt:
lo: int; hi: int
def sample(self, rng): return int(rng.integers(self.lo, self.hi + 1))
# ── Configuration ─────────────────────────────────────────────────────────
CIF_PATH = "target.cif"
CHAIN_ID = "A"
BINDER_LENGTH = 100
SEED = 0
# ── Load models ───────────────────────────────────────────────────────────
folder = Protenix2025()
mpnn = load_mpnn_sol(0.05)
# ── Parse target ──────────────────────────────────────────────────────────
st = gemmi.read_structure(CIF_PATH)
st.remove_ligands_and_waters()
chain = st[0].find_chain(CHAIN_ID)
target_seq = gemmi.one_letter_code([r.name for r in chain.get_polymer()])
features, _ = folder.binder_features(
binder_length=BINDER_LENGTH,
chains=[TargetChain(sequence=target_seq, use_msa=True, template_chain=chain)],
)
# ── Design loss ───────────────────────────────────────────────────────────
cys_idx = TOKENS.index("C")
bias = jnp.zeros((BINDER_LENGTH, 20)).at[:, cys_idx].set(-1e6)
design_loss = NoCys(
loss=folder.build_multisample_loss(
loss=(
1.0 * sp.BinderTargetContact()
+ 1.0 * sp.WithinBinderContact()
+ 10.0 * InverseFoldingSequenceRecovery(mpnn, temp=jnp.array(0.001), bias=bias)
+ 0.05 * sp.TargetBinderPAE()
+ 0.05 * sp.BinderTargetPAE()
+ 0.025 * sp.IPTMLoss()
+ 0.4 * sp.WithinBinderPAE()
+ 0.025 * sp.pTMEnergy()
+ 0.1 * sp.PLDDTLoss()
),
features=features,
recycling_steps=6,
num_samples=4,
)
)
# ── Three-phase optimization ──────────────────────────────────────────────
rng = np.random.default_rng(SEED)
stepsize_base = np.sqrt(BINDER_LENGTH)
pssm = Uniform(0.75, 5.0).sample(rng) * jax.random.gumbel(
jax.random.key(SEED), shape=(BINDER_LENGTH, 19)
)
# Phase 1: explore (softmax space, with momentum, keep best)
_, pssm = simplex_APGM(
loss_function=design_loss,
x=jax.nn.softmax(pssm),
n_steps=UniformInt(100, 110).sample(rng),
stepsize=Uniform(0.08, 0.30).sample(rng) * stepsize_base,
momentum=Uniform(0.1, 0.5).sample(rng),
scale=1.0,
logspace=False,
max_gradient_norm=1.0,
)
# Phase 2: refine (log space)
pssm, _ = simplex_APGM(
loss_function=design_loss,
x=jnp.log(pssm + 1e-5),
n_steps=UniformInt(30, 70).sample(rng),
stepsize=Uniform(0.3, 0.7).sample(rng) * stepsize_base,
scale=1.25,
logspace=True,
max_gradient_norm=1.0,
)
# Phase 3: polish (log space, higher scale)
pssm, _ = simplex_APGM(
loss_function=design_loss,
x=jnp.log(pssm + 1e-5),
n_steps=UniformInt(10, 25).sample(rng),
stepsize=Uniform(0.3, 0.7).sample(rng) * stepsize_base,
scale=1.4,
logspace=True,
max_gradient_norm=1.0,
)
sequence = "".join(TOKENS[int(j)] for j in NoCys.sequence(pssm).argmax(-1))
# ── Score with full sidechain information + more recycles ─────────────────
rank_features, rank_writer = folder.target_only_features(
chains=[
TargetChain(sequence=sequence, use_msa=False),
TargetChain(sequence=target_seq, use_msa=True, template_chain=chain),
]
)
pred = folder.predict(
PSSM=jax.nn.one_hot(NoCys.sequence(pssm).argmax(-1), num_classes=20),
features=rank_features,
writer=rank_writer,
recycling_steps=10,
key=jax.random.key(0),
)
print(f"Sequence: {sequence}")
print(f"iPTM: {float(pred.iptm):.4f}, pLDDT: {float(pred.plddt.mean()):.2f}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment