Last active
July 21, 2023 07:51
-
-
Save raven44099/32babed67e122427ec36e1fafd142c08 to your computer and use it in GitHub Desktop.
entitylinking_genre_colab_minimalcode.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"collapsed_sections": [], | |
"authorship_tag": "ABX9TyM2lAGsvRkRpJ9zceJ2UqzO", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/raven44099/32babed67e122427ec36e1fafd142c08/entitylinking_genre_colab_minimalcode.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## 1. colab GENRE\n", | |
"at the timepoint 2022.Oct.27, this is a fully functional colab script to run Facebook AI's entity linker called 'GENRE' (https://github.com/facebookresearch/GENRE). This file can be used as an entry point to write a colab script that also includes training (but I haven't written that yet).\n", | |
"\n", | |
"A 2nd approach that utilizes huggingface-transfromers package is also provided in one cell, but that code is redundant for the 1st approach." | |
], | |
"metadata": { | |
"id": "O5ixnsog3C05" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"#@title 1.2. huggingface GENRE\n", | |
"# # %%capture\n", | |
"# !pip install transformers\n", | |
"\n", | |
"# from transformers import AutoTokenizer, AutoModelForSeq2SeqLM\n", | |
"\n", | |
"# # OPTIONAL: load the prefix tree (trie), you need to additionally download\n", | |
"# # https://huggingface.co/facebook/genre-kilt/blob/main/trie.py and \n", | |
"# # https://huggingface.co/facebook/genre-kilt/blob/main/kilt_titles_trie_dict.pkl\n", | |
"# # import pickle\n", | |
"# # from trie import Trie\n", | |
"# # with open(\"kilt_titles_trie_dict.pkl\", \"rb\") as f:\n", | |
"# # trie = Trie.load_from_dict(pickle.load(f))\n", | |
"\n", | |
"# tokenizer = AutoTokenizer.from_pretrained(\"facebook/genre-kilt\")\n", | |
"# model = AutoModelForSeq2SeqLM.from_pretrained(\"facebook/genre-kilt\").eval()\n", | |
"\n", | |
"# sentences = [\"Einstein was a German physicist.\"]\n", | |
"\n", | |
"# outputs = model.generate(\n", | |
"# **tokenizer(sentences, return_tensors=\"pt\"),\n", | |
"# num_beams=5,\n", | |
"# num_return_sequences=5,\n", | |
"# # OPTIONAL: use constrained beam search\n", | |
"# # prefix_allowed_tokens_fn=lambda batch_id, sent: trie.get(sent.tolist()),\n", | |
"# )\n", | |
"\n", | |
"# tokenizer.batch_decode(outputs, skip_special_tokens=True)\n", | |
"\n", | |
"# sentences = [\" Proteins are well known in the realm of [START] synthetic biology, where E. coli oftenis used to produce the catalyst for experiments associated with the INS gene ..\"]\n", | |
"\n", | |
"# outputs = model.generate(\n", | |
"# **tokenizer(sentences, return_tensors=\"pt\"),\n", | |
"# num_beams=5,\n", | |
"# num_return_sequences=5,\n", | |
"# # OPTIONAL: use constrained beam search\n", | |
"# # prefix_allowed_tokens_fn=lambda batch_id, sent: trie.get(sent.tolist()),\n", | |
"# )\n", | |
"\n", | |
"# tokenizer.batch_decode(outputs, skip_special_tokens=True)" | |
], | |
"metadata": { | |
"id": "FogMwcT3I0Hi" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## 1.3. download GENRE & configure" | |
], | |
"metadata": { | |
"id": "SriJcWRx3Ae0" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"#@title clone GENRE \n", | |
"# %%capture\n", | |
"%cd /content/\n", | |
"!git clone https://github.com/facebookresearch/GENRE" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "sz7aq_MDSxJi", | |
"outputId": "8d766299-b421-4e90-c3ca-9c7e00e5c400", | |
"cellView": "form" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"/content\n", | |
"Cloning into 'GENRE'...\n", | |
"remote: Enumerating objects: 457, done.\u001b[K\n", | |
"remote: Counting objects: 100% (173/173), done.\u001b[K\n", | |
"remote: Compressing objects: 100% (92/92), done.\u001b[K\n", | |
"remote: Total 457 (delta 114), reused 102 (delta 78), pack-reused 284\u001b[K\n", | |
"Receiving objects: 100% (457/457), 11.00 MiB | 25.36 MiB/s, done.\n", | |
"Resolving deltas: 100% (261/261), done.\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"#@title clone & install fairseq\n", | |
"%%capture\n", | |
"!git clone --branch fixing_prefix_allowed_tokens_fn https://github.com/nicola-decao/fairseq\n", | |
"!pwd\n", | |
"%cd /content/fairseq\n", | |
"! pip install --editable .\n", | |
"#! pip install --editable ./" | |
], | |
"metadata": { | |
"id": "98AaBLuiTkGL", | |
"cellView": "form" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"'''this path.append must maybe be before runtime restart.'''\n", | |
"import sys\n", | |
"sys.path.append(\"/content/fairseq/\")\n", | |
"sys.path.append(\"/content/GENRE\")" | |
], | |
"metadata": { | |
"id": "2Htb8TkMCJs8" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"<font color='red'> restart runtime was initiially necessary here, now not anymore...! </font>" | |
], | |
"metadata": { | |
"id": "wSXi7pqM4nST" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"%%capture\n", | |
"!pip install jsonlines" | |
], | |
"metadata": { | |
"id": "esf4epyTXBFX" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"%cd /content/\n", | |
"!mkdir data\n", | |
"%cd data\n", | |
"### KILT prefix tree\n", | |
"!wget http://dl.fbaipublicfiles.com/GENRE/kilt_titles_trie_dict.pkl" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "0S24jRdQgZQU", | |
"outputId": "472afe65-39e1-4d6e-ba49-e4d489648549" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"/content\n", | |
"/content/data\n", | |
"--2022-11-02 06:17:52-- http://dl.fbaipublicfiles.com/GENRE/kilt_titles_trie_dict.pkl\n", | |
"Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 104.22.75.142, 172.67.9.4, 104.22.74.142, ...\n", | |
"Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|104.22.75.142|:80... connected.\n", | |
"HTTP request sent, awaiting response... 200 OK\n", | |
"Length: 215214973 (205M) [application/octet-stream]\n", | |
"Saving to: ‘kilt_titles_trie_dict.pkl’\n", | |
"\n", | |
"kilt_titles_trie_di 100%[===================>] 205.24M 26.0MB/s in 8.5s \n", | |
"\n", | |
"2022-11-02 06:18:01 (24.3 MB/s) - ‘kilt_titles_trie_dict.pkl’ saved [215214973/215214973]\n", | |
"\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
" from GENRE/scripts_genre/download_all_models.sh #" | |
], | |
"metadata": { | |
"id": "j2cfAYgzeQrK" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"'''\n", | |
"This code sniped is missing in the README.md,\n", | |
"but can be found at 'GENRE/scripts_genre/download_all_models.sh'\n", | |
"'''\n", | |
"%cd /content/\n", | |
"!mkdir models\n", | |
"%cd models\n", | |
"!wget http://dl.fbaipublicfiles.com/GENRE/fairseq_entity_disambiguation_aidayago.tar.gz\n", | |
"!tar -zxvf fairseq_entity_disambiguation_aidayago.tar.gz" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "14ChP3M-duHJ", | |
"outputId": "d8985f54-af2f-4831-d817-56e50dee75c0" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"/content\n", | |
"/content/models\n", | |
"--2022-11-02 06:18:01-- http://dl.fbaipublicfiles.com/GENRE/fairseq_entity_disambiguation_aidayago.tar.gz\n", | |
"Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 104.22.75.142, 172.67.9.4, 104.22.74.142, ...\n", | |
"Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|104.22.75.142|:80... connected.\n", | |
"HTTP request sent, awaiting response... 200 OK\n", | |
"Length: 1201544043 (1.1G) [application/gzip]\n", | |
"Saving to: ‘fairseq_entity_disambiguation_aidayago.tar.gz’\n", | |
"\n", | |
"fairseq_entity_disa 100%[===================>] 1.12G 32.0MB/s in 38s \n", | |
"\n", | |
"2022-11-02 06:18:39 (30.3 MB/s) - ‘fairseq_entity_disambiguation_aidayago.tar.gz’ saved [1201544043/1201544043]\n", | |
"\n", | |
"fairseq_entity_disambiguation_aidayago/\n", | |
"fairseq_entity_disambiguation_aidayago/dict.source.txt\n", | |
"fairseq_entity_disambiguation_aidayago/dict.target.txt\n", | |
"fairseq_entity_disambiguation_aidayago/model.pt\n", | |
"fairseq_entity_disambiguation_aidayago/encoder.json\n", | |
"fairseq_entity_disambiguation_aidayago/vocab.bpe\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import pickle\n", | |
"from genre.trie import Trie\n", | |
"%cd /content\n", | |
"# load the prefix tree (trie)\n", | |
"with open(\"data/kilt_titles_trie_dict.pkl\", \"rb\") as f:\n", | |
" trie = Trie.load_from_dict(pickle.load(f))" | |
], | |
"metadata": { | |
"id": "ZRgMvhYNgP_W", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "64187ccf-3d20-45a3-fec7-ae7ceaba959e" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"/content\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## 1.3 the model" | |
], | |
"metadata": { | |
"id": "O74GGKj8qwdw" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from genre.fairseq_model import GENRE\n", | |
"%cd /content/\n", | |
"model = GENRE.from_pretrained(\"/content/models/fairseq_entity_disambiguation_aidayago\").eval()\n", | |
"\n", | |
"# for huggingface/transformers\n", | |
"# from genre.hf_model import GENRE\n", | |
"# model = GENRE.from_pretrained(\"../models/hf_entity_disambiguation_aidayago\").eval()" | |
], | |
"metadata": { | |
"id": "RGsYFjgARz__", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "d2b74c84-a029-401c-f34c-1d480ddea70d" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"/content\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"1042301B [00:00, 3081465.33B/s]\n", | |
"456318B [00:00, 882396.44B/s]\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# myPhrase = [\"Einstein was a [START_ENT] German [END_ENT] physicist.\"]\n", | |
"\n", | |
"myPhrase = [' Proteins are well known in the realm of synthetic biology, where E. coli often is used to produce material for experiments associated with the INS gene.',\n", | |
" 'Experiments associated with the [START_ENT] INS gene [END_ENT] or in very special cases other sources.',\n", | |
" 'Insulin is a small molecule. experiments associated with the [START_ENT] INS gene [END_ENT] or in very special cases other sources.']\n", | |
"model.sample(\n", | |
" sentences=myPhrase, \n", | |
" prefix_allowed_tokens_fn=lambda batch_id, sent: trie.get(sent.tolist()),\n", | |
")" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "5HrhxvrHf-0r", | |
"outputId": "9ea8a278-c8e0-415e-c22e-41e7f5ef0925" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"/content/fairseq/fairseq/search.py:205: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n", | |
" beams_buf = indices_buf // vocab_size\n", | |
"/content/fairseq/fairseq/sequence_generator.py:659: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').\n", | |
" unfin_idx = idx // beam_size\n" | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"[[{'text': 'Biosynthesis', 'score': tensor(-1.4936)},\n", | |
" {'text': 'Sodium silicate', 'score': tensor(-1.5992)},\n", | |
" {'text': 'Biopterin', 'score': tensor(-1.6096)},\n", | |
" {'text': 'Sodium silicide', 'score': tensor(-1.9425)},\n", | |
" {'text': 'Biopterin-dependent aromatic amino acid hydroxylase',\n", | |
" 'score': tensor(-2.4378)}],\n", | |
" [{'text': 'Inositol trisphosphate', 'score': tensor(-0.1708)},\n", | |
" {'text': 'Insulin-like growth factor', 'score': tensor(-0.5887)},\n", | |
" {'text': 'Immunoglobulin E', 'score': tensor(-0.9796)},\n", | |
" {'text': 'Immunoglobulin A', 'score': tensor(-1.0120)},\n", | |
" {'text': 'Immunosuppressive drug', 'score': tensor(-1.0889)}],\n", | |
" [{'text': 'Insulin', 'score': tensor(-0.3428)},\n", | |
" {'text': 'Insulin-like growth factor', 'score': tensor(-0.3444)},\n", | |
" {'text': 'Insulin receptor', 'score': tensor(-1.1370)},\n", | |
" {'text': 'Insulin resistance', 'score': tensor(-1.5255)},\n", | |
" {'text': 'Insulin receptor substrate', 'score': tensor(-2.6037)}]]" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 9 | |
} | |
] | |
} | |
] | |
} |
Thanks for your comment 😄 . Appreciate it!
Nice! Do you have mGENRE as well?
No... I'm sorry.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This is really helpful. Thanks!