Skip to content

Instantly share code, notes, and snippets.

@harijay
Created September 30, 2025 09:58
Show Gist options
  • Save harijay/0d24ab59ab7a6fa43796df957a2bee86 to your computer and use it in GitHub Desktop.
Save harijay/0d24ab59ab7a6fa43796df957a2bee86 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
"""
Complete standalone implementation of sample.ipynb
This script performs protein structure prediction with SimpleFold and compares with ground truth.
For licensing see accompanying LICENSE file.
Copyright (c) 2025 Apple Inc. Licensed under MIT License.
"""
import sys
import numpy as np
from math import pow
import gc
import os
import psutil
from pathlib import Path
from io import StringIO
from Bio.PDB import PDBIO
from Bio.PDB import MMCIFParser, Superimposer
# Add the source path
sys.path.append(str(Path("./src/simplefold").resolve()))
def print_memory_usage(stage):
"""Print current memory usage"""
process = psutil.Process(os.getpid())
mem_info = process.memory_info()
system_mem = psutil.virtual_memory()
print(f"[{stage}] Process: {mem_info.rss/1024/1024:.1f}MB | System Available: {system_mem.available/1024/1024:.1f}MB")
def check_memory_safety():
"""Check if we have enough memory to continue"""
system_mem = psutil.virtual_memory()
available_gb = system_mem.available / (1024**3)
if available_gb < 2.0: # Less than 2GB available
print(f"⚠️ WARNING: Only {available_gb:.1f}GB memory available. This may cause issues.")
return False
return True
def calculate_tm_score(coords1, coords2, L_target=None):
"""
Compute TM-score for two aligned coordinate sets (numpy arrays).
coords1, coords2: Nx3 numpy arrays (aligned atomic coordinates, e.g. CA atoms)
L_target: length of target protein (default = len(coords1))
"""
assert coords1.shape == coords2.shape, "Aligned coords must have same shape"
N = coords1.shape[0]
if L_target is None:
L_target = N
# distances between aligned atoms
dists = np.linalg.norm(coords1 - coords2, axis=1)
# scaling factor d0
d0 = 1.24 * pow(L_target - 15, 1/3) - 1.8
if d0 < 0.5:
d0 = 0.5 # safeguard, as in TM-align
# TM-score
score = np.sum(1.0 / (1.0 + (dists/d0)**2)) / L_target
return score
def main():
print("=== SimpleFold CAS12L Protein Structure Prediction ===")
print_memory_usage("Start")
# Cell 1-2: Import and setup (already done above)
# Cell 3: Example sequences
print("\n=== Step 1: Setting up protein sequences ===")
example_sequences = {
"7ftv_A": "GASKLRAVLEKLKLSRDDISTAAGMVKGVVDHLLLRLKCDSAFRGVGLLNTGSYYEHVKISAPNEFDVMFKLEVPRIQLEEYSNTRAYYFVKFKRNPKENPLSQFLEGEILSASKMLSKFRKIIKEEINDDTDVIMKRKRGGSPAVTLLISEKISVDITLALESKSSWPASTQEGLRIQNWLSAKVRKQLRLKPFYLVPKHAEETWRLSFSHIEKEILNNHGKSKTCCENKEEKCCRKDCLKLMKYLLEQLKERFKDKKHLDKFSSYHVKTAFFHVCTQNPQDSQWDRKDLGLCFDNCVTYFLQCLRTEKLENYFIPEFNLFSSNLIDKRSKEFLTKQIEYERNNEFPVFD",
"8cny_A": "MGPSLDFALSLLRRNIRQVQTDQGHFTMLGVRDRLAVLPRHSQPGKTIWVEHKLINILDAVELVDEQGVNLELTLVTLDTNEKFRDITKFIPENISAASDATLVINTEHMPSMFVPVGDVVQYGFLNLSGKPTHRTMMYNFPTKAGQCGGVVTSVGKVIGIHIGGNGRQGFCAGLKRSYFAS",
"8g8r_A": "GTVNWSVEDIVKGINSNNLESQLQATQAARKLLSREKQPPIDNIIRAGLIPKFVSFLGKTDCSPIQFESAWALTNIASGTSEQTKAVVDGGAIPAFISLLASPHAHISEQAVWALGNIAGDGSAFRDLVIKHGAIDPLLALLAVPDLSTLACGYLRNLTWTLSNLCRNKNPAPPLDAVEQILPTLVRLLHHNDPEVLADSCWAISYLTDGPNERIEMVVKKGVVPQLVKLLGATELPIVTPALRAIGNIVTGTDEQTQKVIDAGALAVFPSLLTNPKTNIQKEATWTMSNITAGRQDQIQQVVNHGLVPFLVGVLSKADFKTQKEAAWAITNYTSGGTVEQIVYLVHCGIIEPLMNLLSAKDTKIIQVILDAISNIFQAAEKLGETEKLSIMIEECGGLDKIEALQRHENESVYKASLNLIEKYFS",
"8i85_A": "MGILQANRVLLSRLLPGVEPEGLTVRHGQFHQVVIASDRVVCLPRTAAAAARLPRRAAVMRVLAGLDLGCRTPRPLCEGSLPFLVLSRVPGAPLEADALEDSKVAEVVAAQYVTLLSGLASAGADEKVRAALPAPQGRWRQFAADVRAELFPLMSDGGCRQAERELAALDSLPDITEAVVHGNLGAENVLWVRDDGLPRLSGVIDWDEVSIGDPAEDLAAIGAGYGKDFLDQVLTLGGWSDRRMATRIATIRATFALQQALSACRDGDEEELADGLTGYR",
"8g8r_A_x": "GTVNWSVEDIVKGINSNNLESQLQATQAARKLLSREKQPPIDNIIRAGLIPKFVSFLGKTDCSPIQFESAWALTNIASGTSEQTKAVVDGGAIPAFISLLASPHAHISEQAVWALGNIAGDGSAFRDLVIKHGAIDPLLALLAVPDLSTLACGYLRNLTWTLSNLCRNKNPAPPLDAVEQILPTLVRLLHHNDPEVLADSCWAISYLTDGPNERIEMVVKKGVVPQLVKLLGATELPIVTPALRAIGNIVTGTDEQTQKVIDAGALAVFPSLLTNPKTNIQKEATWTMSNITAGRQDQIQQVVNHGLVPFLVGVLSKADFKTQKEAAWAITNYTSGGTVEQIVYLVHCGIIEPLMNLLSAKDTKIIQVILDAISNIFQAAEKLGETEKLSIMIEECGGLDKIEALQRHENESVYKASLNLIEKYFSGTVNWSVEDIVKGINSNNLESQLQATQAARKLLSREKQPPIDNIIRAGLIPKFVSFLGKTDCSPIQFESAWALTNIASGTSEQTKAVVDGGAIPAFISLLASPHAHISEQAVWALGNIAGDGSAFRDLVIKHGAIDPLLALLAVPDLSTLACGYLRNLTWTLSNLCRNKNPAPPLDAVEQILPTLVRLLHHNDPEVLADSCWAISYLTDGPNERIEMVVKKGVVPQLVKLLGATELPIVTPALRAIGNIVTGTDEQTQKVIDAGALAVFPSLLTNPKTNIQKEATWTMSNITAGRQDQIQQVVNHGLVPFLVGVLSKADFKTQKEAAWAITNYTSGGTVEQIVYLVHCGIIEPLMNLLSAKDTKIIQVILDAISNIFQAAEKLGETEKLSIMIEECGGLDKIEALQRHENESVYKASLNLIEKYFSGTVNWSVEDIVKGINSNNLESQLQATQAARKLLSREKQPPIDNIIRAGLIPKFVSFLGKTDCSPIQFESAWALTNIASGTSEQTKAVVDGGAIPAFISLLASPHAHISEQAVWALGNIAGDGSAFRDLVIKHGAIDPLLALLAVPDLSTLACGYLRNLTWTLSNLCRNKNPAPPLDAVEQILPTLVRLLHHNDPEVLADSCWAISYLTDGPNERIEMVVKKGVVPQLVKLLGATELPIVTPALRAIGNIVTGTDEQTQKVIDAGALAVFPSLLTNPKTNIQKEATWTMSNITAGRQDQIQQVVNHGLVPFLVGVLSKADFKTQKEAAWAITNYTSGGTVEQIVYLVHCGIIEPLMNLLSAKDTKIIQVILDAISNIFQAAEKLGETEKLSIMIEECGGLDKIEALQRHENESVYKASLNLIEKYFS",
}
# Allow user to choose sequence
seq_id = "8cny_A" # Default from notebook
if len(sys.argv) > 1:
seq_id = sys.argv[1]
if seq_id not in example_sequences:
print(f"❌ Unknown sequence ID: {seq_id}")
print(f"Available sequences: {list(example_sequences.keys())}")
return 1
aa_sequence = example_sequences[seq_id]
print(f"Predicting structure for {seq_id} with {len(aa_sequence)} amino acids.")
# Cell 4: Configuration
print("\n=== Step 2: Configuration ===")
simplefold_model = "simplefold_100M" # choose from 100M, 360M, 700M, 1.1B, 1.6B, 3B
backend = "torch" # choose from ["mlx", "torch"]
ckpt_dir = "artifacts"
output_dir = "artifacts"
prediction_dir = f"predictions_{simplefold_model}_{backend}"
output_name = f"{seq_id}"
num_steps = 500 # number of inference steps for flow-matching
tau = 0.05 # stochasticity scale
plddt = True # whether to use pLDDT confidence module
nsample_per_protein = 1 # number of samples per protein
print(f" Model: {simplefold_model}")
print(f" Backend: {backend}")
print(f" pLDDT: {plddt}")
print(f" Output: {output_dir}/{prediction_dir}")
try:
# Cell 5: Model initialization
print("\n=== Step 3: Loading models ===")
check_memory_safety()
print_memory_usage("Before model loading")
from src.simplefold.wrapper import ModelWrapper, InferenceWrapper
# Initialize the folding model and pLDDT model
model_wrapper = ModelWrapper(
simplefold_model=simplefold_model,
ckpt_dir=ckpt_dir,
plddt=plddt,
backend=backend,
)
device = model_wrapper.device
print(f"Using device: {device}")
folding_model = model_wrapper.from_pretrained_folding_model()
print_memory_usage("After folding model")
gc.collect()
plddt_model = model_wrapper.from_pretrained_plddt_model()
print_memory_usage("After pLDDT model")
gc.collect()
print("✓ Models loaded successfully")
# Cell 6: Initialize inference wrapper
print("\n=== Step 4: Initializing inference wrapper ===")
check_memory_safety()
inference_wrapper = InferenceWrapper(
output_dir=output_dir,
prediction_dir=prediction_dir,
num_steps=num_steps,
tau=tau,
nsample_per_protein=nsample_per_protein,
device=device,
backend=backend
)
print("✓ Inference wrapper initialized")
print_memory_usage("After inference wrapper")
gc.collect()
# Cell 7: Process input and run inference
print("\n=== Step 5: Running inference ===")
check_memory_safety()
print("Processing input sequence...")
batch, structure, record = inference_wrapper.process_input(aa_sequence)
print("✓ Input processed")
print("Running structure prediction...")
results = inference_wrapper.run_inference(
batch,
folding_model,
plddt_model,
device=device,
)
print("✓ Inference completed")
print("Saving results...")
save_paths = inference_wrapper.save_result(
structure,
record,
results,
out_name=output_name
)
print(f"✓ Structure saved to: {save_paths[0]}")
print_memory_usage("After inference")
gc.collect()
# Cell 8-11: Structure comparison and visualization
print("\n=== Step 6: Structure comparison ===")
pdb_path = save_paths[0]
# Check if ground truth file exists
ground_truth_path = Path(f"assets/{seq_id}.cif")
if not ground_truth_path.exists():
print(f"⚠️ Ground truth file not found: {ground_truth_path}")
print("Skipping structure comparison.")
print(f"✓ Predicted structure available at: {pdb_path}")
else:
print(f"Comparing with ground truth: {ground_truth_path}")
parser = MMCIFParser(QUIET=True)
# Load two structures
struct1 = parser.get_structure("ref", str(ground_truth_path))
struct2 = parser.get_structure("prd", str(pdb_path))
# Select CA atoms for alignment
atoms1 = [a for a in struct1.get_atoms() if a.get_id() == 'CA']
atoms2 = [a for a in struct2.get_atoms() if a.get_id() == 'CA']
print(f"Ground truth CA atoms: {len(atoms1)}, Predicted CA atoms: {len(atoms2)}")
if len(atoms1) != len(atoms2):
print("⚠️ Warning: Different number of CA atoms between structures")
min_len = min(len(atoms1), len(atoms2))
atoms1 = atoms1[:min_len]
atoms2 = atoms2[:min_len]
print(f"Using first {min_len} CA atoms for comparison")
# Superimpose
sup = Superimposer()
sup.set_atoms(atoms1, atoms2)
sup.apply(struct2.get_atoms())
# Calculate TM-score
coords1 = np.array([a.coord for a in atoms1])
coords2 = np.array([a.coord for a in atoms2])
tm_score = calculate_tm_score(coords1, coords2)
print(f"\n=== Structure Comparison Results ===")
print(f"TM-score (0-1, higher is better): {tm_score:.3f}")
print(f"RMSD (lower is better): {sup.rms:.3f}")
# Save aligned structures for inspection
io = PDBIO()
aligned_dir = Path(output_dir) / prediction_dir / "aligned"
aligned_dir.mkdir(parents=True, exist_ok=True)
# Save aligned ground truth
io.set_structure(struct1)
gt_aligned_path = aligned_dir / f"{seq_id}_ground_truth_aligned.pdb"
io.save(str(gt_aligned_path))
# Save aligned prediction
io.set_structure(struct2)
pred_aligned_path = aligned_dir / f"{seq_id}_predicted_aligned.pdb"
io.save(str(pred_aligned_path))
print(f"Aligned structures saved to: {aligned_dir}")
print(f" Ground truth: {gt_aligned_path}")
print(f" Predicted: {pred_aligned_path}")
print(f"\n=== Prediction Complete ===")
print(f"✓ Sequence: {seq_id} ({len(aa_sequence)} amino acids)")
print(f"✓ Model: {simplefold_model}")
print(f"✓ Output: {pdb_path}")
if ground_truth_path.exists():
print(f"✓ TM-score: {tm_score:.3f}")
print(f"✓ RMSD: {sup.rms:.3f}")
print_memory_usage("Final")
return 0
except Exception as e:
print(f"\n❌ Error occurred: {type(e).__name__}: {e}")
import traceback
traceback.print_exc()
print_memory_usage("Error state")
# Emergency cleanup
gc.collect()
print("\n💡 If errors persist, try:")
print(" 1. Close other applications to free memory")
print(" 2. Use a smaller model (simplefold_100M)")
print(" 3. Disable pLDDT (set plddt=False)")
print(" 4. Restart the terminal")
return 1
finally:
# Cleanup
try:
if 'folding_model' in locals():
del folding_model
if 'plddt_model' in locals():
del plddt_model
if 'model_wrapper' in locals():
del model_wrapper
if 'inference_wrapper' in locals():
del inference_wrapper
gc.collect()
print_memory_usage("After cleanup")
except:
pass
if __name__ == "__main__":
exit_code = main()
sys.exit(exit_code)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment