Created
February 24, 2021 16:17
-
-
Save seanbenhur/786239a47685e9ba4768f7701e09bdf7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sam import SAM | |
... | |
model = YourModel() | |
base_optimizer = torch.optim.SGD # define an optimizer for the "sharpness-aware" update | |
optimizer = SAM(model.parameters(), base_optimizer, lr=0.1, momentum=0.9) | |
... | |
for input, output in data: | |
# first forward-backward pass | |
loss = loss_function(output, model(input)) # use this loss for any training statistics | |
loss.backward() | |
optimizer.first_step(zero_grad=True) | |
# second forward-backward pass | |
loss_function(output, model(input)).backward() # make sure to do a full forward pass | |
optimizer.second_step(zero_grad=True) | |
... |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment