Created
February 5, 2025 16:36
-
-
Save pengzhangzhi/dabedb7fbb2ba28506552b6fc5b50cb7 to your computer and use it in GitHub Desktop.
using esmfold to fold a dir of fasta or just a fasta file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright (c) 2023 Meta Platforms, Inc. and affiliates | |
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
# This file has been modified by Xinyou Wang on Jul 21, 2024 | |
# | |
# Original file was released under MIT, with the full license text | |
# available at https://github.com/facebookresearch/esm/blob/main/LICENSE | |
# | |
# This modified file is released under the same license. | |
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
from pathlib import Path | |
import re | |
import sys, os | |
import argparse | |
import logging | |
import typing as T | |
from pathlib import Path | |
from timeit import default_timer as timer | |
import numpy as np | |
import torch | |
import math | |
from typing import List, Optional, Tuple | |
from collections import Counter | |
import esm | |
import pandas as pd | |
logger = logging.getLogger() | |
logger.setLevel(logging.INFO) | |
formatter = logging.Formatter( | |
"%(asctime)s | %(levelname)s | %(name)s | %(message)s", | |
datefmt="%y/%m/%d %H:%M:%S", | |
) | |
console_handler = logging.StreamHandler(sys.stdout) | |
console_handler.setLevel(logging.INFO) | |
console_handler.setFormatter(formatter) | |
logger.addHandler(console_handler) | |
PathLike = T.Union[str, Path] | |
def calculate_entropy(sequence: str) -> float: | |
amino_acid_counts = Counter(sequence) | |
total_amino_acids = len(sequence) | |
probabilities = (count / total_amino_acids for count in amino_acid_counts.values()) | |
return -sum(p * math.log2(p) for p in probabilities if p > 0) | |
def read_fasta( | |
path, | |
keep_gaps=True, | |
keep_insertions=True, | |
to_upper=False, | |
): | |
with open(path, "r") as f: | |
for result in read_alignment_lines( | |
f, keep_gaps=keep_gaps, keep_insertions=keep_insertions, to_upper=to_upper | |
): | |
yield result | |
def read_alignment_lines( | |
lines, | |
keep_gaps=True, | |
keep_insertions=True, | |
to_upper=False, | |
): | |
seq = desc = None | |
def parse(s): | |
if not keep_gaps: | |
s = re.sub("-", "", s) | |
if not keep_insertions: | |
s = re.sub("[a-z]", "", s) | |
return s.upper() if to_upper else s | |
for line in lines: | |
# Line may be empty if seq % file_line_width == 0 | |
if len(line) > 0 and line[0] == ">": | |
if seq is not None and 'X' not in seq: | |
yield desc, parse(seq) | |
desc = line.strip().lstrip(">") | |
seq = "" | |
else: | |
assert isinstance(seq, str) | |
seq += line.strip() | |
assert isinstance(seq, str) and isinstance(desc, str) | |
if 'X' not in seq: | |
yield desc, parse(seq) | |
def enable_cpu_offloading(model): | |
from torch.distributed.fsdp import CPUOffload, FullyShardedDataParallel | |
from torch.distributed.fsdp.wrap import enable_wrap, wrap | |
torch.distributed.init_process_group( | |
backend="nccl", init_method="tcp://localhost:9999", world_size=1, rank=0 | |
) | |
wrapper_kwargs = dict(cpu_offload=CPUOffload(offload_params=True)) | |
with enable_wrap(wrapper_cls=FullyShardedDataParallel, **wrapper_kwargs): | |
for layer_name, layer in model.layers.named_children(): | |
wrapped_layer = wrap(layer) | |
setattr(model.layers, layer_name, wrapped_layer) | |
model = wrap(model) | |
return model | |
def init_model_on_gpu_with_cpu_offloading(model): | |
model = model.eval() | |
model_esm = enable_cpu_offloading(model.esm) | |
del model.esm | |
model.cuda() | |
model.esm = model_esm | |
return model | |
def create_batched_sequence_dataset( | |
sequences: T.List[T.Tuple[str, str]], max_tokens_per_batch: int = 1024 | |
) -> T.Generator[T.Tuple[T.List[str], T.List[str]], None, None]: | |
batch_headers, batch_sequences, num_tokens = [], [], 0 | |
for header, seq in sequences: | |
if (len(seq) + num_tokens > max_tokens_per_batch) and num_tokens > 0: | |
yield batch_headers, batch_sequences | |
batch_headers, batch_sequences, num_tokens = [], [], 0 | |
batch_headers.append(header) | |
batch_sequences.append(seq) | |
num_tokens += len(seq) | |
yield batch_headers, batch_sequences | |
def create_parser(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"-i", | |
"--fasta", | |
help="Path directory to input FASTA file", | |
type=Path, | |
required=True, | |
default='ROOTDIR/' + \ | |
'generation-results/' + \ | |
'touchhigh_various_length_all', | |
) | |
parser.add_argument( | |
"-o", "--pdb", help="Path directory to output PDB directory", type=Path, required=True, | |
default='ROOTDIR/' + \ | |
'generation-results/' + \ | |
'touchhigh_various_length_all/esmfold_pdb', | |
) | |
parser.add_argument( | |
"-m", "--model-dir", help="Parent path to Pretrained ESM data directory. ", type=Path, default=None | |
) | |
parser.add_argument( | |
"--num-recycles", | |
type=int, | |
default=None, | |
help="Number of recycles to run. Defaults to number used in training (4).", | |
) | |
parser.add_argument( | |
"--max-tokens-per-batch", | |
type=int, | |
default=1024, | |
help="Maximum number of tokens per gpu forward-pass. This will group shorter sequences together " | |
"for batched prediction. Lowering this can help with out of memory issues, if these occur on " | |
"short sequences.", | |
) | |
parser.add_argument( | |
"--chunk-size", | |
type=int, | |
default=None, | |
help="Chunks axial attention computation to reduce memory usage from O(L^2) to O(L). " | |
"Equivalent to running a for loop over chunks of of each dimension. Lower values will " | |
"result in lower memory usage at the cost of speed. Recommended values: 128, 64, 32. " | |
"Default: None.", | |
) | |
parser.add_argument("--cpu-only", help="CPU only", action="store_true") | |
parser.add_argument("--cpu-offload", help="Enable CPU offloading", action="store_true") | |
return parser | |
def run( | |
fasta: Path, | |
pdb: Path, | |
model_dir: Optional[Path] = None, | |
num_recycles: Optional[int] = None, | |
max_tokens_per_batch: int = 1024, | |
chunk_size: Optional[int] = None, | |
cpu_only: bool = False, | |
cpu_offload: bool = False, | |
) -> pd.DataFrame: | |
""" | |
Executes the evaluation pipeline and returns a DataFrame with the results. | |
Args: | |
fasta (Path): Path to input FASTA file directory. | |
pdb (Path): Path to output PDB directory. | |
model_dir (Optional[Path]): Path to pretrained ESM data directory. | |
num_recycles (Optional[int]): Number of recycles to run. | |
max_tokens_per_batch (int): Maximum number of tokens per batch. | |
chunk_size (Optional[int]): Chunk size for axial attention. | |
cpu_only (bool): Flag to use CPU only. | |
cpu_offload (bool): Flag to enable CPU offloading. | |
Returns: | |
pd.DataFrame: DataFrame containing the evaluation results. | |
""" | |
# Ensure output directory exists | |
pdb.mkdir(parents=True, exist_ok=True) | |
logger.info("Output PDB directory is set to: %s", pdb) | |
logger.info("Loading model...") | |
# Set ESM cache directory if model_dir is provided | |
if model_dir is not None: | |
torch.hub.set_dir(str(model_dir)) | |
# Load the ESMFold model | |
model = esm.pretrained.esmfold_v1() | |
model = model.eval() | |
model.set_chunk_size(chunk_size) | |
if cpu_only: | |
model.esm.float() # Convert to fp32 as ESM-2 in fp16 is not supported on CPU | |
model.cpu() | |
logger.info("Model moved to CPU only.") | |
elif cpu_offload: | |
model = init_model_on_gpu_with_cpu_offloading(model) | |
logger.info("Model initialized with CPU offloading.") | |
else: | |
model.cuda() | |
logger.info("Model moved to GPU.") | |
# List all FASTA files in the input directory | |
fasta_list = [f for f in os.listdir(fasta) if f.endswith('.fasta') and not os.path.isdir(os.path.join(fasta, f))] | |
logger.info("Found %d FASTA files in the input directory.", len(fasta_list)) | |
# Initialize a list to collect data for the DataFrame | |
data_records = [] | |
for fasta_file in fasta_list: | |
fasta_path = fasta / fasta_file | |
logger.info("Processing FASTA file: %s", fasta_path) | |
# Define output directory for this FASTA file | |
pdbdir = pdb / fasta_file.replace('.fasta', '') | |
pdbdir.mkdir(parents=True, exist_ok=True) | |
logger.info("Output PDB directory for %s: %s", fasta_file, pdbdir) | |
# Read and sort sequences by length | |
logger.info("Reading sequences from %s", fasta_path) | |
all_sequences = sorted(read_fasta(str(fasta_path)), key=lambda header_seq: len(header_seq[1])) | |
logger.info("Loaded %d sequences from %s", len(all_sequences), fasta_file) | |
if not all_sequences: | |
logger.warning("No sequences found in %s. Skipping.", fasta_file) | |
continue | |
# Create batched sequences | |
logger.info("Creating batched sequences...") | |
batched_sequences = create_batched_sequence_dataset(all_sequences, max_tokens_per_batch) | |
logger.info("Batched sequences created.") | |
num_completed = 0 | |
num_sequences = len(all_sequences) | |
# Iterate over each batch | |
for headers, sequences in batched_sequences: | |
start_time = timer() | |
try: | |
# Inference | |
output = model.infer(sequences, num_recycles=num_recycles) | |
except RuntimeError as e: | |
if "CUDA out of memory" in str(e): | |
if len(sequences) > 1: | |
logger.warning( | |
"Failed to predict batch of size %d due to CUDA OOM. " | |
"Consider lowering `--max-tokens-per-batch`.", len(sequences) | |
) | |
else: | |
logger.warning( | |
"Failed to predict sequence %s of length %d due to CUDA OOM.", | |
headers[0], len(sequences[0]) | |
) | |
continue | |
else: | |
raise e | |
# Move outputs to CPU | |
output = {key: value.cpu() for key, value in output.items()} | |
# Convert outputs to PDB format | |
pdbs = model.output_to_pdb(output) | |
paes = ( | |
(output["aligned_confidence_probs"].numpy() * np.arange(64).reshape(1, 1, 1, 64)) | |
.mean(-1) | |
* 31 | |
).mean(-1).mean(-1) | |
# Calculate elapsed time | |
elapsed_time = timer() - start_time | |
time_per_seq = elapsed_time / len(headers) | |
time_info = f"{time_per_seq:0.1f}s per sequence" | |
if len(sequences) > 1: | |
time_info += f" (amortized, batch size {len(sequences)})" | |
# Save each predicted PDB and collect data | |
for header, seq, pdb_string, mean_plddt, ptm, pae in zip( | |
headers, sequences, pdbs, output["mean_plddt"], output["ptm"], paes | |
): | |
pdb_filename = f"{header}_plddt_{mean_plddt:.1f}_ptm_{ptm:.3f}_PAE_{pae:.3f}.pdb" | |
output_file = pdbdir / pdb_filename | |
output_file.write_text(pdb_string) | |
# Calculate entropy | |
entropy = calculate_entropy(seq) | |
# Collect data record | |
data_records.append({ | |
'FASTA_file': fasta_file, | |
'PDB_path': str(output_file.resolve()), | |
'sequence': seq, | |
'Length': len(seq), | |
'pLDDT': mean_plddt.item() if isinstance(mean_plddt, torch.Tensor) else mean_plddt, | |
'pTM': ptm.item() if isinstance(ptm, torch.Tensor) else ptm, | |
'pAE': pae.item() if isinstance(pae, torch.Tensor) else pae, | |
'Entropy': entropy | |
}) | |
num_completed += 1 | |
logger.info( | |
"Predicted structure for %s (length %d): pLDDT=%.1f, pTM=%.3f, PAE=%.3f | Time: %s | Progress: %d/%d", | |
header, len(seq), mean_plddt, ptm, pae, time_info, num_completed, num_sequences | |
) | |
logger.info("Completed processing FASTA file: %s", fasta_file) | |
# Convert the collected data into a DataFrame | |
df = pd.DataFrame(data_records, columns=[ | |
'FASTA_file', | |
'PDB_path', | |
'sequence', | |
'Length', | |
'pLDDT', | |
'pTM', | |
'pAE', | |
'Entropy' | |
]) | |
logger.info("Data collection complete. Returning DataFrame.") | |
return df | |
def main(): | |
parser = create_parser() | |
args = parser.parse_args() | |
run( | |
fasta=args.fasta, | |
pdb=args.pdb, | |
model_dir=args.model_dir, | |
num_recycles=args.num_recycles, | |
max_tokens_per_batch=args.max_tokens_per_batch, | |
chunk_size=args.chunk_size, | |
cpu_only=args.cpu_only, | |
cpu_offload=args.cpu_offload, | |
) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Env:
Running:
check out the args for detailed APIs. Below is a basic use.