Created
May 7, 2023 21:19
-
-
Save dienhoa/c391e69128ee5d5b35cc419092167741 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "bfc2065e-6db9-4acf-92d8-215d37c02cdb", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", | |
"torchvision 0.12.0a0 requires torch==1.11.0a0+17540c5, but you have torch 2.0.0 which is incompatible.\n", | |
"torchtext 0.12.0a0 requires torch==1.11.0a0+17540c5, but you have torch 2.0.0 which is incompatible.\n", | |
"fastai 2.7.4 requires torch<1.12,>=1.7.0, but you have torch 2.0.0 which is incompatible.\u001b[0m\u001b[31m\n", | |
"\u001b[0m" | |
] | |
} | |
], | |
"source": [ | |
"!pip install -qq speechbrain\n", | |
"!pip install -qq seaborn\n", | |
"!pip install -Uqq huggingface_hub" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "93ac4ec0-0c4a-4e76-a8b6-0b7c5d24e985", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/opt/conda/lib/python3.8/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: /opt/conda/lib/python3.8/site-packages/torchvision/image.so: undefined symbol: _ZN5torch3jit17parseSchemaOrNameERKNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEE\n", | |
" warn(f\"Failed to load image Python extension: {e}\")\n" | |
] | |
} | |
], | |
"source": [ | |
"import speechbrain as sb\n", | |
"from speechbrain.pretrained import SpeakerRecognition, EncoderClassifier\n", | |
"from speechbrain.dataio.dataio import read_audio\n", | |
"from speechbrain.utils.metric_stats import EER\n", | |
"import pandas as pd\n", | |
"import numpy as np\n", | |
"import torch\n", | |
"import seaborn as sns\n", | |
"import matplotlib.pyplot as plt\n", | |
"from huggingface_hub import hf_hub_download\n", | |
"from tqdm.notebook import tqdm\n", | |
"import random" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "c1349881-7dfd-4fa5-b1e2-b9e80c9cf40e", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<style type=\"text/css\">\n", | |
"#T_36771 caption {\n", | |
" font-size: 20px;\n", | |
"}\n", | |
"</style>\n", | |
"<table id=\"T_36771\">\n", | |
" <caption>metadata</caption>\n", | |
" <thead>\n", | |
" <tr>\n", | |
" <th class=\"blank level0\" > </th>\n", | |
" <th id=\"T_36771_level0_col0\" class=\"col_heading level0 col0\" >baby_id</th>\n", | |
" <th id=\"T_36771_level0_col1\" class=\"col_heading level0 col1\" >period</th>\n", | |
" <th id=\"T_36771_level0_col2\" class=\"col_heading level0 col2\" >duration</th>\n", | |
" <th id=\"T_36771_level0_col3\" class=\"col_heading level0 col3\" >split</th>\n", | |
" <th id=\"T_36771_level0_col4\" class=\"col_heading level0 col4\" >chronological_index</th>\n", | |
" <th id=\"T_36771_level0_col5\" class=\"col_heading level0 col5\" >file_name</th>\n", | |
" <th id=\"T_36771_level0_col6\" class=\"col_heading level0 col6\" >file_id</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th id=\"T_36771_level0_row0\" class=\"row_heading level0 row0\" >0</th>\n", | |
" <td id=\"T_36771_row0_col0\" class=\"data row0 col0\" >0694</td>\n", | |
" <td id=\"T_36771_row0_col1\" class=\"data row0 col1\" >B</td>\n", | |
" <td id=\"T_36771_row0_col2\" class=\"data row0 col2\" >1.320000</td>\n", | |
" <td id=\"T_36771_row0_col3\" class=\"data row0 col3\" >dev</td>\n", | |
" <td id=\"T_36771_row0_col4\" class=\"data row0 col4\" >000</td>\n", | |
" <td id=\"T_36771_row0_col5\" class=\"data row0 col5\" >audio/dev/0694/B/0694_B_000.wav</td>\n", | |
" <td id=\"T_36771_row0_col6\" class=\"data row0 col6\" >0694_B_000</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th id=\"T_36771_level0_row1\" class=\"row_heading level0 row1\" >1</th>\n", | |
" <td id=\"T_36771_row1_col0\" class=\"data row1 col0\" >0694</td>\n", | |
" <td id=\"T_36771_row1_col1\" class=\"data row1 col1\" >B</td>\n", | |
" <td id=\"T_36771_row1_col2\" class=\"data row1 col2\" >0.940000</td>\n", | |
" <td id=\"T_36771_row1_col3\" class=\"data row1 col3\" >dev</td>\n", | |
" <td id=\"T_36771_row1_col4\" class=\"data row1 col4\" >001</td>\n", | |
" <td id=\"T_36771_row1_col5\" class=\"data row1 col5\" >audio/dev/0694/B/0694_B_001.wav</td>\n", | |
" <td id=\"T_36771_row1_col6\" class=\"data row1 col6\" >0694_B_001</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th id=\"T_36771_level0_row2\" class=\"row_heading level0 row2\" >2</th>\n", | |
" <td id=\"T_36771_row2_col0\" class=\"data row2 col0\" >0694</td>\n", | |
" <td id=\"T_36771_row2_col1\" class=\"data row2 col1\" >B</td>\n", | |
" <td id=\"T_36771_row2_col2\" class=\"data row2 col2\" >0.880000</td>\n", | |
" <td id=\"T_36771_row2_col3\" class=\"data row2 col3\" >dev</td>\n", | |
" <td id=\"T_36771_row2_col4\" class=\"data row2 col4\" >002</td>\n", | |
" <td id=\"T_36771_row2_col5\" class=\"data row2 col5\" >audio/dev/0694/B/0694_B_002.wav</td>\n", | |
" <td id=\"T_36771_row2_col6\" class=\"data row2 col6\" >0694_B_002</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th id=\"T_36771_level0_row3\" class=\"row_heading level0 row3\" >3</th>\n", | |
" <td id=\"T_36771_row3_col0\" class=\"data row3 col0\" >0694</td>\n", | |
" <td id=\"T_36771_row3_col1\" class=\"data row3 col1\" >B</td>\n", | |
" <td id=\"T_36771_row3_col2\" class=\"data row3 col2\" >1.130000</td>\n", | |
" <td id=\"T_36771_row3_col3\" class=\"data row3 col3\" >dev</td>\n", | |
" <td id=\"T_36771_row3_col4\" class=\"data row3 col4\" >003</td>\n", | |
" <td id=\"T_36771_row3_col5\" class=\"data row3 col5\" >audio/dev/0694/B/0694_B_003.wav</td>\n", | |
" <td id=\"T_36771_row3_col6\" class=\"data row3 col6\" >0694_B_003</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th id=\"T_36771_level0_row4\" class=\"row_heading level0 row4\" >4</th>\n", | |
" <td id=\"T_36771_row4_col0\" class=\"data row4 col0\" >0694</td>\n", | |
" <td id=\"T_36771_row4_col1\" class=\"data row4 col1\" >B</td>\n", | |
" <td id=\"T_36771_row4_col2\" class=\"data row4 col2\" >1.180000</td>\n", | |
" <td id=\"T_36771_row4_col3\" class=\"data row4 col3\" >dev</td>\n", | |
" <td id=\"T_36771_row4_col4\" class=\"data row4 col4\" >004</td>\n", | |
" <td id=\"T_36771_row4_col5\" class=\"data row4 col5\" >audio/dev/0694/B/0694_B_004.wav</td>\n", | |
" <td id=\"T_36771_row4_col6\" class=\"data row4 col6\" >0694_B_004</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n" | |
], | |
"text/plain": [ | |
"<pandas.io.formats.style.Styler at 0x7fe7a30fd9a0>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"text/html": [ | |
"<style type=\"text/css\">\n", | |
"#T_e6b97 caption {\n", | |
" font-size: 20px;\n", | |
"}\n", | |
"</style>\n", | |
"<table id=\"T_e6b97\">\n", | |
" <caption>dev_pairs</caption>\n", | |
" <thead>\n", | |
" <tr>\n", | |
" <th class=\"blank level0\" > </th>\n", | |
" <th id=\"T_e6b97_level0_col0\" class=\"col_heading level0 col0\" >baby_id_B</th>\n", | |
" <th id=\"T_e6b97_level0_col1\" class=\"col_heading level0 col1\" >baby_id_D</th>\n", | |
" <th id=\"T_e6b97_level0_col2\" class=\"col_heading level0 col2\" >id</th>\n", | |
" <th id=\"T_e6b97_level0_col3\" class=\"col_heading level0 col3\" >label</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th id=\"T_e6b97_level0_row0\" class=\"row_heading level0 row0\" >0</th>\n", | |
" <td id=\"T_e6b97_row0_col0\" class=\"data row0 col0\" >0133</td>\n", | |
" <td id=\"T_e6b97_row0_col1\" class=\"data row0 col1\" >0611</td>\n", | |
" <td id=\"T_e6b97_row0_col2\" class=\"data row0 col2\" >0133B_0611D</td>\n", | |
" <td id=\"T_e6b97_row0_col3\" class=\"data row0 col3\" >0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th id=\"T_e6b97_level0_row1\" class=\"row_heading level0 row1\" >1</th>\n", | |
" <td id=\"T_e6b97_row1_col0\" class=\"data row1 col0\" >0593</td>\n", | |
" <td id=\"T_e6b97_row1_col1\" class=\"data row1 col1\" >0584</td>\n", | |
" <td id=\"T_e6b97_row1_col2\" class=\"data row1 col2\" >0593B_0584D</td>\n", | |
" <td id=\"T_e6b97_row1_col3\" class=\"data row1 col3\" >0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th id=\"T_e6b97_level0_row2\" class=\"row_heading level0 row2\" >2</th>\n", | |
" <td id=\"T_e6b97_row2_col0\" class=\"data row2 col0\" >0094</td>\n", | |
" <td id=\"T_e6b97_row2_col1\" class=\"data row2 col1\" >0292</td>\n", | |
" <td id=\"T_e6b97_row2_col2\" class=\"data row2 col2\" >0094B_0292D</td>\n", | |
" <td id=\"T_e6b97_row2_col3\" class=\"data row2 col3\" >0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th id=\"T_e6b97_level0_row3\" class=\"row_heading level0 row3\" >3</th>\n", | |
" <td id=\"T_e6b97_row3_col0\" class=\"data row3 col0\" >0563</td>\n", | |
" <td id=\"T_e6b97_row3_col1\" class=\"data row3 col1\" >0094</td>\n", | |
" <td id=\"T_e6b97_row3_col2\" class=\"data row3 col2\" >0563B_0094D</td>\n", | |
" <td id=\"T_e6b97_row3_col3\" class=\"data row3 col3\" >0</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th id=\"T_e6b97_level0_row4\" class=\"row_heading level0 row4\" >4</th>\n", | |
" <td id=\"T_e6b97_row4_col0\" class=\"data row4 col0\" >0122</td>\n", | |
" <td id=\"T_e6b97_row4_col1\" class=\"data row4 col1\" >0694</td>\n", | |
" <td id=\"T_e6b97_row4_col2\" class=\"data row4 col2\" >0122B_0694D</td>\n", | |
" <td id=\"T_e6b97_row4_col3\" class=\"data row4 col3\" >0</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n" | |
], | |
"text/plain": [ | |
"<pandas.io.formats.style.Styler at 0x7fe7f7a24e80>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"text/html": [ | |
"<style type=\"text/css\">\n", | |
"#T_d5c4c caption {\n", | |
" font-size: 20px;\n", | |
"}\n", | |
"</style>\n", | |
"<table id=\"T_d5c4c\">\n", | |
" <caption>test_pairs</caption>\n", | |
" <thead>\n", | |
" <tr>\n", | |
" <th class=\"blank level0\" > </th>\n", | |
" <th id=\"T_d5c4c_level0_col0\" class=\"col_heading level0 col0\" >baby_id_B</th>\n", | |
" <th id=\"T_d5c4c_level0_col1\" class=\"col_heading level0 col1\" >baby_id_D</th>\n", | |
" <th id=\"T_d5c4c_level0_col2\" class=\"col_heading level0 col2\" >id</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th id=\"T_d5c4c_level0_row0\" class=\"row_heading level0 row0\" >0</th>\n", | |
" <td id=\"T_d5c4c_row0_col0\" class=\"data row0 col0\" >anonymous027</td>\n", | |
" <td id=\"T_d5c4c_row0_col1\" class=\"data row0 col1\" >anonymous212</td>\n", | |
" <td id=\"T_d5c4c_row0_col2\" class=\"data row0 col2\" >anonymous027B_anonymous212D</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th id=\"T_d5c4c_level0_row1\" class=\"row_heading level0 row1\" >1</th>\n", | |
" <td id=\"T_d5c4c_row1_col0\" class=\"data row1 col0\" >anonymous035</td>\n", | |
" <td id=\"T_d5c4c_row1_col1\" class=\"data row1 col1\" >anonymous225</td>\n", | |
" <td id=\"T_d5c4c_row1_col2\" class=\"data row1 col2\" >anonymous035B_anonymous225D</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th id=\"T_d5c4c_level0_row2\" class=\"row_heading level0 row2\" >2</th>\n", | |
" <td id=\"T_d5c4c_row2_col0\" class=\"data row2 col0\" >anonymous029</td>\n", | |
" <td id=\"T_d5c4c_row2_col1\" class=\"data row2 col1\" >anonymous288</td>\n", | |
" <td id=\"T_d5c4c_row2_col2\" class=\"data row2 col2\" >anonymous029B_anonymous288D</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th id=\"T_d5c4c_level0_row3\" class=\"row_heading level0 row3\" >3</th>\n", | |
" <td id=\"T_d5c4c_row3_col0\" class=\"data row3 col0\" >anonymous001</td>\n", | |
" <td id=\"T_d5c4c_row3_col1\" class=\"data row3 col1\" >anonymous204</td>\n", | |
" <td id=\"T_d5c4c_row3_col2\" class=\"data row3 col2\" >anonymous001B_anonymous204D</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th id=\"T_d5c4c_level0_row4\" class=\"row_heading level0 row4\" >4</th>\n", | |
" <td id=\"T_d5c4c_row4_col0\" class=\"data row4 col0\" >anonymous075</td>\n", | |
" <td id=\"T_d5c4c_row4_col1\" class=\"data row4 col1\" >anonymous244</td>\n", | |
" <td id=\"T_d5c4c_row4_col2\" class=\"data row4 col2\" >anonymous075B_anonymous244D</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n" | |
], | |
"text/plain": [ | |
"<pandas.io.formats.style.Styler at 0x7fe7a30fd9a0>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"text/html": [ | |
"<style type=\"text/css\">\n", | |
"#T_0c218 caption {\n", | |
" font-size: 20px;\n", | |
"}\n", | |
"</style>\n", | |
"<table id=\"T_0c218\">\n", | |
" <caption>sample_submission</caption>\n", | |
" <thead>\n", | |
" <tr>\n", | |
" <th class=\"blank level0\" > </th>\n", | |
" <th id=\"T_0c218_level0_col0\" class=\"col_heading level0 col0\" >id</th>\n", | |
" <th id=\"T_0c218_level0_col1\" class=\"col_heading level0 col1\" >score</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th id=\"T_0c218_level0_row0\" class=\"row_heading level0 row0\" >0</th>\n", | |
" <td id=\"T_0c218_row0_col0\" class=\"data row0 col0\" >anonymous027B_anonymous212D</td>\n", | |
" <td id=\"T_0c218_row0_col1\" class=\"data row0 col1\" >0.548814</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th id=\"T_0c218_level0_row1\" class=\"row_heading level0 row1\" >1</th>\n", | |
" <td id=\"T_0c218_row1_col0\" class=\"data row1 col0\" >anonymous035B_anonymous225D</td>\n", | |
" <td id=\"T_0c218_row1_col1\" class=\"data row1 col1\" >0.715189</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th id=\"T_0c218_level0_row2\" class=\"row_heading level0 row2\" >2</th>\n", | |
" <td id=\"T_0c218_row2_col0\" class=\"data row2 col0\" >anonymous029B_anonymous288D</td>\n", | |
" <td id=\"T_0c218_row2_col1\" class=\"data row2 col1\" >0.602763</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th id=\"T_0c218_level0_row3\" class=\"row_heading level0 row3\" >3</th>\n", | |
" <td id=\"T_0c218_row3_col0\" class=\"data row3 col0\" >anonymous001B_anonymous204D</td>\n", | |
" <td id=\"T_0c218_row3_col1\" class=\"data row3 col1\" >0.544883</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th id=\"T_0c218_level0_row4\" class=\"row_heading level0 row4\" >4</th>\n", | |
" <td id=\"T_0c218_row4_col0\" class=\"data row4 col0\" >anonymous075B_anonymous244D</td>\n", | |
" <td id=\"T_0c218_row4_col1\" class=\"data row4 col1\" >0.423655</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n" | |
], | |
"text/plain": [ | |
"<pandas.io.formats.style.Styler at 0x7fe7f7a24e80>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"# read metadata\n", | |
"metadata = pd.read_csv(f'metadata.csv', dtype={'baby_id':str, 'chronological_index':str})\n", | |
"dev_metadata = metadata.loc[metadata['split']=='dev'].copy()\n", | |
"# read sample submission\n", | |
"sample_submission = pd.read_csv(f\"sample_submission.csv\") # scores are unfiorm random\n", | |
"# read verification pairs\n", | |
"dev_pairs = pd.read_csv(f\"dev_pairs.csv\", dtype={'baby_id_B':str, 'baby_id_D':str})\n", | |
"test_pairs = pd.read_csv(f\"test_pairs.csv\")\n", | |
"\n", | |
"display(metadata.head().style.set_caption(f\"metadata\").set_table_styles([{'selector': 'caption','props': [('font-size', '20px')]}]))\n", | |
"display(dev_pairs.head().style.set_caption(f\"dev_pairs\").set_table_styles([{'selector': 'caption','props': [('font-size', '20px')]}]))\n", | |
"display(test_pairs.head().style.set_caption(f\"test_pairs\").set_table_styles([{'selector': 'caption','props': [('font-size', '20px')]}]))\n", | |
"display(sample_submission.head().style.set_caption(f\"sample_submission\").set_table_styles([{'selector': 'caption','props': [('font-size', '20px')]}]))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "9896df28-2a89-4a6b-97fb-f928553f6836", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"encoder = SpeakerRecognition.from_hparams(\n", | |
" source=\"Ubenwa/ecapa-voxceleb-ft-cryceleb\",\n", | |
" savedir=f\"ecapa-voxceleb-ft-cryceleb\",\n", | |
" run_opts={\"device\":\"cuda\"} #comment out if no GPU available\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "61824528-fb66-4201-835f-9ef80d02d115", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def shuffle_group_and_concat(x, n=5):\n", | |
" concatenated_results = []\n", | |
" for _ in range(n):\n", | |
" shuffled_values = x.values.copy()\n", | |
" random.shuffle(shuffled_values)\n", | |
" concatenated = np.concatenate(shuffled_values)\n", | |
" concatenated_results.append(concatenated)\n", | |
" return concatenated_results" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "175d6912-30f2-4678-93e7-96caa032f89b", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def compute_cosine_similarity_score(row, cry_dict):\n", | |
" \"\"\" Average scores for all possible pairs \"\"\"\n", | |
" cos = torch.nn.CosineSimilarity(dim=-1)\n", | |
" encoded_cry_B = cry_dict[(row['baby_id_B'], 'B')]['cry_encoded']\n", | |
" encoded_cry_D = cry_dict[(row['baby_id_D'], 'D')]['cry_encoded']\n", | |
" \n", | |
" similarity_scores = []\n", | |
" for tensor_B in encoded_cry_B:\n", | |
" for tensor_D in encoded_cry_D:\n", | |
" similarity_score = cos(tensor_B, tensor_D)\n", | |
" similarity_scores.append(similarity_score.item())\n", | |
" \n", | |
" return sum(similarity_scores) / len(similarity_scores)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "b4358175-b830-44dd-8627-6ecf4aba7b4d", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "4c6d959ad6cc4cc3baeb852068f645cb", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/80 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/opt/conda/lib/python3.8/site-packages/torch/functional.py:641: UserWarning: stft with return_complex=False is deprecated. In a future pytorch release, stft will return complex tensors for all inputs, and return_complex=False will raise an error.\n", | |
"Note: you can still call torch.view_as_real on the complex output to recover the old return format. (Triggered internally at ../aten/src/ATen/native/SpectralOps.cpp:862.)\n", | |
" return _VF.stft(input, n_fft, hop_length, win_length, window, # type: ignore[attr-defined]\n" | |
] | |
} | |
], | |
"source": [ | |
"dev_metadata = metadata.loc[metadata['split']=='dev'].copy()\n", | |
"dev_metadata['cry'] = dev_metadata.apply(lambda row: read_audio(row['file_name']).numpy(), axis=1)\n", | |
"grouped_data = dev_metadata.groupby(['baby_id', 'period'])['cry']\n", | |
"cry_dict = {}\n", | |
"for key, group in grouped_data:\n", | |
" cry_dict[key] = {'cry': shuffle_group_and_concat(group, 5)}\n", | |
" \n", | |
"for (baby_id, period), d in tqdm(cry_dict.items()):\n", | |
" cry_array = d['cry']\n", | |
" cry_encoded_list = []\n", | |
"\n", | |
" for row in cry_array:\n", | |
" encoded_row = encoder.encode_batch(torch.tensor(row), normalize=False)\n", | |
" cry_encoded_list.append(encoded_row)\n", | |
"\n", | |
" d['cry_encoded'] = cry_encoded_list\n", | |
"dev_pairs['score'] = dev_pairs.apply(lambda row: compute_cosine_similarity_score(row=row, cry_dict=cry_dict), axis=1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "a1dc1403-aa80-44f6-a063-2e4134eec9d1", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAX4AAAEWCAYAAABhffzLAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAhAklEQVR4nO3deZgdVZ3/8fcHEmiQsGUD6YQAgUQSQuSXqEMQUBYjIIswIwpKfgEzRmBg4OewzqAjjw8gOrJGI2BAQHSiGRGVRSBkWASSGAJhVYGkszYhLGFNd76/P6pCbi59u2+6+1Z1d31ez1PPrVunbp3vre7+9rmnzj2liMDMzIpjk7wDMDOzbDnxm5kVjBO/mVnBOPGbmRWME7+ZWcE48ZuZFYwTv1kNSPq2pJszqGeCpAdrXY/1LE78PZyklyS9I2l1yXJ1WjZBUnNZ2WpJH23htcskTZO0VTtikKRLJa1Ml0slqcK+h0t6UNJraZ3XSepTUn65pBckvSnpWUlfK3t9SHqr5L1cV1a+j6RZadlySWds7PtJj1N6vtaWneMT2nPMPEnaXNINkt5Iz/tZrew7UtJdkl6R9KEvAknaXtKM9OfwsqSvlJSdX3bu3knPX7+0fJqk98v22bQ277q4nPiL4QsRsVXJclpJ2SNlZVtFxJLy1wKjgY8D57Wj/knA0cDewCjgC8A/V9h3G+Bi4KPAx4CdgO+XlL+Vvn4b4CTgCkn7lh1j75L3csq6jWlyuRP4CdAXGArc3Y73Q+n5Ahay4Tm+ZWOOJalXe2LoZN8Gdgd2Bj4D/Juk8RX2XQP8Cji5Qvk1wPvAQOAEYIqkEQAR8b2yc3cpMDMiXil5/WVlv4/NHX1ztiEnfqtKRCwD7iL5B7CxTgJ+EBENEbEY+AEwoUI9t0bEnRHxdkSsAn4KjCspvygino2ItRHxKPC/wD9UGcdZwF0RcUtEvBcRb0bEM+14P9XaTNJN6aeTBZLGrCtIP02dI2k+8JakXpI+Jenh9NPOE5IOLNl/gqS/p8d6sfxTRfpJaFVa9vl2xHoS8N2IWJWek59S+Wf0XERcDywoL5P0EeBY4N8jYnVEPAjcDny1hX0FfA24sR3xWgc48VtVJNUDnwf+WrLt3DRJtbiUvHwE8ETJ8yfSbdXYnxYSTFr/FsDYFspnpd0Vv5E0pGT7p4BX0+S6QtLvJA2uMo72OBK4DdiWJPldXVb+ZeDwtHwg8HuSTzvbA/8P+LWk/mkyvRL4fET0AfYF5pUc55PAc0A/4DLg+nVdaZKubeVnND/dZztgR9r/Myq1B9AUEc9XcaxPAwOAX5dt/6akVyXNkXRsO2KwtkSElx68AC8Bq4HXSpavp2UTgKaysr+18No3gQDuBbZtRwzNwPCS57unx1MbrzsEWAXsUaH8RpKuG5Vs2x/YjCSZXg08BfRKy55P3+NYoI4kmT7USef44LJt3wb+VPJ8T+CdstdMLHl+DvDzsmPcRdIS/0ga97HAFmX7TAD+WvJ8y/Tc7rAR8Q9KX1NXdu5fauN1Q5MUssG2TwPLyrZ9naQ7p/z11wPTyrbtQ9IN1ws4LP3dG5fX309PXdziL4ajI2LbkuWnJWV/LivbrYXX9gEOBIaTtCo31mpg65LnWwOrI/1Lb4mkTwG3AsfFhq3HdeXfB0YC/1R6nIiYFRHvR8RrwBnALiTXCgDeAWZExOMR8S7wHWBfSdu0cPwfl1xcPH8j3+86y0rW3wbqyvrzF5Ws7wz8Y9knpv2AHSPiLeBLwDeApZJ+L2l4S/VExNvp6sZchF+dPpb/jN7ciGOUHmvrsm0fOpakLYF/pKybJyLmRsTKiGiKiD8AtwBfbEcc1gonfqtKRDwATAMuX7ethREaGywlL19AcmF3nb2p0H2THvfjJF0jEyPi3hbKv0PS7XRoRLzRVujAuhFE89PnpWUtvyjiG7H+4uL32qijvUrrX0TS4i/9J/yRiLgkjeeuiDiEpEvmWZI++DaV/QMrXxakx14FLGUjfkateB7oJWn3No51DPAqMLON45X+/KyTOPHbxvgRcIikveHDIzTKl5LX3QScJWknJUNFzyb5J/IhkkaSdN+cHhG/a6H8POArJF0rK8vKRkgaLWlTJcNOfwAsBtZdwP0ZcEy6T2/g34EHI+L19p6QTnQz8AVJn0vjr5N0oKR6SQMlHZX29b9H0qpeW81By/6BlS+l/e43ARdK2i79NPF1Kv+MJKmOpEuNNNbN0/reAn4D/Kekj0gaBxwF/LzsMCcBN5V/6pN0nKStJG0i6VDgRJJGgHWmvPuavNR2IelLfockWaxbZqRlE0j631eXLWNLXlvedz0F+PVGxiCSi46vpstlbNgvvxr4dLr+M5KkVhrPgpJ9g/XJb91yflr2WZKLnG8BK4D/AXYvi2UyyT+DVcDvgEGddI5b6uO/ueT5kDT2Xq285pPAA+k5aiS52DuYpJX/APA6SV//TGDPkp/hg2XHCWDoRr6HzYEbgDeA5cBZJWWD0/M8uOy9lC4vley/fXru3yIZ6vqVsrp2Irm29KEYSUZpvZ7G8QRwfN5/Qz1xUXqyzcysINzVY2ZWME78ZmYF48RvZlYwTvxmZgXTFSaHalO/fv1iyJAheYdhPdlzzyWPw4blG4dZJ5ozZ84rEdG/fHu3SPxDhgxh9uzZeYdhPdmBByaPM2fmGYVZp5L0ckvb3dVjZlYw3aLFb1ZzF16YdwRmmXHiNwM4+OC8IzDLjBO/GcC8ecnj6NF5RmFdzJo1a2hoaODdd9/NO5RW1dXVUV9fT+/evava34nfDODMM5NHX9y1Eg0NDfTp04chQ4aglm8TnbuIYOXKlTQ0NLDLLrtU9Rpf3DUzq+Ddd9+lb9++XTbpA0iib9++G/WpxInfzKwVXTnpr7OxMTrxm5kVjBO/mVkHbbVV63e6fOmllxg5cuRGHXPChAlMnz69I2FV5Iu7ZgDfS+6ueNAB42hcvqzibv0H7sC9DzyUVVRmNeHEbwaw774ANC5fxvwpp1TcbdTk67KKyLqh1atXc9RRR7Fq1SrWrFnDxRdfzFFHHQVAU1MTJ5xwAnPnzmXEiBHcdNNNbLnllsyZM4ezzjqL1atX069fP6ZNm8aOO+5Y0zjd1WMG8PDDyWLWAXV1dcyYMYO5c+dy//33c/bZZ6+7pSTPPfcc3/zmN3nmmWfYeuutufbaa1mzZg2nn34606dPZ86cOUycOJELLrig5nG6xW8GcP75eUdgPUBEcP755zNr1iw22WQTFi9ezPLlywEYNGgQ48aNA+DEE0/kyiuvZPz48Tz11FMccsghADQ3N9e8tQ9O/GZmneaWW26hsbGROXPm0Lt3b4YMGfLB+PryIZeSiAhGjBjBI488kmmcNevqkXSDpBWSnirZ9n1Jz0qaL2mGpG1rVb+ZWdZef/11BgwYQO/evbn//vt5+eX1syIvXLjwgwR/6623st9++zFs2DAaGxs/2L5mzRoWLFhQ8zhr2cc/DRhftu0eYGREjAKeB86rYf1mZpk64YQTmD17NnvttRc33XQTw4cP/6Bs2LBhXHPNNXzsYx9j1apVTJ48mc0224zp06dzzjnnsPfeezN69GgezuBaU826eiJilqQhZdvuLnn6Z+C4WtVvZpaV1atXA9CvX7+K3TbPPvtsi9tHjx7NrFmzPrR92rRpnRZfuTz7+CcCv6xUKGkSMAlg8ODBWcVkRfWjHyWPxx+baxhmWcgl8Uu6AGgCbqm0T0RMBaYCjBkzJjIKzYrK0zFbgWSe+CVNAI4ADop1A1zN8vanP+UdgVlmMk38ksYD/wYcEBFvZ1m3WasuvjjvCMwyU8vhnL8AHgGGSWqQdDJwNdAHuEfSPEk/rlX9ZmbWslqO6vlyC5uvr1V9ZmZWHc/VY2ZWpUGDd0ZSpy2DBu9cVb133nknw4YNY+jQoVxyySUdfh+essHMrEoNixbyw7uf67TjnXXosDb3aW5u5tRTT+Wee+6hvr6esWPHcuSRR7Lnnnu2u163+M0AfvKTZDHrYh577DGGDh3Krrvuymabbcbxxx/Pb3/72w4d04nfDGDYsGQx62IWL17MoEGDPnheX1/P4sWLO3RMd/WYAfzud3lHYJYZJ34zgB/8IO8IzFq00047sWjRog+eNzQ0sNNOO3XomO7qMTPrwsaOHcsLL7zAiy++yPvvv89tt93GkUce2aFjusVvZlal+kGDqxqJszHHa0uvXr24+uqr+dznPkdzczMTJ05kxIgRHarXid/MrEqLFr7c9k41cNhhh3HYYYd12vHc1WNmVjBu8ZsB/PznyeMhB+YahlkWnPjNAErGSZv1dE78ZgC/rHgzOLMex4nfDGDKlLwjMMuML+6amRWME7+ZWZWGDK7v1GmZhwyur6reiRMnMmDAAEaOHNkp78NdPWZmVXp50WLivu912vH02fOr2m/ChAmcdtppfO1rX+uUet3iNzPr4vbff3+23377TjueW/xmANOnJ4/7fTLfOMwy4MRvBtCvX94RmGXGid8MYNq0vCMwy4wTvxk48VuhOPGbmVVp50E7VT0Sp9rjVePLX/4yM2fO5JVXXqG+vp7vfOc7nHzyye2u14nfzKxKLy1syKXeX/ziF516vJoN55R0g6QVkp4q2ba9pHskvZA+bler+s3MrGW1HMc/DRhftu1c4N6I2B24N31uZmYZqllXT0TMkjSkbPNRwIHp+o3ATOCcWsVg3cu4/Q9k+YrGiuUDB/TnoVkza1P5H/6QPO6zV22Ob91WRCAp7zBaFREbtX/WffwDI2Jpur4MGFhpR0mTgEkAgwe3fV9K6/6Wr2jk1KtmVCy/5vRjalf5lltWtdviJUsYNXy3iuX9B+7AvQ881KFQDjpgHI3Ll9W0DqtOXV0dK1eupG/fvl02+UcEK1eupK6ururX5HZxNyJCUsV/UxExFZgKMGbMmI37d2a2sa69tqrdormZ+VNOqVg+avJ1HQ6lcfmymtdh1amvr6ehoYHGxsqfRLuCuro66uurm/ANsk/8yyXtGBFLJe0IrMi4frOW/epXeUdgXVDv3r3ZZZdd8g6j02U9SdvtwEnp+knAbzOu38ys8Go5nPMXwCPAMEkNkk4GLgEOkfQCcHD63MzMMlTLUT1frlB0UK3qNDOztnk+fjOzgvGUDWYAM2cmj60M1TTrKdziNzMrGLf4zQAuvzzvCMwy48RvBnDHHXlHYJYZd/WYmRWME7+ZWcE48ZuZFYz7+M0Attgi7wjMMuPEbz1Ku+f0/+Mfk0eP47cCcOK3HiXXOf3NugknfjOA73437wjMMuPEbwZw7715R2CWGSd+6zYWL17M0OEjWt1nyZIlGUVj1n058Vu30bx2bav99wDnHD0mo2jMui8nfstMWyNu3Fo3y4YTv2WmrRE3ubbW+/ZNHpctyi8Gs4w48ZsB/PrXyaPH8VsBeMoGM7OCcYvfDOC88/KOwCwzTvxmAI88kncEZplxV4+ZWcE48ZuZFYwTv5lZweSS+CX9q6QFkp6S9AtJdXnEYfaB+vpkMSuAzBO/pJ2AfwHGRMRIYFPg+KzjMNvAzTcni1kB5NXV0wvYQlIvYEvA39U3M8tI5sM5I2KxpMuBhcA7wN0RcXf5fpImAZMABg8enG2QVjxnnpl3BGaZyaOrZzvgKGAX4KPARySdWL5fREyNiDERMaZ///5Zh2lFM29espgVQB5dPQcDL0ZEY0SsAX4D7JtDHGZmhZRH4l8IfErSlpIEHAQ8k0McZmaFlEcf/6OSpgNzgSbgL8DUrOOwjdPWXPoAAwf056FZM7MJyMzaLZe5eiLiIuCiPOq29mlrLn2Aa04/JqNoamCPPZJHz8dvBeBJ2swApqYfOj0fvxWAp2wwMyuYqlr8ksZFxENtbTPrtiZNyjsCs8xU29VzFbBPFdvMctXc1MyUa66sWP7ySy8yqoXunOsXJl8eX5pB5+dBB4yjcfmyiuVLfdN5q7FWf80l/QPJGPv+ks4qKdqaZI4dsy4mmHx45fbIt/54PfOnnPLhgn/9KQDNLzTUKrAPNC5f1nIMqe2P8LgHq6222jebAVul+/Up2f4GcFytgrLuafHixQwdPqJi+RK3ZM26hFYTf0Q8ADwgaVpEvJxRTNZNNa9d2+qQz3OOHpNhNGZWSbU9mptLmgoMKX1NRHy2FkGZ1craCIZOuOpD2y9Y9ToAbzYp65DMMldt4v9v4MfAdUBz7cIxqzFtwqkXXvyhza+lj3HaqZmGY5aHahN/U0RMqWkkZmaWiWoT/+8kfROYAby3bmNEvFqTqMwyNv76ZFTPt3KOwywL1Sb+k9LH0r+LAHbt3HDM8rHVqlV5h2CWmaoSf0TsUutAzMwsG1XN1ZPOnX9hOrIHSbtLOqK2oZmZWS1UO0nbz4D3WX+nrMXAh4dGmJlZl1dt4t8tIi4D1gBExNuABzxbj7F0191YuqunZLZiqPbi7vuStiC5oIuk3SgZ3WPW3T30xWOTlfvuyzcQswxUm/gvAu4EBkm6BRgHTKhVUGZmVjvVjuq5R9Jc4FMkXTxnRMQrNY3MLENHTLkG8Dh+K4ZqR/UcQ/Lt3d9HxB1Ak6SjaxqZWYbq3nqLurfeyjsMs0xUe3H3ooh4fd2TiHgN3yzdzKxbqjbxt7Sfb9RuZtYNVZv4Z0v6oaTd0uWHwJxaBmZmZrVRbeI/neQLXL8EbgPeBTx/rfUYi4Z/jEXDP5Z3GGaZaLO7RtKmwB0R8ZnOqlTStiRz+48k+W7AxIh4pLOOb7axHj3iC8nKnXfmG4hZBtpM/BHRLGmtpG1KL/B20BXAnRFxnKTNgC076bhmZtaGai/QrgaelHQP8MGYt4j4l42tUNI2wP6kXwCLiPdJupHMcnP0Ff8FeBy/FUO1if836dIZdgEagZ9J2pvkIvEZEbHBIGpJk4BJAIMHD+6kqq2ScfsfyPIVjRXLlyxZkmE02eu1Zk3eIZhlptpv7t6YztUzOCKe64Q69wFOj4hHJV0BnAv8e1mdU4GpAGPGjIkO1mltWL6ikVOvmlGx/Jyjx2QYjZnVUrXf3P0CMI9kvh4kjZZ0ezvrbAAaIuLR9Pl0kn8EZmaWgWq7er4NfAKYCRAR8yS167aLEbFM0iJJw9JPDwcBT7fnWGadbW0EQydcVbH8zSbPRm7dX7WJf01EvC5t8Eu/tgP1ng7cko7o+TvwfztwLLMOe3HUqGTlpYWcemHlewx96zR/fcW6v2oT/wJJXwE2lbQ78C/Aw+2tNCLmAe40ti5jzqHjk5Xb78g3ELMMbMw3d0eQ3HzlVuB14MwaxWRmZjXUaotfUh3wDWAo8CTwDxHRlEVgZlk67vLLADg75zjMstBWi/9Gki6ZJ4HPA5fXPCIzM6uptvr494yIvQAkXQ88VvuQzMysltpq8X/wdUZ38ZiZ9Qxttfj3lvRGui5gi/S5gIiIrWsanZmZdbpWE39EbJpVIGZ5en5MOrr4pYX5BmKWAd8+0QyYf+Bnk5XpnTUXoVnX5cRvBvR67728QzDLjBO/GXD0VVcAcEYb+7U1l8+yV1d3YlRmteHEb7YxtEmrc/mcd2Zb/zrM8ufEb11Gc1MzU665smJ5U1NTq+WQ3MC5p1u8ZAmjhu9WsXxFYyMD+vdv9Rj9B+7AvQ881NmhfeCgA8bRuHxZbvVb65z4rQsJJh9e+dYMZ//xhlbL1+3T00VzM/OnnFKxfPsjLmq1HGDU5Os6O6wNNC5f1moMta7fWufEXxBFv7Wima3nxF8QvrVi657ed99kxeP4rQCqnZbZrEd7et/9eHrf/fIOwywTbvGbAXVvvpl3CGaZceI3A474yRQAfGNFKwJ39ZiZFYwTv5lZwTjxm5kVjBO/mVnB+OKuGTD/gAOTlZduyjUOsyy4xW8GPD/2Ezw/9hN5h2GWidxa/JI2BWYDiyPiiLziMAPY6tVX8w7BLDN5dvWcATwD+L69lrvxN3jSMCuOXLp6JNUDhwP+azMzy1heffw/Av4NWJtT/WZmhZV54pd0BLAiIua0sd8kSbMlzW5srDydsJmZbZw8WvzjgCMlvQTcBnxW0s3lO0XE1IgYExFj+rdxNyEzM6te5ok/Is6LiPqIGAIcD9wXESdmHYdZqbmHHMrcQw7NOwyzTPgLXGbA3/cenXcIZpnJNfFHxExgZp4xmAFst6zyjcHNehq3+M2Ag272VA1WHJ6ywcysYJz4zcwKxonfzKxgnPjNzArGid8MeOywI3jsME8Sa8XgUT1mwMI998w7BLPMOPGbAf0XLcw7BLPMOPH3EOP2P5DlKypPZrdkyZIMo+l+DvjlbXmHYJYZJ/4eYvmKRk69akbF8nOOHpNhNGbWlfnirplZwTjxm5kVjBO/mVnBOPGbAQ8d80UeOuaLeYdhlglf3DUDlu42tFOOs6apiaHDR7S6z7LFvpWo5cuJ3wzY8W9/7aQjqdXRVQDnHf1/Oqkus/Zx4jcDxs34Td4hmGXGffxmZgXjxG9mVjDu6jHrRAFMuebKVvdpamrOJphWLF6yhFHDd6tYvqKxkQH9+7e7fGkbU4S0VT9A/4E7cO8DD1UsP+iAcTQur3yv5I6+vppjdFReMTjxm3WyyYfv02r5t/54fUaRVBbNzcyfckrF8u2PuKjD5R2pH2DU5OtaLW9cvqzVY3T09dUco6PyisGJ3wx44EvHJyuXXpZvIGYZcOI3AxoHDc47BLPMOPGbAYOffjqzutZGMHTCVRXL32xSZrFYMTnxdxOeb7+2PvGHO7KrTJtw6oUXVyz+1mmnZheLFVLmiV/SIOAmYCDJIIipEXFF1nF0N55v38w6Sx4t/ibg7IiYK6kPMEfSPRGR3WdtM7MCy/wLXBGxNCLmputvAs8AO2Udh5lZUeXaxy9pCPBx4NEWyiYBkwAGD/aICysOX/y1Wsst8UvaCvg1cGZEvFFeHhFTgakAY8aMiYzDs4K598SvJSsXV77omhlf/LUayyXxS+pNkvRviQhPi2i5W7XDDnmHYJaZPEb1CLgeeCYifph1/WYt2fWJeXmHYJaZPFr844CvAk9KmpduOz8i/pBDLGYA7HPP3XmHYJaZzBN/RDwI+OqUmVlOPB+/mVnBOPGbmRWME7+ZWcF4kjYz4M6J6c0w/uM/8g2kCm19weu1NZu0Wg7+EljROfGbAau33z7vEKrXxhe8zj7ttFbLwV8CKzonfjNgj8cfyzsEs8w48ZsBox6YmXcIZpnxxV0zs4Jxiz8Dbd09q3HFCvoPGNDqMXyHLTPrLE78Gajm7lmtla/bx8ysM7irx8ysYNziNwPu+OfJycp55+UbiFkGnPjNgHf79Mk7BLPMOPGbAXs+/GDeIZhlxonfDNjz4YfzDsEsM764a2ZWME78ZmYF48RvZlYwTvxmZgXji7tmwP+cfkaycvbZ+QZilgEnfjOgafPN8w7BLDNO/GbAqJn35R1Cpjp6Fy/fwat7c+I3A/aYPTvvELLVwbt4+Q5e3ZsTfydoa9plT6lsZl1JLolf0njgCmBT4LqIuCSPODpLNdMum/UkbXUVuSuoa8s88UvaFLgGOARoAB6XdHtEPJ11LGbWTm10FbkrqGvLo8X/CeCvEfF3AEm3AUcBNUn8nXH3q7b2cVeO2Yb8iaBrU0RkW6F0HDA+Ik5Jn38V+GREnFa23yRgUvp0GPBcpoHWTj/glbyD6CJ8LtbzuVjP52K9jp6LnSOif/nGLntxNyKmAlPzjqOzSZodEe70x+eilM/Fej4X69XqXOQxZcNiYFDJ8/p0m5mZZSCPxP84sLukXSRtBhwP3J5DHGZmhZR5V09ENEk6DbiLZDjnDRGxIOs4ctTjuq86wOdiPZ+L9Xwu1qvJucj84q6ZmeXL0zKbmRWME7+ZWcE48deYpO0l3SPphfRxuxb2GS3pEUkLJM2X9KU8Yq0FSeMlPSfpr5LObaF8c0m/TMsflTQkhzAzUcW5OEvS0+nvwL2Sds4jziy0dS5K9jtWUkjqscM7qzkXkv4p/d1YIOnWDlcaEV5quACXAeem6+cCl7awzx7A7un6R4GlwLZ5x94J731T4G/ArsBmwBPAnmX7fBP4cbp+PPDLvOPO8Vx8BtgyXZ9c5HOR7tcHmAX8GRiTd9w5/l7sDvwF2C59PqCj9brFX3tHATem6zcCR5fvEBHPR8QL6foSYAXwoW/bdUMfTM8REe8D66bnKFV6fqYDB0nqid/nb/NcRMT9EfF2+vTPJN9x6Ymq+b0A+C5wKfBulsFlrJpz8XXgmohYBRARKzpaqRN/7Q2MiKXp+jJgYGs7S/oEyX/+v9U6sAzsBCwqed6Qbmtxn4hoAl4H+mYSXbaqORelTgb+WNOI8tPmuZC0DzAoIn6fZWA5qOb3Yg9gD0kPSfpzOrtxh3TZKRu6E0l/AnZooeiC0icREZIqjp+VtCPwc+CkiFjbuVFadyHpRGAMcEDeseRB0ibAD4EJOYfSVfQi6e45kORT4CxJe0XEax05oHVQRBxcqUzSckk7RsTSNLG3+DFN0tbA74ELIuLPNQo1a9VMz7FunwZJvYBtgJXZhJepqqYqkXQwSYPhgIh4L6PYstbWuegDjARmpr1+OwC3SzoyInrardKq+b1oAB6NiDXAi5KeJ/lH8Hh7K3VXT+3dDpyUrp8E/LZ8h3TqihnATRExPcPYaq2a6TlKz89xwH2RXsHqYdo8F5I+DvwEOLIz+nG7sFbPRUS8HhH9ImJIRAwhud7RE5M+VPc38j8krX0k9SPp+vl7Ryp14q+9S4BDJL0AHJw+R9IYSdel+/wTsD8wQdK8dBmdS7SdKO2zXzc9xzPAryJigaT/lHRkutv1QF9JfwXOIhn51ONUeS6+D2wF/Hf6O9Aj57Cq8lwUQpXn4i5gpaSngfuBb0VEhz4Ve8oGM7OCcYvfzKxgnPjNzArGid/MrGCc+M3MCsaJ38ysYJz4zcwKxonfrEbSbyKbdTlO/GYlJH1E0u8lPSHpKUlfkjRW0sPptsck9ZFUJ+lnkp6U9BdJn0lfP0HS7ZLuA+5Nj3dD+rq/SGppFkqzTLlFYrah8cCSiDgcQNI2JHOhfykiHk/nVHoHOINk3r29JA0H7pa0R3qMfYBREfGqpO+RTEMxUdK2wGOS/hQRb2X9xszWcYvfbENPkkyxcamkTwODgaUR8ThARLyRfs1+P+DmdNuzwMskc6gA3BMRr6brhwLnSpoHzATq0mOa5cYtfrMSEfF8Ohf8YcDFwH3tOExpa17AsRHxXGfEZ9YZ3OI3KyHpo8DbEXEzyaRpnwR2lDQ2Le+TXrT9X+CEdNseJK34lpL7XcDp6+4qls7AaZYrt/jNNrQX8H1Ja4E1JPe+FXCVpC1I+vcPBq4Fpkh6EmgCJkTEey3cNfK7wI+A+ekNRl4EjsjijZhV4tk5zcwKxl09ZmYF48RvZlYwTvxmZgXjxG9mVjBO/GZmBePEb2ZWME78ZmYF8/8BioNDYvaUw1EAAAAASUVORK5CYII=\n", | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"def compute_eer_and_plot_verification_scores(pairs_df):\n", | |
" ''' pairs_df must have 'score' and 'label' columns'''\n", | |
" positive_scores = pairs_df.loc[pairs_df['label']==1]['score'].values\n", | |
" negative_scores = pairs_df.loc[pairs_df['label']==0]['score'].values\n", | |
" eer, threshold = EER(torch.tensor(positive_scores), torch.tensor(negative_scores))\n", | |
" ax = sns.histplot(pairs_df, x='score', hue='label', stat='percent', common_norm=False)\n", | |
" ax.set_title(f'EER={round(eer, 4)} - Thresh={round(threshold, 4)}')\n", | |
" plt.axvline(x=[threshold], color='red', ls='--');\n", | |
" return eer, threshold\n", | |
"eer, threshold = compute_eer_and_plot_verification_scores(pairs_df=dev_pairs)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "6d371a14-22b0-4e71-9edf-b7a2b76567c5", | |
"metadata": {}, | |
"source": [ | |
"## Test" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "e1ad25d3-7757-4842-a61e-275c7c0ca222", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "833ed79f6b2e48d799c190cee88b0dc2", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
" 0%| | 0/320 [00:00<?, ?it/s]" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"test_metadata = metadata.loc[metadata['split']=='test'].copy()\n", | |
"test_metadata['cry'] = test_metadata.apply(lambda row: read_audio(row['file_name']).numpy(), axis=1)\n", | |
"grouped_data = test_metadata.groupby(['baby_id', 'period'])['cry']\n", | |
"cry_dict_test = {}\n", | |
"for key, group in grouped_data:\n", | |
" cry_dict_test[key] = {'cry': shuffle_group_and_concat(group, 5)}\n", | |
"for (baby_id, period), d in tqdm(cry_dict_test.items()):\n", | |
" cry_array = d['cry']\n", | |
" cry_encoded_list = []\n", | |
"\n", | |
" for row in cry_array:\n", | |
" encoded_row = encoder.encode_batch(torch.tensor(row), normalize=False)\n", | |
" cry_encoded_list.append(encoded_row)\n", | |
"\n", | |
" d['cry_encoded'] = cry_encoded_list\n", | |
"test_pairs['score'] = test_pairs.apply(lambda row: compute_cosine_similarity_score(row=row, cry_dict=cry_dict_test), axis=1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 90, | |
"id": "206bb9f2-528a-4fa2-9a54-08328c605d0c", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>baby_id_B</th>\n", | |
" <th>baby_id_D</th>\n", | |
" <th>id</th>\n", | |
" <th>score</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>anonymous027</td>\n", | |
" <td>anonymous212</td>\n", | |
" <td>anonymous027B_anonymous212D</td>\n", | |
" <td>-0.149320</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>anonymous035</td>\n", | |
" <td>anonymous225</td>\n", | |
" <td>anonymous035B_anonymous225D</td>\n", | |
" <td>0.058973</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>anonymous029</td>\n", | |
" <td>anonymous288</td>\n", | |
" <td>anonymous029B_anonymous288D</td>\n", | |
" <td>0.010437</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>anonymous001</td>\n", | |
" <td>anonymous204</td>\n", | |
" <td>anonymous001B_anonymous204D</td>\n", | |
" <td>-0.087753</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>anonymous075</td>\n", | |
" <td>anonymous244</td>\n", | |
" <td>anonymous075B_anonymous244D</td>\n", | |
" <td>0.040881</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" baby_id_B baby_id_D id score\n", | |
"0 anonymous027 anonymous212 anonymous027B_anonymous212D -0.149320\n", | |
"1 anonymous035 anonymous225 anonymous035B_anonymous225D 0.058973\n", | |
"2 anonymous029 anonymous288 anonymous029B_anonymous288D 0.010437\n", | |
"3 anonymous001 anonymous204 anonymous001B_anonymous204D -0.087753\n", | |
"4 anonymous075 anonymous244 anonymous075B_anonymous244D 0.040881" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"display(test_pairs.head())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 91, | |
"id": "23404c1f-a4ee-46cc-9a13-65aa2fd057db", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>id</th>\n", | |
" <th>score</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>anonymous027B_anonymous212D</td>\n", | |
" <td>-0.149320</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>anonymous035B_anonymous225D</td>\n", | |
" <td>0.058973</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>anonymous029B_anonymous288D</td>\n", | |
" <td>0.010437</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>anonymous001B_anonymous204D</td>\n", | |
" <td>-0.087753</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>anonymous075B_anonymous244D</td>\n", | |
" <td>0.040881</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" id score\n", | |
"0 anonymous027B_anonymous212D -0.149320\n", | |
"1 anonymous035B_anonymous225D 0.058973\n", | |
"2 anonymous029B_anonymous288D 0.010437\n", | |
"3 anonymous001B_anonymous204D -0.087753\n", | |
"4 anonymous075B_anonymous244D 0.040881" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"#submission must match the 'sample_submission.csv' format exactly\n", | |
"my_submission= test_pairs[['id', 'score']]\n", | |
"my_submission.to_csv('my_submission.csv', index=False)\n", | |
"display(my_submission.head())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "330d31a0-469b-4cfc-bbb3-c00e6c7970f2", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.12" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment