Skip to content

Instantly share code, notes, and snippets.

@thuwyh
Last active June 2, 2020 01:13
Show Gist options
  • Select an option

  • Save thuwyh/371c1448e01c1c5a2718ec01b3987cbc to your computer and use it in GitHub Desktop.

Select an option

Save thuwyh/371c1448e01c1c5a2718ec01b3987cbc to your computer and use it in GitHub Desktop.
fgm
# FGM class
class FGM():
def __init__(self, model):
self.model = model
self.backup = {}
def attack(self, epsilon=1., emb_name='word_embeddings'):
for name, param in self.model.named_parameters():
if param.requires_grad and emb_name in name:
self.backup[name] = param.data.clone()
norm = torch.norm(param.grad)
if norm != 0:
r_at = epsilon * param.grad / norm
param.data.add_(r_at)
def restore(self, emb_name='word_embeddings'):
for name, param in self.model.named_parameters():
if param.requires_grad and emb_name in name:
assert name in self.backup
param.data = self.backup[name]
self.backup = {}
# in training loop
fgm = FGM(model)
for ...
# normal
output = model(input)
loss = loss_fn(output, label)
loss.backward()
# attack
fgm.attack()
output = model(input)
loss = loss_fn(output, label)
loss.backward()
# restore
fgm.restore()
# optimization
optimizer.step()
optimizer.zero_grad()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment