Skip to content

Instantly share code, notes, and snippets.

@pengzhangzhi
Created February 5, 2025 16:36
Show Gist options
  • Save pengzhangzhi/dabedb7fbb2ba28506552b6fc5b50cb7 to your computer and use it in GitHub Desktop.
Save pengzhangzhi/dabedb7fbb2ba28506552b6fc5b50cb7 to your computer and use it in GitHub Desktop.
using esmfold to fold a dir of fasta or just a fasta file.
# Copyright (c) 2023 Meta Platforms, Inc. and affiliates
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: Apache-2.0
#
# This file has been modified by Xinyou Wang on Jul 21, 2024
#
# Original file was released under MIT, with the full license text
# available at https://github.com/facebookresearch/esm/blob/main/LICENSE
#
# This modified file is released under the same license.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from pathlib import Path
import re
import sys, os
import argparse
import logging
import typing as T
from pathlib import Path
from timeit import default_timer as timer
import numpy as np
import torch
import math
from typing import List, Optional, Tuple
from collections import Counter
import esm
import pandas as pd
logger = logging.getLogger()
logger.setLevel(logging.INFO)
formatter = logging.Formatter(
"%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%y/%m/%d %H:%M:%S",
)
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
PathLike = T.Union[str, Path]
def calculate_entropy(sequence: str) -> float:
amino_acid_counts = Counter(sequence)
total_amino_acids = len(sequence)
probabilities = (count / total_amino_acids for count in amino_acid_counts.values())
return -sum(p * math.log2(p) for p in probabilities if p > 0)
def read_fasta(
path,
keep_gaps=True,
keep_insertions=True,
to_upper=False,
):
with open(path, "r") as f:
for result in read_alignment_lines(
f, keep_gaps=keep_gaps, keep_insertions=keep_insertions, to_upper=to_upper
):
yield result
def read_alignment_lines(
lines,
keep_gaps=True,
keep_insertions=True,
to_upper=False,
):
seq = desc = None
def parse(s):
if not keep_gaps:
s = re.sub("-", "", s)
if not keep_insertions:
s = re.sub("[a-z]", "", s)
return s.upper() if to_upper else s
for line in lines:
# Line may be empty if seq % file_line_width == 0
if len(line) > 0 and line[0] == ">":
if seq is not None and 'X' not in seq:
yield desc, parse(seq)
desc = line.strip().lstrip(">")
seq = ""
else:
assert isinstance(seq, str)
seq += line.strip()
assert isinstance(seq, str) and isinstance(desc, str)
if 'X' not in seq:
yield desc, parse(seq)
def enable_cpu_offloading(model):
from torch.distributed.fsdp import CPUOffload, FullyShardedDataParallel
from torch.distributed.fsdp.wrap import enable_wrap, wrap
torch.distributed.init_process_group(
backend="nccl", init_method="tcp://localhost:9999", world_size=1, rank=0
)
wrapper_kwargs = dict(cpu_offload=CPUOffload(offload_params=True))
with enable_wrap(wrapper_cls=FullyShardedDataParallel, **wrapper_kwargs):
for layer_name, layer in model.layers.named_children():
wrapped_layer = wrap(layer)
setattr(model.layers, layer_name, wrapped_layer)
model = wrap(model)
return model
def init_model_on_gpu_with_cpu_offloading(model):
model = model.eval()
model_esm = enable_cpu_offloading(model.esm)
del model.esm
model.cuda()
model.esm = model_esm
return model
def create_batched_sequence_dataset(
sequences: T.List[T.Tuple[str, str]], max_tokens_per_batch: int = 1024
) -> T.Generator[T.Tuple[T.List[str], T.List[str]], None, None]:
batch_headers, batch_sequences, num_tokens = [], [], 0
for header, seq in sequences:
if (len(seq) + num_tokens > max_tokens_per_batch) and num_tokens > 0:
yield batch_headers, batch_sequences
batch_headers, batch_sequences, num_tokens = [], [], 0
batch_headers.append(header)
batch_sequences.append(seq)
num_tokens += len(seq)
yield batch_headers, batch_sequences
def create_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"-i",
"--fasta",
help="Path directory to input FASTA file",
type=Path,
required=True,
default='ROOTDIR/' + \
'generation-results/' + \
'touchhigh_various_length_all',
)
parser.add_argument(
"-o", "--pdb", help="Path directory to output PDB directory", type=Path, required=True,
default='ROOTDIR/' + \
'generation-results/' + \
'touchhigh_various_length_all/esmfold_pdb',
)
parser.add_argument(
"-m", "--model-dir", help="Parent path to Pretrained ESM data directory. ", type=Path, default=None
)
parser.add_argument(
"--num-recycles",
type=int,
default=None,
help="Number of recycles to run. Defaults to number used in training (4).",
)
parser.add_argument(
"--max-tokens-per-batch",
type=int,
default=1024,
help="Maximum number of tokens per gpu forward-pass. This will group shorter sequences together "
"for batched prediction. Lowering this can help with out of memory issues, if these occur on "
"short sequences.",
)
parser.add_argument(
"--chunk-size",
type=int,
default=None,
help="Chunks axial attention computation to reduce memory usage from O(L^2) to O(L). "
"Equivalent to running a for loop over chunks of of each dimension. Lower values will "
"result in lower memory usage at the cost of speed. Recommended values: 128, 64, 32. "
"Default: None.",
)
parser.add_argument("--cpu-only", help="CPU only", action="store_true")
parser.add_argument("--cpu-offload", help="Enable CPU offloading", action="store_true")
return parser
def run(
fasta: Path,
pdb: Path,
model_dir: Optional[Path] = None,
num_recycles: Optional[int] = None,
max_tokens_per_batch: int = 1024,
chunk_size: Optional[int] = None,
cpu_only: bool = False,
cpu_offload: bool = False,
) -> pd.DataFrame:
"""
Executes the evaluation pipeline and returns a DataFrame with the results.
Args:
fasta (Path): Path to input FASTA file directory.
pdb (Path): Path to output PDB directory.
model_dir (Optional[Path]): Path to pretrained ESM data directory.
num_recycles (Optional[int]): Number of recycles to run.
max_tokens_per_batch (int): Maximum number of tokens per batch.
chunk_size (Optional[int]): Chunk size for axial attention.
cpu_only (bool): Flag to use CPU only.
cpu_offload (bool): Flag to enable CPU offloading.
Returns:
pd.DataFrame: DataFrame containing the evaluation results.
"""
# Ensure output directory exists
pdb.mkdir(parents=True, exist_ok=True)
logger.info("Output PDB directory is set to: %s", pdb)
logger.info("Loading model...")
# Set ESM cache directory if model_dir is provided
if model_dir is not None:
torch.hub.set_dir(str(model_dir))
# Load the ESMFold model
model = esm.pretrained.esmfold_v1()
model = model.eval()
model.set_chunk_size(chunk_size)
if cpu_only:
model.esm.float() # Convert to fp32 as ESM-2 in fp16 is not supported on CPU
model.cpu()
logger.info("Model moved to CPU only.")
elif cpu_offload:
model = init_model_on_gpu_with_cpu_offloading(model)
logger.info("Model initialized with CPU offloading.")
else:
model.cuda()
logger.info("Model moved to GPU.")
# List all FASTA files in the input directory
fasta_list = [f for f in os.listdir(fasta) if f.endswith('.fasta') and not os.path.isdir(os.path.join(fasta, f))]
logger.info("Found %d FASTA files in the input directory.", len(fasta_list))
# Initialize a list to collect data for the DataFrame
data_records = []
for fasta_file in fasta_list:
fasta_path = fasta / fasta_file
logger.info("Processing FASTA file: %s", fasta_path)
# Define output directory for this FASTA file
pdbdir = pdb / fasta_file.replace('.fasta', '')
pdbdir.mkdir(parents=True, exist_ok=True)
logger.info("Output PDB directory for %s: %s", fasta_file, pdbdir)
# Read and sort sequences by length
logger.info("Reading sequences from %s", fasta_path)
all_sequences = sorted(read_fasta(str(fasta_path)), key=lambda header_seq: len(header_seq[1]))
logger.info("Loaded %d sequences from %s", len(all_sequences), fasta_file)
if not all_sequences:
logger.warning("No sequences found in %s. Skipping.", fasta_file)
continue
# Create batched sequences
logger.info("Creating batched sequences...")
batched_sequences = create_batched_sequence_dataset(all_sequences, max_tokens_per_batch)
logger.info("Batched sequences created.")
num_completed = 0
num_sequences = len(all_sequences)
# Iterate over each batch
for headers, sequences in batched_sequences:
start_time = timer()
try:
# Inference
output = model.infer(sequences, num_recycles=num_recycles)
except RuntimeError as e:
if "CUDA out of memory" in str(e):
if len(sequences) > 1:
logger.warning(
"Failed to predict batch of size %d due to CUDA OOM. "
"Consider lowering `--max-tokens-per-batch`.", len(sequences)
)
else:
logger.warning(
"Failed to predict sequence %s of length %d due to CUDA OOM.",
headers[0], len(sequences[0])
)
continue
else:
raise e
# Move outputs to CPU
output = {key: value.cpu() for key, value in output.items()}
# Convert outputs to PDB format
pdbs = model.output_to_pdb(output)
paes = (
(output["aligned_confidence_probs"].numpy() * np.arange(64).reshape(1, 1, 1, 64))
.mean(-1)
* 31
).mean(-1).mean(-1)
# Calculate elapsed time
elapsed_time = timer() - start_time
time_per_seq = elapsed_time / len(headers)
time_info = f"{time_per_seq:0.1f}s per sequence"
if len(sequences) > 1:
time_info += f" (amortized, batch size {len(sequences)})"
# Save each predicted PDB and collect data
for header, seq, pdb_string, mean_plddt, ptm, pae in zip(
headers, sequences, pdbs, output["mean_plddt"], output["ptm"], paes
):
pdb_filename = f"{header}_plddt_{mean_plddt:.1f}_ptm_{ptm:.3f}_PAE_{pae:.3f}.pdb"
output_file = pdbdir / pdb_filename
output_file.write_text(pdb_string)
# Calculate entropy
entropy = calculate_entropy(seq)
# Collect data record
data_records.append({
'FASTA_file': fasta_file,
'PDB_path': str(output_file.resolve()),
'sequence': seq,
'Length': len(seq),
'pLDDT': mean_plddt.item() if isinstance(mean_plddt, torch.Tensor) else mean_plddt,
'pTM': ptm.item() if isinstance(ptm, torch.Tensor) else ptm,
'pAE': pae.item() if isinstance(pae, torch.Tensor) else pae,
'Entropy': entropy
})
num_completed += 1
logger.info(
"Predicted structure for %s (length %d): pLDDT=%.1f, pTM=%.3f, PAE=%.3f | Time: %s | Progress: %d/%d",
header, len(seq), mean_plddt, ptm, pae, time_info, num_completed, num_sequences
)
logger.info("Completed processing FASTA file: %s", fasta_file)
# Convert the collected data into a DataFrame
df = pd.DataFrame(data_records, columns=[
'FASTA_file',
'PDB_path',
'sequence',
'Length',
'pLDDT',
'pTM',
'pAE',
'Entropy'
])
logger.info("Data collection complete. Returning DataFrame.")
return df
def main():
parser = create_parser()
args = parser.parse_args()
run(
fasta=args.fasta,
pdb=args.pdb,
model_dir=args.model_dir,
num_recycles=args.num_recycles,
max_tokens_per_batch=args.max_tokens_per_batch,
chunk_size=args.chunk_size,
cpu_only=args.cpu_only,
cpu_offload=args.cpu_offload,
)
if __name__ == "__main__":
main()
@pengzhangzhi
Copy link
Author

Env:

pip install "fair-esm[esmfold]"
# OpenFold and its remaining dependency
pip install 'dllogger @ git+https://github.com/NVIDIA/dllogger.git'
pip install 'openfold @ git+https://github.com/aqlaboratory/openfold.git@4b41059694619831a7db195b7e0988fc4ff3a307'

Running:
check out the args for detailed APIs. Below is a basic use.

python esmfold.py -i fasta_path -o output_pdb_dir

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment