Created
September 8, 2020 11:14
-
-
Save cahya-wirawan/b36e91cae21a6a7f9a10e1c85f59d9ae to your computer and use it in GitHub Desktop.
BERT - GPT2 - CNN
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# BERT - GPT2" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import nlp\n", | |
"import logging\n", | |
"import transformers \n", | |
"from transformers import BertTokenizer, GPT2Tokenizer, EncoderDecoderModel, Trainer, TrainingArguments\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"3.0.2\n" | |
] | |
} | |
], | |
"source": [ | |
"print(transformers.__version__)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"logging.basicConfig(level=logging.INFO)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"INFO:transformers.configuration_utils:loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json from cache at /root/.cache/torch/transformers/b945b69218e98b3e2c95acf911789741307dec43c698d35fad11c1ae28bda352.9da767be51e1327499df13488672789394e2ca38b877837e52618a67d7002391\n", | |
"INFO:transformers.configuration_utils:Model config BertConfig {\n", | |
" \"architectures\": [\n", | |
" \"BertForMaskedLM\"\n", | |
" ],\n", | |
" \"attention_probs_dropout_prob\": 0.1,\n", | |
" \"gradient_checkpointing\": false,\n", | |
" \"hidden_act\": \"gelu\",\n", | |
" \"hidden_dropout_prob\": 0.1,\n", | |
" \"hidden_size\": 768,\n", | |
" \"initializer_range\": 0.02,\n", | |
" \"intermediate_size\": 3072,\n", | |
" \"layer_norm_eps\": 1e-12,\n", | |
" \"max_position_embeddings\": 512,\n", | |
" \"model_type\": \"bert\",\n", | |
" \"num_attention_heads\": 12,\n", | |
" \"num_hidden_layers\": 12,\n", | |
" \"pad_token_id\": 0,\n", | |
" \"type_vocab_size\": 2,\n", | |
" \"vocab_size\": 28996\n", | |
"}\n", | |
"\n", | |
"INFO:transformers.modeling_utils:loading weights file https://cdn.huggingface.co/bert-base-cased-pytorch_model.bin from cache at /root/.cache/torch/transformers/d8f11f061e407be64c4d5d7867ee61d1465263e24085cfa26abf183fdc830569.3fadbea36527ae472139fe84cddaa65454d7429f12d543d80bfc3ad70de55ac2\n", | |
"INFO:transformers.modeling_utils:All model checkpoint weights were used when initializing BertModel.\n", | |
"\n", | |
"INFO:transformers.modeling_utils:All the weights of BertModel were initialized from the model checkpoint at bert-base-cased.\n", | |
"If your task is similar to the task the model of the ckeckpoint was trained on, you can already use BertModel for predictions without further training.\n", | |
"INFO:transformers.configuration_utils:loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json from cache at /root/.cache/torch/transformers/4be02c5697d91738003fb1685c9872f284166aa32e061576bbe6aaeb95649fcf.db13c9bc9c7bdd738ec89e069621d88e05dc670366092d809a9cbcac6798e24e\n", | |
"INFO:transformers.configuration_utils:Model config GPT2Config {\n", | |
" \"activation_function\": \"gelu_new\",\n", | |
" \"architectures\": [\n", | |
" \"GPT2LMHeadModel\"\n", | |
" ],\n", | |
" \"attn_pdrop\": 0.1,\n", | |
" \"bos_token_id\": 50256,\n", | |
" \"embd_pdrop\": 0.1,\n", | |
" \"eos_token_id\": 50256,\n", | |
" \"initializer_range\": 0.02,\n", | |
" \"layer_norm_epsilon\": 1e-05,\n", | |
" \"model_type\": \"gpt2\",\n", | |
" \"n_ctx\": 1024,\n", | |
" \"n_embd\": 768,\n", | |
" \"n_head\": 12,\n", | |
" \"n_layer\": 12,\n", | |
" \"n_positions\": 1024,\n", | |
" \"resid_pdrop\": 0.1,\n", | |
" \"summary_activation\": null,\n", | |
" \"summary_first_dropout\": 0.1,\n", | |
" \"summary_proj_to_labels\": true,\n", | |
" \"summary_type\": \"cls_index\",\n", | |
" \"summary_use_proj\": true,\n", | |
" \"task_specific_params\": {\n", | |
" \"text-generation\": {\n", | |
" \"do_sample\": true,\n", | |
" \"max_length\": 50\n", | |
" }\n", | |
" },\n", | |
" \"vocab_size\": 50257\n", | |
"}\n", | |
"\n", | |
"INFO:transformers.modeling_encoder_decoder:Initializing gpt2 as a decoder model. Cross attention layers are added to gpt2 and randomly initialized if gpt2's architecture allows for cross attention layers.\n", | |
"INFO:transformers.modeling_utils:loading weights file https://cdn.huggingface.co/gpt2-pytorch_model.bin from cache at /root/.cache/torch/transformers/d71fd633e58263bd5e91dd3bde9f658bafd81e11ece622be6a3c2e4d42d8fd89.778cf36f5c4e5d94c8cd9cefcf2a580c8643570eb327f0d4a1f007fab2acbdf1\n", | |
"INFO:transformers.modeling_utils:All model checkpoint weights were used when initializing GPT2LMHeadModel.\n", | |
"\n", | |
"WARNING:transformers.modeling_utils:Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.0.attn.masked_bias', 'h.1.attn.masked_bias', 'h.2.attn.masked_bias', 'h.3.attn.masked_bias', 'h.4.attn.masked_bias', 'h.5.attn.masked_bias', 'h.6.attn.masked_bias', 'h.7.attn.masked_bias', 'h.8.attn.masked_bias', 'h.9.attn.masked_bias', 'h.10.attn.masked_bias', 'h.11.attn.masked_bias', 'lm_head.weight']\n", | |
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", | |
"INFO:transformers.configuration_encoder_decoder:Set `config.is_decoder=True` for decoder_config\n", | |
"INFO:transformers.tokenization_utils_base:loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt from cache at /root/.cache/torch/transformers/5e8a2b4893d13790ed4150ca1906be5f7a03d6c4ddf62296c383f6db42814db2.e13dbb970cb325137104fb2e5f36fe865f27746c6b526f6352861b1980eb80b1\n" | |
] | |
} | |
], | |
"source": [ | |
"model = EncoderDecoderModel.from_encoder_decoder_pretrained(\"bert-base-cased\", \"gpt2\")\n", | |
"# cache is currently not supported by EncoderDecoder framework\n", | |
"model.decoder.config.use_cache = False\n", | |
"bert_tokenizer = BertTokenizer.from_pretrained(\"bert-base-cased\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# CLS token will work as BOS token\n", | |
"bert_tokenizer.bos_token = bert_tokenizer.cls_token\n", | |
"\n", | |
"# SEP token will work as EOS token\n", | |
"bert_tokenizer.eos_token = bert_tokenizer.sep_token\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# make sure GPT2 appends EOS in begin and end\n", | |
"def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):\n", | |
" outputs = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]\n", | |
" return outputs" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"INFO:transformers.tokenization_utils_base:loading file https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json from cache at /root/.cache/torch/transformers/f2808208f9bec2320371a9f5f891c184ae0b674ef866b79c58177067d15732dd.1512018be4ba4e8726e41b9145129dc30651ea4fec86aa61f4b9f40bf94eac71\n", | |
"INFO:transformers.tokenization_utils_base:loading file https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt from cache at /root/.cache/torch/transformers/d629f792e430b3c76a1291bb2766b0a047e36fae0588f9dbc1ae51decdff691b.70bec105b4158ed9a1747fea67a43f5dee97855c64d62b6ec3742f4cfdb5feda\n" | |
] | |
} | |
], | |
"source": [ | |
"GPT2Tokenizer.build_inputs_with_special_tokens = build_inputs_with_special_tokens\n", | |
"gpt2_tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")\n", | |
"## gpt2_tokenizer = GPT2Tokenizer.from_pretrained(\"/output/gpt2-id-100/small\")\n", | |
"# set pad_token_id to unk_token_id -> be careful here as unk_token_id == eos_token_id == bos_token_id\n", | |
"gpt2_tokenizer.pad_token = gpt2_tokenizer.unk_token" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# set decoding params\n", | |
"model.config.decoder_start_token_id = gpt2_tokenizer.bos_token_id\n", | |
"model.config.eos_token_id = gpt2_tokenizer.eos_token_id\n", | |
"model.config.max_length = 142\n", | |
"model.config.min_length = 56\n", | |
"model.config.no_repeat_ngram_size = 3\n", | |
"model.early_stopping = True\n", | |
"model.length_penalty = 2.0\n", | |
"model.num_beams = 4" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"INFO:nlp.load:Checking /root/.cache/huggingface/datasets/720d2e20d8dc6d98f21195a39cc934bb41dd0a40b57ea3d323661a7c5d70522c.4fe1f8a4d3f3c15617ba15dd2d93f559a09627c62d0b04e22f89a5131b7bffb9.py for additional imports.\n", | |
"INFO:filelock:Lock 140606189019032 acquired on /root/.cache/huggingface/datasets/720d2e20d8dc6d98f21195a39cc934bb41dd0a40b57ea3d323661a7c5d70522c.4fe1f8a4d3f3c15617ba15dd2d93f559a09627c62d0b04e22f89a5131b7bffb9.py.lock\n", | |
"INFO:nlp.load:Found main folder for dataset https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/cnn_dailymail/cnn_dailymail.py at /sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/nlp/datasets/cnn_dailymail\n", | |
"INFO:nlp.load:Found specific version folder for dataset https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/cnn_dailymail/cnn_dailymail.py at /sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/nlp/datasets/cnn_dailymail/d8c27f2d603e2864036d92b0ec379f081896f6c28605ffd2e194c42cd04d48d8\n", | |
"INFO:nlp.load:Found script file from https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/cnn_dailymail/cnn_dailymail.py to /sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/nlp/datasets/cnn_dailymail/d8c27f2d603e2864036d92b0ec379f081896f6c28605ffd2e194c42cd04d48d8/cnn_dailymail.py\n", | |
"INFO:nlp.load:Found dataset infos file from https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/cnn_dailymail/dataset_infos.json to /sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/nlp/datasets/cnn_dailymail/d8c27f2d603e2864036d92b0ec379f081896f6c28605ffd2e194c42cd04d48d8/dataset_infos.json\n", | |
"INFO:nlp.load:Found metadata file for dataset https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/cnn_dailymail/cnn_dailymail.py at /sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/nlp/datasets/cnn_dailymail/d8c27f2d603e2864036d92b0ec379f081896f6c28605ffd2e194c42cd04d48d8/cnn_dailymail.json\n", | |
"INFO:filelock:Lock 140606189019032 released on /root/.cache/huggingface/datasets/720d2e20d8dc6d98f21195a39cc934bb41dd0a40b57ea3d323661a7c5d70522c.4fe1f8a4d3f3c15617ba15dd2d93f559a09627c62d0b04e22f89a5131b7bffb9.py.lock\n", | |
"INFO:nlp.info:Loading Dataset Infos from /sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/nlp/datasets/cnn_dailymail/d8c27f2d603e2864036d92b0ec379f081896f6c28605ffd2e194c42cd04d48d8\n", | |
"INFO:nlp.builder:Overwrite dataset info from restored data version.\n", | |
"INFO:nlp.info:Loading Dataset info from /root/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/d8c27f2d603e2864036d92b0ec379f081896f6c28605ffd2e194c42cd04d48d8\n", | |
"INFO:nlp.builder:Reusing dataset cnn_dailymail (/root/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/d8c27f2d603e2864036d92b0ec379f081896f6c28605ffd2e194c42cd04d48d8)\n", | |
"INFO:nlp.builder:Constructing Dataset for split train, from /root/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/d8c27f2d603e2864036d92b0ec379f081896f6c28605ffd2e194c42cd04d48d8\n", | |
"INFO:nlp.utils.info_utils:All the checksums matched successfully for post processing resources\n", | |
"INFO:nlp.load:Checking /root/.cache/huggingface/datasets/720d2e20d8dc6d98f21195a39cc934bb41dd0a40b57ea3d323661a7c5d70522c.4fe1f8a4d3f3c15617ba15dd2d93f559a09627c62d0b04e22f89a5131b7bffb9.py for additional imports.\n", | |
"INFO:filelock:Lock 140605783911784 acquired on /root/.cache/huggingface/datasets/720d2e20d8dc6d98f21195a39cc934bb41dd0a40b57ea3d323661a7c5d70522c.4fe1f8a4d3f3c15617ba15dd2d93f559a09627c62d0b04e22f89a5131b7bffb9.py.lock\n", | |
"INFO:nlp.load:Found main folder for dataset https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/cnn_dailymail/cnn_dailymail.py at /sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/nlp/datasets/cnn_dailymail\n", | |
"INFO:nlp.load:Found specific version folder for dataset https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/cnn_dailymail/cnn_dailymail.py at /sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/nlp/datasets/cnn_dailymail/d8c27f2d603e2864036d92b0ec379f081896f6c28605ffd2e194c42cd04d48d8\n", | |
"INFO:nlp.load:Found script file from https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/cnn_dailymail/cnn_dailymail.py to /sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/nlp/datasets/cnn_dailymail/d8c27f2d603e2864036d92b0ec379f081896f6c28605ffd2e194c42cd04d48d8/cnn_dailymail.py\n", | |
"INFO:nlp.load:Found dataset infos file from https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/cnn_dailymail/dataset_infos.json to /sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/nlp/datasets/cnn_dailymail/d8c27f2d603e2864036d92b0ec379f081896f6c28605ffd2e194c42cd04d48d8/dataset_infos.json\n", | |
"INFO:nlp.load:Found metadata file for dataset https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/cnn_dailymail/cnn_dailymail.py at /sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/nlp/datasets/cnn_dailymail/d8c27f2d603e2864036d92b0ec379f081896f6c28605ffd2e194c42cd04d48d8/cnn_dailymail.json\n", | |
"INFO:filelock:Lock 140605783911784 released on /root/.cache/huggingface/datasets/720d2e20d8dc6d98f21195a39cc934bb41dd0a40b57ea3d323661a7c5d70522c.4fe1f8a4d3f3c15617ba15dd2d93f559a09627c62d0b04e22f89a5131b7bffb9.py.lock\n", | |
"INFO:nlp.info:Loading Dataset Infos from /sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/nlp/datasets/cnn_dailymail/d8c27f2d603e2864036d92b0ec379f081896f6c28605ffd2e194c42cd04d48d8\n", | |
"INFO:nlp.builder:Overwrite dataset info from restored data version.\n", | |
"INFO:nlp.info:Loading Dataset info from /root/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/d8c27f2d603e2864036d92b0ec379f081896f6c28605ffd2e194c42cd04d48d8\n", | |
"INFO:nlp.builder:Reusing dataset cnn_dailymail (/root/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/d8c27f2d603e2864036d92b0ec379f081896f6c28605ffd2e194c42cd04d48d8)\n", | |
"INFO:nlp.builder:Constructing Dataset for split validation[:5%], from /root/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/d8c27f2d603e2864036d92b0ec379f081896f6c28605ffd2e194c42cd04d48d8\n", | |
"INFO:nlp.utils.info_utils:All the checksums matched successfully for post processing resources\n" | |
] | |
} | |
], | |
"source": [ | |
"# load train and validation data\n", | |
"train_dataset = nlp.load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"train\")\n", | |
"val_dataset = nlp.load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"validation[:5%]\")\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"INFO:nlp.load:Checking /root/.cache/huggingface/datasets/5ecb6e4b474317b41ae1fe5d702d1af8d86d452f0b1d70f77a12f6f014ded6ac.35bc2c477aa456d2f589656477ccb0b463c21cdfb83a9de86d63de8560a96d1b.py for additional imports.\n", | |
"INFO:filelock:Lock 140606158875056 acquired on /root/.cache/huggingface/datasets/5ecb6e4b474317b41ae1fe5d702d1af8d86d452f0b1d70f77a12f6f014ded6ac.35bc2c477aa456d2f589656477ccb0b463c21cdfb83a9de86d63de8560a96d1b.py.lock\n", | |
"INFO:nlp.load:Found main folder for metric https://s3.amazonaws.com/datasets.huggingface.co/nlp/metrics/rouge/rouge.py at /sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/nlp/metrics/rouge\n", | |
"INFO:nlp.load:Found specific version folder for metric https://s3.amazonaws.com/datasets.huggingface.co/nlp/metrics/rouge/rouge.py at /sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/nlp/metrics/rouge/06783dbed5f6b6a5413f84d2a5f0d9dc9cb871f1aeb3787f2c90a8e3fe60b1c1\n", | |
"INFO:nlp.load:Found script file from https://s3.amazonaws.com/datasets.huggingface.co/nlp/metrics/rouge/rouge.py to /sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/nlp/metrics/rouge/06783dbed5f6b6a5413f84d2a5f0d9dc9cb871f1aeb3787f2c90a8e3fe60b1c1/rouge.py\n", | |
"INFO:nlp.load:Couldn't find dataset infos file at https://s3.amazonaws.com/datasets.huggingface.co/nlp/metrics/rouge/dataset_infos.json\n", | |
"INFO:nlp.load:Found metadata file for metric https://s3.amazonaws.com/datasets.huggingface.co/nlp/metrics/rouge/rouge.py at /sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/nlp/metrics/rouge/06783dbed5f6b6a5413f84d2a5f0d9dc9cb871f1aeb3787f2c90a8e3fe60b1c1/rouge.json\n", | |
"INFO:filelock:Lock 140606158875056 released on /root/.cache/huggingface/datasets/5ecb6e4b474317b41ae1fe5d702d1af8d86d452f0b1d70f77a12f6f014ded6ac.35bc2c477aa456d2f589656477ccb0b463c21cdfb83a9de86d63de8560a96d1b.py.lock\n", | |
"INFO:filelock:Lock 140605692902928 acquired on /root/.cache/huggingface/metrics/rouge/default/1.0.0/06783dbed5f6b6a5413f84d2a5f0d9dc9cb871f1aeb3787f2c90a8e3fe60b1c1/1-rouge-0.arrow.lock\n" | |
] | |
} | |
], | |
"source": [ | |
"# load rouge for validation\n", | |
"rouge = nlp.load_metric(\"rouge\", experiment_id=1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"encoder_length = 512\n", | |
"decoder_length = 128\n", | |
"batch_size = 16\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# map data correctly\n", | |
"def map_to_encoder_decoder_inputs(batch): # Tokenizer will automatically set [BOS] <text> [EOS] \n", | |
" # use bert tokenizer here for encoder\n", | |
" inputs = bert_tokenizer(batch[\"article\"], padding=\"max_length\", truncation=True, max_length=encoder_length)\n", | |
" # force summarization <= 128\n", | |
" outputs = gpt2_tokenizer(batch[\"highlights\"], padding=\"max_length\", truncation=True, max_length=decoder_length)\n", | |
"\n", | |
" batch[\"input_ids\"] = inputs.input_ids\n", | |
" batch[\"attention_mask\"] = inputs.attention_mask\n", | |
" batch[\"decoder_input_ids\"] = outputs.input_ids\n", | |
" batch[\"labels\"] = outputs.input_ids.copy()\n", | |
" batch[\"decoder_attention_mask\"] = outputs.attention_mask\n", | |
"\n", | |
" # complicated list comprehension here because pad_token_id alone is not good enough to know whether label should be excluded or not\n", | |
" batch[\"labels\"] = [\n", | |
" [-100 if mask == 0 else token for mask, token in mask_and_tokens] for mask_and_tokens in [zip(masks, labels) for masks, labels in zip(batch[\"decoder_attention_mask\"], batch[\"labels\"])]\n", | |
" ]\n", | |
"\n", | |
" assert all([len(x) == encoder_length for x in inputs.input_ids])\n", | |
" assert all([len(x) == decoder_length for x in outputs.input_ids])\n", | |
"\n", | |
" return batch\n", | |
"\n", | |
"\n", | |
"def compute_metrics(pred):\n", | |
" labels_ids = pred.label_ids\n", | |
" pred_ids = pred.predictions\n", | |
"\n", | |
" # all unnecessary tokens are removed\n", | |
" pred_str = gpt2_tokenizer.batch_decode(pred_ids, skip_special_tokens=True)\n", | |
" labels_ids[labels_ids == -100] = gpt2_tokenizer.eos_token_id\n", | |
" label_str = gpt2_tokenizer.batch_decode(labels_ids, skip_special_tokens=True)\n", | |
"\n", | |
" rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=[\"rouge2\"])[\"rouge2\"].mid\n", | |
"\n", | |
" return {\n", | |
" \"rouge2_precision\": round(rouge_output.precision, 4),\n", | |
" \"rouge2_recall\": round(rouge_output.recall, 4),\n", | |
" \"rouge2_fmeasure\": round(rouge_output.fmeasure, 4),\n", | |
" }\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"INFO:nlp.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/d8c27f2d603e2864036d92b0ec379f081896f6c28605ffd2e194c42cd04d48d8/cache-13d913db1e04ac617cf6323b5be63ae6.arrow\n", | |
"INFO:nlp.arrow_dataset:Set __getitem__(key) output type to torch for ['input_ids', 'attention_mask', 'decoder_input_ids', 'decoder_attention_mask', 'labels'] columns (when key is int or slice) and don't output other (un-formated) columns.\n", | |
"INFO:nlp.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/d8c27f2d603e2864036d92b0ec379f081896f6c28605ffd2e194c42cd04d48d8/cache-c84479f093fba2dbcbc88f32a9900f77.arrow\n", | |
"INFO:nlp.arrow_dataset:Set __getitem__(key) output type to torch for ['input_ids', 'attention_mask', 'decoder_input_ids', 'decoder_attention_mask', 'labels'] columns (when key is int or slice) and don't output other (un-formated) columns.\n" | |
] | |
} | |
], | |
"source": [ | |
"# make train dataset ready\n", | |
"train_dataset = train_dataset.map(\n", | |
" map_to_encoder_decoder_inputs, batched=True, batch_size=batch_size, remove_columns=[\"article\", \"highlights\"],\n", | |
")\n", | |
"train_dataset.set_format(\n", | |
" type=\"torch\", columns=[\"input_ids\", \"attention_mask\", \"decoder_input_ids\", \"decoder_attention_mask\", \"labels\"],\n", | |
")\n", | |
"\n", | |
"# same for validation dataset\n", | |
"val_dataset = val_dataset.map(\n", | |
" map_to_encoder_decoder_inputs, batched=True, batch_size=batch_size, remove_columns=[\"article\", \"highlights\"],\n", | |
")\n", | |
"val_dataset.set_format(\n", | |
" type=\"torch\", columns=[\"input_ids\", \"attention_mask\", \"decoder_input_ids\", \"decoder_attention_mask\", \"labels\"],\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# set training arguments - these params are not really tuned, feel free to change\n", | |
"training_args = TrainingArguments(\n", | |
" output_dir=\"./\",\n", | |
" per_device_train_batch_size=batch_size,\n", | |
" per_device_eval_batch_size=batch_size,\n", | |
" #predict_from_generate=True,\n", | |
" #evaluate_during_training=True,\n", | |
" do_train=True,\n", | |
" do_eval=True,\n", | |
" logging_steps=1000,\n", | |
" save_steps=1000,\n", | |
" eval_steps=1000,\n", | |
" overwrite_output_dir=True,\n", | |
" warmup_steps=2000,\n", | |
" save_total_limit=10,\n", | |
" fp16=False,\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"INFO:transformers.training_args:PyTorch: setting up devices\n", | |
"INFO:transformers.trainer:Automatic Weights & Biases logging enabled, to disable set os.environ[\"WANDB_DISABLED\"] = \"true\"\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/html": [ | |
"\n", | |
" Logging results to <a href=\"https://wandb.com\" target=\"_blank\">Weights & Biases</a> <a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">(Documentation)</a>.<br/>\n", | |
" Project page: <a href=\"https://app.wandb.ai/cahya/huggingface\" target=\"_blank\">https://app.wandb.ai/cahya/huggingface</a><br/>\n", | |
" Run page: <a href=\"https://app.wandb.ai/cahya/huggingface/runs/2q9d4if2\" target=\"_blank\">https://app.wandb.ai/cahya/huggingface/runs/2q9d4if2</a><br/>\n", | |
" " | |
], | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"INFO:wandb.run_manager:system metrics and metadata threads started\n", | |
"INFO:wandb.run_manager:checking resume status, waiting at most 10 seconds\n", | |
"INFO:wandb.run_manager:resuming run from id: UnVuOnYxOjJxOWQ0aWYyOmh1Z2dpbmdmYWNlOmNhaHlh\n", | |
"INFO:wandb.run_manager:upserting run before process can begin, waiting at most 10 seconds\n", | |
"INFO:wandb.run_manager:saving pip packages\n", | |
"INFO:wandb.run_manager:initializing streaming files api\n", | |
"INFO:wandb.run_manager:unblocking file change observer, beginning sync with W&B servers\n", | |
"INFO:wandb.run_manager:shutting down system stats and metadata service\n", | |
"INFO:wandb.run_manager:file/dir modified: /root/Work/language-modeling-private/Transformers/Tasks/wandb/run-20200908_110903-2q9d4if2/config.yaml\n", | |
"INFO:wandb.run_manager:file/dir created: /root/Work/language-modeling-private/Transformers/Tasks/wandb/run-20200908_110903-2q9d4if2/requirements.txt\n", | |
"INFO:wandb.run_manager:file/dir created: /root/Work/language-modeling-private/Transformers/Tasks/wandb/run-20200908_110903-2q9d4if2/wandb-summary.json\n", | |
"INFO:wandb.run_manager:file/dir created: /root/Work/language-modeling-private/Transformers/Tasks/wandb/run-20200908_110903-2q9d4if2/wandb-history.jsonl\n", | |
"INFO:wandb.run_manager:file/dir created: /root/Work/language-modeling-private/Transformers/Tasks/wandb/run-20200908_110903-2q9d4if2/wandb-metadata.json\n", | |
"INFO:wandb.run_manager:file/dir created: /root/Work/language-modeling-private/Transformers/Tasks/wandb/run-20200908_110903-2q9d4if2/wandb-events.jsonl\n", | |
"INFO:wandb.run_manager:stopping streaming files and file change observer\n", | |
"INFO:wandb.run_manager:file/dir modified: /root/Work/language-modeling-private/Transformers/Tasks/wandb/run-20200908_110903-2q9d4if2/wandb-metadata.json\n" | |
] | |
} | |
], | |
"source": [ | |
"# instantiate trainer\n", | |
"trainer = Trainer(\n", | |
" model=model,\n", | |
" args=training_args,\n", | |
" compute_metrics=compute_metrics,\n", | |
" train_dataset=train_dataset,\n", | |
" eval_dataset=val_dataset,\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"INFO:transformers.trainer:***** Running training *****\n", | |
"INFO:transformers.trainer: Num examples = 287113\n", | |
"INFO:transformers.trainer: Num Epochs = 3\n", | |
"INFO:transformers.trainer: Instantaneous batch size per device = 16\n", | |
"INFO:transformers.trainer: Total train batch size (w. parallel, distributed & accumulation) = 128\n", | |
"INFO:transformers.trainer: Gradient Accumulation steps = 1\n", | |
"INFO:transformers.trainer: Total optimization steps = 6732\n" | |
] | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "b5b54269b5604a078896beb13cb96452", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(FloatProgress(value=0.0, description='Epoch', max=3.0, style=ProgressStyle(description_width='i…" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"data": { | |
"application/vnd.jupyter.widget-view+json": { | |
"model_id": "757b3693158a4c278138b542e2ffd69a", | |
"version_major": 2, | |
"version_minor": 0 | |
}, | |
"text/plain": [ | |
"HBox(children=(FloatProgress(value=0.0, description='Iteration', max=2244.0, style=ProgressStyle(description_w…" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/opt/conda/conda-bld/pytorch_1591914895884/work/torch/csrc/utils/tensor_numpy.cpp:141: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\n", | |
"\n" | |
] | |
}, | |
{ | |
"ename": "TypeError", | |
"evalue": "Caught TypeError in replica 0 on device 0.\nOriginal Traceback (most recent call last):\n File \"/sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py\", line 60, in _worker\n output = module(*input, **kwargs)\n File \"/sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/torch/nn/modules/module.py\", line 550, in __call__\n result = self.forward(*input, **kwargs)\n File \"/root/Work/transformers/src/transformers/modeling_encoder_decoder.py\", line 290, in forward\n **kwargs_decoder,\n File \"/sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/torch/nn/modules/module.py\", line 550, in __call__\n result = self.forward(*input, **kwargs)\nTypeError: forward() got an unexpected keyword argument 'encoder_hidden_states'\n", | |
"output_type": "error", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m<ipython-input-16-c108335b43e0>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# start training\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", | |
"\u001b[0;32m~/Work/transformers/src/transformers/trainer.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(self, model_path)\u001b[0m\n\u001b[1;32m 497\u001b[0m \u001b[0;32mcontinue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 498\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 499\u001b[0;31m \u001b[0mtr_loss\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_training_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 500\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 501\u001b[0m if (step + 1) % self.args.gradient_accumulation_steps == 0 or (\n", | |
"\u001b[0;32m~/Work/transformers/src/transformers/trainer.py\u001b[0m in \u001b[0;36m_training_step\u001b[0;34m(self, model, inputs, optimizer)\u001b[0m\n\u001b[1;32m 620\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"mems\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_past\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 621\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 622\u001b[0;31m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 623\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moutputs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;31m# model outputs are always tuple in transformers (see doc)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 624\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 548\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 549\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 550\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 551\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 552\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, *inputs, **kwargs)\u001b[0m\n\u001b[1;32m 153\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 154\u001b[0m \u001b[0mreplicas\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreplicate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodule\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice_ids\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 155\u001b[0;31m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparallel_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreplicas\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 156\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgather\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moutput_device\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 157\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py\u001b[0m in \u001b[0;36mparallel_apply\u001b[0;34m(self, replicas, inputs, kwargs)\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 164\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mparallel_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreplicas\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 165\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mparallel_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreplicas\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice_ids\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreplicas\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 166\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 167\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mgather\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput_device\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py\u001b[0m in \u001b[0;36mparallel_apply\u001b[0;34m(modules, inputs, kwargs_tup, devices)\u001b[0m\n\u001b[1;32m 83\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 84\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mExceptionWrapper\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 85\u001b[0;31m \u001b[0moutput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreraise\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 86\u001b[0m \u001b[0moutputs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 87\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0moutputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/torch/_utils.py\u001b[0m in \u001b[0;36mreraise\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 393\u001b[0m \u001b[0;31m# (https://bugs.python.org/issue2651), so we work around it.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 394\u001b[0m \u001b[0mmsg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mKeyErrorMessage\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 395\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexc_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", | |
"\u001b[0;31mTypeError\u001b[0m: Caught TypeError in replica 0 on device 0.\nOriginal Traceback (most recent call last):\n File \"/sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py\", line 60, in _worker\n output = module(*input, **kwargs)\n File \"/sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/torch/nn/modules/module.py\", line 550, in __call__\n result = self.forward(*input, **kwargs)\n File \"/root/Work/transformers/src/transformers/modeling_encoder_decoder.py\", line 290, in forward\n **kwargs_decoder,\n File \"/sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/torch/nn/modules/module.py\", line 550, in __call__\n result = self.forward(*input, **kwargs)\nTypeError: forward() got an unexpected keyword argument 'encoder_hidden_states'\n" | |
] | |
} | |
], | |
"source": [ | |
"# start training\n", | |
"trainer.train()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.0" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment