Skip to content

Instantly share code, notes, and snippets.

@ArthurZucker
Created October 5, 2024 12:10
Show Gist options
  • Save ArthurZucker/a6f6027c8c74cbc04ef67356bd3d8b14 to your computer and use it in GitHub Desktop.
Save ArthurZucker/a6f6027c8c74cbc04ef67356bd3d8b14 to your computer and use it in GitHub Desktop.
Split q k v in sam
from transformers.models.sam.modeling_sam import SamVisionAttention, SamModel, SamVisionLayer
from transformers import SamProcessor
import torch.nn as nn
import torch
from transformers.models.sam import modeling_sam
from PIL import Image
import requests
import matplotlib.pyplot as plt
import numpy as np
from transformers import pipeline
# Issue on https://github.com/huggingface/transformers/issues/33928
class SamVisionAttentionSplit(SamVisionAttention, nn.Module):
def __init__(self, config, window_size):
super().__init__(config, window_size)
del self.qkv
# Separate q, k, v projections
self.q = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
self.k = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
self.v = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias)
self._register_load_state_dict_pre_hook(self.split_q_k_v_load_hook)
def split_q_k_v_load_hook(self, state_dict, prefix, *args):
keys_to_delete = []
for key in list(state_dict.keys()):
if "qkv." in key:
# Split q, k, v from the combined projection
q, k, v = state_dict[key].chunk(3, dim=0)
# Replace with individual q, k, v projections
state_dict[key.replace("qkv.", "q.")] = q
state_dict[key.replace("qkv.", "k.")] = k
state_dict[key.replace("qkv.", "v.")] = v
# Mark the old qkv key for deletion
keys_to_delete.append(key)
# Remove old qkv keys
for key in keys_to_delete:
del state_dict[key]
def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
batch_size, height, width, _ = hidden_states.shape
qkv_shapes = (batch_size * self.num_attention_heads, height * width, -1)
query = self.q(hidden_states).reshape((batch_size, height * width,self.num_attention_heads, -1)).permute(0,2,1,3).reshape(qkv_shapes)
key = self.k(hidden_states).reshape((batch_size, height * width,self.num_attention_heads, -1)).permute(0,2,1,3).reshape(qkv_shapes)
value = self.v(hidden_states).reshape((batch_size, height * width,self.num_attention_heads, -1)).permute(0,2,1,3).reshape(qkv_shapes)
attn_weights = (query * self.scale) @ key.transpose(-2, -1)
if self.use_rel_pos:
attn_weights = self.add_decomposed_rel_pos(
attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
)
attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)
attn_output = self.proj(attn_output)
if output_attentions:
outputs = (attn_output, attn_weights)
else:
outputs = (attn_output, None)
return outputs
modeling_sam.SamVisionAttention = SamVisionAttentionSplit
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
generator = pipeline("mask-generation", device = 0, points_per_batch = 256, model = model, image_processor = processor.image_processor)
image_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
outputs = generator(image_url, points_per_batch = 256)
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
raw_image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB")
plt.imshow(np.array(raw_image))
ax = plt.gca()
for mask in outputs["masks"]:
show_mask(mask, ax=ax, random_color=True)
plt.axis("off")
plt.savefig("dummy_test")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment