Skip to content

Instantly share code, notes, and snippets.

@hideojoho
Created July 8, 2020 04:39
Show Gist options
  • Save hideojoho/1a55634b122e95184b63649c6f2e2134 to your computer and use it in GitHub Desktop.
Save hideojoho/1a55634b122e95184b63649c6f2e2134 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"tags": []
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": "Num GPUs Available: 1\n"
}
],
"source": [
"import tensorflow as tf\n",
"print(\"Num GPUs Available: \", len(tf.config.experimental.list_physical_devices('GPU')))"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"tags": []
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": "Defaulting to user installation because normal site-packages is not writeable\nRequirement already satisfied: tensorflow_datasets in /.local/lib/python3.6/site-packages (3.1.0)\nRequirement already satisfied: ipywidgets in /.local/lib/python3.6/site-packages (7.5.1)\nRequirement already satisfied: torch in /.local/lib/python3.6/site-packages (1.5.1)\nRequirement already satisfied: torchvision in /.local/lib/python3.6/site-packages (0.6.1)\nRequirement already satisfied: wrapt in /usr/local/lib/python3.6/dist-packages (from tensorflow_datasets) (1.12.1)\nRequirement already satisfied: protobuf>=3.6.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow_datasets) (3.11.3)\nRequirement already satisfied: tensorflow-metadata in /.local/lib/python3.6/site-packages (from tensorflow_datasets) (0.22.2)\nRequirement already satisfied: tqdm in /.local/lib/python3.6/site-packages (from tensorflow_datasets) (4.47.0)\nRequirement already satisfied: termcolor in /usr/local/lib/python3.6/dist-packages (from tensorflow_datasets) (1.1.0)\nRequirement already satisfied: attrs>=18.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow_datasets) (19.3.0)\nRequirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from tensorflow_datasets) (1.18.4)\nRequirement already satisfied: dill in /.local/lib/python3.6/site-packages (from tensorflow_datasets) (0.3.2)\nRequirement already satisfied: future in /.local/lib/python3.6/site-packages (from tensorflow_datasets) (0.18.2)\nRequirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from tensorflow_datasets) (1.14.0)\nRequirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow_datasets) (2.23.0)\nRequirement already satisfied: absl-py in /usr/local/lib/python3.6/dist-packages (from tensorflow_datasets) (0.9.0)\nRequirement already satisfied: promise in /.local/lib/python3.6/site-packages (from tensorflow_datasets) (2.3)\nRequirement already satisfied: nbformat>=4.2.0 in /usr/local/lib/python3.6/dist-packages (from ipywidgets) (5.0.7)\nRequirement already satisfied: ipykernel>=4.5.1 in /usr/local/lib/python3.6/dist-packages (from ipywidgets) (5.3.0)\nRequirement already satisfied: ipython>=4.0.0; python_version >= \"3.3\" in /usr/local/lib/python3.6/dist-packages (from ipywidgets) (7.16.1)\nRequirement already satisfied: widgetsnbextension~=3.5.0 in /.local/lib/python3.6/site-packages (from ipywidgets) (3.5.1)\nRequirement already satisfied: traitlets>=4.3.1 in /usr/local/lib/python3.6/dist-packages (from ipywidgets) (4.3.3)\nRequirement already satisfied: pillow>=4.1.1 in /usr/local/lib/python3.6/dist-packages (from torchvision) (7.1.2)\nRequirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from protobuf>=3.6.1->tensorflow_datasets) (46.1.3)\nRequirement already satisfied: googleapis-common-protos in /.local/lib/python3.6/site-packages (from tensorflow-metadata->tensorflow_datasets) (1.52.0)\nRequirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests>=2.19.0->tensorflow_datasets) (2020.4.5.1)\nRequirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests>=2.19.0->tensorflow_datasets) (1.25.9)\nRequirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests>=2.19.0->tensorflow_datasets) (3.0.4)\nRequirement already satisfied: idna<3,>=2.5 in /usr/lib/python3/dist-packages (from requests>=2.19.0->tensorflow_datasets) (2.6)\nRequirement already satisfied: jsonschema!=2.5.0,>=2.4 in /usr/local/lib/python3.6/dist-packages (from nbformat>=4.2.0->ipywidgets) (3.2.0)\nRequirement already satisfied: ipython-genutils in /usr/local/lib/python3.6/dist-packages (from nbformat>=4.2.0->ipywidgets) (0.2.0)\nRequirement already satisfied: jupyter-core in /usr/local/lib/python3.6/dist-packages (from nbformat>=4.2.0->ipywidgets) (4.6.3)\nRequirement already satisfied: jupyter-client in /usr/local/lib/python3.6/dist-packages (from ipykernel>=4.5.1->ipywidgets) (6.1.3)\nRequirement already satisfied: tornado>=4.2 in /usr/local/lib/python3.6/dist-packages (from ipykernel>=4.5.1->ipywidgets) (6.0.4)\nRequirement already satisfied: jedi>=0.10 in /usr/local/lib/python3.6/dist-packages (from ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (0.17.1)\nRequirement already satisfied: backcall in /usr/local/lib/python3.6/dist-packages (from ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (0.2.0)\nRequirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (3.0.5)\nRequirement already satisfied: pickleshare in /usr/local/lib/python3.6/dist-packages (from ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (0.7.5)\nRequirement already satisfied: pexpect; sys_platform != \"win32\" in /usr/local/lib/python3.6/dist-packages (from ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (4.8.0)\nRequirement already satisfied: pygments in /usr/local/lib/python3.6/dist-packages (from ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (2.6.1)\nRequirement already satisfied: decorator in /usr/local/lib/python3.6/dist-packages (from ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (4.4.2)\nRequirement already satisfied: notebook>=4.4.1 in /usr/local/lib/python3.6/dist-packages (from widgetsnbextension~=3.5.0->ipywidgets) (6.0.3)\nRequirement already satisfied: pyrsistent>=0.14.0 in /usr/local/lib/python3.6/dist-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets) (0.16.0)\nRequirement already satisfied: importlib-metadata; python_version < \"3.8\" in /usr/local/lib/python3.6/dist-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets) (1.7.0)\nRequirement already satisfied: pyzmq>=13 in /usr/local/lib/python3.6/dist-packages (from jupyter-client->ipykernel>=4.5.1->ipywidgets) (19.0.1)\nRequirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.6/dist-packages (from jupyter-client->ipykernel>=4.5.1->ipywidgets) (2.8.1)\nRequirement already satisfied: parso<0.8.0,>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from jedi>=0.10->ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (0.7.0)\nRequirement already satisfied: wcwidth in /usr/local/lib/python3.6/dist-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (0.2.5)\nRequirement already satisfied: ptyprocess>=0.5 in /usr/local/lib/python3.6/dist-packages (from pexpect; sys_platform != \"win32\"->ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (0.6.0)\nRequirement already satisfied: nbconvert in /usr/local/lib/python3.6/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (5.6.1)\nRequirement already satisfied: Send2Trash in /usr/local/lib/python3.6/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (1.5.0)\nRequirement already satisfied: terminado>=0.8.1 in /usr/local/lib/python3.6/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.8.3)\nRequirement already satisfied: jinja2 in /usr/local/lib/python3.6/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (2.11.2)\nRequirement already satisfied: prometheus-client in /usr/local/lib/python3.6/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.8.0)\nRequirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata; python_version < \"3.8\"->jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets) (3.1.0)\nRequirement already satisfied: entrypoints>=0.2.2 in /usr/local/lib/python3.6/dist-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.3)\nRequirement already satisfied: bleach in /usr/local/lib/python3.6/dist-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (3.1.5)\nRequirement already satisfied: pandocfilters>=1.4.1 in /usr/local/lib/python3.6/dist-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (1.4.2)\nRequirement already satisfied: defusedxml in /usr/local/lib/python3.6/dist-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.6.0)\nRequirement already satisfied: testpath in /usr/local/lib/python3.6/dist-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.4.4)\nRequirement already satisfied: mistune<2,>=0.8.1 in /usr/local/lib/python3.6/dist-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.8.4)\nRequirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.6/dist-packages (from jinja2->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (1.1.1)\nRequirement already satisfied: webencodings in /usr/local/lib/python3.6/dist-packages (from bleach->nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.5.1)\nRequirement already satisfied: packaging in /usr/local/lib/python3.6/dist-packages (from bleach->nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (20.4)\nRequirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.6/dist-packages (from packaging->bleach->nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (2.4.7)\n"
}
],
"source": [
"import sys\n",
"!{sys.executable} -m pip install tensorflow_datasets ipywidgets torch torchvision"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"import tensorflow_datasets\n",
"from transformers import *\n",
"import torch, torchvision"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"tags": []
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=213450.0, style=ProgressStyle(descripti…",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "eaa0d8b01c4b4a25b965aa45e20c9828"
}
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": "\n"
}
],
"source": [
"# Load dataset, tokenizer, model from pretrained model/vocabulary\n",
"tokenizer = BertTokenizer.from_pretrained('bert-base-cased', cache_dir='./bert_cache')"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"tags": []
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=433.0, style=ProgressStyle(description_…",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "8f67a21702e44f1e92bcc57d750129fc"
}
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": "\n"
},
{
"output_type": "display_data",
"data": {
"text/plain": "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=526681800.0, style=ProgressStyle(descri…",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "f36f321c04bf4e8bbf303ca7b82689af"
}
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": "\nSome weights of the model checkpoint at bert-base-cased were not used when initializing TFBertForSequenceClassification: ['mlm___cls', 'nsp___cls']\n- This IS expected if you are initializing TFBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n- This IS NOT expected if you are initializing TFBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\nSome weights of TFBertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['dropout_37', 'classifier']\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
}
],
"source": [
"model = TFBertForSequenceClassification.from_pretrained('bert-base-cased', cache_dir='./bert_cache')"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"tags": []
},
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": "INFO:absl:Load pre-computed DatasetInfo (eg: splits, num examples,...) from GCS: glue/mrpc/1.0.0\nINFO:absl:Load dataset info from /tmp/tmplm8wq6octfds\nINFO:absl:Field info.description from disk and from code do not match. Keeping the one from code.\nINFO:absl:Field info.citation from disk and from code do not match. Keeping the one from code.\nINFO:absl:Field info.location from disk and from code do not match. Keeping the one from code.\nINFO:absl:Generating dataset glue (./tensorflow_datasets/glue/mrpc/1.0.0)\n\u001b[1mDownloading and preparing dataset glue/mrpc/1.0.0 (download: 1.43 MiB, generated: Unknown size, total: 1.43 MiB) to ./tensorflow_datasets/glue/mrpc/1.0.0...\u001b[0m\n"
},
{
"output_type": "display_data",
"data": {
"text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Completed...', max=1.0, style=Progre…",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "7bc40928dc9f4f1eb3e757e5c3651808"
}
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Dl Size...', max=1.0, style=ProgressSty…",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "126216d0da854704bfe56490c500537c"
}
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stderr",
"text": "INFO:absl:Downloading https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2Fmrpc_dev_ids.tsv?alt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc into ./tensorflow_datasets/downloads/fire.goog.com_v0_b_mtl-sent-repr.apps.com_o_2FjSIMlCiqs1QSmIykr4IRPnEHjPuGwAz5i40v8K9U0Z8.tsvalt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc.tmp.5d0fcd672ff04d74a559dac0ea7479d8...\nINFO:absl:Downloading https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt into ./tensorflow_datasets/downloads/dl.fbaip.com_sente_sente_msr_parap_trainfGxPZuQWGBti4Tbd1YNOwQr-OqxPejJ7gcp0Al6mlSk.txt.tmp.cc4ffc3de0d64b30b5875bb4df860003...\nINFO:absl:Downloading https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt into ./tensorflow_datasets/downloads/dl.fbaip.com_sente_sente_msr_parap_test0PdekMcyqYR-w4Rx_d7OTryq0J3RlYRn4rAMajy9Mak.txt.tmp.920f23e252ca42b4a30dec5a97255b9b...\n/usr/local/lib/python3.6/dist-packages/urllib3/connectionpool.py:986: InsecureRequestWarning: Unverified HTTPS request is being made to host 'dl.fbaipublicfiles.com'. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings\n InsecureRequestWarning,\n/usr/local/lib/python3.6/dist-packages/urllib3/connectionpool.py:986: InsecureRequestWarning: Unverified HTTPS request is being made to host 'dl.fbaipublicfiles.com'. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings\n InsecureRequestWarning,\n/usr/local/lib/python3.6/dist-packages/urllib3/connectionpool.py:986: InsecureRequestWarning: Unverified HTTPS request is being made to host 'firebasestorage.googleapis.com'. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings\n InsecureRequestWarning,\nINFO:absl:Generating split train\n\n\n\n\n"
},
{
"output_type": "display_data",
"data": {
"text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "623ba41852cd4ac38e56ca211d4a5e32"
}
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": "Shuffling and writing examples to ./tensorflow_datasets/glue/mrpc/1.0.0.incompleteOORUIX/glue-train.tfrecord\n"
},
{
"output_type": "display_data",
"data": {
"text/plain": "HBox(children=(FloatProgress(value=0.0, max=3668.0), HTML(value='')))",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "dc23132cb3df46958846ca6cc45127c3"
}
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stderr",
"text": "INFO:absl:Done writing ./tensorflow_datasets/glue/mrpc/1.0.0.incompleteOORUIX/glue-train.tfrecord. Shard lengths: [3668]\nINFO:absl:Generating split validation\n"
},
{
"output_type": "display_data",
"data": {
"text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "88bffc93f29944d7923a5876aca145cf"
}
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": "Shuffling and writing examples to ./tensorflow_datasets/glue/mrpc/1.0.0.incompleteOORUIX/glue-validation.tfrecord\n"
},
{
"output_type": "display_data",
"data": {
"text/plain": "HBox(children=(FloatProgress(value=0.0, max=408.0), HTML(value='')))",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "cf66bc0f1e31464b80d361036790188c"
}
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stderr",
"text": "INFO:absl:Done writing ./tensorflow_datasets/glue/mrpc/1.0.0.incompleteOORUIX/glue-validation.tfrecord. Shard lengths: [408]\nINFO:absl:Generating split test\n"
},
{
"output_type": "display_data",
"data": {
"text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "133b0ca879b64160920753dfebb08b00"
}
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": "Shuffling and writing examples to ./tensorflow_datasets/glue/mrpc/1.0.0.incompleteOORUIX/glue-test.tfrecord\n"
},
{
"output_type": "display_data",
"data": {
"text/plain": "HBox(children=(FloatProgress(value=0.0, max=1725.0), HTML(value='')))",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "39d3b2c18dd7477080138b7754f4278e"
}
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stderr",
"text": "INFO:absl:Done writing ./tensorflow_datasets/glue/mrpc/1.0.0.incompleteOORUIX/glue-test.tfrecord. Shard lengths: [1725]\nINFO:absl:Skipping computing stats for mode ComputeStatsMode.AUTO.\nINFO:absl:Constructing tf.data.Dataset for split None, from ./tensorflow_datasets/glue/mrpc/1.0.0\n\u001b[1mDataset glue downloaded and prepared to ./tensorflow_datasets/glue/mrpc/1.0.0. Subsequent calls will reuse this data.\u001b[0m\n"
}
],
"source": [
"data = tensorflow_datasets.load('glue/mrpc', data_dir='./tensorflow_datasets')"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# Prepare dataset for GLUE as a tf.data.Dataset instance\n",
"train_dataset = glue_convert_examples_to_features(data['train'], tokenizer, max_length=128, task='mrpc')\n",
"valid_dataset = glue_convert_examples_to_features(data['validation'], tokenizer, max_length=128, task='mrpc')\n",
"train_dataset = train_dataset.shuffle(100).batch(32).repeat(2)\n",
"valid_dataset = valid_dataset.batch(64)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# Prepare training: Compile tf.keras model with optimizer, loss and learning rate schedule\n",
"optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0)\n",
"loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
"metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')\n",
"model.compile(optimizer=optimizer, loss=loss, metrics=[metric])"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"tags": []
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": "Epoch 1/2\n115/115 [==============================] - 46s 403ms/step - loss: 0.5555 - accuracy: 0.7118 - val_loss: 0.3794 - val_accuracy: 0.8407\nEpoch 2/2\n115/115 [==============================] - 44s 384ms/step - loss: 0.3185 - accuracy: 0.8727 - val_loss: 0.3951 - val_accuracy: 0.8431\n"
}
],
"source": [
"# Train and evaluate using tf.keras.Model.fit()\n",
"history = model.fit(train_dataset, batch_size=32, epochs=2, steps_per_epoch=115,\n",
" validation_data=valid_dataset, validation_steps=7)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"tags": []
},
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": "WARNING:transformers.modeling_tf_pytorch_utils:All TF 2.0 model weights were used when initializing BertForSequenceClassification.\n\nWARNING:transformers.modeling_tf_pytorch_utils:All the weights of BertForSequenceClassification were initialized from the TF 2.0 model.\nIf your task is similar to the task the model of the ckeckpoint was trained on, you can already use BertForSequenceClassification for predictions without further training.\n"
}
],
"source": [
"# Load the TensorFlow model in PyTorch for inspection\n",
"model.save_pretrained('./bert_cache/')\n",
"pytorch_model = BertForSequenceClassification.from_pretrained('./bert_cache/', from_tf=True)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"tags": []
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": "sentence_0: This research was consistent with his findings.\nsentence_1: His findings were compatible with this research.\nsentence_2: His findings were incompatible with this research.\nsentence_3: This is a random sentence that has nothing to do with setnence_0.\n---\nsentence_1 is a paraphrase of sentence_0\nsentence_2 is a paraphrase of sentence_0\nsentence_3 is NOT a paraphrase of sentence_0\n"
}
],
"source": [
"# Quickly test a few predictions - MRPC is a paraphrasing task, let's see if our model learned the task\n",
"sentence_0 = \"This research was consistent with his findings.\"\n",
"sentence_1 = \"His findings were compatible with this research.\"\n",
"sentence_2 = \"His findings were incompatible with this research.\"\n",
"sentence_3 = \"This is a random sentence that has nothing to do with setnence_0.\"\n",
"\n",
"inputs_1 = tokenizer(sentence_0, sentence_1, add_special_tokens=True, return_tensors='pt')\n",
"inputs_2 = tokenizer(sentence_0, sentence_2, add_special_tokens=True, return_tensors='pt')\n",
"inputs_3 = tokenizer(sentence_0, sentence_3, add_special_tokens=True, return_tensors='pt')\n",
"\n",
"pred_1 = pytorch_model(inputs_1['input_ids'], token_type_ids=inputs_1['token_type_ids'])[0].argmax().item()\n",
"pred_2 = pytorch_model(inputs_2['input_ids'], token_type_ids=inputs_2['token_type_ids'])[0].argmax().item()\n",
"pred_3 = pytorch_model(inputs_3['input_ids'], token_type_ids=inputs_3['token_type_ids'])[0].argmax().item()\n",
"\n",
"print(\"sentence_0:\", sentence_0)\n",
"print(\"sentence_1:\", sentence_1)\n",
"print(\"sentence_2:\", sentence_2)\n",
"print(\"sentence_3:\", sentence_3)\n",
"print(\"---\")\n",
"print(\"sentence_1 is\", \"a paraphrase\" if pred_1 else \"not a paraphrase\", \"of sentence_0\")\n",
"print(\"sentence_2 is\", \"a paraphrase\" if pred_2 else \"NOT a paraphrase\", \"of sentence_0\")\n",
"print(\"sentence_3 is\", \"a paraphrase\" if pred_3 else \"NOT a paraphrase\", \"of sentence_0\")"
]
}
],
"metadata": {
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3-final"
},
"orig_nbformat": 2,
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
@hideojoho
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment