Created
October 8, 2022 04:57
-
-
Save xhluca/28181468e3907145027969a1003ae929 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "792293ef-bc9c-4955-a865-9a8084b7f05f", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import os\n", | |
"import json\n", | |
"\n", | |
"import transformers as hft\n", | |
"import torch\n", | |
"import torch.nn.functional as F\n", | |
"from tqdm import tqdm" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "c7703674-614c-48b9-9a8d-02dc0f5c2f38", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class NQDataset(torch.utils.data.Dataset):\n", | |
" def __init__(self, data: list):\n", | |
" self.data = data\n", | |
" \n", | |
" def __len__(self):\n", | |
" return len(self.data)\n", | |
" \n", | |
" def __getitem__(self, idx):\n", | |
" return (\n", | |
" self.data[idx]['question'],\n", | |
" self.data[idx]['positive_ctxs'][0]['title'],\n", | |
" self.data[idx]['positive_ctxs'][0]['text'],\n", | |
" )\n", | |
"\n", | |
"def get_schedule_linear(\n", | |
" optimizer,\n", | |
" warmup_steps,\n", | |
" total_training_steps,\n", | |
" steps_shift=0,\n", | |
" last_epoch=-1,\n", | |
"):\n", | |
" \"\"\"\n", | |
" Create a schedule with a learning rate that decreases linearly after\n", | |
" linearly increasing during a warmup period.\n", | |
" \n", | |
" Source: https://github.com/facebookresearch/DPR/blob/1ee31c6c53/dpr/utils/model_utils.py\n", | |
" \"\"\"\n", | |
" def lr_lambda(current_step):\n", | |
" current_step += steps_shift\n", | |
" if current_step < warmup_steps:\n", | |
" return float(current_step) / float(max(1, warmup_steps))\n", | |
" return max(\n", | |
" 1e-7,\n", | |
" float(total_training_steps - current_step) / float(max(1, total_training_steps - warmup_steps)),\n", | |
" )\n", | |
"\n", | |
" return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)\n", | |
"\n", | |
"def criterion(S, target=None):\n", | |
" softS = F.log_softmax(S, dim=1)\n", | |
" if target is None:\n", | |
" target = torch.arange(0, S.shape[0])\n", | |
"\n", | |
" target = target.to(softS.device)\n", | |
" loss = F.nll_loss(softS, target, reduction=\"mean\")\n", | |
"\n", | |
" return loss" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "66f6c6cb-91db-455e-beb6-7dccb1962537", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"batch_size = 128\n", | |
"warmup_steps = 1237\n", | |
"num_epochs = 40\n", | |
"learning_rate = 2e-5 # 2e-5 for NQ, 1e-5 for other datasets\n", | |
"max_length = 256\n", | |
"\n", | |
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "40018b4e-db1e-4fa9-9f99-26292cd50119", | |
"metadata": { | |
"collapsed": true, | |
"jupyter": { | |
"outputs_hidden": true | |
}, | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight']\n", | |
"- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", | |
"- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", | |
"Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight']\n", | |
"- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", | |
"- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" | |
] | |
} | |
], | |
"source": [ | |
"q_tokenizer = hft.AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n", | |
"ctx_tokenizer = hft.AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n", | |
"\n", | |
"q_encoder = hft.AutoModel.from_pretrained(\"bert-base-uncased\").to(device)\n", | |
"ctx_encoder = hft.AutoModel.from_pretrained(\"bert-base-uncased\").to(device)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "17fe0612-88d0-4013-88d7-29ba31d96873", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 33.6 s, sys: 11.3 s, total: 44.9 s\n", | |
"Wall time: 44.9 s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"train_json = json.load(open('data/biencoder-nq-train.json'))\n", | |
"valid_json = json.load(open('data/biencoder-nq-dev.json'))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "5cd6e088-b706-469c-b95c-4f664f1ed962", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"train_loader = torch.utils.data.DataLoader(\n", | |
" NQDataset(train_json), batch_size=batch_size, shuffle=True\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "beb5f7fd-9778-4e7e-8fd5-138041085fea", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"optimizer = torch.optim.AdamW(\n", | |
" [{\"params\": ctx_encoder.parameters()}, {\"params\": q_encoder.parameters()}],\n", | |
" lr=learning_rate, \n", | |
" weight_decay=0.0,\n", | |
")\n", | |
"total_training_steps = num_epochs * len(train_loader)\n", | |
"scheduler = get_schedule_linear(\n", | |
" optimizer, warmup_steps=warmup_steps, total_training_steps=total_training_steps\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "5d99c605-32a8-43f8-9e06-53d4b2219318", | |
"metadata": { | |
"collapsed": true, | |
"jupyter": { | |
"outputs_hidden": true | |
}, | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Epoch #0: 100%|██████████| 460/460 [07:38<00:00, 1.00it/s, loss=4.66]\n", | |
"Epoch #1: 100%|██████████| 460/460 [07:38<00:00, 1.00it/s, loss=2.72]\n", | |
"Epoch #2: 100%|██████████| 460/460 [07:35<00:00, 1.01it/s, loss=0.653]\n", | |
"Epoch #3: 100%|██████████| 460/460 [07:36<00:00, 1.01it/s, loss=0.406]\n", | |
"Epoch #4: 100%|██████████| 460/460 [07:35<00:00, 1.01it/s, loss=0.381] \n", | |
"Epoch #5: 100%|██████████| 460/460 [07:35<00:00, 1.01it/s, loss=0.242] \n", | |
"Epoch #6: 100%|██████████| 460/460 [07:32<00:00, 1.02it/s, loss=0.224] \n", | |
"Epoch #7: 100%|██████████| 460/460 [07:34<00:00, 1.01it/s, loss=0.0726]\n", | |
"Epoch #8: 100%|██████████| 460/460 [07:34<00:00, 1.01it/s, loss=0.0708]\n", | |
"Epoch #9: 100%|██████████| 460/460 [07:33<00:00, 1.01it/s, loss=0.11] \n", | |
"Epoch #10: 100%|██████████| 460/460 [07:33<00:00, 1.01it/s, loss=0.0283] \n", | |
"Epoch #11: 100%|██████████| 460/460 [07:32<00:00, 1.02it/s, loss=0.0475] \n", | |
"Epoch #12: 100%|██████████| 460/460 [07:34<00:00, 1.01it/s, loss=0.0762] \n", | |
"Epoch #13: 73%|███████▎ | 334/460 [05:29<02:04, 1.01it/s, loss=0.019] IOPub message rate exceeded.\n", | |
"The Jupyter server will temporarily stop sending output\n", | |
"to the client in order to avoid crashing it.\n", | |
"To change this limit, set the config variable\n", | |
"`--ServerApp.iopub_msg_rate_limit`.\n", | |
"\n", | |
"Current values:\n", | |
"ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n", | |
"ServerApp.rate_limit_window=3.0 (secs)\n", | |
"\n", | |
"Epoch #25: 100%|██████████| 460/460 [07:34<00:00, 1.01it/s, loss=0.0115] \n", | |
"Epoch #26: 100%|██████████| 460/460 [07:33<00:00, 1.01it/s, loss=0.00752] \n", | |
"Epoch #27: 100%|██████████| 460/460 [07:33<00:00, 1.01it/s, loss=0.0207] \n", | |
"Epoch #28: 27%|██▋ | 125/460 [02:04<05:30, 1.01it/s, loss=0.00159] IOPub message rate exceeded.\n", | |
"The Jupyter server will temporarily stop sending output\n", | |
"to the client in order to avoid crashing it.\n", | |
"To change this limit, set the config variable\n", | |
"`--ServerApp.iopub_msg_rate_limit`.\n", | |
"\n", | |
"Current values:\n", | |
"ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n", | |
"ServerApp.rate_limit_window=3.0 (secs)\n", | |
"\n", | |
"Epoch #39: 100%|██████████| 460/460 [07:33<00:00, 1.01it/s, loss=0.000152]\n" | |
] | |
} | |
], | |
"source": [ | |
"results = {}\n", | |
"\n", | |
"for epoch_num in range(num_epochs):\n", | |
" results[epoch_num] = []\n", | |
" \n", | |
" ctx_encoder.train()\n", | |
" q_encoder.train()\n", | |
"\n", | |
" pbar = tqdm(train_loader, desc=f'Epoch #{epoch_num}')\n", | |
" \n", | |
" for q, t, c in pbar:\n", | |
" tokenizer_kwargs = dict(max_length=max_length, return_tensors=\"pt\", truncation=True, padding=True)\n", | |
" queries = q_tokenizer(list(q), **tokenizer_kwargs).to(device)\n", | |
" passages = ctx_tokenizer(list(t), list(c), **tokenizer_kwargs).to(device)\n", | |
"\n", | |
" optimizer.zero_grad()\n", | |
"\n", | |
" Q = q_encoder(**queries).pooler_output\n", | |
" P = ctx_encoder(**passages).pooler_output.to(Q.device)\n", | |
" S = torch.mm(Q, P.T)\n", | |
"\n", | |
" loss = criterion(S)\n", | |
"\n", | |
" loss.backward()\n", | |
" \n", | |
" optimizer.step()\n", | |
" scheduler.step()\n", | |
"\n", | |
" pbar.set_postfix({'loss': loss.item()})\n", | |
" \n", | |
" results[epoch_num].append(loss.item())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "32bee074-3854-4e4a-8b18-9228f03d97bb", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"os.makedirs('models/', exist_ok=True)\n", | |
"\n", | |
"with open('models/results.json', 'w') as f:\n", | |
" json.dump(results, f)\n", | |
"\n", | |
"ctx_encoder.save_pretrained('models/dpr-nq-reproduced/ctx-encoder')\n", | |
"q_encoder.save_pretrained('models/dpr-nq-reproduced/q-encoder')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "88818375-d83a-4d17-8014-3b8c8bf95a25", | |
"metadata": { | |
"collapsed": true, | |
"jupyter": { | |
"outputs_hidden": true | |
}, | |
"tags": [] | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"('models/dpr-nq-reproduced/q-encoder/tokenizer_config.json',\n", | |
" 'models/dpr-nq-reproduced/q-encoder/special_tokens_map.json',\n", | |
" 'models/dpr-nq-reproduced/q-encoder/vocab.txt',\n", | |
" 'models/dpr-nq-reproduced/q-encoder/added_tokens.json',\n", | |
" 'models/dpr-nq-reproduced/q-encoder/tokenizer.json')" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"ctx_tokenizer.save_pretrained('models/dpr-nq-reproduced/ctx-encoder')\n", | |
"q_tokenizer.save_pretrained('models/dpr-nq-reproduced/q-encoder')" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.13" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment