Last active
February 15, 2021 17:58
-
-
Save davidefiocco/6c77070d838328c0b13546886de5c06a to your computer and use it in GitHub Desktop.
Text classification in PyTorch to refactor with petastorm.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "Text classification in PyTorch to refactor with petastorm.ipynb", | |
"provenance": [], | |
"collapsed_sections": [], | |
"authorship_tag": "ABX9TyOa6VtIQB5zcrstMpyPUWiu", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/davidefiocco/6c77070d838328c0b13546886de5c06a/text-classification-in-pytorch-to-refactor-with-petastorm.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "C3dhjnf3cJcd" | |
}, | |
"source": [ | |
"!pip install transformers --quiet" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "ZplY26R4oz50" | |
}, | |
"source": [ | |
"The code below is a PyTorch text classifier obtained by getting code from https://pytorch.org/tutorials/beginner/text_sentiment_ngrams_tutorial.html, changed a little to work on a custom dataframe. \r\n", | |
"How can I transform this to work with pyspark dataframes instead of pandas dataframes?" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "FVfGgGQycB7Y" | |
}, | |
"source": [ | |
"import pandas as pd\r\n", | |
"import torch\r\n", | |
"from torch.utils.data.dataset import Dataset\r\n", | |
"from transformers import BertTokenizer\r\n", | |
"import torch.nn as nn\r\n", | |
"import torch.nn.functional as F\r\n", | |
"from torch.utils.data import DataLoader\r\n", | |
"\r\n", | |
"# using HuggingFace tokenization\r\n", | |
"tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\r\n", | |
"\r\n", | |
"text = [\"This is a test.\", \"This is not a test.\"]*100\r\n", | |
"label = [1, 0]*100\r\n", | |
"\r\n", | |
"df = pd.DataFrame({\"text\": text, \"label\": label})\r\n", | |
"df[\"tokenized\"] = df[\"text\"].apply(tokenizer.encode)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 206 | |
}, | |
"id": "SLL5N35nn8CL", | |
"outputId": "630fcf0f-c1be-4556-86ae-664538d9c966" | |
}, | |
"source": [ | |
"df.head()" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>text</th>\n", | |
" <th>label</th>\n", | |
" <th>tokenized</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>This is a test.</td>\n", | |
" <td>1</td>\n", | |
" <td>[101, 2023, 2003, 1037, 3231, 1012, 102]</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>This is not a test.</td>\n", | |
" <td>0</td>\n", | |
" <td>[101, 2023, 2003, 2025, 1037, 3231, 1012, 102]</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>This is a test.</td>\n", | |
" <td>1</td>\n", | |
" <td>[101, 2023, 2003, 1037, 3231, 1012, 102]</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>This is not a test.</td>\n", | |
" <td>0</td>\n", | |
" <td>[101, 2023, 2003, 2025, 1037, 3231, 1012, 102]</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>This is a test.</td>\n", | |
" <td>1</td>\n", | |
" <td>[101, 2023, 2003, 1037, 3231, 1012, 102]</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" text label tokenized\n", | |
"0 This is a test. 1 [101, 2023, 2003, 1037, 3231, 1012, 102]\n", | |
"1 This is not a test. 0 [101, 2023, 2003, 2025, 1037, 3231, 1012, 102]\n", | |
"2 This is a test. 1 [101, 2023, 2003, 1037, 3231, 1012, 102]\n", | |
"3 This is not a test. 0 [101, 2023, 2003, 2025, 1037, 3231, 1012, 102]\n", | |
"4 This is a test. 1 [101, 2023, 2003, 1037, 3231, 1012, 102]" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 3 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "dRMFGrjPcFI_" | |
}, | |
"source": [ | |
"# Using pandas dataframe I use this class to create a custom PyTorch dataset\r\n", | |
"\r\n", | |
"class TokenizedDataset(Dataset):\r\n", | |
"\r\n", | |
" def __init__(self, df):\r\n", | |
" self.data = df\r\n", | |
" \r\n", | |
" def __getitem__(self, index):\r\n", | |
" text = self.data.loc[index].tokenized\r\n", | |
" text = torch.LongTensor(text)\r\n", | |
" label = self.data.loc[index].label\r\n", | |
" return (text, label)\r\n", | |
"\r\n", | |
" def __len__(self):\r\n", | |
" count = len(self.data)\r\n", | |
" return count\r\n", | |
"\r\n", | |
"train_dataset = TokenizedDataset(df)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "1bv2nnj_co5K" | |
}, | |
"source": [ | |
"def generate_batch(batch):\r\n", | |
" label = torch.LongTensor([entry[1] for entry in batch])\r\n", | |
" text = [entry[0] for entry in batch]\r\n", | |
" offsets = [0] + [len(entry) for entry in text]\r\n", | |
" # torch.Tensor.cumsum returns the cumulative sum\r\n", | |
" # of elements in the dimension dim.\r\n", | |
" # torch.Tensor([1.0, 2.0, 3.0]).cumsum(dim=0)\r\n", | |
"\r\n", | |
" offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)\r\n", | |
" text = torch.cat(text)\r\n", | |
" return text, offsets, label" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "mNzkJINlcpl6" | |
}, | |
"source": [ | |
"def train_func(sub_train_):\r\n", | |
"\r\n", | |
" # Train the model\r\n", | |
" train_loss = 0\r\n", | |
" train_acc = 0\r\n", | |
" data = DataLoader(sub_train_, batch_size=BATCH_SIZE, shuffle=True,\r\n", | |
" collate_fn=generate_batch)\r\n", | |
" for i, (text, offsets, cls) in enumerate(data):\r\n", | |
" optimizer.zero_grad()\r\n", | |
" text, offsets, cls = text.to(device), offsets.to(device), cls.to(device)\r\n", | |
" output = model(text, offsets)\r\n", | |
" loss = criterion(output, cls)\r\n", | |
" train_loss += loss.item()\r\n", | |
" loss.backward()\r\n", | |
" optimizer.step()\r\n", | |
" train_acc += (output.argmax(1) == cls).sum().item()\r\n", | |
"\r\n", | |
" # Adjust the learning rate\r\n", | |
" scheduler.step()\r\n", | |
"\r\n", | |
" return train_loss / len(sub_train_), train_acc / len(sub_train_)\r\n", | |
"\r\n", | |
"def test(data_):\r\n", | |
" loss = 0\r\n", | |
" acc = 0\r\n", | |
" data = DataLoader(data_, batch_size=BATCH_SIZE, collate_fn=generate_batch)\r\n", | |
" for text, offsets, cls in data:\r\n", | |
" text, offsets, cls = text.to(device), offsets.to(device), cls.to(device)\r\n", | |
" with torch.no_grad():\r\n", | |
" output = model(text, offsets)\r\n", | |
" loss = criterion(output, cls)\r\n", | |
" loss += loss.item()\r\n", | |
" acc += (output.argmax(1) == cls).sum().item()\r\n", | |
"\r\n", | |
" return loss / len(data_), acc / len(data_)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "b63NVBxqd5T1" | |
}, | |
"source": [ | |
"class TextSentiment(nn.Module):\r\n", | |
" def __init__(self, vocab_size, embed_dim, num_class):\r\n", | |
" super().__init__()\r\n", | |
" self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)\r\n", | |
" self.fc = nn.Linear(embed_dim, num_class)\r\n", | |
" self.init_weights()\r\n", | |
"\r\n", | |
" def init_weights(self):\r\n", | |
" initrange = 0.5\r\n", | |
" self.embedding.weight.data.uniform_(-initrange, initrange)\r\n", | |
" self.fc.weight.data.uniform_(-initrange, initrange)\r\n", | |
" self.fc.bias.data.zero_()\r\n", | |
"\r\n", | |
" def forward(self, text, offsets):\r\n", | |
" embedded = self.embedding(text, offsets)\r\n", | |
" return self.fc(embedded)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "ak81UL1FkKly" | |
}, | |
"source": [ | |
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "I-qVHqzqmoxE" | |
}, | |
"source": [ | |
"VOCAB_SIZE = 31090\r\n", | |
"EMBED_DIM = 768\r\n", | |
"NUM_CLASS = 2\r\n", | |
"model = TextSentiment(VOCAB_SIZE, EMBED_DIM, NUM_CLASS).to(device)" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "QYrKEBF9d565", | |
"outputId": "6a9a86f5-be10-4a1e-929f-8931d420e43e" | |
}, | |
"source": [ | |
"import time\r\n", | |
"from torch.utils.data.dataset import random_split\r\n", | |
"N_EPOCHS = 5\r\n", | |
"BATCH_SIZE = 8\r\n", | |
"min_valid_loss = float('inf')\r\n", | |
"\r\n", | |
"criterion = torch.nn.CrossEntropyLoss().to(device)\r\n", | |
"optimizer = torch.optim.SGD(model.parameters(), lr=4.0)\r\n", | |
"scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.9)\r\n", | |
"\r\n", | |
"train_len = int(len(train_dataset) * 0.95)\r\n", | |
"sub_train_, sub_valid_ = \\\r\n", | |
" random_split(train_dataset, [train_len, len(train_dataset) - train_len])\r\n", | |
"\r\n", | |
"for epoch in range(N_EPOCHS):\r\n", | |
"\r\n", | |
" start_time = time.time()\r\n", | |
" train_loss, train_acc = train_func(sub_train_)\r\n", | |
" valid_loss, valid_acc = test(sub_valid_)\r\n", | |
"\r\n", | |
" secs = int(time.time() - start_time)\r\n", | |
" mins = secs / 60\r\n", | |
" secs = secs % 60\r\n", | |
"\r\n", | |
" print('Epoch: %d' %(epoch + 1), \" | time in %d minutes, %d seconds\" %(mins, secs))\r\n", | |
" print(f'\\tLoss: {train_loss:.4f}(train)\\t|\\tAcc: {train_acc * 100:.1f}%(train)')\r\n", | |
" print(f'\\tLoss: {valid_loss:.4f}(valid)\\t|\\tAcc: {valid_acc * 100:.1f}%(valid)')" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Epoch: 1 | time in 0 minutes, 0 seconds\n", | |
"\tLoss: 0.2762(train)\t|\tAcc: 84.7%(train)\n", | |
"\tLoss: 0.0000(valid)\t|\tAcc: 100.0%(valid)\n", | |
"Epoch: 2 | time in 0 minutes, 0 seconds\n", | |
"\tLoss: 0.0000(train)\t|\tAcc: 100.0%(train)\n", | |
"\tLoss: 0.0000(valid)\t|\tAcc: 100.0%(valid)\n", | |
"Epoch: 3 | time in 0 minutes, 0 seconds\n", | |
"\tLoss: 0.0000(train)\t|\tAcc: 100.0%(train)\n", | |
"\tLoss: 0.0000(valid)\t|\tAcc: 100.0%(valid)\n", | |
"Epoch: 4 | time in 0 minutes, 0 seconds\n", | |
"\tLoss: 0.0000(train)\t|\tAcc: 100.0%(train)\n", | |
"\tLoss: 0.0000(valid)\t|\tAcc: 100.0%(valid)\n", | |
"Epoch: 5 | time in 0 minutes, 0 seconds\n", | |
"\tLoss: 0.0000(train)\t|\tAcc: 100.0%(train)\n", | |
"\tLoss: 0.0000(valid)\t|\tAcc: 100.0%(valid)\n" | |
], | |
"name": "stdout" | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment