Last active
August 21, 2019 17:43
-
-
Save archetana/b97e301a3cb14d74d8b1c760f18c48f8 to your computer and use it in GitHub Desktop.
WIC.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.8" | |
}, | |
"colab": { | |
"name": "WIC.ipynb", | |
"version": "0.3.2", | |
"provenance": [], | |
"collapsed_sections": [ | |
"iYPt5fCnBaoH" | |
], | |
"toc_visible": true, | |
"include_colab_link": true | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/archetana/b97e301a3cb14d74d8b1c760f18c48f8/wic_slicing_tutorial.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "vhErVlLJBanz", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"# Snorkel Workshop: Slicing Tutorial" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "55WG0IE6Ban2", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"## Setup\n", | |
"To start, let's make sure that we have the right paths/environment variables set by following the instructions in `snorkel-superglue/README.md`.\n", | |
"\n", | |
"Specifically, ensure that (1) `snorkel` is installed and (2) `SUPERGLUEDATA` is set where [download_superglue_data.py](https://github.com/HazyResearch/snorkel-superglue/blob/staging/download_superglue_data.py) was called." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "ZJ-KMCcRBrfG", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"%cd /content\n", | |
"!git clone https://github.com/HazyResearch/snorkel\n", | |
"%cd /content/snorkel\n", | |
"!git fetch --tags\n", | |
"!git checkout snorkel-superglue" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "S6a_pQhfKiza", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"!pip install -r requirements.txt" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "j5ONUAnqaLaS", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 170 | |
}, | |
"outputId": "02f51a49-079b-43b6-9559-03039e177adc" | |
}, | |
"source": [ | |
"%cd /content\n", | |
"!git clone https://github.com/HazyResearch/snorkel-superglue.git\n", | |
"%cd /content/snorkel-superglue\n", | |
"!git checkout staging" | |
], | |
"execution_count": 6, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"/content\n", | |
"Cloning into 'snorkel-superglue'...\n", | |
"remote: Enumerating objects: 322, done.\u001b[K\n", | |
"remote: Total 322 (delta 0), reused 0 (delta 0), pack-reused 322\u001b[K\n", | |
"Receiving objects: 100% (322/322), 117.52 KiB | 4.90 MiB/s, done.\n", | |
"Resolving deltas: 100% (190/190), done.\n", | |
"/content/snorkel-superglue\n", | |
"Branch 'staging' set up to track remote branch 'staging' from 'origin'.\n", | |
"Switched to a new branch 'staging'\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "T6P-YVDFQMfu", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 1000 | |
}, | |
"outputId": "e2b7f1ae-be8e-4cc8-eb2b-bab177f3fd2d" | |
}, | |
"source": [ | |
"!pip install -r requirements.txt" | |
], | |
"execution_count": 7, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Collecting allennlp (from -r requirements.txt (line 1))\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/30/8c/72b14d20c9cbb0306939ea41109fc599302634fd5c59ccba1a659b7d0360/allennlp-0.8.4-py3-none-any.whl (5.7MB)\n", | |
"\u001b[K |████████████████████████████████| 5.7MB 2.8MB/s \n", | |
"\u001b[?25hCollecting jsonlines (from -r requirements.txt (line 2))\n", | |
" Downloading https://files.pythonhosted.org/packages/4f/9a/ab96291470e305504aa4b7a2e0ec132e930da89eb3ca7a82fbe03167c131/jsonlines-1.2.0-py2.py3-none-any.whl\n", | |
"Collecting pytorch_pretrained_bert (from -r requirements.txt (line 3))\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/d7/e0/c08d5553b89973d9a240605b9c12404bcf8227590de62bae27acbcfe076b/pytorch_pretrained_bert-0.6.2-py3-none-any.whl (123kB)\n", | |
"\u001b[K |████████████████████████████████| 133kB 43.9MB/s \n", | |
"\u001b[?25hRequirement already satisfied: tensorboardX>=1.2 in /usr/local/lib/python3.6/dist-packages (from allennlp->-r requirements.txt (line 1)) (1.6)\n", | |
"Collecting awscli>=1.11.91 (from allennlp->-r requirements.txt (line 1))\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/4a/9b/4e513b4bebdb5c8513ad80156e83e93c372862ac45b11f32801637484494/awscli-1.16.222-py2.py3-none-any.whl (1.9MB)\n", | |
"\u001b[K |████████████████████████████████| 1.9MB 33.1MB/s \n", | |
"\u001b[?25hCollecting responses>=0.7 (from allennlp->-r requirements.txt (line 1))\n", | |
" Downloading https://files.pythonhosted.org/packages/d1/5a/b887e89925f1de7890ef298a74438371ed4ed29b33def9e6d02dc6036fd8/responses-0.10.6-py2.py3-none-any.whl\n", | |
"Requirement already satisfied: nltk in /usr/local/lib/python3.6/dist-packages (from allennlp->-r requirements.txt (line 1)) (3.2.5)\n", | |
"Requirement already satisfied: gevent>=1.3.6 in /usr/local/lib/python3.6/dist-packages (from allennlp->-r requirements.txt (line 1)) (1.4.0)\n", | |
"Collecting jsonnet>=0.10.0; sys_platform != \"win32\" (from allennlp->-r requirements.txt (line 1))\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/a9/a8/adba6cd0f84ee6ab064e7f70cd03a2836cefd2e063fd565180ec13beae93/jsonnet-0.13.0.tar.gz (255kB)\n", | |
"\u001b[K |████████████████████████████████| 256kB 44.4MB/s \n", | |
"\u001b[?25hCollecting ftfy (from allennlp->-r requirements.txt (line 1))\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/75/ca/2d9a5030eaf1bcd925dab392762b9709a7ad4bd486a90599d93cd79cb188/ftfy-5.6.tar.gz (58kB)\n", | |
"\u001b[K |████████████████████████████████| 61kB 21.9MB/s \n", | |
"\u001b[?25hRequirement already satisfied: boto3 in /usr/local/lib/python3.6/dist-packages (from allennlp->-r requirements.txt (line 1)) (1.9.205)\n", | |
"Collecting flaky (from allennlp->-r requirements.txt (line 1))\n", | |
" Downloading https://files.pythonhosted.org/packages/fe/12/0f169abf1aa07c7edef4855cca53703d2e6b7ecbded7829588ac7e7e3424/flaky-3.6.1-py2.py3-none-any.whl\n", | |
"Collecting conllu==0.11 (from allennlp->-r requirements.txt (line 1))\n", | |
" Downloading https://files.pythonhosted.org/packages/d4/2c/856344d9b69baf5b374c395b4286626181a80f0c2b2f704914d18a1cea47/conllu-0.11-py2.py3-none-any.whl\n", | |
"Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.6/dist-packages (from allennlp->-r requirements.txt (line 1)) (2018.9)\n", | |
"Requirement already satisfied: pytest in /usr/local/lib/python3.6/dist-packages (from allennlp->-r requirements.txt (line 1)) (3.6.4)\n", | |
"Collecting word2number>=1.1 (from allennlp->-r requirements.txt (line 1))\n", | |
" Downloading https://files.pythonhosted.org/packages/4a/29/a31940c848521f0725f0df6b25dca8917f13a2025b0e8fcbe5d0457e45e6/word2number-1.1.zip\n", | |
"Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from allennlp->-r requirements.txt (line 1)) (1.16.0)\n", | |
"Requirement already satisfied: sqlparse>=0.2.4 in /usr/local/lib/python3.6/dist-packages (from allennlp->-r requirements.txt (line 1)) (0.3.0)\n", | |
"Requirement already satisfied: scikit-learn in /usr/local/lib/python3.6/dist-packages (from allennlp->-r requirements.txt (line 1)) (0.20.2)\n", | |
"Collecting flask-cors>=3.0.7 (from allennlp->-r requirements.txt (line 1))\n", | |
" Downloading https://files.pythonhosted.org/packages/78/38/e68b11daa5d613e3a91e4bf3da76c94ac9ee0d9cd515af9c1ab80d36f709/Flask_Cors-3.0.8-py2.py3-none-any.whl\n", | |
"Requirement already satisfied: scipy in /usr/local/lib/python3.6/dist-packages (from allennlp->-r requirements.txt (line 1)) (1.2.0)\n", | |
"Requirement already satisfied: matplotlib>=2.2.3 in /usr/local/lib/python3.6/dist-packages (from allennlp->-r requirements.txt (line 1)) (3.0.1)\n", | |
"Collecting unidecode (from allennlp->-r requirements.txt (line 1))\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/d0/42/d9edfed04228bacea2d824904cae367ee9efd05e6cce7ceaaedd0b0ad964/Unidecode-1.1.1-py2.py3-none-any.whl (238kB)\n", | |
"\u001b[K |████████████████████████████████| 245kB 41.9MB/s \n", | |
"\u001b[?25hCollecting numpydoc>=0.8.0 (from allennlp->-r requirements.txt (line 1))\n", | |
" Downloading https://files.pythonhosted.org/packages/6a/f3/7cfe4c616e4b9fe05540256cc9c6661c052c8a4cec2915732793b36e1843/numpydoc-0.9.1.tar.gz\n", | |
"Requirement already satisfied: spacy<2.2,>=2.0.18 in /usr/local/lib/python3.6/dist-packages (from allennlp->-r requirements.txt (line 1)) (2.1.3)\n", | |
"Collecting jsonpickle (from allennlp->-r requirements.txt (line 1))\n", | |
" Downloading https://files.pythonhosted.org/packages/07/07/c157520a3ebd166c8c24c6ae0ecae7c3968eb4653ff0e5af369bb82f004d/jsonpickle-1.2-py2.py3-none-any.whl\n", | |
"Collecting parsimonious>=0.8.0 (from allennlp->-r requirements.txt (line 1))\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/02/fc/067a3f89869a41009e1a7cdfb14725f8ddd246f30f63c645e8ef8a1c56f4/parsimonious-0.8.1.tar.gz (45kB)\n", | |
"\u001b[K |████████████████████████████████| 51kB 19.6MB/s \n", | |
"\u001b[?25hRequirement already satisfied: requests>=2.18 in /usr/local/lib/python3.6/dist-packages (from allennlp->-r requirements.txt (line 1)) (2.21.0)\n", | |
"Requirement already satisfied: flask>=1.0.2 in /usr/local/lib/python3.6/dist-packages (from allennlp->-r requirements.txt (line 1)) (1.1.1)\n", | |
"Requirement already satisfied: tqdm>=4.19 in /usr/local/lib/python3.6/dist-packages (from allennlp->-r requirements.txt (line 1)) (4.29.1)\n", | |
"Requirement already satisfied: editdistance in /usr/local/lib/python3.6/dist-packages (from allennlp->-r requirements.txt (line 1)) (0.5.3)\n", | |
"Requirement already satisfied: torch>=0.4.1 in /usr/local/lib/python3.6/dist-packages (from allennlp->-r requirements.txt (line 1)) (1.1.0)\n", | |
"Requirement already satisfied: h5py in /usr/local/lib/python3.6/dist-packages (from allennlp->-r requirements.txt (line 1)) (2.8.0)\n", | |
"Collecting overrides (from allennlp->-r requirements.txt (line 1))\n", | |
" Downloading https://files.pythonhosted.org/packages/de/55/3100c6d14c1ed177492fcf8f07c4a7d2d6c996c0a7fc6a9a0a41308e7eec/overrides-1.9.tar.gz\n", | |
"Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from jsonlines->-r requirements.txt (line 2)) (1.12.0)\n", | |
"Collecting regex (from pytorch_pretrained_bert->-r requirements.txt (line 3))\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/6f/a6/99eeb5904ab763db87af4bd71d9b1dfdd9792681240657a4c0a599c10a81/regex-2019.08.19.tar.gz (654kB)\n", | |
"\u001b[K |████████████████████████████████| 655kB 41.5MB/s \n", | |
"\u001b[?25hRequirement already satisfied: protobuf>=3.2.0 in /usr/local/lib/python3.6/dist-packages (from tensorboardX>=1.2->allennlp->-r requirements.txt (line 1)) (3.7.1)\n", | |
"Requirement already satisfied: docutils<0.16,>=0.10 in /usr/local/lib/python3.6/dist-packages (from awscli>=1.11.91->allennlp->-r requirements.txt (line 1)) (0.14)\n", | |
"Collecting botocore==1.12.212 (from awscli>=1.11.91->allennlp->-r requirements.txt (line 1))\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/89/aa/c6055e862b9de8462f3d02b011373a86bdef5586e1eb4247f0bde4889d10/botocore-1.12.212-py2.py3-none-any.whl (5.7MB)\n", | |
"\u001b[K |████████████████████████████████| 5.7MB 31.3MB/s \n", | |
"\u001b[?25hCollecting colorama<=0.3.9,>=0.2.5 (from awscli>=1.11.91->allennlp->-r requirements.txt (line 1))\n", | |
" Downloading https://files.pythonhosted.org/packages/db/c8/7dcf9dbcb22429512708fe3a547f8b6101c0d02137acbd892505aee57adf/colorama-0.3.9-py2.py3-none-any.whl\n", | |
"Requirement already satisfied: PyYAML<=5.2,>=3.10; python_version != \"2.6\" in /usr/local/lib/python3.6/dist-packages (from awscli>=1.11.91->allennlp->-r requirements.txt (line 1)) (3.13)\n", | |
"Collecting rsa<=3.5.0,>=3.1.2 (from awscli>=1.11.91->allennlp->-r requirements.txt (line 1))\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/e1/ae/baedc9cb175552e95f3395c43055a6a5e125ae4d48a1d7a924baca83e92e/rsa-3.4.2-py2.py3-none-any.whl (46kB)\n", | |
"\u001b[K |████████████████████████████████| 51kB 20.4MB/s \n", | |
"\u001b[?25hRequirement already satisfied: s3transfer<0.3.0,>=0.2.0 in /usr/local/lib/python3.6/dist-packages (from awscli>=1.11.91->allennlp->-r requirements.txt (line 1)) (0.2.1)\n", | |
"Requirement already satisfied: greenlet>=0.4.14; platform_python_implementation == \"CPython\" in /usr/local/lib/python3.6/dist-packages (from gevent>=1.3.6->allennlp->-r requirements.txt (line 1)) (0.4.15)\n", | |
"Requirement already satisfied: wcwidth in /usr/local/lib/python3.6/dist-packages (from ftfy->allennlp->-r requirements.txt (line 1)) (0.1.7)\n", | |
"Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /usr/local/lib/python3.6/dist-packages (from boto3->allennlp->-r requirements.txt (line 1)) (0.9.4)\n", | |
"Requirement already satisfied: py>=1.5.0 in /usr/local/lib/python3.6/dist-packages (from pytest->allennlp->-r requirements.txt (line 1)) (1.8.0)\n", | |
"Requirement already satisfied: pluggy<0.8,>=0.5 in /usr/local/lib/python3.6/dist-packages (from pytest->allennlp->-r requirements.txt (line 1)) (0.7.1)\n", | |
"Requirement already satisfied: more-itertools>=4.0.0 in /usr/local/lib/python3.6/dist-packages (from pytest->allennlp->-r requirements.txt (line 1)) (7.2.0)\n", | |
"Requirement already satisfied: attrs>=17.4.0 in /usr/local/lib/python3.6/dist-packages (from pytest->allennlp->-r requirements.txt (line 1)) (19.1.0)\n", | |
"Requirement already satisfied: atomicwrites>=1.0 in /usr/local/lib/python3.6/dist-packages (from pytest->allennlp->-r requirements.txt (line 1)) (1.3.0)\n", | |
"Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from pytest->allennlp->-r requirements.txt (line 1)) (41.1.0)\n", | |
"Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib>=2.2.3->allennlp->-r requirements.txt (line 1)) (2.5.3)\n", | |
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib>=2.2.3->allennlp->-r requirements.txt (line 1)) (0.10.0)\n", | |
"Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib>=2.2.3->allennlp->-r requirements.txt (line 1)) (2.4.2)\n", | |
"Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib>=2.2.3->allennlp->-r requirements.txt (line 1)) (1.1.0)\n", | |
"Requirement already satisfied: sphinx>=1.6.5 in /usr/local/lib/python3.6/dist-packages (from numpydoc>=0.8.0->allennlp->-r requirements.txt (line 1)) (1.8.5)\n", | |
"Requirement already satisfied: Jinja2>=2.3 in /usr/local/lib/python3.6/dist-packages (from numpydoc>=0.8.0->allennlp->-r requirements.txt (line 1)) (2.10.1)\n", | |
"Requirement already satisfied: srsly<1.1.0,>=0.0.5 in /usr/local/lib/python3.6/dist-packages (from spacy<2.2,>=2.0.18->allennlp->-r requirements.txt (line 1)) (0.0.7)\n", | |
"Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in /usr/local/lib/python3.6/dist-packages (from spacy<2.2,>=2.0.18->allennlp->-r requirements.txt (line 1)) (1.0.2)\n", | |
"Requirement already satisfied: preshed<2.1.0,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from spacy<2.2,>=2.0.18->allennlp->-r requirements.txt (line 1)) (2.0.1)\n", | |
"Requirement already satisfied: blis<0.3.0,>=0.2.2 in /usr/local/lib/python3.6/dist-packages (from spacy<2.2,>=2.0.18->allennlp->-r requirements.txt (line 1)) (0.2.4)\n", | |
"Requirement already satisfied: cymem<2.1.0,>=2.0.2 in /usr/local/lib/python3.6/dist-packages (from spacy<2.2,>=2.0.18->allennlp->-r requirements.txt (line 1)) (2.0.2)\n", | |
"Requirement already satisfied: plac<1.0.0,>=0.9.6 in /usr/local/lib/python3.6/dist-packages (from spacy<2.2,>=2.0.18->allennlp->-r requirements.txt (line 1)) (0.9.6)\n", | |
"Requirement already satisfied: wasabi<1.1.0,>=0.2.0 in /usr/local/lib/python3.6/dist-packages (from spacy<2.2,>=2.0.18->allennlp->-r requirements.txt (line 1)) (0.2.2)\n", | |
"Requirement already satisfied: jsonschema<3.0.0,>=2.6.0 in /usr/local/lib/python3.6/dist-packages (from spacy<2.2,>=2.0.18->allennlp->-r requirements.txt (line 1)) (2.6.0)\n", | |
"Requirement already satisfied: thinc<7.1.0,>=7.0.2 in /usr/local/lib/python3.6/dist-packages (from spacy<2.2,>=2.0.18->allennlp->-r requirements.txt (line 1)) (7.0.8)\n", | |
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests>=2.18->allennlp->-r requirements.txt (line 1)) (2019.6.16)\n", | |
"Requirement already satisfied: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests>=2.18->allennlp->-r requirements.txt (line 1)) (2.8)\n", | |
"Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests>=2.18->allennlp->-r requirements.txt (line 1)) (1.24.3)\n", | |
"Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests>=2.18->allennlp->-r requirements.txt (line 1)) (3.0.4)\n", | |
"Requirement already satisfied: itsdangerous>=0.24 in /usr/local/lib/python3.6/dist-packages (from flask>=1.0.2->allennlp->-r requirements.txt (line 1)) (1.1.0)\n", | |
"Requirement already satisfied: Werkzeug>=0.15 in /usr/local/lib/python3.6/dist-packages (from flask>=1.0.2->allennlp->-r requirements.txt (line 1)) (0.15.5)\n", | |
"Requirement already satisfied: click>=5.1 in /usr/local/lib/python3.6/dist-packages (from flask>=1.0.2->allennlp->-r requirements.txt (line 1)) (7.0)\n", | |
"Requirement already satisfied: pyasn1>=0.1.3 in /usr/local/lib/python3.6/dist-packages (from rsa<=3.5.0,>=3.1.2->awscli>=1.11.91->allennlp->-r requirements.txt (line 1)) (0.4.6)\n", | |
"Requirement already satisfied: alabaster<0.8,>=0.7 in /usr/local/lib/python3.6/dist-packages (from sphinx>=1.6.5->numpydoc>=0.8.0->allennlp->-r requirements.txt (line 1)) (0.7.12)\n", | |
"Requirement already satisfied: babel!=2.0,>=1.3 in /usr/local/lib/python3.6/dist-packages (from sphinx>=1.6.5->numpydoc>=0.8.0->allennlp->-r requirements.txt (line 1)) (2.7.0)\n", | |
"Requirement already satisfied: Pygments>=2.0 in /usr/local/lib/python3.6/dist-packages (from sphinx>=1.6.5->numpydoc>=0.8.0->allennlp->-r requirements.txt (line 1)) (2.1.3)\n", | |
"Requirement already satisfied: sphinxcontrib-websupport in /usr/local/lib/python3.6/dist-packages (from sphinx>=1.6.5->numpydoc>=0.8.0->allennlp->-r requirements.txt (line 1)) (1.1.2)\n", | |
"Requirement already satisfied: imagesize in /usr/local/lib/python3.6/dist-packages (from sphinx>=1.6.5->numpydoc>=0.8.0->allennlp->-r requirements.txt (line 1)) (1.1.0)\n", | |
"Requirement already satisfied: packaging in /usr/local/lib/python3.6/dist-packages (from sphinx>=1.6.5->numpydoc>=0.8.0->allennlp->-r requirements.txt (line 1)) (19.1)\n", | |
"Requirement already satisfied: snowballstemmer>=1.1 in /usr/local/lib/python3.6/dist-packages (from sphinx>=1.6.5->numpydoc>=0.8.0->allennlp->-r requirements.txt (line 1)) (1.9.0)\n", | |
"Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/lib/python3.6/dist-packages (from Jinja2>=2.3->numpydoc>=0.8.0->allennlp->-r requirements.txt (line 1)) (1.1.1)\n", | |
"Building wheels for collected packages: jsonnet, ftfy, word2number, numpydoc, parsimonious, overrides, regex\n", | |
" Building wheel for jsonnet (setup.py) ... \u001b[?25l\u001b[?25hdone\n", | |
" Created wheel for jsonnet: filename=jsonnet-0.13.0-cp36-cp36m-linux_x86_64.whl size=3320377 sha256=006a85087234cd4b7d1a8fcf5104e2a2835ef0a92d2e632401ed8e1edeb428d2\n", | |
" Stored in directory: /root/.cache/pip/wheels/1a/30/ab/ae4a57b1df44fa20a531edb9601b27603da8f5336225691f3f\n", | |
" Building wheel for ftfy (setup.py) ... \u001b[?25l\u001b[?25hdone\n", | |
" Created wheel for ftfy: filename=ftfy-5.6-cp36-none-any.whl size=44554 sha256=9b872437541ab88439a5f17616a15d0b1df905e08e884dae21245711cc3ac158\n", | |
" Stored in directory: /root/.cache/pip/wheels/43/34/ce/cbb38d71543c408de56f3c5e26ce8ba495a0fa5a28eaaf1046\n", | |
" Building wheel for word2number (setup.py) ... \u001b[?25l\u001b[?25hdone\n", | |
" Created wheel for word2number: filename=word2number-1.1-cp36-none-any.whl size=5587 sha256=34efc91f98deeddbf54e484f6f05e90527e69b9957cec695e8917ccad1ec0939\n", | |
" Stored in directory: /root/.cache/pip/wheels/46/2f/53/5f5c1d275492f2fce1cdab9a9bb12d49286dead829a4078e0e\n", | |
" Building wheel for numpydoc (setup.py) ... \u001b[?25l\u001b[?25hdone\n", | |
" Created wheel for numpydoc: filename=numpydoc-0.9.1-cp36-none-any.whl size=31873 sha256=6906ff218f2256d96c5cf9b4ed75fae935c8706a77ce80afdaa0d2f5a34b62ad\n", | |
" Stored in directory: /root/.cache/pip/wheels/51/30/d1/92a39ba40f21cb70e53f8af96eb98f002a781843c065406500\n", | |
" Building wheel for parsimonious (setup.py) ... \u001b[?25l\u001b[?25hdone\n", | |
" Created wheel for parsimonious: filename=parsimonious-0.8.1-cp36-none-any.whl size=42710 sha256=e237871c7499fe70b060dc52347bbc354b3b8216810bb8303fafe0eeafb88a42\n", | |
" Stored in directory: /root/.cache/pip/wheels/b7/8d/e7/a0e74217da5caeb3c1c7689639b6d28ddbf9985b840bc96a9a\n", | |
" Building wheel for overrides (setup.py) ... \u001b[?25l\u001b[?25hdone\n", | |
" Created wheel for overrides: filename=overrides-1.9-cp36-none-any.whl size=4213 sha256=3f2e244eade45ba14db64ad62f025ec1fc2b2abe74ac12858a6637af17cbdb4f\n", | |
" Stored in directory: /root/.cache/pip/wheels/8d/52/86/e5a83b1797e7d263b458d2334edd2704c78508b3eea9323718\n", | |
" Building wheel for regex (setup.py) ... \u001b[?25l\u001b[?25hdone\n", | |
" Created wheel for regex: filename=regex-2019.8.19-cp36-cp36m-linux_x86_64.whl size=609226 sha256=60b3613dfb34d09b865bc07f620c4315864d4faca454afca694843acf98101e0\n", | |
" Stored in directory: /root/.cache/pip/wheels/90/04/07/b5010fb816721eb3d6dd64ed5cc8111ca23f97fdab8619b5be\n", | |
"Successfully built jsonnet ftfy word2number numpydoc parsimonious overrides regex\n", | |
"Installing collected packages: botocore, colorama, rsa, awscli, responses, jsonnet, ftfy, flaky, conllu, word2number, flask-cors, unidecode, numpydoc, regex, pytorch-pretrained-bert, jsonpickle, parsimonious, overrides, allennlp, jsonlines\n", | |
" Found existing installation: botocore 1.12.205\n", | |
" Uninstalling botocore-1.12.205:\n", | |
" Successfully uninstalled botocore-1.12.205\n", | |
" Found existing installation: rsa 4.0\n", | |
" Uninstalling rsa-4.0:\n", | |
" Successfully uninstalled rsa-4.0\n", | |
"Successfully installed allennlp-0.8.4 awscli-1.16.222 botocore-1.12.212 colorama-0.3.9 conllu-0.11 flaky-3.6.1 flask-cors-3.0.8 ftfy-5.6 jsonlines-1.2.0 jsonnet-0.13.0 jsonpickle-1.2 numpydoc-0.9.1 overrides-1.9 parsimonious-0.8.1 pytorch-pretrained-bert-0.6.2 regex-2019.8.19 responses-0.10.6 rsa-3.4.2 unidecode-1.1.1 word2number-1.1\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"application/vnd.colab-display-data+json": { | |
"pip_warning": { | |
"packages": [ | |
"botocore", | |
"rsa" | |
] | |
} | |
} | |
}, | |
"metadata": { | |
"tags": [] | |
} | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "3Br1MasaBan3", | |
"colab_type": "code", | |
"outputId": "b84425f9-2e82-44a6-fd3e-5c2609e2ca73", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 88 | |
} | |
}, | |
"source": [ | |
"import sys, os\n", | |
"from pathlib import Path\n", | |
"\n", | |
"if not \"cwd\" in globals():\n", | |
" cwd = Path(os.getcwd())\n", | |
"sys.path.insert(0, str(cwd))\n", | |
"print(sys.path)\n", | |
"print(cwd)" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"/content/snorkel\n", | |
"['/content/snorkel', '/content', '/content', '/content', '', '/env/python', '/usr/lib/python36.zip', '/usr/lib/python3.6', '/usr/lib/python3.6/lib-dynload', '/usr/local/lib/python3.6/dist-packages', '/usr/lib/python3/dist-packages', '/usr/local/lib/python3.6/dist-packages/IPython/extensions', '/root/.ipython']\n", | |
"/content/snorkel\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "7Qbn_RaiBan7", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"import pandas as pd\n", | |
"# Don't truncate the sentence when viewing examples\n", | |
"pd.set_option('display.max_colwidth', -1)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "SAj3mzyTBan_", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"Note, we rely heavily on the `snorkel.mtl` module, which is a great abstraction for implementing these slicing tasks. \n", | |
"Intuitively, we want an API to add extra model capacity corresponding to each slice—exactly what this module flexibly provides!" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "5AAXgHuYBaoA", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"from snorkel.mtl.data import MultitaskDataLoader\n", | |
"from snorkel.mtl.model import MultitaskModel\n", | |
"from snorkel.mtl.snorkel_config import default_config as config\n", | |
"from snorkel.mtl.trainer import Trainer" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "XX17UkI1BaoE", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"import superglue_tasks\n", | |
"from tokenizer import get_tokenizer\n", | |
"from utils import task_dataset_to_dataframe" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "erB05hA6QtX6", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 1000 | |
}, | |
"outputId": "b887050c-e18a-4699-e1f9-7908d72cb89e" | |
}, | |
"source": [ | |
"!sh download_superglue_data.sh /content/snorkel-superglue/data" | |
], | |
"execution_count": 9, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Downloading primary SuperGLUE tasks.\n", | |
"Downloading and extracting CB...\n", | |
"\tCompleted!\n", | |
"Downloading and extracting COPA...\n", | |
"\tCompleted!\n", | |
"Downloading and extracting MultiRC...\n", | |
"\tCompleted!\n", | |
"Downloading and extracting RTE...\n", | |
"\tCompleted!\n", | |
"Downloading and extracting WiC...\n", | |
"\tCompleted!\n", | |
"Downloading and extracting WSC...\n", | |
"\tCompleted!\n", | |
"Downloading and extracting diagnostic...\n", | |
"\tCompleted!\n", | |
"Downloading and unzipping SWAG.\n", | |
"--2019-08-21 17:40:14-- https://www.dropbox.com/s/cklxrrzisd3zzuh/SWAG.zip?dl=1\n", | |
"Resolving www.dropbox.com (www.dropbox.com)... 162.125.9.1, 2620:100:601f:1::a27d:901\n", | |
"Connecting to www.dropbox.com (www.dropbox.com)|162.125.9.1|:443... connected.\n", | |
"HTTP request sent, awaiting response... 301 Moved Permanently\n", | |
"Location: /s/dl/cklxrrzisd3zzuh/SWAG.zip [following]\n", | |
"--2019-08-21 17:40:14-- https://www.dropbox.com/s/dl/cklxrrzisd3zzuh/SWAG.zip\n", | |
"Reusing existing connection to www.dropbox.com:443.\n", | |
"HTTP request sent, awaiting response... 302 Found\n", | |
"Location: https://uc23ceae382864dd122725b0b23d.dl.dropboxusercontent.com/cd/0/get/AnD2t18FEbyb-5M4J819sa4dGHe3XFSRn9J14RNIlF_mADGmBBjaT4rD3guiIai4pk4SkMnEh3dlYHJdirflgujVCXSzt2jmI9Y96bJLP1V_nQ/file?dl=1# [following]\n", | |
"--2019-08-21 17:40:14-- https://uc23ceae382864dd122725b0b23d.dl.dropboxusercontent.com/cd/0/get/AnD2t18FEbyb-5M4J819sa4dGHe3XFSRn9J14RNIlF_mADGmBBjaT4rD3guiIai4pk4SkMnEh3dlYHJdirflgujVCXSzt2jmI9Y96bJLP1V_nQ/file?dl=1\n", | |
"Resolving uc23ceae382864dd122725b0b23d.dl.dropboxusercontent.com (uc23ceae382864dd122725b0b23d.dl.dropboxusercontent.com)... 162.125.9.6, 2620:100:601f:6::a27d:906\n", | |
"Connecting to uc23ceae382864dd122725b0b23d.dl.dropboxusercontent.com (uc23ceae382864dd122725b0b23d.dl.dropboxusercontent.com)|162.125.9.6|:443... connected.\n", | |
"HTTP request sent, awaiting response... 200 OK\n", | |
"Length: 57192456 (55M) [application/binary]\n", | |
"Saving to: ‘/content/snorkel-superglue/data/SWAG.zip?dl=1’\n", | |
"\n", | |
"SWAG.zip?dl=1 100%[===================>] 54.54M 7.72MB/s in 8.2s \n", | |
"\n", | |
"2019-08-21 17:40:23 (6.67 MB/s) - ‘/content/snorkel-superglue/data/SWAG.zip?dl=1’ saved [57192456/57192456]\n", | |
"\n", | |
"Archive: /content/snorkel-superglue/data/SWAG.zip\n", | |
" inflating: /content/snorkel-superglue/data/SWAG/README.md \n", | |
" creating: /content/snorkel-superglue/data/SWAG/swagaf/\n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/LICENSE \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/requirements.txt \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/Dockerfile \n", | |
" creating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/\n", | |
" creating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/unarylstm/\n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/unarylstm/train-lstmbasic-numberbatch.json \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/unarylstm/train-lstmbasic-numberbatch-goldonly.json \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/unarylstm/train-lstmbasic-glove-endingonly.json \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/unarylstm/train-lstmbasic-elmo.json \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/unarylstm/train-lstmbasic-numberbatch-goldonly-endingonly.json \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/unarylstm/train-cnn.json \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/unarylstm/run_experiments.sh \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/unarylstm/train-lstmbasic-elmo-endingonly.json \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/unarylstm/predict.py \n", | |
" extracting: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/unarylstm/__init__.py \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/unarylstm/lstm_swag.py \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/unarylstm/train-lstmbasic-glove-goldonly.json \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/unarylstm/train-lstmbasic-elmo-goldonly-endingonly.json \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/unarylstm/train.json \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/unarylstm/train-lstmbasic-elmo-goldonly.json \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/unarylstm/dataset_reader.py \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/unarylstm/run_experiments_ending.sh \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/unarylstm/train-lstmbasic-numberbatch-endingonly.json \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/unarylstm/train-lstmbasic-glove-goldonly-endingonly.json \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/unarylstm/train-lstmbasic-glove.json \n", | |
" extracting: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/__init__.py \n", | |
" creating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/fasttext/\n", | |
" extracting: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/fasttext/__init__.py \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/fasttext/prep_data.py \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/fasttext/README.md \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/fasttext/compute_performance.py \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/README.md \n", | |
" creating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/decomposable_attention/\n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/decomposable_attention/train-numberbatch.json \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/decomposable_attention/train-glove-goldonly-840.json \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/decomposable_attention/train-glove-goldonly.json \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/decomposable_attention/run_experiments.sh \n", | |
" extracting: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/decomposable_attention/__init__.py \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/decomposable_attention/README.md \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/decomposable_attention/train-glove-840.json \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/decomposable_attention/train-elmo.json \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/decomposable_attention/train-elmo-goldonly.json \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/decomposable_attention/dataset_reader.py \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/decomposable_attention/train-glove.json \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/decomposable_attention/train-numberbatch-goldonly.json \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/decomposable_attention/decomposable_attention_swag.py \n", | |
" creating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/esim/\n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/esim/train-numberbatch.json \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/esim/train-glove-goldonly.json \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/esim/run_experiments.sh \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/esim/predict.py \n", | |
" extracting: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/esim/__init__.py \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/esim/README.md \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/esim/esim_swag.py \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/esim/train-elmo.json \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/esim/train-elmo-goldonly.json \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/esim/dataset_reader.py \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/esim/train-glove.json \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/swag_baselines/esim/train-numberbatch-goldonly.json \n", | |
" creating: /content/snorkel-superglue/data/SWAG/swagaf/raw_data/\n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/raw_data/events.py \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/README.md \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/.dockerignore \n", | |
" creating: /content/snorkel-superglue/data/SWAG/swagaf/create_swag/\n", | |
" creating: /content/snorkel-superglue/data/SWAG/swagaf/create_swag/lm/\n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/create_swag/lm/config.py \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/create_swag/lm/train_lm.sh \n", | |
" extracting: /content/snorkel-superglue/data/SWAG/swagaf/create_swag/lm/__init__.py \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/create_swag/lm/pretrain_lm.py \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/create_swag/lm/README.md \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/create_swag/lm/train_lm.py \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/create_swag/lm/load_data.py \n", | |
" creating: /content/snorkel-superglue/data/SWAG/swagaf/create_swag/lm/vocabulary/\n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/create_swag/lm/vocabulary/tokens.txt \n", | |
" extracting: /content/snorkel-superglue/data/SWAG/swagaf/create_swag/lm/vocabulary/non_padded_namespaces.txt \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/create_swag/lm/simple_bilm.py \n", | |
" extracting: /content/snorkel-superglue/data/SWAG/swagaf/create_swag/__init__.py \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/create_swag/turktemplate.html \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/create_swag/README.md \n", | |
" creating: /content/snorkel-superglue/data/SWAG/swagaf/create_swag/generate_candidates/\n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/create_swag/generate_candidates/rebalance_dataset_mlp.py \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/create_swag/generate_candidates/rebalance_dataset_ensemble.py \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/create_swag/generate_candidates/sample_candidates.sh \n", | |
" extracting: /content/snorkel-superglue/data/SWAG/swagaf/create_swag/generate_candidates/__init__.py \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/create_swag/generate_candidates/README.md \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/create_swag/generate_candidates/classifiers.py \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/create_swag/generate_candidates/questions2mturk.py \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/create_swag/generate_candidates/sample_candidates.py \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/evaluation.yaml \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/pytorch_misc.py \n", | |
" creating: /content/snorkel-superglue/data/SWAG/swagaf/.git/\n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/.git/config \n", | |
" creating: /content/snorkel-superglue/data/SWAG/swagaf/.git/objects/\n", | |
" creating: /content/snorkel-superglue/data/SWAG/swagaf/.git/objects/pack/\n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/.git/objects/pack/pack-73a4d1b716989105c801b48a8a502612b22c142d.pack \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/.git/objects/pack/pack-73a4d1b716989105c801b48a8a502612b22c142d.idx \n", | |
" creating: /content/snorkel-superglue/data/SWAG/swagaf/.git/objects/info/\n", | |
" extracting: /content/snorkel-superglue/data/SWAG/swagaf/.git/HEAD \n", | |
" creating: /content/snorkel-superglue/data/SWAG/swagaf/.git/info/\n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/.git/info/exclude \n", | |
" creating: /content/snorkel-superglue/data/SWAG/swagaf/.git/logs/\n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/.git/logs/HEAD \n", | |
" creating: /content/snorkel-superglue/data/SWAG/swagaf/.git/logs/refs/\n", | |
" creating: /content/snorkel-superglue/data/SWAG/swagaf/.git/logs/refs/heads/\n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/.git/logs/refs/heads/master \n", | |
" creating: /content/snorkel-superglue/data/SWAG/swagaf/.git/logs/refs/remotes/\n", | |
" creating: /content/snorkel-superglue/data/SWAG/swagaf/.git/logs/refs/remotes/origin/\n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/.git/logs/refs/remotes/origin/HEAD \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/.git/description \n", | |
" creating: /content/snorkel-superglue/data/SWAG/swagaf/.git/hooks/\n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/.git/hooks/commit-msg.sample \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/.git/hooks/pre-rebase.sample \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/.git/hooks/pre-commit.sample \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/.git/hooks/applypatch-msg.sample \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/.git/hooks/prepare-commit-msg.sample \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/.git/hooks/post-update.sample \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/.git/hooks/pre-applypatch.sample \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/.git/hooks/pre-push.sample \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/.git/hooks/update.sample \n", | |
" creating: /content/snorkel-superglue/data/SWAG/swagaf/.git/refs/\n", | |
" creating: /content/snorkel-superglue/data/SWAG/swagaf/.git/refs/heads/\n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/.git/refs/heads/master \n", | |
" creating: /content/snorkel-superglue/data/SWAG/swagaf/.git/refs/tags/\n", | |
" creating: /content/snorkel-superglue/data/SWAG/swagaf/.git/refs/remotes/\n", | |
" creating: /content/snorkel-superglue/data/SWAG/swagaf/.git/refs/remotes/origin/\n", | |
" extracting: /content/snorkel-superglue/data/SWAG/swagaf/.git/refs/remotes/origin/HEAD \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/.git/index \n", | |
" creating: /content/snorkel-superglue/data/SWAG/swagaf/.git/branches/\n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/.git/packed-refs \n", | |
" creating: /content/snorkel-superglue/data/SWAG/swagaf/data/\n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/data/val.csv \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/data/test.csv \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/data/README.md \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/data/train_full.csv \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/data/val_full.csv \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/swagaf/data/train.csv \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/test.csv \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/train.csv \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/train_full.csv \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/val.csv \n", | |
" inflating: /content/snorkel-superglue/data/SWAG/val_full.csv \n", | |
"Done.\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "iYPt5fCnBaoH", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"## Explore the WiC dataset\n", | |
"We'll be working with the [Words in Context (WiC) task](https://pilehvar.github.io/wic/). To start, let's look at a few examples. To do so, we'll convert them to dataframes." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "qs2azRoOBaoI", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"from dataloaders import get_jsonl_path\n", | |
"from superglue_parsers.wic import get_rows\n", | |
"\n", | |
"task_name = \"WiC\"\n", | |
"data_dir = \"/content/snorkel-superglue/data\"\n", | |
"split = \"valid\"\n", | |
"max_data_samples = None # max examples to include in dataset\n", | |
"\n", | |
"jsonl_path = get_jsonl_path(data_dir, task_name, split)\n", | |
"wic_df = pd.DataFrame.from_records(get_rows(jsonl_path, max_data_samples=max_data_samples))" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "NgHUqfbOBaoL", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"Recall, the WiC task is used to identify the intended meaning of specified words across multiple contexts—the `label` indicates whether the word is used in the same sense in both `sentence1` and `sentence2`!" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "4QmPSEMOBaoM", | |
"colab_type": "code", | |
"outputId": "47d3f98d-6cce-42d6-e1de-4d5c38537c30", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 238 | |
} | |
}, | |
"source": [ | |
"wic_df[[\"sentence1\", \"sentence2\", \"word\", \"label\"]].head()" | |
], | |
"execution_count": 13, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>sentence1</th>\n", | |
" <th>sentence2</th>\n", | |
" <th>word</th>\n", | |
" <th>label</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>Room and board .</td>\n", | |
" <td>He nailed boards across the windows .</td>\n", | |
" <td>board</td>\n", | |
" <td>False</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>Circulate a rumor .</td>\n", | |
" <td>This letter is being circulated among the faculty .</td>\n", | |
" <td>circulate</td>\n", | |
" <td>False</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>Hook a fish .</td>\n", | |
" <td>He hooked a snake accidentally , and was so scared he dropped his rod into the water .</td>\n", | |
" <td>hook</td>\n", | |
" <td>True</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>For recreation he wrote poetry and solved crossword puzzles .</td>\n", | |
" <td>Drug abuse is often regarded as a form of recreation .</td>\n", | |
" <td>recreation</td>\n", | |
" <td>True</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>Making a hobby of domesticity .</td>\n", | |
" <td>A royal family living in unpretentious domesticity .</td>\n", | |
" <td>domesticity</td>\n", | |
" <td>False</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" sentence1 ... label\n", | |
"0 Room and board . ... False\n", | |
"1 Circulate a rumor . ... False\n", | |
"2 Hook a fish . ... True \n", | |
"3 For recreation he wrote poetry and solved crossword puzzles . ... True \n", | |
"4 Making a hobby of domesticity . ... False\n", | |
"\n", | |
"[5 rows x 4 columns]" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 13 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "YZTShyemBaoW", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"## Train a model using BERT\n", | |
"Now, let's train a model using the Snorkel API, with the [BERT](https://arxiv.org/abs/1810.04805) model, a powerful pre-training mechanism for general language understanding.\n", | |
"Thanks to folks at [huggingface](https://github.com/huggingface/pytorch-pretrained-BERT), we can use this model in PyTorch with with a simple import statement!" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "tLKHbd7_BaoX", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"bert_model = \"bert-large-cased\"\n", | |
"tokenizer_name = \"bert-large-cased\"\n", | |
"batch_size = 4\n", | |
"max_sequence_length = 256" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "rGpZrkbUBaoa", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"outputId": "338975f9-3d27-4fc3-b904-9eeaa5c73c45" | |
}, | |
"source": [ | |
"# load the word-piece tokenizer for the 'bert-large-cased' vocabulary\n", | |
"tokenizer = get_tokenizer(tokenizer_name)" | |
], | |
"execution_count": 15, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 213450/213450 [00:00<00:00, 5268333.12B/s]\n" | |
], | |
"name": "stderr" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "pW07_XmiBaog", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"In the style of the Snorkel `Multitask` tutorial, we'll use a few helpers to load them into PyTorch datasets that we wrap with a `MultitaskDataLoader`." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "JE5LJvRgBaoi", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"from dataloaders import get_dataset\n", | |
"\n", | |
"datasets = []\n", | |
"dataloaders = []\n", | |
"for split in [\"train\", \"valid\"]:\n", | |
" # parse raw data and format it as a Pytorch dataset\n", | |
" dataset = get_dataset(\n", | |
" data_dir, task_name, split, tokenizer, max_data_samples, max_sequence_length\n", | |
" )\n", | |
" dataloader = MultitaskDataLoader(\n", | |
" task_to_label_dict={task_name: \"labels\"},\n", | |
" dataset=dataset,\n", | |
" split=split,\n", | |
" batch_size=batch_size,\n", | |
" shuffle=(split == \"train\"),\n", | |
" )\n", | |
" datasets.append(dataset)\n", | |
" dataloaders.append(dataloader)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "ddxCci8dBaom", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"Our model is fairly simple, and identical to the baseline model suggested by the SuperGLUE creators. We feed both sentences through a pre-trained BERT module, then concatenate the output of its classification token with the final representation of the target token (the word whose sense we're disambiguating) in each sentence." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "tqwOO4gFBaon", | |
"colab_type": "code", | |
"outputId": "27d3b81a-1976-44be-87ea-e9642503da3e", | |
"colab": {} | |
}, | |
"source": [ | |
"# Construct base task\n", | |
"base_task = superglue_tasks.task_funcs[task_name](bert_model)\n", | |
"tasks = [base_task]\n", | |
"tasks" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"[Task(name=WiC)]" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 10 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "xCOqwCeEBaor", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"model = MultitaskModel(\n", | |
" name=f\"SuperGLUE\",\n", | |
" tasks=tasks, \n", | |
" dataparallel=False,\n", | |
" device=-1 # use CPU\n", | |
")" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "gxjUzFXpBaov", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"We've pretrained a model for you, but feel free to uncomment this line to experiment with it yourself!" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "U7ojPBpRBaow", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# trainer = Trainer(**config)\n", | |
"# trainer.train_model(slice_model, dataloaders)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "nSjnMGcUBapB", | |
"colab_type": "code", | |
"outputId": "bc1502ee-46e6-4a0f-fef9-5770abf51767", | |
"colab": {} | |
}, | |
"source": [ | |
"# If you're missing the model, uncomment this line:\n", | |
"# ! wget -nc https://www.dropbox.com/s/vix9bhzy18o3wjl/WiC_bert.pth\n", | |
"\n", | |
"# wic_path = \"WiC_bert.pth\"\n", | |
"# model.load(wic_path)" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"--2019-08-03 02:32:48-- https://www.dropbox.com/s/vix9bhzy18o3wjl/WiC_bert.pth\n", | |
"Resolving www.dropbox.com (www.dropbox.com)... 162.125.1.1, 2620:100:6016:1::a27d:101\n", | |
"Connecting to www.dropbox.com (www.dropbox.com)|162.125.1.1|:443... connected.\n", | |
"HTTP request sent, awaiting response... 301 Moved Permanently\n", | |
"Location: /s/raw/vix9bhzy18o3wjl/WiC_bert.pth [following]\n", | |
"--2019-08-03 02:32:48-- https://www.dropbox.com/s/raw/vix9bhzy18o3wjl/WiC_bert.pth\n", | |
"Reusing existing connection to www.dropbox.com:443.\n", | |
"HTTP request sent, awaiting response... 302 Found\n", | |
"Location: https://uc5fbf1e57c6929ec88de0a85822.dl.dropboxusercontent.com/cd/0/inline/Al7aymtIUKN8POR4ob4DribD1koqNAU_hr0282L7a9GA7cFoSj629Q9q1DX7EIDdaeuFO3vNOQk7PDQdvd6hvzV8dN7wEtxHWuLSCumWVvYZ2A/file# [following]\n", | |
"--2019-08-03 02:32:49-- https://uc5fbf1e57c6929ec88de0a85822.dl.dropboxusercontent.com/cd/0/inline/Al7aymtIUKN8POR4ob4DribD1koqNAU_hr0282L7a9GA7cFoSj629Q9q1DX7EIDdaeuFO3vNOQk7PDQdvd6hvzV8dN7wEtxHWuLSCumWVvYZ2A/file\n", | |
"Resolving uc5fbf1e57c6929ec88de0a85822.dl.dropboxusercontent.com (uc5fbf1e57c6929ec88de0a85822.dl.dropboxusercontent.com)... 162.125.1.6, 2620:100:6016:6::a27d:106\n", | |
"Connecting to uc5fbf1e57c6929ec88de0a85822.dl.dropboxusercontent.com (uc5fbf1e57c6929ec88de0a85822.dl.dropboxusercontent.com)|162.125.1.6|:443... connected.\n", | |
"HTTP request sent, awaiting response... 302 FOUND\n", | |
"Location: /cd/0/inline2/Al7PburfRR6N1Vf7oK23g3yqxiOJMht6SBoMocy1un04fX-jTb_541KY31jRU4mAHSpD-ugX64sjgvykWM5DIKBPm77Ssv_kh9lyCkut5jrct5jBfj_MyBYnbZ1ZzuOfVKIfP6GoHySpFGCTL2NQme6RKZ5J7ub8bD-C-yDihxydJMREeGomVddA40CBL8XzYRKIB9tPeb4V3XvA9Hnwu0kMF6qGwex7Nv7xjHmK51aAJh2S_cqIZlBYnIOy-5sY4zwk0ZpH1Q5pMCn50GTDn-n9ypQbENNqSTKqoCpHgxxk1e0ZvL9YPcqCzkBpxcVFnJ7N_U0aIajJQmuWEEy98BHA/file [following]\n", | |
"--2019-08-03 02:32:50-- https://uc5fbf1e57c6929ec88de0a85822.dl.dropboxusercontent.com/cd/0/inline2/Al7PburfRR6N1Vf7oK23g3yqxiOJMht6SBoMocy1un04fX-jTb_541KY31jRU4mAHSpD-ugX64sjgvykWM5DIKBPm77Ssv_kh9lyCkut5jrct5jBfj_MyBYnbZ1ZzuOfVKIfP6GoHySpFGCTL2NQme6RKZ5J7ub8bD-C-yDihxydJMREeGomVddA40CBL8XzYRKIB9tPeb4V3XvA9Hnwu0kMF6qGwex7Nv7xjHmK51aAJh2S_cqIZlBYnIOy-5sY4zwk0ZpH1Q5pMCn50GTDn-n9ypQbENNqSTKqoCpHgxxk1e0ZvL9YPcqCzkBpxcVFnJ7N_U0aIajJQmuWEEy98BHA/file\n", | |
"Reusing existing connection to uc5fbf1e57c6929ec88de0a85822.dl.dropboxusercontent.com:443.\n", | |
"HTTP request sent, awaiting response... 200 OK\n", | |
"Length: 1334440755 (1.2G) [application/octet-stream]\n", | |
"Saving to: ‘WiC_bert.pth’\n", | |
"\n", | |
"WiC_bert.pth 100%[===================>] 1.24G 36.3MB/s in 35s \n", | |
"\n", | |
"2019-08-03 02:33:25 (36.4 MB/s) - ‘WiC_bert.pth’ saved [1334440755/1334440755]\n", | |
"\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "pRzhnFVsBapG", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"How well do we do on the valid set?" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "1ehRdkCOBapH", | |
"colab_type": "code", | |
"outputId": "db90e040-4a80-4acb-ade9-b011b09450bc", | |
"colab": {} | |
}, | |
"source": [ | |
"%%time\n", | |
"model.score(dataloaders[1])" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 6min 49s, sys: 980 ms, total: 6min 50s\n", | |
"Wall time: 4min 14s\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"{'WiC/SuperGLUE/valid/accuracy': 0.7460815047021944}" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 15 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "YZeQyDmGBapL", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"## Error analysis (to give us ideas for slicing)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "3ShsSHrRBapM", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"The key to debugging machine learning models---error analysis! let's look at a few examples that we get wrong." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "WaIr76a-BapN", | |
"colab_type": "code", | |
"outputId": "a18a69ad-245d-4aaf-8d87-57d1a684d0c5", | |
"colab": {} | |
}, | |
"source": [ | |
"%%time\n", | |
"results = model.predict(dataloaders[1], return_preds=True)\n", | |
"golds, preds = results[\"golds\"][task_name], results[\"preds\"][task_name]" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 6min 25s, sys: 1.71 s, total: 6min 26s\n", | |
"Wall time: 3min 54s\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "dheHQKdLBapS", | |
"colab_type": "code", | |
"outputId": "4c3554e0-b6a7-452e-f556-88f3b08165ce", | |
"colab": {} | |
}, | |
"source": [ | |
"incorrect_preds = golds != preds\n", | |
"wic_df[incorrect_preds][[\"sentence1\", \"sentence2\", \"word\", \"label\"]].head()" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>sentence1</th>\n", | |
" <th>sentence2</th>\n", | |
" <th>word</th>\n", | |
" <th>label</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>Circulate a rumor .</td>\n", | |
" <td>This letter is being circulated among the faculty .</td>\n", | |
" <td>circulate</td>\n", | |
" <td>False</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>Making a hobby of domesticity .</td>\n", | |
" <td>A royal family living in unpretentious domesticity .</td>\n", | |
" <td>domesticity</td>\n", | |
" <td>False</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>5</th>\n", | |
" <td>The child 's acquisition of language .</td>\n", | |
" <td>That graphite tennis racquet is quite an acquisition .</td>\n", | |
" <td>acquisition</td>\n", | |
" <td>False</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>7</th>\n", | |
" <td>They swam in the nude .</td>\n", | |
" <td>The marketing rule ' nude sells ' spread from verbal to visual mainstream media in the 20th century .</td>\n", | |
" <td>nude</td>\n", | |
" <td>False</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>16</th>\n", | |
" <td>He took the manuscript in both hands and gave it a mighty tear .</td>\n", | |
" <td>There were big tears rolling down Lisa 's cheeks .</td>\n", | |
" <td>tear</td>\n", | |
" <td>False</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" sentence1 \\\n", | |
"1 Circulate a rumor . \n", | |
"4 Making a hobby of domesticity . \n", | |
"5 The child 's acquisition of language . \n", | |
"7 They swam in the nude . \n", | |
"16 He took the manuscript in both hands and gave it a mighty tear . \n", | |
"\n", | |
" sentence2 \\\n", | |
"1 This letter is being circulated among the faculty . \n", | |
"4 A royal family living in unpretentious domesticity . \n", | |
"5 That graphite tennis racquet is quite an acquisition . \n", | |
"7 The marketing rule ' nude sells ' spread from verbal to visual mainstream media in the 20th century . \n", | |
"16 There were big tears rolling down Lisa 's cheeks . \n", | |
"\n", | |
" word label \n", | |
"1 circulate False \n", | |
"4 domesticity False \n", | |
"5 acquisition False \n", | |
"7 nude False \n", | |
"16 tear False " | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 17 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "aILg9eLlBapY", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"We notice that one particular error mode occurs when the target **word** is a _verb_. Let's investigate further...\n", | |
"\n", | |
"We view examples where we make the wrong prediction _and_ the target word is a verb." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "0oZsYxzdBapa", | |
"colab_type": "code", | |
"outputId": "bc64bec7-fa4c-4f9a-a316-c7c2e95037c2", | |
"colab": {} | |
}, | |
"source": [ | |
"target_is_verb = wic_df[\"pos\"] == \"V\"\n", | |
"df_wrong_and_target_is_verb = wic_df[incorrect_preds & target_is_verb]\n", | |
"df_wrong_and_target_is_verb[[\"sentence1\", \"sentence2\", \"word\", \"pos\", \"label\"]].head()" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>sentence1</th>\n", | |
" <th>sentence2</th>\n", | |
" <th>word</th>\n", | |
" <th>pos</th>\n", | |
" <th>label</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>Circulate a rumor .</td>\n", | |
" <td>This letter is being circulated among the faculty .</td>\n", | |
" <td>circulate</td>\n", | |
" <td>V</td>\n", | |
" <td>False</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>45</th>\n", | |
" <td>To clutch power .</td>\n", | |
" <td>She clutched her purse .</td>\n", | |
" <td>clutch</td>\n", | |
" <td>V</td>\n", | |
" <td>True</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>62</th>\n", | |
" <td>She used to wait down at the Dew Drop Inn .</td>\n", | |
" <td>Wait here until your car arrives .</td>\n", | |
" <td>wait</td>\n", | |
" <td>V</td>\n", | |
" <td>False</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>78</th>\n", | |
" <td>Wear gloves so your hands stay warm .</td>\n", | |
" <td>Stay with me , please .</td>\n", | |
" <td>stay</td>\n", | |
" <td>V</td>\n", | |
" <td>True</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>83</th>\n", | |
" <td>You need to push quite hard to get this door open .</td>\n", | |
" <td>Nora pushed through the crowd .</td>\n", | |
" <td>push</td>\n", | |
" <td>V</td>\n", | |
" <td>True</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" sentence1 \\\n", | |
"1 Circulate a rumor . \n", | |
"45 To clutch power . \n", | |
"62 She used to wait down at the Dew Drop Inn . \n", | |
"78 Wear gloves so your hands stay warm . \n", | |
"83 You need to push quite hard to get this door open . \n", | |
"\n", | |
" sentence2 word pos label \n", | |
"1 This letter is being circulated among the faculty . circulate V False \n", | |
"45 She clutched her purse . clutch V True \n", | |
"62 Wait here until your car arrives . wait V False \n", | |
"78 Stay with me , please . stay V True \n", | |
"83 Nora pushed through the crowd . push V True " | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 18 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "A_Kz4knfBapm", | |
"colab_type": "code", | |
"outputId": "575552b7-3e08-4cdc-85d6-5b60b6f50942", | |
"colab": {} | |
}, | |
"source": [ | |
"len(df_wrong_and_target_is_verb) / len(wic_df[incorrect_preds])" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"0.3765432098765432" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 19 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "DOvyBRjQBapt", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"This error mode accounts for over **37%** of our incorrect predictions! Let's address with _slicing_." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "boDEUBrrBapu", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"## Write slicing functions\n", | |
"We write slicing functions to target specific subsets of the data that we care about—this could correspond to the examples we find underperforming in an error analysis, or specific subsets that are application critical (e.g. night-time images in a self-driving dataset). Then, we'd like to add slice-specific capacity to our model so that it pays more attention to these examples!\n", | |
"\n", | |
"We build our slicing functions in the same way that we write labeling functions—with a decorator: `@slicing_function()`. These slicing functions can also be passed previously defined preprocessors, resources, etc. that the slicing function depends on it—just like with labeling fucntions." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "EEJ6VAWCBapv", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"from snorkel.slicing.sf import slicing_function\n", | |
"from snorkel.types import DataPoint\n", | |
"\n", | |
"@slicing_function()\n", | |
"def SF_verb(x: DataPoint) -> int:\n", | |
" return x.pos == 'V'\n", | |
"\n", | |
"slicing_functions = [SF_verb]\n", | |
"slice_names = [sf.name for sf in slicing_functions]" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "OdkD4xx3Bap4", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"## Train a _slice-aware_ model" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "zSvusR1ABap5", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"Now, let's update our tasks to add _additional capacity_ corresponding to each slice we've specified.\n", | |
"\n", | |
"For each slice, the model will contain two \"task heads\" (PyTorch modules):\n", | |
"- The \"indicator head\" is trained to classify whether each example is a member of that slice or not (a binary problem)\n", | |
"- The \"predictor head\" is trained on the base task using only those examples that were identified as being in the slice, so it becomes an effective expert on those examples.\n", | |
"\n", | |
"At a high level, the helper method `convert_to_slicing_tasks()` will take an existing task and create the following:\n", | |
"- Two task heads (ind + pred) for the \"base slice,\" which all examples belong to\n", | |
"- Two task heads (ind + pred) for each slice you specified with a slicing function\n", | |
"- A new \"master head\" that makes predictions for the main task while taking advantage of information learned by the slice-specific task heads.\n", | |
"\n", | |
"For each example, the indicator heads specify whether that example is in their slice or not. \n", | |
"The magnitude of the predictor head output is used as a proxy for the slice-specific classifier's confidence.\n", | |
"These two scores are multiplied together to make a weighted combination of the representations learned by each of the predictor heads. \n", | |
"It is this reweighted representation (which accentuates those features that are most relevant to making good predictions on members of those slices) that is used by the master head to make the final prediction. \n", | |
"\n", | |
"Note that this plays nicely into our MTL abstraction—additional tasks are easy to pop on and off our network, and they allow us to provide \"spot\" capacity to target and improve performance on particular subsets of our data." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "nyUJgQzRBap6", | |
"colab_type": "code", | |
"outputId": "fee0ca1f-e406-45cf-9ec6-10a62f14e584", | |
"colab": {} | |
}, | |
"source": [ | |
"from snorkel.slicing.utils import convert_to_slice_tasks\n", | |
"\n", | |
"slice_tasks = convert_to_slice_tasks(base_task, slice_names)\n", | |
"slice_tasks" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"[Task(name=WiC_slice:SF_verb_ind),\n", | |
" Task(name=WiC_slice:base_ind),\n", | |
" Task(name=WiC_slice:SF_verb_pred),\n", | |
" Task(name=WiC_slice:base_pred),\n", | |
" Task(name=WiC)]" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 21 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "CBwr5Un9BaqA", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"We then update our dataloaders to include the label sets for these slices so that those heads can be trained as well in addition to the overall task head." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "9g8Wdu4RBaqC", | |
"colab_type": "code", | |
"outputId": "4ed9a79f-bb64-44a3-e0e7-712beec2536d", | |
"colab": {} | |
}, | |
"source": [ | |
"from snorkel.slicing.apply import PandasSFApplier\n", | |
"from snorkel.slicing.utils import add_slice_labels\n", | |
"\n", | |
"slice_dataloaders = []\n", | |
"applier = PandasSFApplier(slicing_functions)\n", | |
"\n", | |
"for dl in dataloaders:\n", | |
" df = task_dataset_to_dataframe(dl.dataset)\n", | |
" S_matrix = applier.apply(df)\n", | |
" # updates dataloaders in place\n", | |
" add_slice_labels(dl, base_task, S_matrix, slice_names)" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"100%|██████████| 5428/5428 [00:00<00:00, 35696.04it/s]\n", | |
"100%|██████████| 638/638 [00:00<00:00, 31381.53it/s]\n" | |
], | |
"name": "stderr" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "MxSSYKcBBaqH", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"We initialize a new _slice-aware model_, and train!" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Oaijn4Y2BaqJ", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"slice_model = MultitaskModel(\n", | |
" name=f\"SuperGLUE\", \n", | |
" tasks=slice_tasks, \n", | |
" dataparallel=False,\n", | |
" device=-1\n", | |
")" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "hrgTAMtbBaqP", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"Again, we've loaded a pretrained model for you to explore on your own, but you can explore training if you'd like." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "4n5vO-3WBaqQ", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# trainer = Trainer(**config)\n", | |
"# trainer.train_model(slice_model, dataloaders)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "lW2Jp4b7BaqU", | |
"colab_type": "code", | |
"outputId": "0e549187-b891-4dfa-a087-0e5fa4e4593a", | |
"colab": {} | |
}, | |
"source": [ | |
"# If you're missing the model, uncomment this line:\n", | |
"# ! wget -nc https://www.dropbox.com/s/h6620vfeompgu9o/WiC_slice_verb.pth\n", | |
"\n", | |
"# slice_wic_path = \"WiC_slice_verb.pth\"\n", | |
"# slice_model.load(slice_wic_path)" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"--2019-08-03 02:42:34-- https://www.dropbox.com/s/h6620vfeompgu9o/WiC_slice_verb.pth\n", | |
"Resolving www.dropbox.com (www.dropbox.com)... 162.125.1.1, 2620:100:6016:1::a27d:101\n", | |
"Connecting to www.dropbox.com (www.dropbox.com)|162.125.1.1|:443... connected.\n", | |
"HTTP request sent, awaiting response... 301 Moved Permanently\n", | |
"Location: /s/raw/h6620vfeompgu9o/WiC_slice_verb.pth [following]\n", | |
"--2019-08-03 02:42:35-- https://www.dropbox.com/s/raw/h6620vfeompgu9o/WiC_slice_verb.pth\n", | |
"Reusing existing connection to www.dropbox.com:443.\n", | |
"HTTP request sent, awaiting response... 302 Found\n", | |
"Location: https://uc961666a77e61c224a3d20242f4.dl.dropboxusercontent.com/cd/0/inline/Al7vz4cwvR11sKbF4BUlGxmTu8vQtg7g7XagSOrG-wAzgNiEhjM8_j_lazTfoIGaak6nWjNxNY7mCUfVK9EqoB43JP2PwyrRibH40Rpuog3R3Q/file# [following]\n", | |
"--2019-08-03 02:42:35-- https://uc961666a77e61c224a3d20242f4.dl.dropboxusercontent.com/cd/0/inline/Al7vz4cwvR11sKbF4BUlGxmTu8vQtg7g7XagSOrG-wAzgNiEhjM8_j_lazTfoIGaak6nWjNxNY7mCUfVK9EqoB43JP2PwyrRibH40Rpuog3R3Q/file\n", | |
"Resolving uc961666a77e61c224a3d20242f4.dl.dropboxusercontent.com (uc961666a77e61c224a3d20242f4.dl.dropboxusercontent.com)... 162.125.1.6, 2620:100:6016:6::a27d:106\n", | |
"Connecting to uc961666a77e61c224a3d20242f4.dl.dropboxusercontent.com (uc961666a77e61c224a3d20242f4.dl.dropboxusercontent.com)|162.125.1.6|:443... connected.\n", | |
"HTTP request sent, awaiting response... 302 FOUND\n", | |
"Location: /cd/0/inline2/Al6ESJLKXJ9eHJZxmSypkMHBsK8cVZ_k3ttNkwxnOnzs3t6V9pR9vwBbmNJBA29RQ1Mu--GcFxYD7Ydr6E69v9xTwPf7zP1sGxJRdBEZjIIcl69i1n3C5dPKpZcUejbbIOx3Hs4ayXTzq-WVAuJgcFLJzF4zsVdNplmQsBbu55TG52c28Zyxc-7qkKmpgYf5cP02lc5krrNO8rDFWx_WhyWzz5pHZY2RyeRWYa2qDLV3aDG_oV2CoCpKIZiNNNrdBUz-QnTCEDLYe8SDnj1zyX_lsN6FNNCmPpsJzBT0L0Hi8trFbSXtXBpGKCa8itRlKVPy25GcKPBpEJzXXuviNayU/file [following]\n", | |
"--2019-08-03 02:42:36-- https://uc961666a77e61c224a3d20242f4.dl.dropboxusercontent.com/cd/0/inline2/Al6ESJLKXJ9eHJZxmSypkMHBsK8cVZ_k3ttNkwxnOnzs3t6V9pR9vwBbmNJBA29RQ1Mu--GcFxYD7Ydr6E69v9xTwPf7zP1sGxJRdBEZjIIcl69i1n3C5dPKpZcUejbbIOx3Hs4ayXTzq-WVAuJgcFLJzF4zsVdNplmQsBbu55TG52c28Zyxc-7qkKmpgYf5cP02lc5krrNO8rDFWx_WhyWzz5pHZY2RyeRWYa2qDLV3aDG_oV2CoCpKIZiNNNrdBUz-QnTCEDLYe8SDnj1zyX_lsN6FNNCmPpsJzBT0L0Hi8trFbSXtXBpGKCa8itRlKVPy25GcKPBpEJzXXuviNayU/file\n", | |
"Reusing existing connection to uc961666a77e61c224a3d20242f4.dl.dropboxusercontent.com:443.\n", | |
"HTTP request sent, awaiting response... 200 OK\n", | |
"Length: 1410039344 (1.3G) [application/octet-stream]\n", | |
"Saving to: ‘WiC_slice_verb.pth’\n", | |
"\n", | |
"WiC_slice_verb.pth 100%[===================>] 1.31G 35.3MB/s in 39s \n", | |
"\n", | |
"2019-08-03 02:43:15 (34.4 MB/s) - ‘WiC_slice_verb.pth’ saved [1410039344/1410039344]\n", | |
"\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "tFxUrQVrBaqe", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"## Evaluate _slice-aware_ model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "4_pO5cGTBaqh", | |
"colab_type": "code", | |
"outputId": "cba4e629-f080-470f-ab4b-b9a651ed5348", | |
"colab": {} | |
}, | |
"source": [ | |
"%%time \n", | |
"slice_model.score(dataloaders[1])" | |
], | |
"execution_count": 0, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"/home/ubuntu/snorkel-superglue/.env/lib/python3.6/site-packages/snorkel/slicing/modules/slice_combiner.py:40: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n", | |
" for slice_ind_name in slice_ind_op_names\n", | |
"/home/ubuntu/snorkel-superglue/.env/lib/python3.6/site-packages/snorkel/slicing/modules/slice_combiner.py:47: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n", | |
" for slice_pred_name in slice_pred_op_names\n" | |
], | |
"name": "stderr" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 5min 22s, sys: 1.43 s, total: 5min 24s\n", | |
"Wall time: 2min 50s\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"{'WiC/SuperGLUE/valid/accuracy': 0.7554858934169278,\n", | |
" 'WiC_slice:SF_verb_ind/SuperGLUE/valid/f1': 0.5687022900763358,\n", | |
" 'WiC_slice:SF_verb_pred/SuperGLUE/valid/accuracy': 0.448559670781893,\n", | |
" 'WiC_slice:base_ind/SuperGLUE/valid/f1': 1.0,\n", | |
" 'WiC_slice:base_pred/SuperGLUE/valid/accuracy': 0.7570532915360502}" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 27 | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "UYWMSvdZBaqm", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"With some simple error analysis and an interface to specifying which _slice_ of the data we care about, we've improved our model **0.94 accuracy points** over a previous state-of-the-art model!" | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment