Last active
November 23, 2024 20:53
-
-
Save thomwolf/ac7a7da6b1888c2eeac8ac8b9b05d3d3 to your computer and use it in GitHub Desktop.
PyTorch gradient accumulation training loop
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
model.zero_grad() # Reset gradients tensors | |
for i, (inputs, labels) in enumerate(training_set): | |
predictions = model(inputs) # Forward pass | |
loss = loss_function(predictions, labels) # Compute loss function | |
loss = loss / accumulation_steps # Normalize our loss (if averaged) | |
loss.backward() # Backward pass | |
if (i+1) % accumulation_steps == 0: # Wait for several backward steps | |
optimizer.step() # Now we can do an optimizer step | |
model.zero_grad() # Reset gradients tensors | |
if (i+1) % evaluation_steps == 0: # Evaluate the model when we... | |
evaluate_model() # ...have no gradients accumulated |
@thomwolf Thanks for the code. Just a little fix in the condition at line 7:
if (i+1) % accumulation_steps == 0:
This assumes that the number of batches is perfectly divisble by the accumulation steps. However, if there are, say, 10, batches, and the accumulation steps are 4, the last two batches would not make to the optimizer.step().
And even if you add an extra condition, you will still need to adjust the normalization denominator because there will be only 2 not 4 accumulation steps.
So the updated code would be:
n = len(training_set)
remainder_batches = n % accumulation_steps # calculate number of remainder batches
for i, (inputs, labels) in enumerate(training_set):
predictions = model(inputs)
loss = loss_function(predictions, labels)
remaining = n - i
# update the denominator if the remaining batches are leq number of remainder batches
denominator = remainder_batches if remaining <= remainder_batches else accumulation_steps
loss = loss / denominator
loss.backward()
if (i+1) % accumulation_steps == 0 or i == n - 1: # add condition for last iteration
optimizer.step()
model.zero_grad()
You can emulate the logic in a standalone script as follows:
def get_value():
return 5
n = 10
steps = 4
values = []
val = 0
remainder = n % steps
for i in range(n):
a = get_value()
remaining = n - i
if remaining <= remainder:
denom = n % steps
else:
denom = steps
val += a / denom
print(i, denom)
if (i + 1) % steps == 0 or i == n - 1:
values.append(val)
val = 0
print("update")
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@arquolo - okay thankyou for the clarification.