-
-
Save thuwyh/371c1448e01c1c5a2718ec01b3987cbc to your computer and use it in GitHub Desktop.
fgm
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # FGM class | |
| class FGM(): | |
| def __init__(self, model): | |
| self.model = model | |
| self.backup = {} | |
| def attack(self, epsilon=1., emb_name='word_embeddings'): | |
| for name, param in self.model.named_parameters(): | |
| if param.requires_grad and emb_name in name: | |
| self.backup[name] = param.data.clone() | |
| norm = torch.norm(param.grad) | |
| if norm != 0: | |
| r_at = epsilon * param.grad / norm | |
| param.data.add_(r_at) | |
| def restore(self, emb_name='word_embeddings'): | |
| for name, param in self.model.named_parameters(): | |
| if param.requires_grad and emb_name in name: | |
| assert name in self.backup | |
| param.data = self.backup[name] | |
| self.backup = {} | |
| # in training loop | |
| fgm = FGM(model) | |
| for ... | |
| # normal | |
| output = model(input) | |
| loss = loss_fn(output, label) | |
| loss.backward() | |
| # attack | |
| fgm.attack() | |
| output = model(input) | |
| loss = loss_fn(output, label) | |
| loss.backward() | |
| # restore | |
| fgm.restore() | |
| # optimization | |
| optimizer.step() | |
| optimizer.zero_grad() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment