Skip to content

Instantly share code, notes, and snippets.

@seanbenhur
Created February 24, 2021 16:17
Show Gist options
  • Save seanbenhur/786239a47685e9ba4768f7701e09bdf7 to your computer and use it in GitHub Desktop.
Save seanbenhur/786239a47685e9ba4768f7701e09bdf7 to your computer and use it in GitHub Desktop.
from sam import SAM
...
model = YourModel()
base_optimizer = torch.optim.SGD # define an optimizer for the "sharpness-aware" update
optimizer = SAM(model.parameters(), base_optimizer, lr=0.1, momentum=0.9)
...
for input, output in data:
# first forward-backward pass
loss = loss_function(output, model(input)) # use this loss for any training statistics
loss.backward()
optimizer.first_step(zero_grad=True)
# second forward-backward pass
loss_function(output, model(input)).backward() # make sure to do a full forward pass
optimizer.second_step(zero_grad=True)
...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment