Skip to content

Instantly share code, notes, and snippets.

@airalcorn2
Last active July 15, 2024 19:38
Show Gist options
  • Save airalcorn2/50ec06517ce96ecc143503e21fa6cb91 to your computer and use it in GitHub Desktop.
Save airalcorn2/50ec06517ce96ecc143503e21fa6cb91 to your computer and use it in GitHub Desktop.
A simple script for extracting the attention weights from a PyTorch Transformer.
# Inspired by: https://towardsdatascience.com/the-one-pytorch-trick-which-you-should-know-2d5e9c1da2ca.
# Monkey patching idea suggested by @kklemon here:
# https://gist.github.com/airalcorn2/50ec06517ce96ecc143503e21fa6cb91?permalink_comment_id=4407423#gistcomment-4407423.
import torch
from torch import nn
def patch_attention(m):
forward_orig = m.forward
def wrap(*args, **kwargs):
kwargs["need_weights"] = True
kwargs["average_attn_weights"] = False
return forward_orig(*args, **kwargs)
m.forward = wrap
class SaveOutput:
def __init__(self):
self.outputs = []
def __call__(self, module, module_in, module_out):
self.outputs.append(module_out[1])
def clear(self):
self.outputs = []
d_model = 512
nhead = 8
dim_feedforward = 2048
dropout = 0.0
num_layers = 6
encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout)
transformer = nn.TransformerEncoder(encoder_layer, num_layers)
transformer.eval()
save_output = SaveOutput()
patch_attention(transformer.layers[-1].self_attn)
hook_handle = transformer.layers[-1].self_attn.register_forward_hook(save_output)
seq_len = 20
X = torch.rand(seq_len, 1, d_model)
with torch.no_grad():
out = transformer(X)
print(save_output.outputs[0][0])
@kklemon
Copy link

kklemon commented Dec 19, 2022

The attention module can be easily patched to return attention weights. This will also work flawlessly with the rest of the Transformer implementation, as it simply disregards the output anyway.

...

def patch_attention(m):
    forward_orig = m.forward

    def wrap(*args, **kwargs):
        kwargs['need_weights'] = True
        kwargs['average_attn_weights'] = False

        return forward_orig(*args, **kwargs)

    m.forward = wrap

for module in transformer.modules():
    if isinstance(module, nn.MultiheadAttention):
        utils.patch_attention(module)
        module.register_forward_hook(save_output)

@airalcorn2
Copy link
Author

Thanks, @kklemon. I've incorporated your monkey patching suggestion into the gist (with credit).

@77komorebi
Copy link

If I set batch_first=True and use X = torch.rand(1, seq_len, d_model), it returns [] for the attention weights, seems like patch_attention does not work. Do you know how to solve this problem? Thanks :)

@airalcorn2
Copy link
Author

@77komorebi - as you can see here, using batch_first=True leads to the TransformerEncoderLayer layer calling torch._transformer_encoder_layer_fwd. As a result, the MultiheadAttention layer is never called, and so the forward hook is never activated (the code for how hooks are used in Modules can be found here).

@77komorebi
Copy link

@airalcorn2 Thanks!

@Adam-Thiesen
Copy link

Hi all, thanks for this script it is very helpful. When I run it as is I run out of GPU memory. But when I run it like this

patch_attention(model.transformer_layer.self_attn)
save_output = SaveOutput()
hook_handle = model.transformer_layer.self_attn.register_forward_hook(save_output)

I do not. Is extracting the weights from the TransformerEncoderLayer not as useful as extracting them directly from the layers? Thanks in advance

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment