Skip to content

Instantly share code, notes, and snippets.

@jgoodie
Created January 19, 2025 19:42
Show Gist options
  • Save jgoodie/5f16baae3831ae56a4fc21ffa0080474 to your computer and use it in GitHub Desktop.
Save jgoodie/5f16baae3831ae56a4fc21ffa0080474 to your computer and use it in GitHub Desktop.
visualize the attention weights
import torch
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModel
# Load pre-trained tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", clean_up_tokenization_spaces=False)
model = AutoModel.from_pretrained("bert-base-uncased", output_attentions=True)
# Sentence
sentence = "the otter swam across the river to the other bank"
# Create sentence embeddings and attention weights
inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True)
embeddings = model.embeddings.word_embeddings(inputs['input_ids'])
outputs = model(inputs_embeds=embeddings)
attention = outputs.attentions
attention_matrix = attention[0][0][0].detach().numpy()
labels = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
sns.heatmap(attention_matrix, xticklabels=labels, yticklabels=labels, cmap="plasma")
plt.title("Attention Weights")
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment