-
-
Save airalcorn2/50ec06517ce96ecc143503e21fa6cb91 to your computer and use it in GitHub Desktop.
# Inspired by: https://towardsdatascience.com/the-one-pytorch-trick-which-you-should-know-2d5e9c1da2ca. | |
# Monkey patching idea suggested by @kklemon here: | |
# https://gist.github.com/airalcorn2/50ec06517ce96ecc143503e21fa6cb91?permalink_comment_id=4407423#gistcomment-4407423. | |
import torch | |
from torch import nn | |
def patch_attention(m): | |
forward_orig = m.forward | |
def wrap(*args, **kwargs): | |
kwargs["need_weights"] = True | |
kwargs["average_attn_weights"] = False | |
return forward_orig(*args, **kwargs) | |
m.forward = wrap | |
class SaveOutput: | |
def __init__(self): | |
self.outputs = [] | |
def __call__(self, module, module_in, module_out): | |
self.outputs.append(module_out[1]) | |
def clear(self): | |
self.outputs = [] | |
d_model = 512 | |
nhead = 8 | |
dim_feedforward = 2048 | |
dropout = 0.0 | |
num_layers = 6 | |
encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout) | |
transformer = nn.TransformerEncoder(encoder_layer, num_layers) | |
transformer.eval() | |
save_output = SaveOutput() | |
patch_attention(transformer.layers[-1].self_attn) | |
hook_handle = transformer.layers[-1].self_attn.register_forward_hook(save_output) | |
seq_len = 20 | |
X = torch.rand(seq_len, 1, d_model) | |
with torch.no_grad(): | |
out = transformer(X) | |
print(save_output.outputs[0][0]) |
Thanks, @kklemon. I've incorporated your monkey patching suggestion into the gist (with credit).
If I set batch_first=True
and use X = torch.rand(1, seq_len, d_model)
, it returns []
for the attention weights, seems like patch_attention
does not work. Do you know how to solve this problem? Thanks :)
@77komorebi - as you can see here, using batch_first=True
leads to the TransformerEncoderLayer
layer calling torch._transformer_encoder_layer_fwd
. As a result, the MultiheadAttention
layer is never called, and so the forward hook is never activated (the code for how hooks are used in Module
s can be found here).
@airalcorn2 Thanks!
Hi all, thanks for this script it is very helpful. When I run it as is I run out of GPU memory. But when I run it like this
patch_attention(model.transformer_layer.self_attn)
save_output = SaveOutput()
hook_handle = model.transformer_layer.self_attn.register_forward_hook(save_output)
I do not. Is extracting the weights from the TransformerEncoderLayer not as useful as extracting them directly from the layers? Thanks in advance
The attention module can be easily patched to return attention weights. This will also work flawlessly with the rest of the Transformer implementation, as it simply disregards the output anyway.