Skip to content

Instantly share code, notes, and snippets.

@rohan-paul
Created May 28, 2023 18:29
Show Gist options
  • Save rohan-paul/33f43e970258c86d6c0c2a10185e863b to your computer and use it in GitHub Desktop.
Save rohan-paul/33f43e970258c86d6c0c2a10185e863b to your computer and use it in GitHub Desktop.
def train_one_epoch(model, optimizer, scheduler, dataloader, device, epoch):
model.train()
dataset_size = 0
running_loss = 0.0
bar = tqdm(enumerate(dataloader), total=len(dataloader))
""" The total argument in tqdm specifies the total number of iterations (or updates to the progress bar). In this case, len(dataloader) is used as the total which is the total number of batches in the dataloader. """
for step, data in bar:
ids = data['input_ids'].to(device, dtype = torch.long)
mask = data['attention_mask'].to(device, dtype = torch.long)
targets = data['target'].to(device, dtype=torch.long)
batch_size = ids.size(0)
outputs = model(ids, mask)
loss = criterion(outputs, targets)
""" Gradient accumulation is happening in below loss calculation and .backward()
Gradient accumulation involves accumulating gradients over multiple mini-batches before performing a weight update step.
And that Gradient accumulation over several forward passes is achieved through the following two lines in the train_one_epoch() function: """
loss = loss / CONFIG['n_accumulate']
""" The `backward()` call on the next line calculates the gradients of the loss with respect to model parameters. Importantly, these gradients are not removed after the computation, they remain stored in the .grad attributes of the model parameters.
BUT Instead of updating the parameters right away, add the computed gradients to the accumulated gradients. This step is repeated for a specified number of mini-batches. """
loss.backward()
# After accumulating gradients over the desired number of mini-batches, perform the weight update step.
if (step + 1) % CONFIG['n_accumulate'] == 0:
# performs the actual parameter update using the accumulated gradients.
optimizer.step()
# clears out all the accumulated gradients from the parameters to prepare for the next round of accumulation. This happens after every CONFIG['n_accumulate'] batches, as checked by the if condition.
optimizer.zero_grad()
if scheduler is not None:
scheduler.step()
running_loss += (loss.item() * batch_size)
dataset_size += batch_size
epoch_loss = running_loss / dataset_size
bar.set_postfix(Epoch=epoch, Train_Loss=epoch_loss,
LR=optimizer.param_groups[0]['lr'])
gc.collect()
return epoch_loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment