Created
March 2, 2026 01:14
-
-
Save nboyd/1ca6acdb5a6f2676f1e1a0cf6e3459c2 to your computer and use it in GitHub Desktop.
Hallucination with protenix + mosaic.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # /// script | |
| # requires-python = ">=3.12" | |
| # dependencies = [ | |
| # "mosaic @ git+https://github.com/escalante-bio/mosaic.git", | |
| # "gemmi>=0.6.5", | |
| # "jax[cuda12]", | |
| # "numpy", | |
| # ] | |
| # | |
| # [tool.uv] | |
| # override-dependencies = ["scipy>=1.15.3", "numpy>=2.1"] | |
| # /// | |
| """ | |
| Minimal protein binder hallucination with Protenix + ProteinMPNN. | |
| Optimizer hyperparameters are randomized per design for diversity. | |
| """ | |
| import os | |
| os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.95" | |
| import gemmi | |
| import jax | |
| import jax.numpy as jnp | |
| import numpy as np | |
| from mosaic.common import TOKENS | |
| from mosaic.losses.protein_mpnn import InverseFoldingSequenceRecovery | |
| from mosaic.losses.transformations import NoCys | |
| import mosaic.losses.structure_prediction as sp | |
| from mosaic.models.protenix import Protenix2025 | |
| from mosaic.optimizers import simplex_APGM | |
| from mosaic.proteinmpnn.mpnn import load_mpnn_sol | |
| from mosaic.structure_prediction import TargetChain | |
| from dataclasses import dataclass | |
| # ── Hyperparameter distributions ────────────────────────────────────────── | |
| @dataclass(frozen=True) | |
| class Uniform: | |
| lo: float; hi: float | |
| def sample(self, rng): return float(rng.uniform(self.lo, self.hi)) | |
| @dataclass(frozen=True) | |
| class UniformInt: | |
| lo: int; hi: int | |
| def sample(self, rng): return int(rng.integers(self.lo, self.hi + 1)) | |
| # ── Configuration ───────────────────────────────────────────────────────── | |
| CIF_PATH = "target.cif" | |
| CHAIN_ID = "A" | |
| BINDER_LENGTH = 100 | |
| SEED = 0 | |
| # ── Load models ─────────────────────────────────────────────────────────── | |
| folder = Protenix2025() | |
| mpnn = load_mpnn_sol(0.05) | |
| # ── Parse target ────────────────────────────────────────────────────────── | |
| st = gemmi.read_structure(CIF_PATH) | |
| st.remove_ligands_and_waters() | |
| chain = st[0].find_chain(CHAIN_ID) | |
| target_seq = gemmi.one_letter_code([r.name for r in chain.get_polymer()]) | |
| features, _ = folder.binder_features( | |
| binder_length=BINDER_LENGTH, | |
| chains=[TargetChain(sequence=target_seq, use_msa=True, template_chain=chain)], | |
| ) | |
| # ── Design loss ─────────────────────────────────────────────────────────── | |
| cys_idx = TOKENS.index("C") | |
| bias = jnp.zeros((BINDER_LENGTH, 20)).at[:, cys_idx].set(-1e6) | |
| design_loss = NoCys( | |
| loss=folder.build_multisample_loss( | |
| loss=( | |
| 1.0 * sp.BinderTargetContact() | |
| + 1.0 * sp.WithinBinderContact() | |
| + 10.0 * InverseFoldingSequenceRecovery(mpnn, temp=jnp.array(0.001), bias=bias) | |
| + 0.05 * sp.TargetBinderPAE() | |
| + 0.05 * sp.BinderTargetPAE() | |
| + 0.025 * sp.IPTMLoss() | |
| + 0.4 * sp.WithinBinderPAE() | |
| + 0.025 * sp.pTMEnergy() | |
| + 0.1 * sp.PLDDTLoss() | |
| ), | |
| features=features, | |
| recycling_steps=6, | |
| num_samples=4, | |
| ) | |
| ) | |
| # ── Three-phase optimization ────────────────────────────────────────────── | |
| rng = np.random.default_rng(SEED) | |
| stepsize_base = np.sqrt(BINDER_LENGTH) | |
| pssm = Uniform(0.75, 5.0).sample(rng) * jax.random.gumbel( | |
| jax.random.key(SEED), shape=(BINDER_LENGTH, 19) | |
| ) | |
| # Phase 1: explore (softmax space, with momentum, keep best) | |
| _, pssm = simplex_APGM( | |
| loss_function=design_loss, | |
| x=jax.nn.softmax(pssm), | |
| n_steps=UniformInt(100, 110).sample(rng), | |
| stepsize=Uniform(0.08, 0.30).sample(rng) * stepsize_base, | |
| momentum=Uniform(0.1, 0.5).sample(rng), | |
| scale=1.0, | |
| logspace=False, | |
| max_gradient_norm=1.0, | |
| ) | |
| # Phase 2: refine (log space) | |
| pssm, _ = simplex_APGM( | |
| loss_function=design_loss, | |
| x=jnp.log(pssm + 1e-5), | |
| n_steps=UniformInt(30, 70).sample(rng), | |
| stepsize=Uniform(0.3, 0.7).sample(rng) * stepsize_base, | |
| scale=1.25, | |
| logspace=True, | |
| max_gradient_norm=1.0, | |
| ) | |
| # Phase 3: polish (log space, higher scale) | |
| pssm, _ = simplex_APGM( | |
| loss_function=design_loss, | |
| x=jnp.log(pssm + 1e-5), | |
| n_steps=UniformInt(10, 25).sample(rng), | |
| stepsize=Uniform(0.3, 0.7).sample(rng) * stepsize_base, | |
| scale=1.4, | |
| logspace=True, | |
| max_gradient_norm=1.0, | |
| ) | |
| sequence = "".join(TOKENS[int(j)] for j in NoCys.sequence(pssm).argmax(-1)) | |
| # ── Score with full sidechain information + more recycles ───────────────── | |
| rank_features, rank_writer = folder.target_only_features( | |
| chains=[ | |
| TargetChain(sequence=sequence, use_msa=False), | |
| TargetChain(sequence=target_seq, use_msa=True, template_chain=chain), | |
| ] | |
| ) | |
| pred = folder.predict( | |
| PSSM=jax.nn.one_hot(NoCys.sequence(pssm).argmax(-1), num_classes=20), | |
| features=rank_features, | |
| writer=rank_writer, | |
| recycling_steps=10, | |
| key=jax.random.key(0), | |
| ) | |
| print(f"Sequence: {sequence}") | |
| print(f"iPTM: {float(pred.iptm):.4f}, pLDDT: {float(pred.plddt.mean()):.2f}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment