-
-
Save airalcorn2/50ec06517ce96ecc143503e21fa6cb91 to your computer and use it in GitHub Desktop.
# Inspired by: https://towardsdatascience.com/the-one-pytorch-trick-which-you-should-know-2d5e9c1da2ca. | |
# Monkey patching idea suggested by @kklemon here: | |
# https://gist.github.com/airalcorn2/50ec06517ce96ecc143503e21fa6cb91?permalink_comment_id=4407423#gistcomment-4407423. | |
import torch | |
from torch import nn | |
def patch_attention(m): | |
forward_orig = m.forward | |
def wrap(*args, **kwargs): | |
kwargs["need_weights"] = True | |
kwargs["average_attn_weights"] = False | |
return forward_orig(*args, **kwargs) | |
m.forward = wrap | |
class SaveOutput: | |
def __init__(self): | |
self.outputs = [] | |
def __call__(self, module, module_in, module_out): | |
self.outputs.append(module_out[1]) | |
def clear(self): | |
self.outputs = [] | |
d_model = 512 | |
nhead = 8 | |
dim_feedforward = 2048 | |
dropout = 0.0 | |
num_layers = 6 | |
encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout) | |
transformer = nn.TransformerEncoder(encoder_layer, num_layers) | |
transformer.eval() | |
save_output = SaveOutput() | |
patch_attention(transformer.layers[-1].self_attn) | |
hook_handle = transformer.layers[-1].self_attn.register_forward_hook(save_output) | |
seq_len = 20 | |
X = torch.rand(seq_len, 1, d_model) | |
with torch.no_grad(): | |
out = transformer(X) | |
print(save_output.outputs[0][0]) |
If I set batch_first=True
and use X = torch.rand(1, seq_len, d_model)
, it returns []
for the attention weights, seems like patch_attention
does not work. Do you know how to solve this problem? Thanks :)
@77komorebi - as you can see here, using batch_first=True
leads to the TransformerEncoderLayer
layer calling torch._transformer_encoder_layer_fwd
. As a result, the MultiheadAttention
layer is never called, and so the forward hook is never activated (the code for how hooks are used in Module
s can be found here).
@airalcorn2 Thanks!
Hi all, thanks for this script it is very helpful. When I run it as is I run out of GPU memory. But when I run it like this
patch_attention(model.transformer_layer.self_attn)
save_output = SaveOutput()
hook_handle = model.transformer_layer.self_attn.register_forward_hook(save_output)
I do not. Is extracting the weights from the TransformerEncoderLayer not as useful as extracting them directly from the layers? Thanks in advance
Thanks, @kklemon. I've incorporated your monkey patching suggestion into the gist (with credit).