Created
September 30, 2025 09:58
-
-
Save harijay/0d24ab59ab7a6fa43796df957a2bee86 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| """ | |
| Complete standalone implementation of sample.ipynb | |
| This script performs protein structure prediction with SimpleFold and compares with ground truth. | |
| For licensing see accompanying LICENSE file. | |
| Copyright (c) 2025 Apple Inc. Licensed under MIT License. | |
| """ | |
| import sys | |
| import numpy as np | |
| from math import pow | |
| import gc | |
| import os | |
| import psutil | |
| from pathlib import Path | |
| from io import StringIO | |
| from Bio.PDB import PDBIO | |
| from Bio.PDB import MMCIFParser, Superimposer | |
| # Add the source path | |
| sys.path.append(str(Path("./src/simplefold").resolve())) | |
| def print_memory_usage(stage): | |
| """Print current memory usage""" | |
| process = psutil.Process(os.getpid()) | |
| mem_info = process.memory_info() | |
| system_mem = psutil.virtual_memory() | |
| print(f"[{stage}] Process: {mem_info.rss/1024/1024:.1f}MB | System Available: {system_mem.available/1024/1024:.1f}MB") | |
| def check_memory_safety(): | |
| """Check if we have enough memory to continue""" | |
| system_mem = psutil.virtual_memory() | |
| available_gb = system_mem.available / (1024**3) | |
| if available_gb < 2.0: # Less than 2GB available | |
| print(f"⚠️ WARNING: Only {available_gb:.1f}GB memory available. This may cause issues.") | |
| return False | |
| return True | |
| def calculate_tm_score(coords1, coords2, L_target=None): | |
| """ | |
| Compute TM-score for two aligned coordinate sets (numpy arrays). | |
| coords1, coords2: Nx3 numpy arrays (aligned atomic coordinates, e.g. CA atoms) | |
| L_target: length of target protein (default = len(coords1)) | |
| """ | |
| assert coords1.shape == coords2.shape, "Aligned coords must have same shape" | |
| N = coords1.shape[0] | |
| if L_target is None: | |
| L_target = N | |
| # distances between aligned atoms | |
| dists = np.linalg.norm(coords1 - coords2, axis=1) | |
| # scaling factor d0 | |
| d0 = 1.24 * pow(L_target - 15, 1/3) - 1.8 | |
| if d0 < 0.5: | |
| d0 = 0.5 # safeguard, as in TM-align | |
| # TM-score | |
| score = np.sum(1.0 / (1.0 + (dists/d0)**2)) / L_target | |
| return score | |
| def main(): | |
| print("=== SimpleFold CAS12L Protein Structure Prediction ===") | |
| print_memory_usage("Start") | |
| # Cell 1-2: Import and setup (already done above) | |
| # Cell 3: Example sequences | |
| print("\n=== Step 1: Setting up protein sequences ===") | |
| example_sequences = { | |
| "7ftv_A": "GASKLRAVLEKLKLSRDDISTAAGMVKGVVDHLLLRLKCDSAFRGVGLLNTGSYYEHVKISAPNEFDVMFKLEVPRIQLEEYSNTRAYYFVKFKRNPKENPLSQFLEGEILSASKMLSKFRKIIKEEINDDTDVIMKRKRGGSPAVTLLISEKISVDITLALESKSSWPASTQEGLRIQNWLSAKVRKQLRLKPFYLVPKHAEETWRLSFSHIEKEILNNHGKSKTCCENKEEKCCRKDCLKLMKYLLEQLKERFKDKKHLDKFSSYHVKTAFFHVCTQNPQDSQWDRKDLGLCFDNCVTYFLQCLRTEKLENYFIPEFNLFSSNLIDKRSKEFLTKQIEYERNNEFPVFD", | |
| "8cny_A": "MGPSLDFALSLLRRNIRQVQTDQGHFTMLGVRDRLAVLPRHSQPGKTIWVEHKLINILDAVELVDEQGVNLELTLVTLDTNEKFRDITKFIPENISAASDATLVINTEHMPSMFVPVGDVVQYGFLNLSGKPTHRTMMYNFPTKAGQCGGVVTSVGKVIGIHIGGNGRQGFCAGLKRSYFAS", | |
| "8g8r_A": "GTVNWSVEDIVKGINSNNLESQLQATQAARKLLSREKQPPIDNIIRAGLIPKFVSFLGKTDCSPIQFESAWALTNIASGTSEQTKAVVDGGAIPAFISLLASPHAHISEQAVWALGNIAGDGSAFRDLVIKHGAIDPLLALLAVPDLSTLACGYLRNLTWTLSNLCRNKNPAPPLDAVEQILPTLVRLLHHNDPEVLADSCWAISYLTDGPNERIEMVVKKGVVPQLVKLLGATELPIVTPALRAIGNIVTGTDEQTQKVIDAGALAVFPSLLTNPKTNIQKEATWTMSNITAGRQDQIQQVVNHGLVPFLVGVLSKADFKTQKEAAWAITNYTSGGTVEQIVYLVHCGIIEPLMNLLSAKDTKIIQVILDAISNIFQAAEKLGETEKLSIMIEECGGLDKIEALQRHENESVYKASLNLIEKYFS", | |
| "8i85_A": "MGILQANRVLLSRLLPGVEPEGLTVRHGQFHQVVIASDRVVCLPRTAAAAARLPRRAAVMRVLAGLDLGCRTPRPLCEGSLPFLVLSRVPGAPLEADALEDSKVAEVVAAQYVTLLSGLASAGADEKVRAALPAPQGRWRQFAADVRAELFPLMSDGGCRQAERELAALDSLPDITEAVVHGNLGAENVLWVRDDGLPRLSGVIDWDEVSIGDPAEDLAAIGAGYGKDFLDQVLTLGGWSDRRMATRIATIRATFALQQALSACRDGDEEELADGLTGYR", | |
| "8g8r_A_x": "GTVNWSVEDIVKGINSNNLESQLQATQAARKLLSREKQPPIDNIIRAGLIPKFVSFLGKTDCSPIQFESAWALTNIASGTSEQTKAVVDGGAIPAFISLLASPHAHISEQAVWALGNIAGDGSAFRDLVIKHGAIDPLLALLAVPDLSTLACGYLRNLTWTLSNLCRNKNPAPPLDAVEQILPTLVRLLHHNDPEVLADSCWAISYLTDGPNERIEMVVKKGVVPQLVKLLGATELPIVTPALRAIGNIVTGTDEQTQKVIDAGALAVFPSLLTNPKTNIQKEATWTMSNITAGRQDQIQQVVNHGLVPFLVGVLSKADFKTQKEAAWAITNYTSGGTVEQIVYLVHCGIIEPLMNLLSAKDTKIIQVILDAISNIFQAAEKLGETEKLSIMIEECGGLDKIEALQRHENESVYKASLNLIEKYFSGTVNWSVEDIVKGINSNNLESQLQATQAARKLLSREKQPPIDNIIRAGLIPKFVSFLGKTDCSPIQFESAWALTNIASGTSEQTKAVVDGGAIPAFISLLASPHAHISEQAVWALGNIAGDGSAFRDLVIKHGAIDPLLALLAVPDLSTLACGYLRNLTWTLSNLCRNKNPAPPLDAVEQILPTLVRLLHHNDPEVLADSCWAISYLTDGPNERIEMVVKKGVVPQLVKLLGATELPIVTPALRAIGNIVTGTDEQTQKVIDAGALAVFPSLLTNPKTNIQKEATWTMSNITAGRQDQIQQVVNHGLVPFLVGVLSKADFKTQKEAAWAITNYTSGGTVEQIVYLVHCGIIEPLMNLLSAKDTKIIQVILDAISNIFQAAEKLGETEKLSIMIEECGGLDKIEALQRHENESVYKASLNLIEKYFSGTVNWSVEDIVKGINSNNLESQLQATQAARKLLSREKQPPIDNIIRAGLIPKFVSFLGKTDCSPIQFESAWALTNIASGTSEQTKAVVDGGAIPAFISLLASPHAHISEQAVWALGNIAGDGSAFRDLVIKHGAIDPLLALLAVPDLSTLACGYLRNLTWTLSNLCRNKNPAPPLDAVEQILPTLVRLLHHNDPEVLADSCWAISYLTDGPNERIEMVVKKGVVPQLVKLLGATELPIVTPALRAIGNIVTGTDEQTQKVIDAGALAVFPSLLTNPKTNIQKEATWTMSNITAGRQDQIQQVVNHGLVPFLVGVLSKADFKTQKEAAWAITNYTSGGTVEQIVYLVHCGIIEPLMNLLSAKDTKIIQVILDAISNIFQAAEKLGETEKLSIMIEECGGLDKIEALQRHENESVYKASLNLIEKYFS", | |
| } | |
| # Allow user to choose sequence | |
| seq_id = "8cny_A" # Default from notebook | |
| if len(sys.argv) > 1: | |
| seq_id = sys.argv[1] | |
| if seq_id not in example_sequences: | |
| print(f"❌ Unknown sequence ID: {seq_id}") | |
| print(f"Available sequences: {list(example_sequences.keys())}") | |
| return 1 | |
| aa_sequence = example_sequences[seq_id] | |
| print(f"Predicting structure for {seq_id} with {len(aa_sequence)} amino acids.") | |
| # Cell 4: Configuration | |
| print("\n=== Step 2: Configuration ===") | |
| simplefold_model = "simplefold_100M" # choose from 100M, 360M, 700M, 1.1B, 1.6B, 3B | |
| backend = "torch" # choose from ["mlx", "torch"] | |
| ckpt_dir = "artifacts" | |
| output_dir = "artifacts" | |
| prediction_dir = f"predictions_{simplefold_model}_{backend}" | |
| output_name = f"{seq_id}" | |
| num_steps = 500 # number of inference steps for flow-matching | |
| tau = 0.05 # stochasticity scale | |
| plddt = True # whether to use pLDDT confidence module | |
| nsample_per_protein = 1 # number of samples per protein | |
| print(f" Model: {simplefold_model}") | |
| print(f" Backend: {backend}") | |
| print(f" pLDDT: {plddt}") | |
| print(f" Output: {output_dir}/{prediction_dir}") | |
| try: | |
| # Cell 5: Model initialization | |
| print("\n=== Step 3: Loading models ===") | |
| check_memory_safety() | |
| print_memory_usage("Before model loading") | |
| from src.simplefold.wrapper import ModelWrapper, InferenceWrapper | |
| # Initialize the folding model and pLDDT model | |
| model_wrapper = ModelWrapper( | |
| simplefold_model=simplefold_model, | |
| ckpt_dir=ckpt_dir, | |
| plddt=plddt, | |
| backend=backend, | |
| ) | |
| device = model_wrapper.device | |
| print(f"Using device: {device}") | |
| folding_model = model_wrapper.from_pretrained_folding_model() | |
| print_memory_usage("After folding model") | |
| gc.collect() | |
| plddt_model = model_wrapper.from_pretrained_plddt_model() | |
| print_memory_usage("After pLDDT model") | |
| gc.collect() | |
| print("✓ Models loaded successfully") | |
| # Cell 6: Initialize inference wrapper | |
| print("\n=== Step 4: Initializing inference wrapper ===") | |
| check_memory_safety() | |
| inference_wrapper = InferenceWrapper( | |
| output_dir=output_dir, | |
| prediction_dir=prediction_dir, | |
| num_steps=num_steps, | |
| tau=tau, | |
| nsample_per_protein=nsample_per_protein, | |
| device=device, | |
| backend=backend | |
| ) | |
| print("✓ Inference wrapper initialized") | |
| print_memory_usage("After inference wrapper") | |
| gc.collect() | |
| # Cell 7: Process input and run inference | |
| print("\n=== Step 5: Running inference ===") | |
| check_memory_safety() | |
| print("Processing input sequence...") | |
| batch, structure, record = inference_wrapper.process_input(aa_sequence) | |
| print("✓ Input processed") | |
| print("Running structure prediction...") | |
| results = inference_wrapper.run_inference( | |
| batch, | |
| folding_model, | |
| plddt_model, | |
| device=device, | |
| ) | |
| print("✓ Inference completed") | |
| print("Saving results...") | |
| save_paths = inference_wrapper.save_result( | |
| structure, | |
| record, | |
| results, | |
| out_name=output_name | |
| ) | |
| print(f"✓ Structure saved to: {save_paths[0]}") | |
| print_memory_usage("After inference") | |
| gc.collect() | |
| # Cell 8-11: Structure comparison and visualization | |
| print("\n=== Step 6: Structure comparison ===") | |
| pdb_path = save_paths[0] | |
| # Check if ground truth file exists | |
| ground_truth_path = Path(f"assets/{seq_id}.cif") | |
| if not ground_truth_path.exists(): | |
| print(f"⚠️ Ground truth file not found: {ground_truth_path}") | |
| print("Skipping structure comparison.") | |
| print(f"✓ Predicted structure available at: {pdb_path}") | |
| else: | |
| print(f"Comparing with ground truth: {ground_truth_path}") | |
| parser = MMCIFParser(QUIET=True) | |
| # Load two structures | |
| struct1 = parser.get_structure("ref", str(ground_truth_path)) | |
| struct2 = parser.get_structure("prd", str(pdb_path)) | |
| # Select CA atoms for alignment | |
| atoms1 = [a for a in struct1.get_atoms() if a.get_id() == 'CA'] | |
| atoms2 = [a for a in struct2.get_atoms() if a.get_id() == 'CA'] | |
| print(f"Ground truth CA atoms: {len(atoms1)}, Predicted CA atoms: {len(atoms2)}") | |
| if len(atoms1) != len(atoms2): | |
| print("⚠️ Warning: Different number of CA atoms between structures") | |
| min_len = min(len(atoms1), len(atoms2)) | |
| atoms1 = atoms1[:min_len] | |
| atoms2 = atoms2[:min_len] | |
| print(f"Using first {min_len} CA atoms for comparison") | |
| # Superimpose | |
| sup = Superimposer() | |
| sup.set_atoms(atoms1, atoms2) | |
| sup.apply(struct2.get_atoms()) | |
| # Calculate TM-score | |
| coords1 = np.array([a.coord for a in atoms1]) | |
| coords2 = np.array([a.coord for a in atoms2]) | |
| tm_score = calculate_tm_score(coords1, coords2) | |
| print(f"\n=== Structure Comparison Results ===") | |
| print(f"TM-score (0-1, higher is better): {tm_score:.3f}") | |
| print(f"RMSD (lower is better): {sup.rms:.3f}") | |
| # Save aligned structures for inspection | |
| io = PDBIO() | |
| aligned_dir = Path(output_dir) / prediction_dir / "aligned" | |
| aligned_dir.mkdir(parents=True, exist_ok=True) | |
| # Save aligned ground truth | |
| io.set_structure(struct1) | |
| gt_aligned_path = aligned_dir / f"{seq_id}_ground_truth_aligned.pdb" | |
| io.save(str(gt_aligned_path)) | |
| # Save aligned prediction | |
| io.set_structure(struct2) | |
| pred_aligned_path = aligned_dir / f"{seq_id}_predicted_aligned.pdb" | |
| io.save(str(pred_aligned_path)) | |
| print(f"Aligned structures saved to: {aligned_dir}") | |
| print(f" Ground truth: {gt_aligned_path}") | |
| print(f" Predicted: {pred_aligned_path}") | |
| print(f"\n=== Prediction Complete ===") | |
| print(f"✓ Sequence: {seq_id} ({len(aa_sequence)} amino acids)") | |
| print(f"✓ Model: {simplefold_model}") | |
| print(f"✓ Output: {pdb_path}") | |
| if ground_truth_path.exists(): | |
| print(f"✓ TM-score: {tm_score:.3f}") | |
| print(f"✓ RMSD: {sup.rms:.3f}") | |
| print_memory_usage("Final") | |
| return 0 | |
| except Exception as e: | |
| print(f"\n❌ Error occurred: {type(e).__name__}: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| print_memory_usage("Error state") | |
| # Emergency cleanup | |
| gc.collect() | |
| print("\n💡 If errors persist, try:") | |
| print(" 1. Close other applications to free memory") | |
| print(" 2. Use a smaller model (simplefold_100M)") | |
| print(" 3. Disable pLDDT (set plddt=False)") | |
| print(" 4. Restart the terminal") | |
| return 1 | |
| finally: | |
| # Cleanup | |
| try: | |
| if 'folding_model' in locals(): | |
| del folding_model | |
| if 'plddt_model' in locals(): | |
| del plddt_model | |
| if 'model_wrapper' in locals(): | |
| del model_wrapper | |
| if 'inference_wrapper' in locals(): | |
| del inference_wrapper | |
| gc.collect() | |
| print_memory_usage("After cleanup") | |
| except: | |
| pass | |
| if __name__ == "__main__": | |
| exit_code = main() | |
| sys.exit(exit_code) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment