Created
June 12, 2024 20:57
-
-
Save yuanzhi-zhu/1fb3524be484bc80eb63558d5068540c to your computer and use it in GitHub Desktop.
self attention with RoPE
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This code was originally obtained from: | |
https://github.com/meta-llama/codellama/blob/main/llama/model.py | |
adapted from https://github.com/naver-ai/rope-vit/blob/main/self-attn/rope_self_attn.py | |
""" | |
import torch | |
import torch.nn as nn | |
from functools import partial | |
import einops | |
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'): | |
ATTENTION_MODE = 'flash' | |
else: | |
try: | |
import xformers | |
import xformers.ops | |
ATTENTION_MODE = 'xformers' | |
except: | |
ATTENTION_MODE = 'math' | |
def init_pos_xy(height: int, width: int): | |
# get the x and y coordinates of the grid for each patch | |
num_tokens = torch.arange(height * width, dtype=torch.float32) # total number of pixels | |
pos_x = (num_tokens // height).float() # x coordinate; row | |
pos_y = (num_tokens % width).float() # y coordinate; column | |
return pos_x, pos_y | |
def compute_axial_cis(dim: int, height: int, width: int, theta: float = 10_000): | |
"""Compute 2D axial RoPE. | |
inputs: | |
- dim: the head dimension of the token (dim//head_size) | |
- height: the width of the image // patch_size | |
- width: the height of the image // patch_size | |
- theta: the base of the exponential function | |
""" | |
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) | |
pos_x, pos_y = init_pos_xy(height, width) | |
freqs_x = torch.outer(pos_x, freqs) # for the first half dimension | |
freqs_y = torch.outer(pos_y, freqs) # for the second half dimension | |
freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x) | |
freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y) | |
return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1) | |
def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor): | |
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # (B, H, N, D) -> (B, H, N, D//2, 2) -> (B, H, N, D//2) | |
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # (B, H, N, D) -> (B, H, N, D//2, 2) -> (B, H, N, D//2) | |
assert freqs_cis.shape == (xq_.shape[-2], xq_.shape[-1]), f'freqs_cis.shape != (xq_.shape[-2], xq_.shape[-1]), {freqs_cis.shape} != {(xq_.shape[-2], xq_.shape[-1])}' | |
xq_out = torch.view_as_real(torch.einsum('bhnd,nd->bhnd', xq_, freqs_cis)).flatten(3) | |
xk_out = torch.view_as_real(torch.einsum('bhnd,nd->bhnd', xk_, freqs_cis)).flatten(3) | |
return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device) | |
# modified from https://github.com/baofff/U-ViT/blob/ce551708dc9cde9818d2af7d84dfadfeb7bd9034/libs/uvit.py#L66 | |
def calculate_attention(q, k, v, attn_drop_fn, num_heads=1, scale=1, mask=None, dropout=0., training=True): | |
if ATTENTION_MODE == 'flash': | |
q, k, v = q.float(), k.float(), v.float() # float32 more stable? | |
attn = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=dropout if training else 0) | |
attn = einops.rearrange(attn, 'B H N D -> B N (H D)') | |
elif ATTENTION_MODE == 'xformers': | |
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3) # (B, N, H, D) | |
attn = xformers.ops.memory_efficient_attention(q, k, v, p=dropout if training else 0) | |
attn = einops.rearrange(attn, 'B N H D -> B N (H D)', H=num_heads) | |
elif ATTENTION_MODE == 'math': | |
B, H, N, D = q.shape | |
attn = (q @ k.transpose(-2, -1)) * scale | |
if mask is not None: | |
attn = attn.masked_fill(mask == 0, float('-inf')) | |
attn = attn.softmax(dim=-1) | |
attn = attn_drop_fn(attn) | |
attn = (attn @ v).transpose(1, 2).reshape(B, N, H*D) | |
else: | |
raise NotImplemented | |
return attn | |
class Attention(nn.Module): | |
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py | |
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., **kwargs): | |
super().__init__() | |
self.dim = dim | |
self.num_heads = num_heads | |
head_dim = dim // num_heads | |
self.scale = qk_scale or head_dim ** -0.5 | |
self.qkv_proj = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
self.qkv_proj.num_heads = num_heads | |
self.attn_drop = attn_drop | |
self.attn_drop_fn = nn.Dropout(attn_drop) | |
self.proj = nn.Linear(dim, dim) | |
self.proj_drop = nn.Dropout(proj_drop) | |
def forward(self, x): | |
B, N, C = x.shape | |
qkv = self.qkv_proj(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (K, B, H, N, D) | |
q, k, v = qkv[0], qkv[1], qkv[2] # (B, H, N, D) | |
x = calculate_attention(q, k, v, self.attn_drop_fn, num_heads=self.num_heads, scale=self.scale, dropout=self.attn_drop, training=self.training) | |
x = self.proj(x) | |
x = self.proj_drop(x) | |
return x | |
class RoPEAttention(Attention): | |
"""Multi-head Attention block with rotary position embeddings.""" | |
def __init__(self, *args, rope_theta=10_000, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.compute_cis = partial(compute_axial_cis, dim=self.dim // self.num_heads, theta=rope_theta) | |
freqs_cis = self.compute_cis(height = kwargs['height'], width = kwargs['width']) | |
self.extra_tokens = kwargs['extras'] if 'extras' in kwargs else 0 | |
self.freqs_cis = freqs_cis | |
def forward(self, x): | |
B, N, C = x.shape | |
qkv = self.qkv_proj(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (K, B, H, N, D) | |
q, k, v = qkv[0], qkv[1], qkv[2] # (B, H, N, D) | |
######### Apply rotary position embedding | |
# skip the extra tokens ([CLS], [SEP], time_embedding, etc.) | |
q[:,:,self.extra_tokens:], k[:,:,self.extra_tokens:] = apply_rotary_emb(q[:,:,self.extra_tokens:], k[:,:,self.extra_tokens:], freqs_cis=self.freqs_cis.to(x.device)) | |
######### | |
x = calculate_attention(q, k, v, self.attn_drop, num_heads=self.num_heads, scale=self.scale) | |
x = self.proj(x) | |
x = self.proj_drop(x) | |
return x |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment