Skip to content

Instantly share code, notes, and snippets.

@stas00
Last active March 24, 2020 15:55
Show Gist options
  • Save stas00/0ba5d30f0109967324f122bfcc8b52f5 to your computer and use it in GitHub Desktop.
Save stas00/0ba5d30f0109967324f122bfcc8b52f5 to your computer and use it in GitHub Desktop.
bert training loop w/ validation loss reporting (and more compact)
# drop in replacement for the training loop in https://mccormickml.com/2019/07/22/BERT-fine-tuning/
# ---- cell 1 ----
import random
# This training code is based on the `run_glue.py` script here:
# https://github.com/huggingface/transformers/blob/5bfcd0485ece086ebcbed2d008813037968a9e58/examples/run_glue.py#L128
# Set the seed value all over the place to make this reproducible.
seed_val = 42
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)
# Store the average loss after each epoch so we can plot them.
train_loss_values, valid_loss_values = [], []
print(f"epoch | vald acc | trn loss | vld loss | time")
total_start_time = time.perf_counter()
# For each epoch...
for epoch in range(1, epochs+1):
start_time = time.perf_counter()
### Training
train_total_loss = 0
_=model.train()
for batch in train_dataloader:
batch = tuple(t.to(device) for t in batch)
b_input_ids, b_input_mask, b_labels = batch
model.zero_grad()
outputs = model(b_input_ids,
token_type_ids=None,
attention_mask=b_input_mask,
labels=b_labels)
loss = outputs[0]
train_total_loss += loss.item()
loss.backward()
_=torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
avg_train_loss = train_total_loss / len(train_dataloader)
train_loss_values.append(avg_train_loss)
### Validation
valid_total_loss = 0
num_correct = 0
valid_losses = []
valid_total_len = 0
_=model.eval()
for batch in validation_dataloader:
valid_total_len += batch[0].shape[0]
batch = tuple(t.to(device) for t in batch)
b_input_ids, b_input_mask, b_labels = batch
with torch.no_grad():
outputs = model(b_input_ids,
token_type_ids=None,
attention_mask=b_input_mask,
labels=b_labels)
loss, output = outputs
valid_total_loss += loss.item()
pred = output.argmax(1, keepdim=True).float()
correct_tensor = pred.eq(b_labels.float().view_as(pred))
correct = np.squeeze(correct_tensor.cpu().numpy())
num_correct += np.sum(correct)
epoch_time = time.perf_counter() - start_time
n_batches = len(validation_dataloader)
valid_acc = num_correct / valid_total_len
avg_valid_loss = valid_total_loss / len(validation_dataloader)
valid_loss_values.append(avg_valid_loss)
print(f"{epoch:5d} | {valid_acc:.5f} | {avg_train_loss:.5f} | {avg_valid_loss:.5f} | {time.strftime('%H:%M:%S', time.gmtime(epoch_time))}")
total_time = time.perf_counter()-total_start_time
print(f"Total runtime: { time.strftime('%H:%M:%S', time.gmtime(total_time)) } secs")
# ---- cell 2 ----
# plot both losses
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
# Use plot styling from seaborn.
sns.set(style='darkgrid')
# Increase the plot size and font size.
sns.set(font_scale=1.5)
plt.rcParams["figure.figsize"] = (12,6)
plt.plot(train_loss_values, 'b-o', label="train")
plt.plot(valid_loss_values, 'g-o', label="valid")
# Label the plot.
plt.title("Training/validation loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()
@gogokre
Copy link

gogokre commented Mar 15, 2020

Hello! Thank you very much for sharing. :)
I have a question.
How do I get the Precision and Recall scores with the code you shared?

Finally, I want to draw a Precision-Recall Curve.

@stas00
Copy link
Author

stas00 commented Mar 15, 2020

How do I get the Precision and Recall scores with the code you shared?

since you have predictions and targets, you just use sklearn.metrics, so edit the above code to inject:

# https://scikit-learn.org/stable/modules/model_evaluation.html
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
        [...]
        targets = []
        preds   = []
        for batch in validation_dataloader:
            [...]
            pred = output.argmax(1, keepdim=True).float()
            preds.append(pred.cpu().numpy())
            targets.append(b_labels.cpu().numpy())
        [...]
        preds   = np.concatenate(preds)
        targets = np.concatenate(targets)
        valid_acc       = accuracy_score( preds, targets)
        valid_precision = precision_score(preds, targets, average='macro')
        valid_recall    = recall_score(   preds, targets, average='macro')
        valid_f1        = f1_score(       preds, targets, average='macro')

I haven't tested it, but it should be more or less correct.

Finally, I want to draw a Precision-Recall Curve.

Google

@gogokre
Copy link

gogokre commented Mar 16, 2020

How do I get the Precision and Recall scores with the code you shared?

since you have predictions and targets, you just use sklearn.metrics, so edit the above code to inject:

# https://scikit-learn.org/stable/modules/model_evaluation.html
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
        [...]
        targets = []
        preds   = []
        for batch in validation_dataloader:
            [...]
            pred = output.argmax(1, keepdim=True).float()
            preds.append(pred.cpu().numpy())
            targets.append(b_labels.cpu().numpy())
        [...]
        preds   = np.concatenate(preds)
        targets = np.concatenate(targets)
        valid_acc       = accuracy_score( preds, targets)
        valid_precision = precision_score(preds, targets, average='macro')
        valid_recall    = recall_score(   preds, targets, average='macro')
        valid_f1        = f1_score(       preds, targets, average='macro')

I haven't tested it, but it should be more or less correct.

Finally, I want to draw a Precision-Recall Curve.

Google

thank you for the reply. :)
Thanks, I solved it. Have a nice day!

@gogokre
Copy link

gogokre commented Mar 24, 2020

d

Sorry, can I ask you one more question?

I want to draw a continuous precision-recall curve like this, but what should I do? :)

@stas00
Copy link
Author

stas00 commented Mar 24, 2020

You will find learning to use the search engines to your great advantage. Try this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment