Last active
July 21, 2022 19:13
-
-
Save jenkspt/3a09cc150ab531781c6084c166047639 to your computer and use it in GitHub Desktop.
Demonstrate fix and parity of CLIP AttentionPool2d
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This gist demonstrates the equivalence between the existing CLIP `AttentionPool2d` | |
and the proposed `AttentionPool2dFix`, which only computes attention where needed. | |
""" | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
class AttentionPool2d(nn.Module): | |
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): | |
super().__init__() | |
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) | |
self.k_proj = nn.Linear(embed_dim, embed_dim) | |
self.q_proj = nn.Linear(embed_dim, embed_dim) | |
self.v_proj = nn.Linear(embed_dim, embed_dim) | |
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) | |
self.num_heads = num_heads | |
def forward(self, x): | |
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC | |
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC | |
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC | |
x, _ = F.multi_head_attention_forward( | |
query=x, key=x, value=x, | |
embed_dim_to_check=x.shape[-1], | |
num_heads=self.num_heads, | |
q_proj_weight=self.q_proj.weight, | |
k_proj_weight=self.k_proj.weight, | |
v_proj_weight=self.v_proj.weight, | |
in_proj_weight=None, | |
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), | |
bias_k=None, | |
bias_v=None, | |
add_zero_attn=False, | |
dropout_p=0, | |
out_proj_weight=self.c_proj.weight, | |
out_proj_bias=self.c_proj.bias, | |
use_separate_proj_weight=True, | |
training=self.training, | |
need_weights=False | |
) | |
# x has shape [(HW+1), B, C] | |
print("X shape after attention:", x.shape) | |
return x[0] | |
class AttentionPool2dFix(nn.Module): | |
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): | |
super().__init__() | |
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) | |
self.k_proj = nn.Linear(embed_dim, embed_dim) | |
self.q_proj = nn.Linear(embed_dim, embed_dim) | |
self.v_proj = nn.Linear(embed_dim, embed_dim) | |
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) | |
self.num_heads = num_heads | |
def forward(self, x): | |
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC | |
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC | |
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC | |
x, _ = F.multi_head_attention_forward( | |
query=x[:1], key=x, value=x, | |
embed_dim_to_check=x.shape[-1], | |
num_heads=self.num_heads, | |
q_proj_weight=self.q_proj.weight, | |
k_proj_weight=self.k_proj.weight, | |
v_proj_weight=self.v_proj.weight, | |
in_proj_weight=None, | |
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), | |
bias_k=None, | |
bias_v=None, | |
add_zero_attn=False, | |
dropout_p=0, | |
out_proj_weight=self.c_proj.weight, | |
out_proj_bias=self.c_proj.bias, | |
use_separate_proj_weight=True, | |
training=self.training, | |
need_weights=False | |
) | |
return x.squeeze(0) | |
if __name__ == "__main__": | |
batch_dim = 5 | |
spacial_dim = 8 | |
embed_dim = 16 | |
num_heads = 4 | |
output_dim = 2 | |
x = torch.randn(batch_dim, embed_dim, spacial_dim, spacial_dim) | |
pool1 = AttentionPool2d(spacial_dim, embed_dim, num_heads, output_dim) | |
y1 = pool1(x) | |
assert y1.shape == (batch_dim, output_dim) | |
pool2 = AttentionPool2dFix(spacial_dim, embed_dim, num_heads, output_dim) | |
# Make sure parameter state is the same | |
pool2.load_state_dict(pool1.state_dict()) | |
y2 = pool2(x) | |
assert y2.shape == (batch_dim, output_dim) | |
assert torch.allclose(y1, y2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment