Created
July 10, 2024 09:43
-
-
Save adam-kral/d8c82f02f77ae0ec1c0f9255b29c3ab6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from pathlib import Path\n", | |
"\n", | |
"model_type_dict = {\n", | |
" \"classification\" : \"saprot/saprot_classification_model\",\n", | |
" \"token_classification\" : \"saprot/saprot_token_classification_model\",\n", | |
" \"regression\" : \"saprot/saprot_regression_model\",\n", | |
" \"pair_classification\" : \"saprot/saprot_pair_classification_model\",\n", | |
" \"pair_regression\" : \"saprot/saprot_pair_regression_model\",\n", | |
"}\n", | |
"multi_lora = False\n", | |
"task_type = 'regression'\n", | |
"\n", | |
"ADAPTER_HOME = Path('adapter_home')\n", | |
"LMDB_HOME = Path('lmdb_home')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"adapter_input(value='SaProtHub/Model-Thermostability-650M')" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from collections import namedtuple\n", | |
"# create a named tuple for the input to the adapter\n", | |
"adapter_input = namedtuple('adapter_input', ['value'])(value='SaProtHub/Model-Thermostability-650M')\n", | |
"adapter_input" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# add ../saprot to the path\n", | |
"import sys\n", | |
"sys.path.append('..')\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Fetching 4 files: 100%|██████████| 4/4 [00:00<00:00, 582.50it/s]\n" | |
] | |
} | |
], | |
"source": [ | |
"from pathlib import Path\n", | |
"import json\n", | |
"from easydict import EasyDict\n", | |
"\n", | |
"\n", | |
"#@title 3.1.2: Get your Result\n", | |
"from transformers import EsmTokenizer\n", | |
"import torch\n", | |
"import copy\n", | |
"import sys\n", | |
"from saprot.scripts.training import my_load_model\n", | |
"from huggingface_hub import snapshot_download\n", | |
"\n", | |
"\n", | |
"################################################################################\n", | |
"################################# 0. MARKDOWN ##################################\n", | |
"################################################################################\n", | |
"\n", | |
"\n", | |
"# @markdown Click the run button to make prediction.\n", | |
"\n", | |
"# @markdown <font color=\"red\">**Note that:**</font> When predicting a category, the index of categories starts from zero.\n", | |
"\n", | |
"################################################################################\n", | |
"################################# 1. DATASET ##################################\n", | |
"################################################################################\n", | |
"\n", | |
"# if mode == \"Multiple Sequences\":\n", | |
"# csv_dataset_path = get_SA_sequence_by_data_type(data_type, raw_data)\n", | |
"# else:\n", | |
"# single_sa_seq = get_SA_sequence_by_data_type(data_type, raw_data)\n", | |
"\n", | |
"\n", | |
"################################################################################\n", | |
"################################# 2. MODEL ##################################\n", | |
"################################################################################\n", | |
"def get_base_model(adapter_path):\n", | |
" adapter_config = Path(adapter_path) / \"adapter_config.json\"\n", | |
" with open(adapter_config, 'r') as f:\n", | |
" adapter_config_dict = json.load(f)\n", | |
" base_model = adapter_config_dict['base_model_name_or_path']\n", | |
" if 'SaProt_650M_AF2' in base_model:\n", | |
" base_model = \"westlake-repl/SaProt_650M_AF2\"\n", | |
" elif 'SaProt_35M_AF2' in base_model:\n", | |
" base_model = \"westlake-repl/SaProt_35M_AF2\"\n", | |
" else:\n", | |
" raise RuntimeError(\"Please ensure the base model is \\\"SaProt_650M_AF2\\\" or \\\"SaProt_35M_AF2\\\"\")\n", | |
" return base_model\n", | |
"\n", | |
"\n", | |
"# base_model = \"westlake-repl/SaProt_35M_AF2\"\n", | |
"\n", | |
"\n", | |
"snapshot_download(repo_id=adapter_input.value, repo_type=\"model\", local_dir=ADAPTER_HOME / task_type / adapter_input.value)\n", | |
"\n", | |
"adapter_path = ADAPTER_HOME / task_type / adapter_input.value\n", | |
"base_model = get_base_model(adapter_path)\n", | |
"\n", | |
"lora_kwargs = {\n", | |
" \"num_lora\": 1,\n", | |
" \"config_list\": [{\"lora_config_path\": adapter_path}]\n", | |
"}\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/krala/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", | |
" warnings.warn(\n", | |
"Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at westlake-repl/SaProt_650M_AF2 and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight', 'esm.contact_head.regression.bias', 'esm.contact_head.regression.weight', 'esm.embeddings.position_embeddings.weight']\n", | |
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", | |
"/Users/krala/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: Metric `SpearmanCorrcoef` will save all targets and predictions in the buffer. For large datasets, this may lead to large memory footprint.\n", | |
" warnings.warn(*args, **kwargs) # noqa: B028\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"No optimizer_kwargs provided. The default optimizer is AdamW.\n", | |
"Now active LoRA model: default\n", | |
"trainable params: 7,723,521 || all params: 659,293,442 || trainable%: 1.1714845784860697\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"SaprotRegressionModel(\n", | |
" (model): PeftModelForSequenceClassification(\n", | |
" (base_model): LoraModel(\n", | |
" (model): EsmForSequenceClassification(\n", | |
" (esm): EsmModel(\n", | |
" (embeddings): EsmEmbeddings(\n", | |
" (word_embeddings): Embedding(446, 1280, padding_idx=1)\n", | |
" (dropout): Dropout(p=0.1, inplace=False)\n", | |
" (position_embeddings): None\n", | |
" )\n", | |
" (encoder): EsmEncoder(\n", | |
" (layer): ModuleList(\n", | |
" (0-32): 33 x EsmLayer(\n", | |
" (attention): EsmAttention(\n", | |
" (self): EsmSelfAttention(\n", | |
" (query): lora.Linear(\n", | |
" (base_layer): Linear(in_features=1280, out_features=1280, bias=True)\n", | |
" (lora_dropout): ModuleDict(\n", | |
" (default): Dropout(p=0.1, inplace=False)\n", | |
" )\n", | |
" (lora_A): ModuleDict(\n", | |
" (default): Linear(in_features=1280, out_features=8, bias=False)\n", | |
" )\n", | |
" (lora_B): ModuleDict(\n", | |
" (default): Linear(in_features=8, out_features=1280, bias=False)\n", | |
" )\n", | |
" (lora_embedding_A): ParameterDict()\n", | |
" (lora_embedding_B): ParameterDict()\n", | |
" )\n", | |
" (key): lora.Linear(\n", | |
" (base_layer): Linear(in_features=1280, out_features=1280, bias=True)\n", | |
" (lora_dropout): ModuleDict(\n", | |
" (default): Dropout(p=0.1, inplace=False)\n", | |
" )\n", | |
" (lora_A): ModuleDict(\n", | |
" (default): Linear(in_features=1280, out_features=8, bias=False)\n", | |
" )\n", | |
" (lora_B): ModuleDict(\n", | |
" (default): Linear(in_features=8, out_features=1280, bias=False)\n", | |
" )\n", | |
" (lora_embedding_A): ParameterDict()\n", | |
" (lora_embedding_B): ParameterDict()\n", | |
" )\n", | |
" (value): lora.Linear(\n", | |
" (base_layer): Linear(in_features=1280, out_features=1280, bias=True)\n", | |
" (lora_dropout): ModuleDict(\n", | |
" (default): Dropout(p=0.1, inplace=False)\n", | |
" )\n", | |
" (lora_A): ModuleDict(\n", | |
" (default): Linear(in_features=1280, out_features=8, bias=False)\n", | |
" )\n", | |
" (lora_B): ModuleDict(\n", | |
" (default): Linear(in_features=8, out_features=1280, bias=False)\n", | |
" )\n", | |
" (lora_embedding_A): ParameterDict()\n", | |
" (lora_embedding_B): ParameterDict()\n", | |
" )\n", | |
" (dropout): Dropout(p=0.1, inplace=False)\n", | |
" (rotary_embeddings): RotaryEmbedding()\n", | |
" )\n", | |
" (output): EsmSelfOutput(\n", | |
" (dense): lora.Linear(\n", | |
" (base_layer): Linear(in_features=1280, out_features=1280, bias=True)\n", | |
" (lora_dropout): ModuleDict(\n", | |
" (default): Dropout(p=0.1, inplace=False)\n", | |
" )\n", | |
" (lora_A): ModuleDict(\n", | |
" (default): Linear(in_features=1280, out_features=8, bias=False)\n", | |
" )\n", | |
" (lora_B): ModuleDict(\n", | |
" (default): Linear(in_features=8, out_features=1280, bias=False)\n", | |
" )\n", | |
" (lora_embedding_A): ParameterDict()\n", | |
" (lora_embedding_B): ParameterDict()\n", | |
" )\n", | |
" (dropout): Dropout(p=0.1, inplace=False)\n", | |
" )\n", | |
" (LayerNorm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n", | |
" )\n", | |
" (intermediate): EsmIntermediate(\n", | |
" (dense): lora.Linear(\n", | |
" (base_layer): Linear(in_features=1280, out_features=5120, bias=True)\n", | |
" (lora_dropout): ModuleDict(\n", | |
" (default): Dropout(p=0.1, inplace=False)\n", | |
" )\n", | |
" (lora_A): ModuleDict(\n", | |
" (default): Linear(in_features=1280, out_features=8, bias=False)\n", | |
" )\n", | |
" (lora_B): ModuleDict(\n", | |
" (default): Linear(in_features=8, out_features=5120, bias=False)\n", | |
" )\n", | |
" (lora_embedding_A): ParameterDict()\n", | |
" (lora_embedding_B): ParameterDict()\n", | |
" )\n", | |
" )\n", | |
" (output): EsmOutput(\n", | |
" (dense): lora.Linear(\n", | |
" (base_layer): Linear(in_features=5120, out_features=1280, bias=True)\n", | |
" (lora_dropout): ModuleDict(\n", | |
" (default): Dropout(p=0.1, inplace=False)\n", | |
" )\n", | |
" (lora_A): ModuleDict(\n", | |
" (default): Linear(in_features=5120, out_features=8, bias=False)\n", | |
" )\n", | |
" (lora_B): ModuleDict(\n", | |
" (default): Linear(in_features=8, out_features=1280, bias=False)\n", | |
" )\n", | |
" (lora_embedding_A): ParameterDict()\n", | |
" (lora_embedding_B): ParameterDict()\n", | |
" )\n", | |
" (dropout): Dropout(p=0.1, inplace=False)\n", | |
" )\n", | |
" (LayerNorm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n", | |
" )\n", | |
" )\n", | |
" (emb_layer_norm_after): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n", | |
" )\n", | |
" (contact_head): None\n", | |
" )\n", | |
" (classifier): ModulesToSaveWrapper(\n", | |
" (original_module): EsmClassificationHead(\n", | |
" (dense): Linear(in_features=1280, out_features=1280, bias=True)\n", | |
" (dropout): Dropout(p=0.1, inplace=False)\n", | |
" (out_proj): Linear(in_features=1280, out_features=1, bias=True)\n", | |
" )\n", | |
" (modules_to_save): ModuleDict(\n", | |
" (default): EsmClassificationHead(\n", | |
" (dense): Linear(in_features=1280, out_features=1280, bias=True)\n", | |
" (dropout): Dropout(p=0.1, inplace=False)\n", | |
" (out_proj): Linear(in_features=1280, out_features=1, bias=True)\n", | |
" )\n", | |
" )\n", | |
" )\n", | |
" )\n", | |
" )\n", | |
" )\n", | |
" (train_loss): MeanSquaredError()\n", | |
" (train_spearman): SpearmanCorrCoef()\n", | |
" (train_R2): R2Score()\n", | |
" (train_pearson): PearsonCorrCoef()\n", | |
" (valid_loss): MeanSquaredError()\n", | |
" (valid_spearman): SpearmanCorrCoef()\n", | |
" (valid_R2): R2Score()\n", | |
" (valid_pearson): PearsonCorrCoef()\n", | |
" (test_loss): MeanSquaredError()\n", | |
" (test_spearman): SpearmanCorrCoef()\n", | |
" (test_R2): R2Score()\n", | |
" (test_pearson): PearsonCorrCoef()\n", | |
")" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# if use_existing_model:\n", | |
"# if adapter_combobox.value =='':\n", | |
"# print(\"Please select a model!\")\n", | |
"# sys.exit()\n", | |
"\n", | |
"# if \". \" in adapter_combobox.value:\n", | |
"# adapter_path = ADAPTER_HOME / task_type / adapter_combobox.value\n", | |
"# else:\n", | |
"# adapter_path = adapter_combobox.value\n", | |
"\n", | |
"################################################################################\n", | |
"##################################### config ###################################\n", | |
"################################################################################\n", | |
"from saprot.config.config_dict import Default_config\n", | |
"config = copy.deepcopy(Default_config)\n", | |
"\n", | |
"# base model\n", | |
"config.model.model_py_path = model_type_dict[task_type]\n", | |
"# config.model.save_path = model_save_path\n", | |
"config.model.kwargs.config_path = base_model\n", | |
"\n", | |
"# lora\n", | |
"# config.model.kwargs.lora_config_path = adapter_path\n", | |
"# config.model.kwargs.use_lora = True\n", | |
"# config.model.kwargs.lora_inference = True\n", | |
"config.model.lora_kwargs = lora_kwargs\n", | |
"\n", | |
"################################################################################\n", | |
"################################### inference ##################################\n", | |
"################################################################################\n", | |
"from peft import PeftModelForSequenceClassification\n", | |
"\n", | |
"config.model.lora_kwargs.is_trainable = True\n", | |
"model = my_load_model(config.model)\n", | |
"tokenizer = EsmTokenizer.from_pretrained(config.model.kwargs.config_path)\n", | |
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", | |
"model.to(device)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'input_ids': [0, 3, 2], 'attention_mask': [1, 1, 1]}" | |
] | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"t = EsmTokenizer.from_pretrained(config.model.kwargs.config_path)\n", | |
"t(\"A\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"SaprotRegressionModel(\n", | |
" (model): PeftModelForSequenceClassification(\n", | |
" (base_model): LoraModel(\n", | |
" (model): EsmForSequenceClassification(\n", | |
" (esm): EsmModel(\n", | |
" (embeddings): EsmEmbeddings(\n", | |
" (word_embeddings): Embedding(446, 1280, padding_idx=1)\n", | |
" (dropout): Dropout(p=0.1, inplace=False)\n", | |
" (position_embeddings): None\n", | |
" )\n", | |
" (encoder): EsmEncoder(\n", | |
" (layer): ModuleList(\n", | |
" (0-32): 33 x EsmLayer(\n", | |
" (attention): EsmAttention(\n", | |
" (self): EsmSelfAttention(\n", | |
" (query): lora.Linear(\n", | |
" (base_layer): Linear(in_features=1280, out_features=1280, bias=True)\n", | |
" (lora_dropout): ModuleDict(\n", | |
" (default): Dropout(p=0.1, inplace=False)\n", | |
" )\n", | |
" (lora_A): ModuleDict(\n", | |
" (default): Linear(in_features=1280, out_features=8, bias=False)\n", | |
" )\n", | |
" (lora_B): ModuleDict(\n", | |
" (default): Linear(in_features=8, out_features=1280, bias=False)\n", | |
" )\n", | |
" (lora_embedding_A): ParameterDict()\n", | |
" (lora_embedding_B): ParameterDict()\n", | |
" )\n", | |
" (key): lora.Linear(\n", | |
" (base_layer): Linear(in_features=1280, out_features=1280, bias=True)\n", | |
" (lora_dropout): ModuleDict(\n", | |
" (default): Dropout(p=0.1, inplace=False)\n", | |
" )\n", | |
" (lora_A): ModuleDict(\n", | |
" (default): Linear(in_features=1280, out_features=8, bias=False)\n", | |
" )\n", | |
" (lora_B): ModuleDict(\n", | |
" (default): Linear(in_features=8, out_features=1280, bias=False)\n", | |
" )\n", | |
" (lora_embedding_A): ParameterDict()\n", | |
" (lora_embedding_B): ParameterDict()\n", | |
" )\n", | |
" (value): lora.Linear(\n", | |
" (base_layer): Linear(in_features=1280, out_features=1280, bias=True)\n", | |
" (lora_dropout): ModuleDict(\n", | |
" (default): Dropout(p=0.1, inplace=False)\n", | |
" )\n", | |
" (lora_A): ModuleDict(\n", | |
" (default): Linear(in_features=1280, out_features=8, bias=False)\n", | |
" )\n", | |
" (lora_B): ModuleDict(\n", | |
" (default): Linear(in_features=8, out_features=1280, bias=False)\n", | |
" )\n", | |
" (lora_embedding_A): ParameterDict()\n", | |
" (lora_embedding_B): ParameterDict()\n", | |
" )\n", | |
" (dropout): Dropout(p=0.1, inplace=False)\n", | |
" (rotary_embeddings): RotaryEmbedding()\n", | |
" )\n", | |
" (output): EsmSelfOutput(\n", | |
" (dense): lora.Linear(\n", | |
" (base_layer): Linear(in_features=1280, out_features=1280, bias=True)\n", | |
" (lora_dropout): ModuleDict(\n", | |
" (default): Dropout(p=0.1, inplace=False)\n", | |
" )\n", | |
" (lora_A): ModuleDict(\n", | |
" (default): Linear(in_features=1280, out_features=8, bias=False)\n", | |
" )\n", | |
" (lora_B): ModuleDict(\n", | |
" (default): Linear(in_features=8, out_features=1280, bias=False)\n", | |
" )\n", | |
" (lora_embedding_A): ParameterDict()\n", | |
" (lora_embedding_B): ParameterDict()\n", | |
" )\n", | |
" (dropout): Dropout(p=0.1, inplace=False)\n", | |
" )\n", | |
" (LayerNorm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n", | |
" )\n", | |
" (intermediate): EsmIntermediate(\n", | |
" (dense): lora.Linear(\n", | |
" (base_layer): Linear(in_features=1280, out_features=5120, bias=True)\n", | |
" (lora_dropout): ModuleDict(\n", | |
" (default): Dropout(p=0.1, inplace=False)\n", | |
" )\n", | |
" (lora_A): ModuleDict(\n", | |
" (default): Linear(in_features=1280, out_features=8, bias=False)\n", | |
" )\n", | |
" (lora_B): ModuleDict(\n", | |
" (default): Linear(in_features=8, out_features=5120, bias=False)\n", | |
" )\n", | |
" (lora_embedding_A): ParameterDict()\n", | |
" (lora_embedding_B): ParameterDict()\n", | |
" )\n", | |
" )\n", | |
" (output): EsmOutput(\n", | |
" (dense): lora.Linear(\n", | |
" (base_layer): Linear(in_features=5120, out_features=1280, bias=True)\n", | |
" (lora_dropout): ModuleDict(\n", | |
" (default): Dropout(p=0.1, inplace=False)\n", | |
" )\n", | |
" (lora_A): ModuleDict(\n", | |
" (default): Linear(in_features=5120, out_features=8, bias=False)\n", | |
" )\n", | |
" (lora_B): ModuleDict(\n", | |
" (default): Linear(in_features=8, out_features=1280, bias=False)\n", | |
" )\n", | |
" (lora_embedding_A): ParameterDict()\n", | |
" (lora_embedding_B): ParameterDict()\n", | |
" )\n", | |
" (dropout): Dropout(p=0.1, inplace=False)\n", | |
" )\n", | |
" (LayerNorm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n", | |
" )\n", | |
" )\n", | |
" (emb_layer_norm_after): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n", | |
" )\n", | |
" (contact_head): None\n", | |
" )\n", | |
" (classifier): ModulesToSaveWrapper(\n", | |
" (original_module): EsmClassificationHead(\n", | |
" (dense): Linear(in_features=1280, out_features=1280, bias=True)\n", | |
" (dropout): Dropout(p=0.1, inplace=False)\n", | |
" (out_proj): Linear(in_features=1280, out_features=1, bias=True)\n", | |
" )\n", | |
" (modules_to_save): ModuleDict(\n", | |
" (default): EsmClassificationHead(\n", | |
" (dense): Linear(in_features=1280, out_features=1280, bias=True)\n", | |
" (dropout): Dropout(p=0.1, inplace=False)\n", | |
" (out_proj): Linear(in_features=1280, out_features=1, bias=True)\n", | |
" )\n", | |
" )\n", | |
" )\n", | |
" )\n", | |
" )\n", | |
" )\n", | |
" (train_loss): MeanSquaredError()\n", | |
" (train_spearman): SpearmanCorrCoef()\n", | |
" (train_R2): R2Score()\n", | |
" (train_pearson): PearsonCorrCoef()\n", | |
" (valid_loss): MeanSquaredError()\n", | |
" (valid_spearman): SpearmanCorrCoef()\n", | |
" (valid_R2): R2Score()\n", | |
" (valid_pearson): PearsonCorrCoef()\n", | |
" (test_loss): MeanSquaredError()\n", | |
" (test_spearman): SpearmanCorrCoef()\n", | |
" (test_R2): R2Score()\n", | |
" (test_pearson): PearsonCorrCoef()\n", | |
")" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model.eval()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
" 0%| | 0/89 [00:00<?, ?it/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"{'inputs': {'input_ids': tensor([[ 0, 235, 277, ..., 1, 1, 1],\n", | |
" [ 0, 235, 340, ..., 1, 1, 1],\n", | |
" [ 0, 223, 288, ..., 1, 1, 1],\n", | |
" ...,\n", | |
" [ 0, 235, 214, ..., 1, 1, 1],\n", | |
" [ 0, 235, 130, ..., 193, 193, 2],\n", | |
" [ 0, 223, 13, ..., 1, 1, 1]]), 'attention_mask': tensor([[1, 1, 1, ..., 0, 0, 0],\n", | |
" [1, 1, 1, ..., 0, 0, 0],\n", | |
" [1, 1, 1, ..., 0, 0, 0],\n", | |
" ...,\n", | |
" [1, 1, 1, ..., 0, 0, 0],\n", | |
" [1, 1, 1, ..., 1, 1, 1],\n", | |
" [1, 1, 1, ..., 0, 0, 0]])}}\n", | |
"{'labels': tensor([0.3935, 0.6543, 0.2917, 0.9112, 0.1567, 0.2465, 0.1562, 0.6610])}\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
" 1%| | 1/89 [00:33<48:40, 33.19s/it]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"tensor([0., 0., 0., 0., 0., 0., 0., 0.])\n", | |
"{'inputs': {'input_ids': tensor([[ 0, 235, 340, ..., 183, 204, 2],\n", | |
" [ 0, 235, 256, ..., 1, 1, 1],\n", | |
" [ 0, 235, 88, ..., 1, 1, 1],\n", | |
" ...,\n", | |
" [ 0, 235, 25, ..., 1, 1, 1],\n", | |
" [ 0, 235, 25, ..., 1, 1, 1],\n", | |
" [ 0, 235, 340, ..., 1, 1, 1]]), 'attention_mask': tensor([[1, 1, 1, ..., 1, 1, 1],\n", | |
" [1, 1, 1, ..., 0, 0, 0],\n", | |
" [1, 1, 1, ..., 0, 0, 0],\n", | |
" ...,\n", | |
" [1, 1, 1, ..., 0, 0, 0],\n", | |
" [1, 1, 1, ..., 0, 0, 0],\n", | |
" [1, 1, 1, ..., 0, 0, 0]])}}\n", | |
"{'labels': tensor([0.5051, 0.2567, 0.7527, 0.0107, 0.2490, 0.6707, 0.5796, 0.2031])}\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
" 1%| | 1/89 [00:40<1:00:03, 40.94s/it]\n" | |
] | |
}, | |
{ | |
"ename": "KeyboardInterrupt", | |
"evalue": "", | |
"output_type": "error", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", | |
"Cell \u001b[0;32mIn[14], line 29\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[38;5;28mprint\u001b[39m(inputs)\n\u001b[1;32m 28\u001b[0m \u001b[38;5;28mprint\u001b[39m(labels)\n\u001b[0;32m---> 29\u001b[0m model_output \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 30\u001b[0m \u001b[38;5;28mprint\u001b[39m(model_output)\n\u001b[1;32m 31\u001b[0m outputs[dl\u001b[38;5;241m.\u001b[39mdataset\u001b[38;5;241m.\u001b[39mstage] \u001b[38;5;241m=\u001b[39m model_output\n", | |
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", | |
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", | |
"File \u001b[0;32m~/SaprotHub/saprot/model/saprot/saprot_regression_model.py:44\u001b[0m, in \u001b[0;36mSaprotRegressionModel.forward\u001b[0;34m(self, inputs, structure_info)\u001b[0m\n\u001b[1;32m 41\u001b[0m logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel\u001b[38;5;241m.\u001b[39mclassifier\u001b[38;5;241m.\u001b[39mout_proj(x)\u001b[38;5;241m.\u001b[39msqueeze(dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 43\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 44\u001b[0m logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mlogits\u001b[38;5;241m.\u001b[39msqueeze(dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 46\u001b[0m \u001b[38;5;66;03m# For ProtBERT\u001b[39;00m\n\u001b[1;32m 47\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbert\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n", | |
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", | |
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", | |
"File \u001b[0;32m~/SaprotHub/saprot/model/saprot/self_peft/peft_model.py:936\u001b[0m, in \u001b[0;36mPeftModelForSequenceClassification.forward\u001b[0;34m(self, input_ids, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, **kwargs)\u001b[0m\n\u001b[1;32m 934\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m peft_config\u001b[38;5;241m.\u001b[39mpeft_type \u001b[38;5;241m==\u001b[39m PeftType\u001b[38;5;241m.\u001b[39mPOLY:\n\u001b[1;32m 935\u001b[0m kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtask_ids\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m task_ids\n\u001b[0;32m--> 936\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbase_model\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 937\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 938\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 939\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 940\u001b[0m \u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlabels\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 941\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 942\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 943\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 944\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 945\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 947\u001b[0m batch_size \u001b[38;5;241m=\u001b[39m _get_batch_size(input_ids, inputs_embeds)\n\u001b[1;32m 948\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m attention_mask \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 949\u001b[0m \u001b[38;5;66;03m# concat prompt attention mask\u001b[39;00m\n", | |
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", | |
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", | |
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/peft/tuners/tuners_utils.py:179\u001b[0m, in \u001b[0;36mBaseTuner.forward\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 178\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any):\n\u001b[0;32m--> 179\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", | |
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/transformers/models/esm/modeling_esm.py:1103\u001b[0m, in \u001b[0;36mEsmForSequenceClassification.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, head_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 1095\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 1096\u001b[0m \u001b[38;5;124;03mlabels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\u001b[39;00m\n\u001b[1;32m 1097\u001b[0m \u001b[38;5;124;03m Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\u001b[39;00m\n\u001b[1;32m 1098\u001b[0m \u001b[38;5;124;03m config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\u001b[39;00m\n\u001b[1;32m 1099\u001b[0m \u001b[38;5;124;03m `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\u001b[39;00m\n\u001b[1;32m 1100\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 1101\u001b[0m return_dict \u001b[38;5;241m=\u001b[39m return_dict \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39muse_return_dict\n\u001b[0;32m-> 1103\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mesm\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1104\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1105\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1106\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1107\u001b[0m \u001b[43m \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhead_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1108\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1109\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1110\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1111\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1112\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1113\u001b[0m sequence_output \u001b[38;5;241m=\u001b[39m outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 1114\u001b[0m logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mclassifier(sequence_output)\n", | |
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", | |
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", | |
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/transformers/models/esm/modeling_esm.py:907\u001b[0m, in \u001b[0;36mEsmModel.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 898\u001b[0m head_mask \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_head_mask(head_mask, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mnum_hidden_layers)\n\u001b[1;32m 900\u001b[0m embedding_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39membeddings(\n\u001b[1;32m 901\u001b[0m input_ids\u001b[38;5;241m=\u001b[39minput_ids,\n\u001b[1;32m 902\u001b[0m position_ids\u001b[38;5;241m=\u001b[39mposition_ids,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 905\u001b[0m past_key_values_length\u001b[38;5;241m=\u001b[39mpast_key_values_length,\n\u001b[1;32m 906\u001b[0m )\n\u001b[0;32m--> 907\u001b[0m encoder_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mencoder\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 908\u001b[0m \u001b[43m \u001b[49m\u001b[43membedding_output\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 909\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mextended_attention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 910\u001b[0m \u001b[43m \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhead_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 911\u001b[0m \u001b[43m \u001b[49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 912\u001b[0m \u001b[43m \u001b[49m\u001b[43mencoder_attention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mencoder_extended_attention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 913\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 914\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 915\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 916\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 917\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 918\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 919\u001b[0m sequence_output \u001b[38;5;241m=\u001b[39m encoder_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 920\u001b[0m pooled_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpooler(sequence_output) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpooler \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n", | |
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", | |
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", | |
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/transformers/models/esm/modeling_esm.py:612\u001b[0m, in \u001b[0;36mEsmEncoder.forward\u001b[0;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 601\u001b[0m layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_gradient_checkpointing_func(\n\u001b[1;32m 602\u001b[0m layer_module\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__call__\u001b[39m,\n\u001b[1;32m 603\u001b[0m hidden_states,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 609\u001b[0m output_attentions,\n\u001b[1;32m 610\u001b[0m )\n\u001b[1;32m 611\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 612\u001b[0m layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[43mlayer_module\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 613\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 614\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 615\u001b[0m \u001b[43m \u001b[49m\u001b[43mlayer_head_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 616\u001b[0m \u001b[43m \u001b[49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 617\u001b[0m \u001b[43m \u001b[49m\u001b[43mencoder_attention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 618\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 619\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 620\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 622\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m layer_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 623\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m use_cache:\n", | |
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", | |
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", | |
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/transformers/models/esm/modeling_esm.py:502\u001b[0m, in \u001b[0;36mEsmLayer.forward\u001b[0;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)\u001b[0m\n\u001b[1;32m 490\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\n\u001b[1;32m 491\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 492\u001b[0m hidden_states,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 499\u001b[0m ):\n\u001b[1;32m 500\u001b[0m \u001b[38;5;66;03m# decoder uni-directional self-attention cached key/values tuple is at positions 1,2\u001b[39;00m\n\u001b[1;32m 501\u001b[0m self_attn_past_key_value \u001b[38;5;241m=\u001b[39m past_key_value[:\u001b[38;5;241m2\u001b[39m] \u001b[38;5;28;01mif\u001b[39;00m past_key_value \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 502\u001b[0m self_attention_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mattention\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 503\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 504\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 505\u001b[0m \u001b[43m \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 506\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 507\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mself_attn_past_key_value\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 508\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 509\u001b[0m attention_output \u001b[38;5;241m=\u001b[39m self_attention_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 511\u001b[0m \u001b[38;5;66;03m# if decoder, the last output is tuple of self-attn cache\u001b[39;00m\n", | |
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", | |
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", | |
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/transformers/models/esm/modeling_esm.py:436\u001b[0m, in \u001b[0;36mEsmAttention.forward\u001b[0;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)\u001b[0m\n\u001b[1;32m 425\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\n\u001b[1;32m 426\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 427\u001b[0m hidden_states,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 433\u001b[0m output_attentions\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m 434\u001b[0m ):\n\u001b[1;32m 435\u001b[0m hidden_states_ln \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mLayerNorm(hidden_states)\n\u001b[0;32m--> 436\u001b[0m self_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mself\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 437\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states_ln\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 438\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 439\u001b[0m \u001b[43m \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 440\u001b[0m \u001b[43m \u001b[49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 441\u001b[0m \u001b[43m \u001b[49m\u001b[43mencoder_attention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 442\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 443\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 444\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 445\u001b[0m attention_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput(self_outputs[\u001b[38;5;241m0\u001b[39m], hidden_states)\n\u001b[1;32m 446\u001b[0m outputs \u001b[38;5;241m=\u001b[39m (attention_output,) \u001b[38;5;241m+\u001b[39m self_outputs[\u001b[38;5;241m1\u001b[39m:] \u001b[38;5;66;03m# add attentions if we output them\u001b[39;00m\n", | |
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", | |
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", | |
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/transformers/models/esm/modeling_esm.py:316\u001b[0m, in \u001b[0;36mEsmSelfAttention.forward\u001b[0;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)\u001b[0m\n\u001b[1;32m 314\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 315\u001b[0m key_layer \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtranspose_for_scores(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mkey(hidden_states))\n\u001b[0;32m--> 316\u001b[0m value_layer \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtranspose_for_scores(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalue\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 318\u001b[0m query_layer \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtranspose_for_scores(mixed_query_layer)\n\u001b[1;32m 320\u001b[0m \u001b[38;5;66;03m# Matt: Our BERT model (which this code was derived from) scales attention logits down by sqrt(head_dim).\u001b[39;00m\n\u001b[1;32m 321\u001b[0m \u001b[38;5;66;03m# ESM scales the query down by the same factor instead. Modulo numerical stability these are equivalent,\u001b[39;00m\n\u001b[1;32m 322\u001b[0m \u001b[38;5;66;03m# but not when rotary embeddings get involved. Therefore, we scale the query here to match the original\u001b[39;00m\n\u001b[1;32m 323\u001b[0m \u001b[38;5;66;03m# ESM code and fix rotary embeddings.\u001b[39;00m\n", | |
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", | |
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", | |
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/peft/tuners/lora/layer.py:557\u001b[0m, in \u001b[0;36mLinear.forward\u001b[0;34m(self, x, *args, **kwargs)\u001b[0m\n\u001b[1;32m 555\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbase_layer(x, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 556\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 557\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbase_layer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 558\u001b[0m torch_result_dtype \u001b[38;5;241m=\u001b[39m result\u001b[38;5;241m.\u001b[39mdtype\n\u001b[1;32m 559\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m active_adapter \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mactive_adapters:\n", | |
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", | |
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n", | |
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/torch/nn/modules/linear.py:116\u001b[0m, in \u001b[0;36mLinear.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 116\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlinear\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n", | |
"\u001b[0;31mKeyboardInterrupt\u001b[0m: " | |
] | |
} | |
], | |
"source": [ | |
"from tqdm import tqdm\n", | |
"from saprot.dataset.saprot.saprot_regression_dataset import SaprotRegressionDataset\n", | |
"\n", | |
"dataset = SaprotRegressionDataset(config.model.kwargs.config_path,\n", | |
" train_lmdb=str(LMDB_HOME / \"Dataset-Thermostability-FLIP\" / \"train\"),\n", | |
" test_lmdb=str(LMDB_HOME / \"Dataset-Thermostability-FLIP\" / \"test\"),\n", | |
" valid_lmdb=str(LMDB_HOME / \"Dataset-Thermostability-FLIP\" / \"valid\"),\n", | |
" dataloader_kwargs=dict(batch_size=8),)\n", | |
"\n", | |
"outputs = {}\n", | |
"\n", | |
"with torch.no_grad():\n", | |
" for dl in (\n", | |
" # dataset.train_dataloader(),\n", | |
" dataset.val_dataloader(),\n", | |
" dataset.test_dataloader()\n", | |
" ):\n", | |
" y_pred = []\n", | |
" y = []\n", | |
"\n", | |
" # predict\n", | |
" for d in tqdm(dl):\n", | |
" inputs, labels = d\n", | |
"\n", | |
"\n", | |
" inputs = {k: v.to(device) for k, v in inputs.items()}\n", | |
" print(inputs)\n", | |
" print(labels)\n", | |
" model_output = model(**inputs)\n", | |
" print(model_output)\n", | |
" outputs[dl.dataset.stage] = model_output\n", | |
" y_pred.append(model_output)\n", | |
" y.append(labels['labels'])\n", | |
"\n", | |
" # save predictions\n", | |
" outputs[dl.dataset.stage] = (torch.cat(y_pred).cpu().numpy(), torch.cat(y).cpu().numpy()) \n", | |
"\n" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "dipl-saprot", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.12.0" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment