Last active
June 11, 2023 14:43
-
-
Save avidale/de8145a75331be2e3b6d9b6729b711f1 to your computer and use it in GitHub Desktop.
laser_2_3_speech.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"toc_visible": true, | |
"authorship_tag": "ABX9TyMf0rQoBkligPBHCBVLz63m", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/avidale/de8145a75331be2e3b6d9b6729b711f1/laser_2_3_speech.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"This notebook shows how to apply LASER (v2, v3 and speech embeddings) with minimal code. \n", | |
"\n", | |
"Based on https://github.com/facebookresearch/laser" | |
], | |
"metadata": { | |
"id": "aUBI3E_mywlX" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Step 0: install the Python packages\n", | |
"\n", | |
"to run LASER, we need the code from the LASER repository and also Fairseq and Sentencepiece." | |
], | |
"metadata": { | |
"id": "054fvVpBzhGm" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!pip install fairseq sentencepiece sacremoses" | |
], | |
"metadata": { | |
"id": "eqFM19Zz0sd-" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!git clone https://github.com/facebookresearch/LASER" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "JLwBuv2JzkF6", | |
"outputId": "b7872171-f2c3-44d4-8ab4-d0b26403b779" | |
}, | |
"execution_count": 2, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Cloning into 'LASER'...\n", | |
"remote: Enumerating objects: 938, done.\u001b[K\n", | |
"remote: Counting objects: 100% (151/151), done.\u001b[K\n", | |
"remote: Compressing objects: 100% (94/94), done.\u001b[K\n", | |
"remote: Total 938 (delta 70), reused 125 (delta 55), pack-reused 787\u001b[K\n", | |
"Receiving objects: 100% (938/938), 2.90 MiB | 5.92 MiB/s, done.\n", | |
"Resolving deltas: 100% (373/373), done.\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import os\n", | |
"import sys\n", | |
"LASER_MODELS_DIR = os.path.abspath('.')\n", | |
"LASER_CODE_DIR = os.path.abspath('./LASER')\n", | |
"USE_CUDA = False" | |
], | |
"metadata": { | |
"id": "v_H804xu0EWD" | |
}, | |
"execution_count": 3, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"LASER is not a Python package, so to \"install\" it, we just add a path to its directory to Python. " | |
], | |
"metadata": { | |
"id": "q0AQeQUWzr2M" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"os.environ['LASER'] = LASER_CODE_DIR\n", | |
"sys.path.append(os.path.join(LASER_CODE_DIR, 'source'))" | |
], | |
"metadata": { | |
"id": "yeFjaG5Uzto6" | |
}, | |
"execution_count": 4, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Text encoders" | |
], | |
"metadata": { | |
"id": "LVo_OELQ4R0m" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Step 1: download the models.\n", | |
"\n", | |
"The first one (LASER-2) is multilingual.\n", | |
"It supports [about 100 languages](https://github.com/facebookresearch/LASER/#supported-languages). It is a biLSTM model." | |
], | |
"metadata": { | |
"id": "tBGK9KYvy-NO" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"id": "AYGc43fCxFIp" | |
}, | |
"outputs": [], | |
"source": [ | |
"!wget --trust-server-names -q https://tinyurl.com/nllblaser2\n", | |
"!wget --trust-server-names -q https://dl.fbaipublicfiles.com/nllb/laser/laser2.spm\n", | |
"!wget --trust-server-names -q https://dl.fbaipublicfiles.com/nllb/laser/laser2.cvocab" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"The second model is an example of LASER-3 family. This family supports [about 100 more languages](https://github.com/facebookresearch/LASER/tree/main/nllb), with a separate model for each of them. Each model is a transformer encoder.\n", | |
"\n", | |
"We'll download one for the Bashkort language, as an example. " | |
], | |
"metadata": { | |
"id": "idE3abUd2QCm" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!wget --trust-server-names -q https://dl.fbaipublicfiles.com/nllb/laser/laser3-bak_Cyrl.v1.pt\n", | |
"!wget --trust-server-names -q https://dl.fbaipublicfiles.com/nllb/laser/laser3-bak_Cyrl.v1.spm\n", | |
"!wget --trust-server-names -q https://dl.fbaipublicfiles.com/nllb/laser/laser3-bak_Cyrl.v1.cvocab" | |
], | |
"metadata": { | |
"id": "EUnh5Knn2RIs" | |
}, | |
"execution_count": 6, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"Ironically, the monolingual models are three times larger than the model for 100 languages." | |
], | |
"metadata": { | |
"id": "xgqFgKut8AQP" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!ls -alsh" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "1utR-FvkzE6N", | |
"outputId": "2cb65373-ab8b-45f2-d96e-1ca2ef29c3d9" | |
}, | |
"execution_count": 7, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"total 810M\n", | |
"4.0K drwxr-xr-x 1 root root 4.0K Jun 11 14:41 .\n", | |
"4.0K drwxr-xr-x 1 root root 4.0K Jun 11 14:36 ..\n", | |
"4.0K drwxr-xr-x 4 root root 4.0K Jun 8 18:17 .config\n", | |
"4.0K drwxr-xr-x 9 root root 4.0K Jun 11 14:40 LASER\n", | |
"460K -rw-r--r-- 1 root root 460K Jun 24 2022 laser2.cvocab\n", | |
"172M -rw-r--r-- 1 root root 172M Jun 24 2022 laser2.pt\n", | |
"988K -rw-r--r-- 1 root root 985K Jun 24 2022 laser2.spm\n", | |
"872K -rw-r--r-- 1 root root 869K Jun 24 2022 laser3-bak_Cyrl.v1.cvocab\n", | |
"635M -rw-r--r-- 1 root root 635M Jun 24 2022 laser3-bak_Cyrl.v1.pt\n", | |
"1.5M -rw-r--r-- 1 root root 1.5M Jun 24 2022 laser3-bak_Cyrl.v1.spm\n", | |
"4.0K drwxr-xr-x 1 root root 4.0K Jun 8 18:18 sample_data\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Step 2: loading the models" | |
], | |
"metadata": { | |
"id": "m7qDKPV6zNEg" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# `embed` is a directory in the LASER repo\n", | |
"from embed import LaserTransformerEncoder, load_model\n", | |
"import sentencepiece as spm" | |
], | |
"metadata": { | |
"id": "jtKhJaaSz5uS" | |
}, | |
"execution_count": 8, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"Loading the multilingual model" | |
], | |
"metadata": { | |
"id": "uooPrs2Y2B6u" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"laser2_encoder = load_model(\n", | |
" LASER_MODELS_DIR + '/laser2.pt', \n", | |
" LASER_MODELS_DIR + '/laser2.spm', \n", | |
" LASER_MODELS_DIR + '/laser2.cvocab',\n", | |
" cpu=not USE_CUDA,\n", | |
")\n", | |
"laser2_tokenizer = spm.SentencePieceProcessor()\n", | |
"laser2_tokenizer.Load(LASER_MODELS_DIR + '/laser2.spm')" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "kCk279dG0RI8", | |
"outputId": "5288751e-bac5-4b70-9b4c-1c0b07cf14ed" | |
}, | |
"execution_count": 9, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"True" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 9 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"encoder_ba = load_model(\n", | |
" LASER_MODELS_DIR + '/laser3-bak_Cyrl.v1.pt', \n", | |
" LASER_MODELS_DIR + '/laser3-bak_Cyrl.v1.spm', \n", | |
" LASER_MODELS_DIR + '/laser3-bak_Cyrl.v1.cvocab',\n", | |
" cpu=not USE_CUDA,\n", | |
")\n", | |
"\n", | |
"ba_tokenizer = spm.SentencePieceProcessor()\n", | |
"ba_tokenizer.Load(LASER_MODELS_DIR + '/laser3-bak_Cyrl.v1.spm')" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "6Ix6n-DJ18uo", | |
"outputId": "fdbceec6-e9fa-40e2-90f0-6c928e4dbc47" | |
}, | |
"execution_count": 10, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"True" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 10 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Step 3: normalization" | |
], | |
"metadata": { | |
"id": "PaVqz_-LuGzv" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"The original code from LASER repository [includes](https://github.com/facebookresearch/LASER/blob/main/source/lib/text_processing.py#L136) a ton of perl scripts (MOSES) for text preprocessing:\n", | |
"- REM_NON_PRINT_CHAR\n", | |
"- NORM_PUNC - we use the Python implementation from Sacremoses\n", | |
"- ROMAN_LC - we don't romanize, and lowercase by Python\n", | |
"\n", | |
"As we don't use them, the results might sligtly differ from the original ones." | |
], | |
"metadata": { | |
"id": "lz5ieA9s8bN1" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import re\n", | |
"from sacremoses import MosesPunctNormalizer\n", | |
"mpn = MosesPunctNormalizer()\n", | |
"\n", | |
"def laser2_preprocess(text):\n", | |
" # removing all nonprintable characters\n", | |
" # first, we replace space-like charachers with space, or they will be removed\n", | |
" text = re.sub('\\s+', ' ', text)\n", | |
" # In Python, Nonprintable characters are those characters defined in the Unicode character database as “Other” or “Separator”, \n", | |
" # excepting the ASCII space (0x20) which is considered printable\n", | |
" text = ''.join(c for c in text if c.isprintable())\n", | |
"\n", | |
" # normalizing punctuation\n", | |
" text = mpn.normalize(text)\n", | |
"\n", | |
" # lowercasing\n", | |
" text = text.lower()\n", | |
" return text\n", | |
"\n", | |
"print(laser2_preprocess('ПРивет, э\\x00то «Я»!'))" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "CMolw3lTov-O", | |
"outputId": "d69dbedd-52c1-4cee-c52d-5360fd69e4aa" | |
}, | |
"execution_count": 13, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"привет, это \"я\"!\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import torch\n", | |
"\n", | |
"def embed(texts, model, tokenizer, normalize=True):\n", | |
" \"\"\" Using a LASER model and tokenizer, compute embeddings for texts .\"\"\"\n", | |
" texts_tok = [\n", | |
" ' '.join(tokenizer.encode_as_pieces(laser2_preprocess(text)))\n", | |
" for text in texts\n", | |
" ]\n", | |
" with torch.inference_mode():\n", | |
" emb = model.encode_sentences(texts_tok)\n", | |
" if normalize:\n", | |
" emb = emb / (emb**2).sum(1, keepdims=True) ** 0.5\n", | |
" return emb" | |
], | |
"metadata": { | |
"id": "0ZHztAC428Di" | |
}, | |
"execution_count": 14, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Step 4: embedding" | |
], | |
"metadata": { | |
"id": "HcKHMXsj22k7" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"ru_sents = ['привет мир!', 'привет, как дела?', 'Мама мыла раму']\n", | |
"en_sents = ['Hello world!', 'Hi, how are you?', 'Mom was washing the frame']\n", | |
"ba_sents = ['сәләм тыныслыҡ!', 'сәләм, хәлдәр нисек?', 'Әсәй рамаҙан айы']" | |
], | |
"metadata": { | |
"id": "KMXd2gXa3FUx" | |
}, | |
"execution_count": 15, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"b_ru = embed(ru_sents, laser2_encoder, laser2_tokenizer)\n", | |
"b_en = embed(en_sents, laser2_encoder, laser2_tokenizer)\n", | |
"b_ba = embed(ba_sents, encoder_ba, ba_tokenizer)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "qOllQxBh3IRq", | |
"outputId": "b7591f01-9029-47cd-aa8c-f36fed3e0709" | |
}, | |
"execution_count": 16, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"/usr/local/lib/python3.10/dist-packages/fairseq/models/transformer/transformer_encoder.py:281: UserWarning: The PyTorch API of nested tensors is in prototype stage and will change in the near future. (Triggered internally at ../aten/src/ATen/NestedTensorImpl.cpp:177.)\n", | |
" x = torch._nested_tensor_from_mask(\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"Let's compute cosine similarity between each pair of embeddings.\n", | |
"\n", | |
"They are L2-normalized, so this similarity equals just their dot product.\n", | |
"\n", | |
"We can see that within each language, sentence 1 an sentence 2 are more similar to each other than to sentence 3. Also, for each pair of languages, the sentences with the same number are the most similar to each other." | |
], | |
"metadata": { | |
"id": "irPXF2T53uQi" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"b_ru.dot(b_ru.T)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "93e2ce9a-674a-4709-dd97-7c2d8b34f622", | |
"id": "KVtA1AOR5ouV" | |
}, | |
"execution_count": 17, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"array([[0.99999994, 0.61413217, 0.3357919 ],\n", | |
" [0.61413217, 1. , 0.20773266],\n", | |
" [0.3357919 , 0.20773266, 1. ]], dtype=float32)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 17 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"b_en.dot(b_en.T)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "6f8ebcc0-6ea7-498c-ea23-c13d0fbb5725", | |
"id": "8D26KhEY5ouV" | |
}, | |
"execution_count": 18, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"array([[0.99999994, 0.62056834, 0.362594 ],\n", | |
" [0.62056834, 1. , 0.25831014],\n", | |
" [0.362594 , 0.25831014, 0.9999999 ]], dtype=float32)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 18 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"b_ba.dot(b_ba.T)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "9a9ad372-8e9a-45ea-a20d-65e9729e0562", | |
"id": "coErpi7J5ouW" | |
}, | |
"execution_count": 19, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"array([[1. , 0.6931297 , 0.5284328 ],\n", | |
" [0.6931297 , 1.0000002 , 0.44858515],\n", | |
" [0.5284328 , 0.44858515, 0.99999994]], dtype=float32)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 19 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"b_ru.dot(b_en.T)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "0e5a3db3-2b8d-435f-d146-456bdc7811bb", | |
"id": "2RHyhb225ouW" | |
}, | |
"execution_count": 20, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"array([[0.82902086, 0.606028 , 0.33507192],\n", | |
" [0.6140566 , 0.91600704, 0.26718485],\n", | |
" [0.37389138, 0.22469705, 0.6815194 ]], dtype=float32)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 20 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"b_ru.dot(b_ba.T)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "22a6e260-6fdd-4762-8094-53be60e99e72", | |
"id": "guILB2sy5ouX" | |
}, | |
"execution_count": 21, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"array([[0.86894596, 0.6789486 , 0.43244153],\n", | |
" [0.59563744, 0.8806851 , 0.29750603],\n", | |
" [0.3929627 , 0.34107754, 0.8217282 ]], dtype=float32)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 21 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"b_en.dot(b_ba.T)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "e4c32d9f-3069-4a8b-96c4-19b6ffd48adf", | |
"id": "gp9JZlk25ouX" | |
}, | |
"execution_count": 22, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"array([[0.7456117 , 0.68734574, 0.4956516 ],\n", | |
" [0.61173916, 0.8490674 , 0.31819192],\n", | |
" [0.37096176, 0.4003557 , 0.6890281 ]], dtype=float32)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 22 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Speech encoders" | |
], | |
"metadata": { | |
"id": "-VCWbNjS3tWS" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"LASER also supports speech sentence encoders for ~20 European languages. \n", | |
"\n", | |
"They project sentences into exactly the same space as LASER-2 and LASER-3 text encoders, so with them, you can use text sentences and speech utterances interchangeably.\n", | |
"\n", | |
"Warning: these models are super fat (24 GB each), so you probably don't want to use them unless you have a lot of time or a very powerful GPU. \n", | |
"\n", | |
"If you can afford such large models, though, the code for their inference is very simple. The models and a code snippet are available at https://github.com/facebookresearch/fairseq/blob/ust/examples/speech_matrix/speech_laser_encoders.md. The code requires Fairseq only to be installed." | |
], | |
"metadata": { | |
"id": "YwF5CMpH4bLh" | |
} | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment