Last active
July 15, 2024 19:38
-
-
Save airalcorn2/50ec06517ce96ecc143503e21fa6cb91 to your computer and use it in GitHub Desktop.
A simple script for extracting the attention weights from a PyTorch Transformer.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Inspired by: https://towardsdatascience.com/the-one-pytorch-trick-which-you-should-know-2d5e9c1da2ca. | |
# Monkey patching idea suggested by @kklemon here: | |
# https://gist.github.com/airalcorn2/50ec06517ce96ecc143503e21fa6cb91?permalink_comment_id=4407423#gistcomment-4407423. | |
import torch | |
from torch import nn | |
def patch_attention(m): | |
forward_orig = m.forward | |
def wrap(*args, **kwargs): | |
kwargs["need_weights"] = True | |
kwargs["average_attn_weights"] = False | |
return forward_orig(*args, **kwargs) | |
m.forward = wrap | |
class SaveOutput: | |
def __init__(self): | |
self.outputs = [] | |
def __call__(self, module, module_in, module_out): | |
self.outputs.append(module_out[1]) | |
def clear(self): | |
self.outputs = [] | |
d_model = 512 | |
nhead = 8 | |
dim_feedforward = 2048 | |
dropout = 0.0 | |
num_layers = 6 | |
encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout) | |
transformer = nn.TransformerEncoder(encoder_layer, num_layers) | |
transformer.eval() | |
save_output = SaveOutput() | |
patch_attention(transformer.layers[-1].self_attn) | |
hook_handle = transformer.layers[-1].self_attn.register_forward_hook(save_output) | |
seq_len = 20 | |
X = torch.rand(seq_len, 1, d_model) | |
with torch.no_grad(): | |
out = transformer(X) | |
print(save_output.outputs[0][0]) |
@airalcorn2 Thanks!
Hi all, thanks for this script it is very helpful. When I run it as is I run out of GPU memory. But when I run it like this
patch_attention(model.transformer_layer.self_attn)
save_output = SaveOutput()
hook_handle = model.transformer_layer.self_attn.register_forward_hook(save_output)
I do not. Is extracting the weights from the TransformerEncoderLayer not as useful as extracting them directly from the layers? Thanks in advance
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@77komorebi - as you can see here, using
batch_first=True
leads to theTransformerEncoderLayer
layer callingtorch._transformer_encoder_layer_fwd
. As a result, theMultiheadAttention
layer is never called, and so the forward hook is never activated (the code for how hooks are used inModule
s can be found here).