Last active
March 24, 2020 15:55
-
-
Save stas00/0ba5d30f0109967324f122bfcc8b52f5 to your computer and use it in GitHub Desktop.
bert training loop w/ validation loss reporting (and more compact)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# drop in replacement for the training loop in https://mccormickml.com/2019/07/22/BERT-fine-tuning/ | |
# ---- cell 1 ---- | |
import random | |
# This training code is based on the `run_glue.py` script here: | |
# https://github.com/huggingface/transformers/blob/5bfcd0485ece086ebcbed2d008813037968a9e58/examples/run_glue.py#L128 | |
# Set the seed value all over the place to make this reproducible. | |
seed_val = 42 | |
random.seed(seed_val) | |
np.random.seed(seed_val) | |
torch.manual_seed(seed_val) | |
torch.cuda.manual_seed_all(seed_val) | |
# Store the average loss after each epoch so we can plot them. | |
train_loss_values, valid_loss_values = [], [] | |
print(f"epoch | vald acc | trn loss | vld loss | time") | |
total_start_time = time.perf_counter() | |
# For each epoch... | |
for epoch in range(1, epochs+1): | |
start_time = time.perf_counter() | |
### Training | |
train_total_loss = 0 | |
_=model.train() | |
for batch in train_dataloader: | |
batch = tuple(t.to(device) for t in batch) | |
b_input_ids, b_input_mask, b_labels = batch | |
model.zero_grad() | |
outputs = model(b_input_ids, | |
token_type_ids=None, | |
attention_mask=b_input_mask, | |
labels=b_labels) | |
loss = outputs[0] | |
train_total_loss += loss.item() | |
loss.backward() | |
_=torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
optimizer.step() | |
scheduler.step() | |
avg_train_loss = train_total_loss / len(train_dataloader) | |
train_loss_values.append(avg_train_loss) | |
### Validation | |
valid_total_loss = 0 | |
num_correct = 0 | |
valid_losses = [] | |
valid_total_len = 0 | |
_=model.eval() | |
for batch in validation_dataloader: | |
valid_total_len += batch[0].shape[0] | |
batch = tuple(t.to(device) for t in batch) | |
b_input_ids, b_input_mask, b_labels = batch | |
with torch.no_grad(): | |
outputs = model(b_input_ids, | |
token_type_ids=None, | |
attention_mask=b_input_mask, | |
labels=b_labels) | |
loss, output = outputs | |
valid_total_loss += loss.item() | |
pred = output.argmax(1, keepdim=True).float() | |
correct_tensor = pred.eq(b_labels.float().view_as(pred)) | |
correct = np.squeeze(correct_tensor.cpu().numpy()) | |
num_correct += np.sum(correct) | |
epoch_time = time.perf_counter() - start_time | |
n_batches = len(validation_dataloader) | |
valid_acc = num_correct / valid_total_len | |
avg_valid_loss = valid_total_loss / len(validation_dataloader) | |
valid_loss_values.append(avg_valid_loss) | |
print(f"{epoch:5d} | {valid_acc:.5f} | {avg_train_loss:.5f} | {avg_valid_loss:.5f} | {time.strftime('%H:%M:%S', time.gmtime(epoch_time))}") | |
total_time = time.perf_counter()-total_start_time | |
print(f"Total runtime: { time.strftime('%H:%M:%S', time.gmtime(total_time)) } secs") | |
# ---- cell 2 ---- | |
# plot both losses | |
import matplotlib.pyplot as plt | |
%matplotlib inline | |
import seaborn as sns | |
# Use plot styling from seaborn. | |
sns.set(style='darkgrid') | |
# Increase the plot size and font size. | |
sns.set(font_scale=1.5) | |
plt.rcParams["figure.figsize"] = (12,6) | |
plt.plot(train_loss_values, 'b-o', label="train") | |
plt.plot(valid_loss_values, 'g-o', label="valid") | |
# Label the plot. | |
plt.title("Training/validation loss") | |
plt.xlabel("Epoch") | |
plt.ylabel("Loss") | |
plt.legend() | |
plt.show() |
How do I get the Precision and Recall scores with the code you shared?
since you have predictions and targets, you just use sklearn.metrics, so edit the above code to inject:
# https://scikit-learn.org/stable/modules/model_evaluation.html
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
[...]
targets = []
preds = []
for batch in validation_dataloader:
[...]
pred = output.argmax(1, keepdim=True).float()
preds.append(pred.cpu().numpy())
targets.append(b_labels.cpu().numpy())
[...]
preds = np.concatenate(preds)
targets = np.concatenate(targets)
valid_acc = accuracy_score( preds, targets)
valid_precision = precision_score(preds, targets, average='macro')
valid_recall = recall_score( preds, targets, average='macro')
valid_f1 = f1_score( preds, targets, average='macro')
I haven't tested it, but it should be more or less correct.
Finally, I want to draw a Precision-Recall Curve.
How do I get the Precision and Recall scores with the code you shared?
since you have predictions and targets, you just use sklearn.metrics, so edit the above code to inject:
# https://scikit-learn.org/stable/modules/model_evaluation.html import numpy as np from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score [...] targets = [] preds = [] for batch in validation_dataloader: [...] pred = output.argmax(1, keepdim=True).float() preds.append(pred.cpu().numpy()) targets.append(b_labels.cpu().numpy()) [...] preds = np.concatenate(preds) targets = np.concatenate(targets) valid_acc = accuracy_score( preds, targets) valid_precision = precision_score(preds, targets, average='macro') valid_recall = recall_score( preds, targets, average='macro') valid_f1 = f1_score( preds, targets, average='macro')
I haven't tested it, but it should be more or less correct.
Finally, I want to draw a Precision-Recall Curve.
thank you for the reply. :)
Thanks, I solved it. Have a nice day!
You will find learning to use the search engines to your great advantage. Try this.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hello! Thank you very much for sharing. :)
I have a question.
How do I get the Precision and Recall scores with the code you shared?
Finally, I want to draw a Precision-Recall Curve.