Created
August 20, 2024 04:30
-
-
Save partrita/6fa0c88e4e2edd5ff3c3d845500bca33 to your computer and use it in GitHub Desktop.
Use transformers to user Esmfold to predict protein structure.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from typing import Dict, List | |
import torch | |
import pandas as pd | |
from transformers import AutoTokenizer, EsmForProteinFolding | |
from transformers.models.esm.openfold_utils.protein import to_pdb, Protein as OFProtein | |
from transformers.models.esm.openfold_utils.feats import atom14_to_atom37 | |
import biotite.structure.io as bsio | |
def read_fasta(file_path: str) -> Dict[str, str]: | |
""" | |
FASTA 파일을 읽어 시퀀스 ID와 시퀀스를 딕셔너리로 반환합니다. | |
""" | |
sequences = {} | |
current_sequence_id = None | |
current_sequence = [] | |
with open(file_path, 'r') as file: | |
for line in file: | |
line = line.strip() | |
if line.startswith('>'): | |
if current_sequence_id: | |
sequences[current_sequence_id] = ''.join(current_sequence) | |
current_sequence_id = line[1:] | |
current_sequence = [] | |
else: | |
current_sequence.append(line) | |
if current_sequence_id: | |
sequences[current_sequence_id] = ''.join(current_sequence) | |
return sequences | |
def process_sequences(protein_sequences: Dict[str, str], model, tokenizer, output_path: str) -> None: | |
""" | |
단백질 시퀀스를 처리하고 PDB 파일로 저장합니다. | |
""" | |
for sequence_id, sequence in protein_sequences.items(): | |
tokenized_input = tokenizer([sequence], return_tensors="pt", add_special_tokens=False)['input_ids'].cuda() | |
with torch.no_grad(): | |
output = model(tokenized_input) | |
pdb = convert_outputs_to_pdb(output) | |
file_name = os.path.join(output_path, f"{sequence_id}.pdb") | |
with open(file_name, "w") as f: | |
f.write("".join(pdb)) | |
print(f"Processed and saved: {file_name}") | |
def convert_outputs_to_pdb(outputs) -> List[str]: | |
""" | |
모델 출력을 PDB 형식으로 변환합니다. | |
""" | |
final_atom_positions = atom14_to_atom37(outputs["positions"][-1], outputs) | |
outputs = {k: v.to("cpu").numpy() for k, v in outputs.items()} | |
final_atom_positions = final_atom_positions.cpu().numpy() | |
final_atom_mask = outputs["atom37_atom_exists"] | |
pdbs = [] | |
for i in range(outputs["aatype"].shape[0]): | |
pred = OFProtein( | |
aatype=outputs["aatype"][i], | |
atom_positions=final_atom_positions[i], | |
atom_mask=final_atom_mask[i], | |
residue_index=outputs["residue_index"][i] + 1, | |
b_factors=outputs["plddt"][i], | |
chain_index=outputs["chain_index"][i] if "chain_index" in outputs else None, | |
) | |
pdbs.append(to_pdb(pred)) | |
return pdbs | |
def process_pdb_files(directory: str) -> pd.DataFrame: | |
""" | |
디렉토리 내의 PDB 파일들을 처리하고 b-factor 결과를 데이터프레임으로 반환합니다. | |
""" | |
results = [] | |
for filename in os.listdir(directory): | |
if filename.endswith(".pdb"): | |
file_path = os.path.join(directory, filename) | |
struct = bsio.load_structure(file_path, extra_fields=["b_factor"]) | |
mean_b_factor = struct.b_factor.mean() | |
results.append({ | |
"filename": filename, | |
"mean_b_factor": mean_b_factor | |
}) | |
return pd.DataFrame(results) | |
def main(input_path: str, output_path: str, model, tokenizer): | |
""" | |
메인 실행 함수 | |
""" | |
protein_sequences = read_fasta(input_path) | |
for sequence_id, sequence in protein_sequences.items(): | |
print(f"Sequence ID: {sequence_id}") | |
print(f"Sequence: {sequence[:50]}...") # 처음 50개 문자만 출력 | |
process_sequences(protein_sequences, model, tokenizer, output_path) | |
results_df = process_pdb_files(output_path) | |
print(results_df) | |
results_df.to_csv(os.path.join(output_path, "b_factor_results.csv"), index=False) | |
if __name__ == "__main__": | |
# 여기에 필요한 모델과 토크나이저 초기화 코드를 추가하세요 | |
model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=True) | |
tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1") | |
model = model.cuda() | |
# Uncomment to switch the stem to float16 | |
model.esm = model.esm.half() | |
# can enable TensorFloat32 computation for a general speedup if your hardware supports it. | |
torch.backends.cuda.matmul.allow_tf32 = True | |
input_path = "../input/sample.fasta" | |
output_path = "../output" | |
main(input_path, output_path, model, tokenizer) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment