Skip to content

Instantly share code, notes, and snippets.

@ljmartin
Created November 27, 2024 22:23
Show Gist options
  • Save ljmartin/cccf3778f956f7af7f9289c979bc86c4 to your computer and use it in GitHub Desktop.
Save ljmartin/cccf3778f956f7af7f9289c979bc86c4 to your computer and use it in GitHub Desktop.
import modal
image = (
modal.Image.debian_slim()
.pip_install("uv")
.run_commands("uv pip install --system --compile-bytecode chai_lab==0.4.2")
.run_commands('mkdir -p /usr/local/lib/python3.11/site-packages/downloads/models_v2')
.apt_install("wget")
.run_commands('wget -O /usr/local/lib/python3.11/site-packages/downloads/conformers_v1.apkl https://chaiassets.com/chai1-inference-depencencies/conformers_v1.apkl')
.run_commands('wget -O /usr/local/lib/python3.11/site-packages/downloads/models_v2/trunk.pt https://chaiassets.com/chai1-inference-depencencies/models_v2/trunk.pt')
.run_commands('wget -O /usr/local/lib/python3.11/site-packages/downloads/models_v2/token_embedder.pt https://chaiassets.com/chai1-inference-depencencies/models_v2/token_embedder.pt')
.run_commands('wget -O /usr/local/lib/python3.11/site-packages/downloads/models_v2/feature_embedding.pt https://chaiassets.com/chai1-inference-depencencies/models_v2/feature_embedding.pt')
.run_commands('wget -O /usr/local/lib/python3.11/site-packages/downloads/models_v2/diffusion_module.pt https://chaiassets.com/chai1-inference-depencencies/models_v2/diffusion_module.pt')
.run_commands('wget -O /usr/local/lib/python3.11/site-packages/downloads/models_v2/confidence_head.pt https://chaiassets.com/chai1-inference-depencencies/models_v2/confidence_head.pt')
)
app = modal.App(name="chairun", image=image)
@app.function(
timeout = 60*15, # 15 mins - increase obviously for bigger or multiple runs.
gpu="A100"
)
def chairun():
import os
from pathlib import Path
import numpy as np
import torch
from chai_lab.chai1 import run_inference
# We use fasta-like format for inputs.
# - each entity encodes protein, ligand, RNA or DNA
# - each entity is labeled with unique name;
# - ligands are encoded with SMILES; modified residues encoded like AAA(SEP)AAA
# Example given below, just modify it
example_fasta = """
>protein|name=example-of-long-protein
AGSHSMRYFSTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASPRGEPRAPWVEQEGPEYWDRETQKYKRQAQTDRVSLRNLRGYYNQSEAGSHTLQWMFGCDLGPDGRLLRGYDQSAYDGKDYIALNEDLRSWTAADTAAQITQRKWEAAREAEQRRAYLEGTCVEWLRRYLENGKETLQRAEHPKTHVTHHPVSDHEATLRCWALGFYPAEITLTWQWDGEDQTQDTELVETRPAGDGTFQKWAAVVVPSGEEQRYTCHVQHEGLPEPLTLRWEP
>protein|name=example-of-short-protein
AIQRTPKIQVYSRHPAENGKSNFLNCYVSGFHPSDIEVDLLKNGERIEKVEHSDLSFSKDWSFYLLYYTEFTPTEKDEYACRVNHVTLSQPKIVKWDRDM
>protein|name=example-peptide
GAAL
>ligand|name=example-ligand-as-smiles
CCCCCCCCCCCCCC(=O)O
""".strip()
fasta_path = Path("/tmp/example.fasta")
fasta_path.write_text(example_fasta)
output_dir = Path("/tmp/outputs")
candidates = run_inference(
fasta_file=fasta_path,
output_dir=output_dir,
# 'default' setup
num_trunk_recycles=3,
num_diffn_timesteps=200,
seed=42,
device=torch.device("cuda:0"),
use_esm_embeddings=True,
)
cif_paths = candidates.cif_paths
scores = [rd.aggregate_score for rd in candidates.ranking_data]
# Load pTM, ipTM, pLDDTs and clash scores for the 5 samples
npzs = []
for i in range(5):
npzdata = open(output_dir.joinpath(f"scores.model_idx_{i}.npz"), 'rb').read()
npzs.append(npzdata)
# the example script saves five cif outputs to /tmp/outputs/pred.model_idx_N.cif.
# read em and keep:
cifs = []
for i in range(5):
cifdat = open(f'/tmp/outputs/pred.model_idx_{i}.cif', 'r').read()
cifs.append(cifdat)
return npzs, cifs
@app.local_entrypoint()
def main():
npzs, cifs = chairun.remote()
for c, npz in enumerate(npzs):
with open(f'./result{c}.npz', 'wb') as f:
f.write(npz)
for c, cif in enumerate(cifs):
with open(f'./result{c}.cif', 'w') as f:
f.write(cif)
print('done')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment