Skip to content

Instantly share code, notes, and snippets.

@vikhyat
Created August 3, 2024 09:11
Show Gist options
  • Save vikhyat/7bd81855d0d4c285cd53e97417cbe378 to your computer and use it in GitHub Desktop.
Save vikhyat/7bd81855d0d4c285cd53e97417cbe378 to your computer and use it in GitHub Desktop.
from typing import List, Dict, Tuple
import torch
import torch.nn as nn
from einops import rearrange
from transformers import pipeline, Owlv2Processor, Owlv2ForObjectDetection
from transformers.image_transforms import center_to_corners_format
from flash_attn.modules.mha import FlashSelfAttention
class Attention(nn.Module):
def __init__(self, old_attn):
super().__init__()
self.embed_dim = old_attn.embed_dim
self.num_heads = old_attn.num_heads
self.head_dim = old_attn.head_dim
self.q_proj = old_attn.q_proj
self.k_proj = old_attn.k_proj
self.v_proj = old_attn.v_proj
self.out_proj = old_attn.out_proj
self.attn = FlashSelfAttention()
def forward(
self,
hidden_states,
attention_mask=None,
causal_attention_mask=None,
output_attentions=False,
):
qkv = torch.cat(
(
self.q_proj(hidden_states),
self.k_proj(hidden_states),
self.v_proj(hidden_states),
),
dim=-1,
)
qkv = rearrange(
qkv, "... (three h d) -> ... three h d", three=3, h=self.num_heads
)
attn_output = self.attn(qkv)
output = rearrange(attn_output, "... h d -> ... (h d)")
output = self.out_proj(output)
return output, None
model_id = "google/owlv2-large-patch14-ensemble"
processor = Owlv2Processor.from_pretrained(model_id)
model = Owlv2ForObjectDetection.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map={"": "cuda"},
)
for l in self.model.owlv2.vision_model.encoder.layers:
l.self_attn = Attention(l.self_attn)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment