Last active
July 10, 2020 00:25
-
-
Save keitakurita/0fac1cc175591971a18c456ce18c45a5 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "%load_ext autoreload\n", | |
| "%autoreload 2" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from pathlib import Path\n", | |
| "from typing import *\n", | |
| "import torch\n", | |
| "import torch.optim as optim\n", | |
| "import numpy as np\n", | |
| "import pandas as pd\n", | |
| "from functools import partial\n", | |
| "from overrides import overrides\n", | |
| "\n", | |
| "from allennlp.data import Instance\n", | |
| "from allennlp.data.token_indexers import TokenIndexer\n", | |
| "from allennlp.data.tokenizers import Token\n", | |
| "from allennlp.nn import util as nn_util" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "class Config(dict):\n", | |
| " def __init__(self, **kwargs):\n", | |
| " super().__init__(**kwargs)\n", | |
| " for k, v in kwargs.items():\n", | |
| " setattr(self, k, v)\n", | |
| " \n", | |
| " def set(self, key, val):\n", | |
| " self[key] = val\n", | |
| " setattr(self, key, val)\n", | |
| " \n", | |
| "config = Config(\n", | |
| " testing=True,\n", | |
| " seed=1,\n", | |
| " batch_size=64,\n", | |
| " lr=3e-4,\n", | |
| " epochs=2,\n", | |
| " hidden_sz=64,\n", | |
| " max_seq_len=100, # necessary to limit memory usage\n", | |
| " max_vocab_size=100000,\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from allennlp.common.checks import ConfigurationError" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "USE_GPU = torch.cuda.is_available()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "DATA_ROOT = Path(\"../data\") / \"jigsaw\"" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Set random seed manually to replicate results" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "<torch._C.Generator at 0x1176dd710>" | |
| ] | |
| }, | |
| "execution_count": 7, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "torch.manual_seed(config.seed)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Load Data" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from allennlp.data.vocabulary import Vocabulary\n", | |
| "from allennlp.data.dataset_readers import DatasetReader" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Prepare dataset" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "label_cols = [\"toxic\", \"severe_toxic\", \"obscene\",\n", | |
| " \"threat\", \"insult\", \"identity_hate\"]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from allennlp.data.fields import TextField, MetadataField, ArrayField\n", | |
| "\n", | |
| "class JigsawDatasetReader(DatasetReader):\n", | |
| " def __init__(self, tokenizer: Callable[[str], List[str]]=lambda x: x.split(),\n", | |
| " token_indexers: Dict[str, TokenIndexer] = None,\n", | |
| " max_seq_len: Optional[int]=config.max_seq_len) -> None:\n", | |
| " super().__init__(lazy=False)\n", | |
| " self.tokenizer = tokenizer\n", | |
| " self.token_indexers = token_indexers or {\"tokens\": SingleIdTokenIndexer()}\n", | |
| " self.max_seq_len = max_seq_len\n", | |
| "\n", | |
| " @overrides\n", | |
| " def text_to_instance(self, tokens: List[Token], id: str,\n", | |
| " labels: np.ndarray) -> Instance:\n", | |
| " sentence_field = TextField(tokens, self.token_indexers)\n", | |
| " fields = {\"tokens\": sentence_field}\n", | |
| " \n", | |
| " id_field = MetadataField(id)\n", | |
| " fields[\"id\"] = id_field\n", | |
| " \n", | |
| " label_field = ArrayField(array=labels)\n", | |
| " fields[\"label\"] = label_field\n", | |
| "\n", | |
| " return Instance(fields)\n", | |
| " \n", | |
| " @overrides\n", | |
| " def _read(self, file_path: str) -> Iterator[Instance]:\n", | |
| " df = pd.read_csv(file_path)\n", | |
| " if config.testing: df = df.head(1000)\n", | |
| " for i, row in df.iterrows():\n", | |
| " yield self.text_to_instance(\n", | |
| " [Token(x) for x in self.tokenizer(row[\"comment_text\"])],\n", | |
| " row[\"id\"], row[label_cols].values,\n", | |
| " )" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Prepare token handlers" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "We will use the spacy tokenizer here" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from allennlp.data.tokenizers.word_splitter import SpacyWordSplitter\n", | |
| "from allennlp.data.token_indexers import SingleIdTokenIndexer\n", | |
| "\n", | |
| "# the token indexer is responsible for mapping tokens to integers\n", | |
| "token_indexer = SingleIdTokenIndexer()\n", | |
| "\n", | |
| "def tokenizer(x: str):\n", | |
| " return [w.text for w in\n", | |
| " SpacyWordSplitter(language='en_core_web_sm', \n", | |
| " pos_tags=False).split_words(x)[:config.max_seq_len]]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "reader = JigsawDatasetReader(\n", | |
| " tokenizer=tokenizer,\n", | |
| " token_indexers={\"tokens\": token_indexer}\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "267it [00:02, 94.93it/s]\n", | |
| "251it [00:01, 172.26it/s]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "train_ds, test_ds = (reader.read(DATA_ROOT / fname) for fname in [\"train.csv\", \"test_proced.csv\"])\n", | |
| "val_ds = None" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "267" | |
| ] | |
| }, | |
| "execution_count": 14, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "len(train_ds)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 15, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "[<allennlp.data.instance.Instance at 0x1a2b034160>,\n", | |
| " <allennlp.data.instance.Instance at 0x1a2b016208>,\n", | |
| " <allennlp.data.instance.Instance at 0x1a2afec748>,\n", | |
| " <allennlp.data.instance.Instance at 0x1a2af92828>,\n", | |
| " <allennlp.data.instance.Instance at 0x1a2af8a4a8>,\n", | |
| " <allennlp.data.instance.Instance at 0x1a2af7e630>,\n", | |
| " <allennlp.data.instance.Instance at 0x1a2af79710>,\n", | |
| " <allennlp.data.instance.Instance at 0x1a2af66550>,\n", | |
| " <allennlp.data.instance.Instance at 0x10bd9a518>,\n", | |
| " <allennlp.data.instance.Instance at 0x1a28d5def0>]" | |
| ] | |
| }, | |
| "execution_count": 15, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "train_ds[:10]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 16, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "{'tokens': [Explanation,\n", | |
| " Why,\n", | |
| " the,\n", | |
| " edits,\n", | |
| " made,\n", | |
| " under,\n", | |
| " my,\n", | |
| " username,\n", | |
| " Hardcore,\n", | |
| " Metallica,\n", | |
| " Fan,\n", | |
| " were,\n", | |
| " reverted,\n", | |
| " ?,\n", | |
| " They,\n", | |
| " were,\n", | |
| " n't,\n", | |
| " vandalisms,\n", | |
| " ,,\n", | |
| " just,\n", | |
| " closure,\n", | |
| " on,\n", | |
| " some,\n", | |
| " GAs,\n", | |
| " after,\n", | |
| " I,\n", | |
| " voted,\n", | |
| " at,\n", | |
| " New,\n", | |
| " York,\n", | |
| " Dolls,\n", | |
| " FAC,\n", | |
| " .,\n", | |
| " And,\n", | |
| " please,\n", | |
| " do,\n", | |
| " n't,\n", | |
| " remove,\n", | |
| " the,\n", | |
| " template,\n", | |
| " from,\n", | |
| " the,\n", | |
| " talk,\n", | |
| " page,\n", | |
| " since,\n", | |
| " I,\n", | |
| " 'm,\n", | |
| " retired,\n", | |
| " now.89.205.38.27],\n", | |
| " '_token_indexers': {'tokens': <allennlp.data.token_indexers.single_id_token_indexer.SingleIdTokenIndexer at 0x1a27b07400>},\n", | |
| " '_indexed_tokens': None,\n", | |
| " '_indexer_name_to_indexed_token': None}" | |
| ] | |
| }, | |
| "execution_count": 16, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "vars(train_ds[0].fields[\"tokens\"])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Prepare vocabulary" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 17, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "01/29/2019 19:57:00 - INFO - allennlp.data.vocabulary - Fitting token dictionary from dataset.\n", | |
| "100%|██████████| 267/267 [00:00<00:00, 11635.95it/s]\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "vocab = Vocabulary.from_instances(train_ds, max_vocab_size=config.max_vocab_size)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Prepare iterator" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "The iterator is responsible for batching the data and preparing it for input into the model. We'll use the BucketIterator that batches text sequences of smilar lengths together." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 18, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from allennlp.data.iterators import BucketIterator" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 19, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "iterator = BucketIterator(batch_size=config.batch_size, \n", | |
| " sorting_keys=[(\"tokens\", \"num_tokens\")],\n", | |
| " )" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "We need to tell the iterator how to numericalize the text data. We do this by passing the vocabulary to the iterator. This step is easy to forget so be careful! " | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 20, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "iterator.index_with(vocab)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Read sample" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 21, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "batch = next(iter(iterator(train_ds)))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 22, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "{'tokens': {'tokens': tensor([[ 131, 264, 21, ..., 0, 0, 0],\n", | |
| " [ 74, 203, 24, ..., 0, 0, 0],\n", | |
| " [ 5, 85, 26, ..., 0, 0, 0],\n", | |
| " ...,\n", | |
| " [ 5, 103, 1068, ..., 0, 0, 0],\n", | |
| " [2972, 622, 1099, ..., 0, 0, 0],\n", | |
| " [ 5, 3301, 8, ..., 0, 0, 0]])},\n", | |
| " 'id': ['0029541a38c523a0',\n", | |
| " '003d77a20601cec1',\n", | |
| " '006fda507acd9769',\n", | |
| " '00173958f46763a2',\n", | |
| " '00a5394e626e72c6',\n", | |
| " '005e2ae8f864f76c',\n", | |
| " '0060c5c9030b2d14',\n", | |
| " '0095756047a71716',\n", | |
| " '0000997932d777bf',\n", | |
| " '0061b075244dd234',\n", | |
| " '002d6c9d9f85e81f',\n", | |
| " '001735f961a23fc4',\n", | |
| " '000113f07ec002fd',\n", | |
| " '0029b87aa9c7dc4a',\n", | |
| " '00070ef96486d6f9',\n", | |
| " '007bc29766a43e3c',\n", | |
| " '004f5608984d99f1',\n", | |
| " '000b08c464718505',\n", | |
| " '007bbfa4da2bc32d',\n", | |
| " '00537730daf8c5f1',\n", | |
| " '008f22e7b58e559b',\n", | |
| " '009b3b15f1ada72f',\n", | |
| " '004f981460421bdf',\n", | |
| " '000f35deef84dc4a',\n", | |
| " '002f0e29c60807b1',\n", | |
| " '003dbd1b9b354c1f',\n", | |
| " '004f6dbe69f3545d',\n", | |
| " '00733f0a4a58cf42',\n", | |
| " '0057b7710cb5ebb2',\n", | |
| " '006f2c1459f3b6b1',\n", | |
| " '00349c6325526c11',\n", | |
| " '0030614cfd96d9d1',\n", | |
| " '006120d209a4a46c',\n", | |
| " '00744c2f77391702',\n", | |
| " '007571394afafcb5',\n", | |
| " '0005c987bdfc9d4b',\n", | |
| " '0052a7e684beeb1a',\n", | |
| " '0022cf8467ebc9fd',\n", | |
| " '004de318396bbf8b',\n", | |
| " '0053bab79133c0fc',\n", | |
| " '001956c382006abd',\n", | |
| " '0015f4aa35ebe9b5',\n", | |
| " '000ffab30195c5e1',\n", | |
| " '002a6beca33307b3',\n", | |
| " '0028d62e8a5629aa',\n", | |
| " '0063dd8f202a698a',\n", | |
| " '008198c5a9d85a8e',\n", | |
| " '007f1839ada915e6',\n", | |
| " '0037e59caead9dab',\n", | |
| " '00a20f187531df59',\n", | |
| " '00585c1da10b448b',\n", | |
| " '00961bcaadd6a278',\n", | |
| " '0082b4b42b3f07a1',\n", | |
| " '005b214511a69b4b',\n", | |
| " '008e2acf5bcf4be4',\n", | |
| " '002264ea4d5f2887',\n", | |
| " '0013a8b1a5f26bcb',\n", | |
| " '0074b307c2d9a100',\n", | |
| " '008f93320e3661b8',\n", | |
| " '008344a80c43b8c9',\n", | |
| " '0038f191ffc93d75',\n", | |
| " '0069e6d57a3beb51',\n", | |
| " '0087c131ccffe160',\n", | |
| " '00a330961879175c'],\n", | |
| " 'label': tensor([[0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [1., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [1., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [1., 0., 1., 0., 1., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [1., 0., 1., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [1., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.],\n", | |
| " [0., 0., 0., 0., 0., 0.]])}" | |
| ] | |
| }, | |
| "execution_count": 22, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "batch" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 23, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[ 131, 264, 21, ..., 0, 0, 0],\n", | |
| " [ 74, 203, 24, ..., 0, 0, 0],\n", | |
| " [ 5, 85, 26, ..., 0, 0, 0],\n", | |
| " ...,\n", | |
| " [ 5, 103, 1068, ..., 0, 0, 0],\n", | |
| " [2972, 622, 1099, ..., 0, 0, 0],\n", | |
| " [ 5, 3301, 8, ..., 0, 0, 0]])" | |
| ] | |
| }, | |
| "execution_count": 23, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "batch[\"tokens\"][\"tokens\"]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 24, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "torch.Size([64, 84])" | |
| ] | |
| }, | |
| "execution_count": 24, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "batch[\"tokens\"][\"tokens\"].shape" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Prepare Model" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 25, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import torch\n", | |
| "import torch.nn as nn\n", | |
| "import torch.optim as optim" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 26, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from allennlp.modules.seq2vec_encoders import Seq2VecEncoder, PytorchSeq2VecWrapper\n", | |
| "from allennlp.nn.util import get_text_field_mask\n", | |
| "from allennlp.models import Model\n", | |
| "from allennlp.modules.text_field_embedders import TextFieldEmbedder\n", | |
| "\n", | |
| "class BaselineModel(Model):\n", | |
| " def __init__(self, word_embeddings: TextFieldEmbedder,\n", | |
| " encoder: Seq2VecEncoder,\n", | |
| " out_sz: int=len(label_cols)):\n", | |
| " super().__init__(vocab)\n", | |
| " self.word_embeddings = word_embeddings\n", | |
| " self.encoder = encoder\n", | |
| " self.projection = nn.Linear(self.encoder.get_output_dim(), out_sz)\n", | |
| " self.loss = nn.BCEWithLogitsLoss()\n", | |
| " \n", | |
| " def forward(self, tokens: Dict[str, torch.Tensor],\n", | |
| " id: Any, label: torch.Tensor) -> torch.Tensor:\n", | |
| " mask = get_text_field_mask(tokens)\n", | |
| " embeddings = self.word_embeddings(tokens)\n", | |
| " state = self.encoder(embeddings, mask)\n", | |
| " class_logits = self.projection(state)\n", | |
| " \n", | |
| " output = {\"class_logits\": class_logits}\n", | |
| " output[\"loss\"] = self.loss(class_logits, label)\n", | |
| "\n", | |
| " return output" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Prepare embeddings" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 27, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from allennlp.modules.token_embedders import Embedding\n", | |
| "from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder\n", | |
| "\n", | |
| "token_embedding = Embedding(num_embeddings=config.max_vocab_size + 2,\n", | |
| " embedding_dim=300, padding_index=0)\n", | |
| "# the embedder maps the input tokens to the appropriate embedding matrix\n", | |
| "word_embeddings: TextFieldEmbedder = BasicTextFieldEmbedder({\"tokens\": token_embedding})" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 28, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from allennlp.modules.seq2vec_encoders import PytorchSeq2VecWrapper\n", | |
| "encoder: Seq2VecEncoder = PytorchSeq2VecWrapper(nn.LSTM(word_embeddings.get_output_dim(),\n", | |
| " config.hidden_sz, bidirectional=True, batch_first=True))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "Notice how simple and modular the code for initializing the model is. All the complexity is delegated to each component." | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 29, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "model = BaselineModel(\n", | |
| " word_embeddings, \n", | |
| " encoder, \n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 30, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "if USE_GPU: model.cuda()\n", | |
| "else: model" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Basic sanity checks" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 31, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "batch = nn_util.move_to_device(batch, 0 if USE_GPU else -1)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 32, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "tokens = batch[\"tokens\"]\n", | |
| "labels = batch" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 33, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "{'tokens': tensor([[ 131, 264, 21, ..., 0, 0, 0],\n", | |
| " [ 74, 203, 24, ..., 0, 0, 0],\n", | |
| " [ 5, 85, 26, ..., 0, 0, 0],\n", | |
| " ...,\n", | |
| " [ 5, 103, 1068, ..., 0, 0, 0],\n", | |
| " [2972, 622, 1099, ..., 0, 0, 0],\n", | |
| " [ 5, 3301, 8, ..., 0, 0, 0]])}" | |
| ] | |
| }, | |
| "execution_count": 33, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "tokens" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 34, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[1, 1, 1, ..., 0, 0, 0],\n", | |
| " [1, 1, 1, ..., 0, 0, 0],\n", | |
| " [1, 1, 1, ..., 0, 0, 0],\n", | |
| " ...,\n", | |
| " [1, 1, 1, ..., 0, 0, 0],\n", | |
| " [1, 1, 1, ..., 0, 0, 0],\n", | |
| " [1, 1, 1, ..., 0, 0, 0]])" | |
| ] | |
| }, | |
| "execution_count": 34, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "mask = get_text_field_mask(tokens)\n", | |
| "mask" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 35, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([[-0.0372, 0.0801, -0.0296, 0.0281, 0.0153, -0.0244],\n", | |
| " [-0.0387, 0.0801, -0.0318, 0.0281, 0.0164, -0.0250],\n", | |
| " [-0.0384, 0.0804, -0.0303, 0.0254, 0.0170, -0.0253],\n", | |
| " [-0.0393, 0.0794, -0.0324, 0.0276, 0.0162, -0.0238],\n", | |
| " [-0.0388, 0.0807, -0.0317, 0.0281, 0.0175, -0.0239],\n", | |
| " [-0.0397, 0.0785, -0.0307, 0.0275, 0.0172, -0.0237],\n", | |
| " [-0.0382, 0.0805, -0.0304, 0.0266, 0.0164, -0.0243],\n", | |
| " [-0.0376, 0.0781, -0.0321, 0.0286, 0.0144, -0.0258],\n", | |
| " [-0.0392, 0.0809, -0.0318, 0.0294, 0.0155, -0.0242],\n", | |
| " [-0.0386, 0.0801, -0.0310, 0.0295, 0.0176, -0.0244],\n", | |
| " [-0.0380, 0.0806, -0.0310, 0.0264, 0.0162, -0.0255],\n", | |
| " [-0.0396, 0.0808, -0.0306, 0.0273, 0.0169, -0.0257],\n", | |
| " [-0.0389, 0.0808, -0.0314, 0.0292, 0.0173, -0.0251],\n", | |
| " [-0.0398, 0.0805, -0.0308, 0.0264, 0.0161, -0.0255],\n", | |
| " [-0.0376, 0.0790, -0.0300, 0.0289, 0.0143, -0.0261],\n", | |
| " [-0.0393, 0.0804, -0.0301, 0.0296, 0.0165, -0.0248],\n", | |
| " [-0.0391, 0.0809, -0.0312, 0.0263, 0.0168, -0.0254],\n", | |
| " [-0.0388, 0.0803, -0.0315, 0.0268, 0.0168, -0.0255],\n", | |
| " [-0.0381, 0.0799, -0.0306, 0.0290, 0.0154, -0.0245],\n", | |
| " [-0.0379, 0.0810, -0.0316, 0.0274, 0.0140, -0.0264],\n", | |
| " [-0.0392, 0.0802, -0.0315, 0.0271, 0.0163, -0.0254],\n", | |
| " [-0.0377, 0.0803, -0.0314, 0.0279, 0.0165, -0.0237],\n", | |
| " [-0.0394, 0.0814, -0.0301, 0.0294, 0.0172, -0.0246],\n", | |
| " [-0.0384, 0.0813, -0.0299, 0.0286, 0.0169, -0.0247],\n", | |
| " [-0.0399, 0.0799, -0.0328, 0.0273, 0.0149, -0.0255],\n", | |
| " [-0.0383, 0.0796, -0.0309, 0.0283, 0.0150, -0.0262],\n", | |
| " [-0.0395, 0.0779, -0.0302, 0.0266, 0.0139, -0.0243],\n", | |
| " [-0.0394, 0.0797, -0.0304, 0.0267, 0.0151, -0.0246],\n", | |
| " [-0.0395, 0.0805, -0.0316, 0.0269, 0.0161, -0.0256],\n", | |
| " [-0.0374, 0.0806, -0.0300, 0.0263, 0.0157, -0.0255],\n", | |
| " [-0.0385, 0.0813, -0.0313, 0.0279, 0.0175, -0.0247],\n", | |
| " [-0.0374, 0.0794, -0.0316, 0.0272, 0.0154, -0.0256],\n", | |
| " [-0.0388, 0.0797, -0.0314, 0.0269, 0.0165, -0.0253],\n", | |
| " [-0.0401, 0.0795, -0.0311, 0.0280, 0.0162, -0.0247],\n", | |
| " [-0.0391, 0.0808, -0.0308, 0.0271, 0.0166, -0.0259],\n", | |
| " [-0.0399, 0.0808, -0.0315, 0.0270, 0.0164, -0.0260],\n", | |
| " [-0.0390, 0.0803, -0.0306, 0.0267, 0.0164, -0.0247],\n", | |
| " [-0.0394, 0.0803, -0.0299, 0.0286, 0.0168, -0.0236],\n", | |
| " [-0.0387, 0.0800, -0.0304, 0.0267, 0.0163, -0.0255],\n", | |
| " [-0.0409, 0.0800, -0.0297, 0.0295, 0.0179, -0.0246],\n", | |
| " [-0.0381, 0.0802, -0.0332, 0.0277, 0.0154, -0.0244],\n", | |
| " [-0.0377, 0.0800, -0.0298, 0.0282, 0.0167, -0.0252],\n", | |
| " [-0.0396, 0.0812, -0.0302, 0.0295, 0.0158, -0.0243],\n", | |
| " [-0.0394, 0.0799, -0.0318, 0.0285, 0.0161, -0.0245],\n", | |
| " [-0.0390, 0.0808, -0.0313, 0.0279, 0.0153, -0.0254],\n", | |
| " [-0.0387, 0.0797, -0.0304, 0.0258, 0.0163, -0.0245],\n", | |
| " [-0.0393, 0.0803, -0.0324, 0.0285, 0.0175, -0.0229],\n", | |
| " [-0.0387, 0.0778, -0.0314, 0.0280, 0.0158, -0.0258],\n", | |
| " [-0.0375, 0.0809, -0.0312, 0.0277, 0.0151, -0.0259],\n", | |
| " [-0.0395, 0.0806, -0.0313, 0.0276, 0.0151, -0.0245],\n", | |
| " [-0.0387, 0.0797, -0.0311, 0.0275, 0.0142, -0.0256],\n", | |
| " [-0.0394, 0.0804, -0.0304, 0.0289, 0.0174, -0.0231],\n", | |
| " [-0.0379, 0.0788, -0.0325, 0.0275, 0.0152, -0.0239],\n", | |
| " [-0.0387, 0.0802, -0.0319, 0.0261, 0.0165, -0.0248],\n", | |
| " [-0.0401, 0.0804, -0.0319, 0.0292, 0.0139, -0.0255],\n", | |
| " [-0.0386, 0.0815, -0.0327, 0.0297, 0.0152, -0.0263],\n", | |
| " [-0.0402, 0.0799, -0.0302, 0.0263, 0.0168, -0.0246],\n", | |
| " [-0.0391, 0.0801, -0.0312, 0.0280, 0.0169, -0.0242],\n", | |
| " [-0.0393, 0.0799, -0.0310, 0.0270, 0.0163, -0.0258],\n", | |
| " [-0.0383, 0.0800, -0.0317, 0.0283, 0.0172, -0.0236],\n", | |
| " [-0.0367, 0.0808, -0.0306, 0.0292, 0.0149, -0.0246],\n", | |
| " [-0.0386, 0.0797, -0.0308, 0.0266, 0.0170, -0.0247],\n", | |
| " [-0.0389, 0.0798, -0.0319, 0.0276, 0.0171, -0.0256],\n", | |
| " [-0.0388, 0.0804, -0.0307, 0.0261, 0.0168, -0.0248]],\n", | |
| " grad_fn=<AddmmBackward>)" | |
| ] | |
| }, | |
| "execution_count": 35, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "embeddings = model.word_embeddings(tokens)\n", | |
| "state = model.encoder(embeddings, mask)\n", | |
| "class_logits = model.projection(state)\n", | |
| "class_logits" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 36, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "{'class_logits': tensor([[-0.0372, 0.0801, -0.0296, 0.0281, 0.0153, -0.0244],\n", | |
| " [-0.0387, 0.0801, -0.0318, 0.0281, 0.0164, -0.0250],\n", | |
| " [-0.0384, 0.0804, -0.0303, 0.0254, 0.0170, -0.0253],\n", | |
| " [-0.0393, 0.0794, -0.0324, 0.0276, 0.0162, -0.0238],\n", | |
| " [-0.0388, 0.0807, -0.0317, 0.0281, 0.0175, -0.0239],\n", | |
| " [-0.0397, 0.0785, -0.0307, 0.0275, 0.0172, -0.0237],\n", | |
| " [-0.0382, 0.0805, -0.0304, 0.0266, 0.0164, -0.0243],\n", | |
| " [-0.0376, 0.0781, -0.0321, 0.0286, 0.0144, -0.0258],\n", | |
| " [-0.0392, 0.0809, -0.0318, 0.0294, 0.0155, -0.0242],\n", | |
| " [-0.0386, 0.0801, -0.0310, 0.0295, 0.0176, -0.0244],\n", | |
| " [-0.0380, 0.0806, -0.0310, 0.0264, 0.0162, -0.0255],\n", | |
| " [-0.0396, 0.0808, -0.0306, 0.0273, 0.0169, -0.0257],\n", | |
| " [-0.0389, 0.0808, -0.0314, 0.0292, 0.0173, -0.0251],\n", | |
| " [-0.0398, 0.0805, -0.0308, 0.0264, 0.0161, -0.0255],\n", | |
| " [-0.0376, 0.0790, -0.0300, 0.0289, 0.0143, -0.0261],\n", | |
| " [-0.0393, 0.0804, -0.0301, 0.0296, 0.0165, -0.0248],\n", | |
| " [-0.0391, 0.0809, -0.0312, 0.0263, 0.0168, -0.0254],\n", | |
| " [-0.0388, 0.0803, -0.0315, 0.0268, 0.0168, -0.0255],\n", | |
| " [-0.0381, 0.0799, -0.0306, 0.0290, 0.0154, -0.0245],\n", | |
| " [-0.0379, 0.0810, -0.0316, 0.0274, 0.0140, -0.0264],\n", | |
| " [-0.0392, 0.0802, -0.0315, 0.0271, 0.0163, -0.0254],\n", | |
| " [-0.0377, 0.0803, -0.0314, 0.0279, 0.0165, -0.0237],\n", | |
| " [-0.0394, 0.0814, -0.0301, 0.0294, 0.0172, -0.0246],\n", | |
| " [-0.0384, 0.0813, -0.0299, 0.0286, 0.0169, -0.0247],\n", | |
| " [-0.0399, 0.0799, -0.0328, 0.0273, 0.0149, -0.0255],\n", | |
| " [-0.0383, 0.0796, -0.0309, 0.0283, 0.0150, -0.0262],\n", | |
| " [-0.0395, 0.0779, -0.0302, 0.0266, 0.0139, -0.0243],\n", | |
| " [-0.0394, 0.0797, -0.0304, 0.0267, 0.0151, -0.0246],\n", | |
| " [-0.0395, 0.0805, -0.0316, 0.0269, 0.0161, -0.0256],\n", | |
| " [-0.0374, 0.0806, -0.0300, 0.0263, 0.0157, -0.0255],\n", | |
| " [-0.0385, 0.0813, -0.0313, 0.0279, 0.0175, -0.0247],\n", | |
| " [-0.0374, 0.0794, -0.0316, 0.0272, 0.0154, -0.0256],\n", | |
| " [-0.0388, 0.0797, -0.0314, 0.0269, 0.0165, -0.0253],\n", | |
| " [-0.0401, 0.0795, -0.0311, 0.0280, 0.0162, -0.0247],\n", | |
| " [-0.0391, 0.0808, -0.0308, 0.0271, 0.0166, -0.0259],\n", | |
| " [-0.0399, 0.0808, -0.0315, 0.0270, 0.0164, -0.0260],\n", | |
| " [-0.0390, 0.0803, -0.0306, 0.0267, 0.0164, -0.0247],\n", | |
| " [-0.0394, 0.0803, -0.0299, 0.0286, 0.0168, -0.0236],\n", | |
| " [-0.0387, 0.0800, -0.0304, 0.0267, 0.0163, -0.0255],\n", | |
| " [-0.0409, 0.0800, -0.0297, 0.0295, 0.0179, -0.0246],\n", | |
| " [-0.0381, 0.0802, -0.0332, 0.0277, 0.0154, -0.0244],\n", | |
| " [-0.0377, 0.0800, -0.0298, 0.0282, 0.0167, -0.0252],\n", | |
| " [-0.0396, 0.0812, -0.0302, 0.0295, 0.0158, -0.0243],\n", | |
| " [-0.0394, 0.0799, -0.0318, 0.0285, 0.0161, -0.0245],\n", | |
| " [-0.0390, 0.0808, -0.0313, 0.0279, 0.0153, -0.0254],\n", | |
| " [-0.0387, 0.0797, -0.0304, 0.0258, 0.0163, -0.0245],\n", | |
| " [-0.0393, 0.0803, -0.0324, 0.0285, 0.0175, -0.0229],\n", | |
| " [-0.0387, 0.0778, -0.0314, 0.0280, 0.0158, -0.0258],\n", | |
| " [-0.0375, 0.0809, -0.0312, 0.0277, 0.0151, -0.0259],\n", | |
| " [-0.0395, 0.0806, -0.0313, 0.0276, 0.0151, -0.0245],\n", | |
| " [-0.0387, 0.0797, -0.0311, 0.0275, 0.0142, -0.0256],\n", | |
| " [-0.0394, 0.0804, -0.0304, 0.0289, 0.0174, -0.0231],\n", | |
| " [-0.0379, 0.0788, -0.0325, 0.0275, 0.0152, -0.0239],\n", | |
| " [-0.0387, 0.0802, -0.0319, 0.0261, 0.0165, -0.0248],\n", | |
| " [-0.0401, 0.0804, -0.0319, 0.0292, 0.0139, -0.0255],\n", | |
| " [-0.0386, 0.0815, -0.0327, 0.0297, 0.0152, -0.0263],\n", | |
| " [-0.0402, 0.0799, -0.0302, 0.0263, 0.0168, -0.0246],\n", | |
| " [-0.0391, 0.0801, -0.0312, 0.0280, 0.0169, -0.0242],\n", | |
| " [-0.0393, 0.0799, -0.0310, 0.0270, 0.0163, -0.0258],\n", | |
| " [-0.0383, 0.0800, -0.0317, 0.0283, 0.0172, -0.0236],\n", | |
| " [-0.0367, 0.0808, -0.0306, 0.0292, 0.0149, -0.0246],\n", | |
| " [-0.0386, 0.0797, -0.0308, 0.0266, 0.0170, -0.0247],\n", | |
| " [-0.0389, 0.0798, -0.0319, 0.0276, 0.0171, -0.0256],\n", | |
| " [-0.0388, 0.0804, -0.0307, 0.0261, 0.0168, -0.0248]],\n", | |
| " grad_fn=<AddmmBackward>),\n", | |
| " 'loss': tensor(0.6964, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)}" | |
| ] | |
| }, | |
| "execution_count": 36, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "model(**batch)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 37, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "loss = model(**batch)[\"loss\"]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 38, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor(0.6964, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)" | |
| ] | |
| }, | |
| "execution_count": 38, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "loss" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 39, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "loss.backward()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Train" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 41, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "optimizer = optim.Adam(model.parameters(), lr=config.lr)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 42, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from allennlp.training.trainer import Trainer\n", | |
| "\n", | |
| "trainer = Trainer(\n", | |
| " model=model,\n", | |
| " optimizer=optimizer,\n", | |
| " iterator=iterator,\n", | |
| " train_dataset=train_ds,\n", | |
| " cuda_device=0 if USE_GPU else -1,\n", | |
| " num_epochs=config.epochs,\n", | |
| ")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 43, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "01/29/2019 19:57:05 - INFO - allennlp.training.trainer - Beginning training.\n", | |
| "01/29/2019 19:57:05 - INFO - allennlp.training.trainer - Epoch 0/1\n", | |
| "01/29/2019 19:57:05 - INFO - allennlp.training.trainer - Peak CPU memory usage MB: 315.695104\n", | |
| "01/29/2019 19:57:05 - INFO - allennlp.training.trainer - Training\n", | |
| "loss: 0.6929 ||: 100%|██████████| 5/5 [00:08<00:00, 1.54s/it]\n", | |
| "01/29/2019 19:57:14 - INFO - allennlp.training.trainer - Training | Validation\n", | |
| "01/29/2019 19:57:14 - INFO - allennlp.training.trainer - loss | 0.693 | N/A\n", | |
| "01/29/2019 19:57:14 - INFO - allennlp.training.trainer - cpu_memory_MB | 315.695 | N/A\n", | |
| "01/29/2019 19:57:14 - INFO - allennlp.training.trainer - Epoch duration: 00:00:08\n", | |
| "01/29/2019 19:57:14 - INFO - allennlp.training.trainer - Estimated training time remaining: 0:00:08\n", | |
| "01/29/2019 19:57:14 - INFO - allennlp.training.trainer - Epoch 1/1\n", | |
| "01/29/2019 19:57:14 - INFO - allennlp.training.trainer - Peak CPU memory usage MB: 734.257152\n", | |
| "01/29/2019 19:57:14 - INFO - allennlp.training.trainer - Training\n", | |
| "loss: 0.6820 ||: 100%|██████████| 5/5 [00:07<00:00, 1.76s/it]\n", | |
| "01/29/2019 19:57:22 - INFO - allennlp.training.trainer - Training | Validation\n", | |
| "01/29/2019 19:57:22 - INFO - allennlp.training.trainer - loss | 0.682 | N/A\n", | |
| "01/29/2019 19:57:22 - INFO - allennlp.training.trainer - cpu_memory_MB | 734.257 | N/A\n", | |
| "01/29/2019 19:57:22 - INFO - allennlp.training.trainer - Epoch duration: 00:00:08\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "metrics = trainer.train()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "celltoolbar": "Tags", | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.6.6" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi, what is the test_proced file here?