-
-
Save airalcorn2/50ec06517ce96ecc143503e21fa6cb91 to your computer and use it in GitHub Desktop.
# Inspired by: https://towardsdatascience.com/the-one-pytorch-trick-which-you-should-know-2d5e9c1da2ca. | |
# Monkey patching idea suggested by @kklemon here: | |
# https://gist.github.com/airalcorn2/50ec06517ce96ecc143503e21fa6cb91?permalink_comment_id=4407423#gistcomment-4407423. | |
import torch | |
from torch import nn | |
def patch_attention(m): | |
forward_orig = m.forward | |
def wrap(*args, **kwargs): | |
kwargs["need_weights"] = True | |
kwargs["average_attn_weights"] = False | |
return forward_orig(*args, **kwargs) | |
m.forward = wrap | |
class SaveOutput: | |
def __init__(self): | |
self.outputs = [] | |
def __call__(self, module, module_in, module_out): | |
self.outputs.append(module_out[1]) | |
def clear(self): | |
self.outputs = [] | |
d_model = 512 | |
nhead = 8 | |
dim_feedforward = 2048 | |
dropout = 0.0 | |
num_layers = 6 | |
encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout) | |
transformer = nn.TransformerEncoder(encoder_layer, num_layers) | |
transformer.eval() | |
save_output = SaveOutput() | |
patch_attention(transformer.layers[-1].self_attn) | |
hook_handle = transformer.layers[-1].self_attn.register_forward_hook(save_output) | |
seq_len = 20 | |
X = torch.rand(seq_len, 1, d_model) | |
with torch.no_grad(): | |
out = transformer(X) | |
print(save_output.outputs[0][0]) |
The attention module can be easily patched to return attention weights. This will also work flawlessly with the rest of the Transformer implementation, as it simply disregards the output anyway.
...
def patch_attention(m):
forward_orig = m.forward
def wrap(*args, **kwargs):
kwargs['need_weights'] = True
kwargs['average_attn_weights'] = False
return forward_orig(*args, **kwargs)
m.forward = wrap
for module in transformer.modules():
if isinstance(module, nn.MultiheadAttention):
utils.patch_attention(module)
module.register_forward_hook(save_output)
Thanks, @kklemon. I've incorporated your monkey patching suggestion into the gist (with credit).
If I set batch_first=True
and use X = torch.rand(1, seq_len, d_model)
, it returns []
for the attention weights, seems like patch_attention
does not work. Do you know how to solve this problem? Thanks :)
@77komorebi - as you can see here, using batch_first=True
leads to the TransformerEncoderLayer
layer calling torch._transformer_encoder_layer_fwd
. As a result, the MultiheadAttention
layer is never called, and so the forward hook is never activated (the code for how hooks are used in Module
s can be found here).
@airalcorn2 Thanks!
Hi all, thanks for this script it is very helpful. When I run it as is I run out of GPU memory. But when I run it like this
patch_attention(model.transformer_layer.self_attn)
save_output = SaveOutput()
hook_handle = model.transformer_layer.self_attn.register_forward_hook(save_output)
I do not. Is extracting the weights from the TransformerEncoderLayer not as useful as extracting them directly from the layers? Thanks in advance
@PaulOppelt - it looks like the implementations for the various Transformer layers have changed a lot in the nearly two years since I posted this. The easiest thing to do (which isn't particularly easy!) might be to just write custom
TransformerEncoder
andTransformerEncoderLayer
classes where you setneed_weights=True
in theself.self_attn
call in_sa_block
.