Skip to content

Instantly share code, notes, and snippets.

@rwightman
Created November 21, 2023 18:39
Show Gist options
  • Save rwightman/dbb5a8222df173687d734ad5e257908b to your computer and use it in GitHub Desktop.
Save rwightman/dbb5a8222df173687d734ad5e257908b to your computer and use it in GitHub Desktop.
Extract attention maps from timm vits' with Torch FX
import torch
import timm
from torchvision.models.feature_extraction import get_graph_node_names
timm.layers.set_fused_attn(False) # disable F.sdpa so softmax node is exposed
mm = timm.create_model('vit_medium_patch16_gap_256.sw_in12k_ft_in1k', pretrained=True)
softmax_nodes = [n for n in get_graph_node_names(mm)[0] if 'softmax' in n]
ff = timm.models.create_feature_extractor(mm, softmax_nodes)
with torch.no_grad():
output = ff(torch.randn(2, 3, 256, 256))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment