Created
October 5, 2024 12:10
-
-
Save ArthurZucker/a6f6027c8c74cbc04ef67356bd3d8b14 to your computer and use it in GitHub Desktop.
Split q k v in sam
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers.models.sam.modeling_sam import SamVisionAttention, SamModel, SamVisionLayer | |
from transformers import SamProcessor | |
import torch.nn as nn | |
import torch | |
from transformers.models.sam import modeling_sam | |
from PIL import Image | |
import requests | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from transformers import pipeline | |
# Issue on https://github.com/huggingface/transformers/issues/33928 | |
class SamVisionAttentionSplit(SamVisionAttention, nn.Module): | |
def __init__(self, config, window_size): | |
super().__init__(config, window_size) | |
del self.qkv | |
# Separate q, k, v projections | |
self.q = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias) | |
self.k = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias) | |
self.v = nn.Linear(config.hidden_size, config.hidden_size, bias=config.qkv_bias) | |
self._register_load_state_dict_pre_hook(self.split_q_k_v_load_hook) | |
def split_q_k_v_load_hook(self, state_dict, prefix, *args): | |
keys_to_delete = [] | |
for key in list(state_dict.keys()): | |
if "qkv." in key: | |
# Split q, k, v from the combined projection | |
q, k, v = state_dict[key].chunk(3, dim=0) | |
# Replace with individual q, k, v projections | |
state_dict[key.replace("qkv.", "q.")] = q | |
state_dict[key.replace("qkv.", "k.")] = k | |
state_dict[key.replace("qkv.", "v.")] = v | |
# Mark the old qkv key for deletion | |
keys_to_delete.append(key) | |
# Remove old qkv keys | |
for key in keys_to_delete: | |
del state_dict[key] | |
def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor: | |
batch_size, height, width, _ = hidden_states.shape | |
qkv_shapes = (batch_size * self.num_attention_heads, height * width, -1) | |
query = self.q(hidden_states).reshape((batch_size, height * width,self.num_attention_heads, -1)).permute(0,2,1,3).reshape(qkv_shapes) | |
key = self.k(hidden_states).reshape((batch_size, height * width,self.num_attention_heads, -1)).permute(0,2,1,3).reshape(qkv_shapes) | |
value = self.v(hidden_states).reshape((batch_size, height * width,self.num_attention_heads, -1)).permute(0,2,1,3).reshape(qkv_shapes) | |
attn_weights = (query * self.scale) @ key.transpose(-2, -1) | |
if self.use_rel_pos: | |
attn_weights = self.add_decomposed_rel_pos( | |
attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width) | |
) | |
attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype) | |
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) | |
attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1) | |
attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1) | |
attn_output = self.proj(attn_output) | |
if output_attentions: | |
outputs = (attn_output, attn_weights) | |
else: | |
outputs = (attn_output, None) | |
return outputs | |
modeling_sam.SamVisionAttention = SamVisionAttentionSplit | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = SamModel.from_pretrained("facebook/sam-vit-base").to(device) | |
processor = SamProcessor.from_pretrained("facebook/sam-vit-base") | |
generator = pipeline("mask-generation", device = 0, points_per_batch = 256, model = model, image_processor = processor.image_processor) | |
image_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png" | |
outputs = generator(image_url, points_per_batch = 256) | |
def show_mask(mask, ax, random_color=False): | |
if random_color: | |
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) | |
else: | |
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6]) | |
h, w = mask.shape[-2:] | |
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) | |
ax.imshow(mask_image) | |
raw_image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB") | |
plt.imshow(np.array(raw_image)) | |
ax = plt.gca() | |
for mask in outputs["masks"]: | |
show_mask(mask, ax=ax, random_color=True) | |
plt.axis("off") | |
plt.savefig("dummy_test") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment