Skip to content

Instantly share code, notes, and snippets.

@adam-kral
Created July 10, 2024 09:43
Show Gist options
  • Save adam-kral/d8c82f02f77ae0ec1c0f9255b29c3ab6 to your computer and use it in GitHub Desktop.
Save adam-kral/d8c82f02f77ae0ec1c0f9255b29c3ab6 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"model_type_dict = {\n",
" \"classification\" : \"saprot/saprot_classification_model\",\n",
" \"token_classification\" : \"saprot/saprot_token_classification_model\",\n",
" \"regression\" : \"saprot/saprot_regression_model\",\n",
" \"pair_classification\" : \"saprot/saprot_pair_classification_model\",\n",
" \"pair_regression\" : \"saprot/saprot_pair_regression_model\",\n",
"}\n",
"multi_lora = False\n",
"task_type = 'regression'\n",
"\n",
"ADAPTER_HOME = Path('adapter_home')\n",
"LMDB_HOME = Path('lmdb_home')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"adapter_input(value='SaProtHub/Model-Thermostability-650M')"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from collections import namedtuple\n",
"# create a named tuple for the input to the adapter\n",
"adapter_input = namedtuple('adapter_input', ['value'])(value='SaProtHub/Model-Thermostability-650M')\n",
"adapter_input"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# add ../saprot to the path\n",
"import sys\n",
"sys.path.append('..')\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Fetching 4 files: 100%|██████████| 4/4 [00:00<00:00, 582.50it/s]\n"
]
}
],
"source": [
"from pathlib import Path\n",
"import json\n",
"from easydict import EasyDict\n",
"\n",
"\n",
"#@title 3.1.2: Get your Result\n",
"from transformers import EsmTokenizer\n",
"import torch\n",
"import copy\n",
"import sys\n",
"from saprot.scripts.training import my_load_model\n",
"from huggingface_hub import snapshot_download\n",
"\n",
"\n",
"################################################################################\n",
"################################# 0. MARKDOWN ##################################\n",
"################################################################################\n",
"\n",
"\n",
"# @markdown Click the run button to make prediction.\n",
"\n",
"# @markdown <font color=\"red\">**Note that:**</font> When predicting a category, the index of categories starts from zero.\n",
"\n",
"################################################################################\n",
"################################# 1. DATASET ##################################\n",
"################################################################################\n",
"\n",
"# if mode == \"Multiple Sequences\":\n",
"# csv_dataset_path = get_SA_sequence_by_data_type(data_type, raw_data)\n",
"# else:\n",
"# single_sa_seq = get_SA_sequence_by_data_type(data_type, raw_data)\n",
"\n",
"\n",
"################################################################################\n",
"################################# 2. MODEL ##################################\n",
"################################################################################\n",
"def get_base_model(adapter_path):\n",
" adapter_config = Path(adapter_path) / \"adapter_config.json\"\n",
" with open(adapter_config, 'r') as f:\n",
" adapter_config_dict = json.load(f)\n",
" base_model = adapter_config_dict['base_model_name_or_path']\n",
" if 'SaProt_650M_AF2' in base_model:\n",
" base_model = \"westlake-repl/SaProt_650M_AF2\"\n",
" elif 'SaProt_35M_AF2' in base_model:\n",
" base_model = \"westlake-repl/SaProt_35M_AF2\"\n",
" else:\n",
" raise RuntimeError(\"Please ensure the base model is \\\"SaProt_650M_AF2\\\" or \\\"SaProt_35M_AF2\\\"\")\n",
" return base_model\n",
"\n",
"\n",
"# base_model = \"westlake-repl/SaProt_35M_AF2\"\n",
"\n",
"\n",
"snapshot_download(repo_id=adapter_input.value, repo_type=\"model\", local_dir=ADAPTER_HOME / task_type / adapter_input.value)\n",
"\n",
"adapter_path = ADAPTER_HOME / task_type / adapter_input.value\n",
"base_model = get_base_model(adapter_path)\n",
"\n",
"lora_kwargs = {\n",
" \"num_lora\": 1,\n",
" \"config_list\": [{\"lora_config_path\": adapter_path}]\n",
"}\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/krala/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
" warnings.warn(\n",
"Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at westlake-repl/SaProt_650M_AF2 and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight', 'esm.contact_head.regression.bias', 'esm.contact_head.regression.weight', 'esm.embeddings.position_embeddings.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
"/Users/krala/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: Metric `SpearmanCorrcoef` will save all targets and predictions in the buffer. For large datasets, this may lead to large memory footprint.\n",
" warnings.warn(*args, **kwargs) # noqa: B028\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"No optimizer_kwargs provided. The default optimizer is AdamW.\n",
"Now active LoRA model: default\n",
"trainable params: 7,723,521 || all params: 659,293,442 || trainable%: 1.1714845784860697\n"
]
},
{
"data": {
"text/plain": [
"SaprotRegressionModel(\n",
" (model): PeftModelForSequenceClassification(\n",
" (base_model): LoraModel(\n",
" (model): EsmForSequenceClassification(\n",
" (esm): EsmModel(\n",
" (embeddings): EsmEmbeddings(\n",
" (word_embeddings): Embedding(446, 1280, padding_idx=1)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (position_embeddings): None\n",
" )\n",
" (encoder): EsmEncoder(\n",
" (layer): ModuleList(\n",
" (0-32): 33 x EsmLayer(\n",
" (attention): EsmAttention(\n",
" (self): EsmSelfAttention(\n",
" (query): lora.Linear(\n",
" (base_layer): Linear(in_features=1280, out_features=1280, bias=True)\n",
" (lora_dropout): ModuleDict(\n",
" (default): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (lora_A): ModuleDict(\n",
" (default): Linear(in_features=1280, out_features=8, bias=False)\n",
" )\n",
" (lora_B): ModuleDict(\n",
" (default): Linear(in_features=8, out_features=1280, bias=False)\n",
" )\n",
" (lora_embedding_A): ParameterDict()\n",
" (lora_embedding_B): ParameterDict()\n",
" )\n",
" (key): lora.Linear(\n",
" (base_layer): Linear(in_features=1280, out_features=1280, bias=True)\n",
" (lora_dropout): ModuleDict(\n",
" (default): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (lora_A): ModuleDict(\n",
" (default): Linear(in_features=1280, out_features=8, bias=False)\n",
" )\n",
" (lora_B): ModuleDict(\n",
" (default): Linear(in_features=8, out_features=1280, bias=False)\n",
" )\n",
" (lora_embedding_A): ParameterDict()\n",
" (lora_embedding_B): ParameterDict()\n",
" )\n",
" (value): lora.Linear(\n",
" (base_layer): Linear(in_features=1280, out_features=1280, bias=True)\n",
" (lora_dropout): ModuleDict(\n",
" (default): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (lora_A): ModuleDict(\n",
" (default): Linear(in_features=1280, out_features=8, bias=False)\n",
" )\n",
" (lora_B): ModuleDict(\n",
" (default): Linear(in_features=8, out_features=1280, bias=False)\n",
" )\n",
" (lora_embedding_A): ParameterDict()\n",
" (lora_embedding_B): ParameterDict()\n",
" )\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (rotary_embeddings): RotaryEmbedding()\n",
" )\n",
" (output): EsmSelfOutput(\n",
" (dense): lora.Linear(\n",
" (base_layer): Linear(in_features=1280, out_features=1280, bias=True)\n",
" (lora_dropout): ModuleDict(\n",
" (default): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (lora_A): ModuleDict(\n",
" (default): Linear(in_features=1280, out_features=8, bias=False)\n",
" )\n",
" (lora_B): ModuleDict(\n",
" (default): Linear(in_features=8, out_features=1280, bias=False)\n",
" )\n",
" (lora_embedding_A): ParameterDict()\n",
" (lora_embedding_B): ParameterDict()\n",
" )\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (LayerNorm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (intermediate): EsmIntermediate(\n",
" (dense): lora.Linear(\n",
" (base_layer): Linear(in_features=1280, out_features=5120, bias=True)\n",
" (lora_dropout): ModuleDict(\n",
" (default): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (lora_A): ModuleDict(\n",
" (default): Linear(in_features=1280, out_features=8, bias=False)\n",
" )\n",
" (lora_B): ModuleDict(\n",
" (default): Linear(in_features=8, out_features=5120, bias=False)\n",
" )\n",
" (lora_embedding_A): ParameterDict()\n",
" (lora_embedding_B): ParameterDict()\n",
" )\n",
" )\n",
" (output): EsmOutput(\n",
" (dense): lora.Linear(\n",
" (base_layer): Linear(in_features=5120, out_features=1280, bias=True)\n",
" (lora_dropout): ModuleDict(\n",
" (default): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (lora_A): ModuleDict(\n",
" (default): Linear(in_features=5120, out_features=8, bias=False)\n",
" )\n",
" (lora_B): ModuleDict(\n",
" (default): Linear(in_features=8, out_features=1280, bias=False)\n",
" )\n",
" (lora_embedding_A): ParameterDict()\n",
" (lora_embedding_B): ParameterDict()\n",
" )\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (LayerNorm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" )\n",
" (emb_layer_norm_after): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (contact_head): None\n",
" )\n",
" (classifier): ModulesToSaveWrapper(\n",
" (original_module): EsmClassificationHead(\n",
" (dense): Linear(in_features=1280, out_features=1280, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (out_proj): Linear(in_features=1280, out_features=1, bias=True)\n",
" )\n",
" (modules_to_save): ModuleDict(\n",
" (default): EsmClassificationHead(\n",
" (dense): Linear(in_features=1280, out_features=1280, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (out_proj): Linear(in_features=1280, out_features=1, bias=True)\n",
" )\n",
" )\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (train_loss): MeanSquaredError()\n",
" (train_spearman): SpearmanCorrCoef()\n",
" (train_R2): R2Score()\n",
" (train_pearson): PearsonCorrCoef()\n",
" (valid_loss): MeanSquaredError()\n",
" (valid_spearman): SpearmanCorrCoef()\n",
" (valid_R2): R2Score()\n",
" (valid_pearson): PearsonCorrCoef()\n",
" (test_loss): MeanSquaredError()\n",
" (test_spearman): SpearmanCorrCoef()\n",
" (test_R2): R2Score()\n",
" (test_pearson): PearsonCorrCoef()\n",
")"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# if use_existing_model:\n",
"# if adapter_combobox.value =='':\n",
"# print(\"Please select a model!\")\n",
"# sys.exit()\n",
"\n",
"# if \". \" in adapter_combobox.value:\n",
"# adapter_path = ADAPTER_HOME / task_type / adapter_combobox.value\n",
"# else:\n",
"# adapter_path = adapter_combobox.value\n",
"\n",
"################################################################################\n",
"##################################### config ###################################\n",
"################################################################################\n",
"from saprot.config.config_dict import Default_config\n",
"config = copy.deepcopy(Default_config)\n",
"\n",
"# base model\n",
"config.model.model_py_path = model_type_dict[task_type]\n",
"# config.model.save_path = model_save_path\n",
"config.model.kwargs.config_path = base_model\n",
"\n",
"# lora\n",
"# config.model.kwargs.lora_config_path = adapter_path\n",
"# config.model.kwargs.use_lora = True\n",
"# config.model.kwargs.lora_inference = True\n",
"config.model.lora_kwargs = lora_kwargs\n",
"\n",
"################################################################################\n",
"################################### inference ##################################\n",
"################################################################################\n",
"from peft import PeftModelForSequenceClassification\n",
"\n",
"config.model.lora_kwargs.is_trainable = True\n",
"model = my_load_model(config.model)\n",
"tokenizer = EsmTokenizer.from_pretrained(config.model.kwargs.config_path)\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"model.to(device)\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'input_ids': [0, 3, 2], 'attention_mask': [1, 1, 1]}"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"t = EsmTokenizer.from_pretrained(config.model.kwargs.config_path)\n",
"t(\"A\")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"SaprotRegressionModel(\n",
" (model): PeftModelForSequenceClassification(\n",
" (base_model): LoraModel(\n",
" (model): EsmForSequenceClassification(\n",
" (esm): EsmModel(\n",
" (embeddings): EsmEmbeddings(\n",
" (word_embeddings): Embedding(446, 1280, padding_idx=1)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (position_embeddings): None\n",
" )\n",
" (encoder): EsmEncoder(\n",
" (layer): ModuleList(\n",
" (0-32): 33 x EsmLayer(\n",
" (attention): EsmAttention(\n",
" (self): EsmSelfAttention(\n",
" (query): lora.Linear(\n",
" (base_layer): Linear(in_features=1280, out_features=1280, bias=True)\n",
" (lora_dropout): ModuleDict(\n",
" (default): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (lora_A): ModuleDict(\n",
" (default): Linear(in_features=1280, out_features=8, bias=False)\n",
" )\n",
" (lora_B): ModuleDict(\n",
" (default): Linear(in_features=8, out_features=1280, bias=False)\n",
" )\n",
" (lora_embedding_A): ParameterDict()\n",
" (lora_embedding_B): ParameterDict()\n",
" )\n",
" (key): lora.Linear(\n",
" (base_layer): Linear(in_features=1280, out_features=1280, bias=True)\n",
" (lora_dropout): ModuleDict(\n",
" (default): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (lora_A): ModuleDict(\n",
" (default): Linear(in_features=1280, out_features=8, bias=False)\n",
" )\n",
" (lora_B): ModuleDict(\n",
" (default): Linear(in_features=8, out_features=1280, bias=False)\n",
" )\n",
" (lora_embedding_A): ParameterDict()\n",
" (lora_embedding_B): ParameterDict()\n",
" )\n",
" (value): lora.Linear(\n",
" (base_layer): Linear(in_features=1280, out_features=1280, bias=True)\n",
" (lora_dropout): ModuleDict(\n",
" (default): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (lora_A): ModuleDict(\n",
" (default): Linear(in_features=1280, out_features=8, bias=False)\n",
" )\n",
" (lora_B): ModuleDict(\n",
" (default): Linear(in_features=8, out_features=1280, bias=False)\n",
" )\n",
" (lora_embedding_A): ParameterDict()\n",
" (lora_embedding_B): ParameterDict()\n",
" )\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (rotary_embeddings): RotaryEmbedding()\n",
" )\n",
" (output): EsmSelfOutput(\n",
" (dense): lora.Linear(\n",
" (base_layer): Linear(in_features=1280, out_features=1280, bias=True)\n",
" (lora_dropout): ModuleDict(\n",
" (default): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (lora_A): ModuleDict(\n",
" (default): Linear(in_features=1280, out_features=8, bias=False)\n",
" )\n",
" (lora_B): ModuleDict(\n",
" (default): Linear(in_features=8, out_features=1280, bias=False)\n",
" )\n",
" (lora_embedding_A): ParameterDict()\n",
" (lora_embedding_B): ParameterDict()\n",
" )\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (LayerNorm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (intermediate): EsmIntermediate(\n",
" (dense): lora.Linear(\n",
" (base_layer): Linear(in_features=1280, out_features=5120, bias=True)\n",
" (lora_dropout): ModuleDict(\n",
" (default): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (lora_A): ModuleDict(\n",
" (default): Linear(in_features=1280, out_features=8, bias=False)\n",
" )\n",
" (lora_B): ModuleDict(\n",
" (default): Linear(in_features=8, out_features=5120, bias=False)\n",
" )\n",
" (lora_embedding_A): ParameterDict()\n",
" (lora_embedding_B): ParameterDict()\n",
" )\n",
" )\n",
" (output): EsmOutput(\n",
" (dense): lora.Linear(\n",
" (base_layer): Linear(in_features=5120, out_features=1280, bias=True)\n",
" (lora_dropout): ModuleDict(\n",
" (default): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (lora_A): ModuleDict(\n",
" (default): Linear(in_features=5120, out_features=8, bias=False)\n",
" )\n",
" (lora_B): ModuleDict(\n",
" (default): Linear(in_features=8, out_features=1280, bias=False)\n",
" )\n",
" (lora_embedding_A): ParameterDict()\n",
" (lora_embedding_B): ParameterDict()\n",
" )\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (LayerNorm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" )\n",
" (emb_layer_norm_after): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
" )\n",
" (contact_head): None\n",
" )\n",
" (classifier): ModulesToSaveWrapper(\n",
" (original_module): EsmClassificationHead(\n",
" (dense): Linear(in_features=1280, out_features=1280, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (out_proj): Linear(in_features=1280, out_features=1, bias=True)\n",
" )\n",
" (modules_to_save): ModuleDict(\n",
" (default): EsmClassificationHead(\n",
" (dense): Linear(in_features=1280, out_features=1280, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (out_proj): Linear(in_features=1280, out_features=1, bias=True)\n",
" )\n",
" )\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (train_loss): MeanSquaredError()\n",
" (train_spearman): SpearmanCorrCoef()\n",
" (train_R2): R2Score()\n",
" (train_pearson): PearsonCorrCoef()\n",
" (valid_loss): MeanSquaredError()\n",
" (valid_spearman): SpearmanCorrCoef()\n",
" (valid_R2): R2Score()\n",
" (valid_pearson): PearsonCorrCoef()\n",
" (test_loss): MeanSquaredError()\n",
" (test_spearman): SpearmanCorrCoef()\n",
" (test_R2): R2Score()\n",
" (test_pearson): PearsonCorrCoef()\n",
")"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.eval()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/89 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'inputs': {'input_ids': tensor([[ 0, 235, 277, ..., 1, 1, 1],\n",
" [ 0, 235, 340, ..., 1, 1, 1],\n",
" [ 0, 223, 288, ..., 1, 1, 1],\n",
" ...,\n",
" [ 0, 235, 214, ..., 1, 1, 1],\n",
" [ 0, 235, 130, ..., 193, 193, 2],\n",
" [ 0, 223, 13, ..., 1, 1, 1]]), 'attention_mask': tensor([[1, 1, 1, ..., 0, 0, 0],\n",
" [1, 1, 1, ..., 0, 0, 0],\n",
" [1, 1, 1, ..., 0, 0, 0],\n",
" ...,\n",
" [1, 1, 1, ..., 0, 0, 0],\n",
" [1, 1, 1, ..., 1, 1, 1],\n",
" [1, 1, 1, ..., 0, 0, 0]])}}\n",
"{'labels': tensor([0.3935, 0.6543, 0.2917, 0.9112, 0.1567, 0.2465, 0.1562, 0.6610])}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 1%| | 1/89 [00:33<48:40, 33.19s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([0., 0., 0., 0., 0., 0., 0., 0.])\n",
"{'inputs': {'input_ids': tensor([[ 0, 235, 340, ..., 183, 204, 2],\n",
" [ 0, 235, 256, ..., 1, 1, 1],\n",
" [ 0, 235, 88, ..., 1, 1, 1],\n",
" ...,\n",
" [ 0, 235, 25, ..., 1, 1, 1],\n",
" [ 0, 235, 25, ..., 1, 1, 1],\n",
" [ 0, 235, 340, ..., 1, 1, 1]]), 'attention_mask': tensor([[1, 1, 1, ..., 1, 1, 1],\n",
" [1, 1, 1, ..., 0, 0, 0],\n",
" [1, 1, 1, ..., 0, 0, 0],\n",
" ...,\n",
" [1, 1, 1, ..., 0, 0, 0],\n",
" [1, 1, 1, ..., 0, 0, 0],\n",
" [1, 1, 1, ..., 0, 0, 0]])}}\n",
"{'labels': tensor([0.5051, 0.2567, 0.7527, 0.0107, 0.2490, 0.6707, 0.5796, 0.2031])}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 1%| | 1/89 [00:40<1:00:03, 40.94s/it]\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[14], line 29\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[38;5;28mprint\u001b[39m(inputs)\n\u001b[1;32m 28\u001b[0m \u001b[38;5;28mprint\u001b[39m(labels)\n\u001b[0;32m---> 29\u001b[0m model_output \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 30\u001b[0m \u001b[38;5;28mprint\u001b[39m(model_output)\n\u001b[1;32m 31\u001b[0m outputs[dl\u001b[38;5;241m.\u001b[39mdataset\u001b[38;5;241m.\u001b[39mstage] \u001b[38;5;241m=\u001b[39m model_output\n",
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
"File \u001b[0;32m~/SaprotHub/saprot/model/saprot/saprot_regression_model.py:44\u001b[0m, in \u001b[0;36mSaprotRegressionModel.forward\u001b[0;34m(self, inputs, structure_info)\u001b[0m\n\u001b[1;32m 41\u001b[0m logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel\u001b[38;5;241m.\u001b[39mclassifier\u001b[38;5;241m.\u001b[39mout_proj(x)\u001b[38;5;241m.\u001b[39msqueeze(dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 43\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 44\u001b[0m logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mlogits\u001b[38;5;241m.\u001b[39msqueeze(dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 46\u001b[0m \u001b[38;5;66;03m# For ProtBERT\u001b[39;00m\n\u001b[1;32m 47\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbert\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n",
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
"File \u001b[0;32m~/SaprotHub/saprot/model/saprot/self_peft/peft_model.py:936\u001b[0m, in \u001b[0;36mPeftModelForSequenceClassification.forward\u001b[0;34m(self, input_ids, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, **kwargs)\u001b[0m\n\u001b[1;32m 934\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m peft_config\u001b[38;5;241m.\u001b[39mpeft_type \u001b[38;5;241m==\u001b[39m PeftType\u001b[38;5;241m.\u001b[39mPOLY:\n\u001b[1;32m 935\u001b[0m kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtask_ids\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m task_ids\n\u001b[0;32m--> 936\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbase_model\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 937\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 938\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 939\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 940\u001b[0m \u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlabels\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 941\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 942\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 943\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 944\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 945\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 947\u001b[0m batch_size \u001b[38;5;241m=\u001b[39m _get_batch_size(input_ids, inputs_embeds)\n\u001b[1;32m 948\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m attention_mask \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 949\u001b[0m \u001b[38;5;66;03m# concat prompt attention mask\u001b[39;00m\n",
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/peft/tuners/tuners_utils.py:179\u001b[0m, in \u001b[0;36mBaseTuner.forward\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 178\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any):\n\u001b[0;32m--> 179\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/transformers/models/esm/modeling_esm.py:1103\u001b[0m, in \u001b[0;36mEsmForSequenceClassification.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, head_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 1095\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 1096\u001b[0m \u001b[38;5;124;03mlabels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):\u001b[39;00m\n\u001b[1;32m 1097\u001b[0m \u001b[38;5;124;03m Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\u001b[39;00m\n\u001b[1;32m 1098\u001b[0m \u001b[38;5;124;03m config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\u001b[39;00m\n\u001b[1;32m 1099\u001b[0m \u001b[38;5;124;03m `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\u001b[39;00m\n\u001b[1;32m 1100\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 1101\u001b[0m return_dict \u001b[38;5;241m=\u001b[39m return_dict \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39muse_return_dict\n\u001b[0;32m-> 1103\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mesm\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1104\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1105\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1106\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1107\u001b[0m \u001b[43m \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhead_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1108\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1109\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1110\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1111\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1112\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1113\u001b[0m sequence_output \u001b[38;5;241m=\u001b[39m outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 1114\u001b[0m logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mclassifier(sequence_output)\n",
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/transformers/models/esm/modeling_esm.py:907\u001b[0m, in \u001b[0;36mEsmModel.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 898\u001b[0m head_mask \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mget_head_mask(head_mask, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mnum_hidden_layers)\n\u001b[1;32m 900\u001b[0m embedding_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39membeddings(\n\u001b[1;32m 901\u001b[0m input_ids\u001b[38;5;241m=\u001b[39minput_ids,\n\u001b[1;32m 902\u001b[0m position_ids\u001b[38;5;241m=\u001b[39mposition_ids,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 905\u001b[0m past_key_values_length\u001b[38;5;241m=\u001b[39mpast_key_values_length,\n\u001b[1;32m 906\u001b[0m )\n\u001b[0;32m--> 907\u001b[0m encoder_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mencoder\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 908\u001b[0m \u001b[43m \u001b[49m\u001b[43membedding_output\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 909\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mextended_attention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 910\u001b[0m \u001b[43m \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhead_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 911\u001b[0m \u001b[43m \u001b[49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 912\u001b[0m \u001b[43m \u001b[49m\u001b[43mencoder_attention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mencoder_extended_attention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 913\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 914\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 915\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 916\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 917\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 918\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 919\u001b[0m sequence_output \u001b[38;5;241m=\u001b[39m encoder_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 920\u001b[0m pooled_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpooler(sequence_output) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpooler \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n",
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/transformers/models/esm/modeling_esm.py:612\u001b[0m, in \u001b[0;36mEsmEncoder.forward\u001b[0;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 601\u001b[0m layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_gradient_checkpointing_func(\n\u001b[1;32m 602\u001b[0m layer_module\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__call__\u001b[39m,\n\u001b[1;32m 603\u001b[0m hidden_states,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 609\u001b[0m output_attentions,\n\u001b[1;32m 610\u001b[0m )\n\u001b[1;32m 611\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 612\u001b[0m layer_outputs \u001b[38;5;241m=\u001b[39m \u001b[43mlayer_module\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 613\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 614\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 615\u001b[0m \u001b[43m \u001b[49m\u001b[43mlayer_head_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 616\u001b[0m \u001b[43m \u001b[49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 617\u001b[0m \u001b[43m \u001b[49m\u001b[43mencoder_attention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 618\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 619\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 620\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 622\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m layer_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 623\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m use_cache:\n",
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/transformers/models/esm/modeling_esm.py:502\u001b[0m, in \u001b[0;36mEsmLayer.forward\u001b[0;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)\u001b[0m\n\u001b[1;32m 490\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\n\u001b[1;32m 491\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 492\u001b[0m hidden_states,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 499\u001b[0m ):\n\u001b[1;32m 500\u001b[0m \u001b[38;5;66;03m# decoder uni-directional self-attention cached key/values tuple is at positions 1,2\u001b[39;00m\n\u001b[1;32m 501\u001b[0m self_attn_past_key_value \u001b[38;5;241m=\u001b[39m past_key_value[:\u001b[38;5;241m2\u001b[39m] \u001b[38;5;28;01mif\u001b[39;00m past_key_value \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 502\u001b[0m self_attention_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mattention\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 503\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 504\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 505\u001b[0m \u001b[43m \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 506\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 507\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mself_attn_past_key_value\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 508\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 509\u001b[0m attention_output \u001b[38;5;241m=\u001b[39m self_attention_outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 511\u001b[0m \u001b[38;5;66;03m# if decoder, the last output is tuple of self-attn cache\u001b[39;00m\n",
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/transformers/models/esm/modeling_esm.py:436\u001b[0m, in \u001b[0;36mEsmAttention.forward\u001b[0;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)\u001b[0m\n\u001b[1;32m 425\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\n\u001b[1;32m 426\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 427\u001b[0m hidden_states,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 433\u001b[0m output_attentions\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m 434\u001b[0m ):\n\u001b[1;32m 435\u001b[0m hidden_states_ln \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mLayerNorm(hidden_states)\n\u001b[0;32m--> 436\u001b[0m self_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mself\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 437\u001b[0m \u001b[43m \u001b[49m\u001b[43mhidden_states_ln\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 438\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 439\u001b[0m \u001b[43m \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 440\u001b[0m \u001b[43m \u001b[49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 441\u001b[0m \u001b[43m \u001b[49m\u001b[43mencoder_attention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 442\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_value\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 443\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 444\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 445\u001b[0m attention_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput(self_outputs[\u001b[38;5;241m0\u001b[39m], hidden_states)\n\u001b[1;32m 446\u001b[0m outputs \u001b[38;5;241m=\u001b[39m (attention_output,) \u001b[38;5;241m+\u001b[39m self_outputs[\u001b[38;5;241m1\u001b[39m:] \u001b[38;5;66;03m# add attentions if we output them\u001b[39;00m\n",
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/transformers/models/esm/modeling_esm.py:316\u001b[0m, in \u001b[0;36mEsmSelfAttention.forward\u001b[0;34m(self, hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions)\u001b[0m\n\u001b[1;32m 314\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 315\u001b[0m key_layer \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtranspose_for_scores(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mkey(hidden_states))\n\u001b[0;32m--> 316\u001b[0m value_layer \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtranspose_for_scores(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalue\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 318\u001b[0m query_layer \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtranspose_for_scores(mixed_query_layer)\n\u001b[1;32m 320\u001b[0m \u001b[38;5;66;03m# Matt: Our BERT model (which this code was derived from) scales attention logits down by sqrt(head_dim).\u001b[39;00m\n\u001b[1;32m 321\u001b[0m \u001b[38;5;66;03m# ESM scales the query down by the same factor instead. Modulo numerical stability these are equivalent,\u001b[39;00m\n\u001b[1;32m 322\u001b[0m \u001b[38;5;66;03m# but not when rotary embeddings get involved. Therefore, we scale the query here to match the original\u001b[39;00m\n\u001b[1;32m 323\u001b[0m \u001b[38;5;66;03m# ESM code and fix rotary embeddings.\u001b[39;00m\n",
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/peft/tuners/lora/layer.py:557\u001b[0m, in \u001b[0;36mLinear.forward\u001b[0;34m(self, x, *args, **kwargs)\u001b[0m\n\u001b[1;32m 555\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbase_layer(x, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 556\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 557\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbase_layer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 558\u001b[0m torch_result_dtype \u001b[38;5;241m=\u001b[39m result\u001b[38;5;241m.\u001b[39mdtype\n\u001b[1;32m 559\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m active_adapter \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mactive_adapters:\n",
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/torch/nn/modules/module.py:1532\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1531\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/torch/nn/modules/module.py:1541\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1536\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1539\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1544\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
"File \u001b[0;32m~/miniforge3/envs/dipl-saprot/lib/python3.12/site-packages/torch/nn/modules/linear.py:116\u001b[0m, in \u001b[0;36mLinear.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 116\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlinear\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"from tqdm import tqdm\n",
"from saprot.dataset.saprot.saprot_regression_dataset import SaprotRegressionDataset\n",
"\n",
"dataset = SaprotRegressionDataset(config.model.kwargs.config_path,\n",
" train_lmdb=str(LMDB_HOME / \"Dataset-Thermostability-FLIP\" / \"train\"),\n",
" test_lmdb=str(LMDB_HOME / \"Dataset-Thermostability-FLIP\" / \"test\"),\n",
" valid_lmdb=str(LMDB_HOME / \"Dataset-Thermostability-FLIP\" / \"valid\"),\n",
" dataloader_kwargs=dict(batch_size=8),)\n",
"\n",
"outputs = {}\n",
"\n",
"with torch.no_grad():\n",
" for dl in (\n",
" # dataset.train_dataloader(),\n",
" dataset.val_dataloader(),\n",
" dataset.test_dataloader()\n",
" ):\n",
" y_pred = []\n",
" y = []\n",
"\n",
" # predict\n",
" for d in tqdm(dl):\n",
" inputs, labels = d\n",
"\n",
"\n",
" inputs = {k: v.to(device) for k, v in inputs.items()}\n",
" print(inputs)\n",
" print(labels)\n",
" model_output = model(**inputs)\n",
" print(model_output)\n",
" outputs[dl.dataset.stage] = model_output\n",
" y_pred.append(model_output)\n",
" y.append(labels['labels'])\n",
"\n",
" # save predictions\n",
" outputs[dl.dataset.stage] = (torch.cat(y_pred).cpu().numpy(), torch.cat(y).cpu().numpy()) \n",
"\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "dipl-saprot",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment