Skip to content

Instantly share code, notes, and snippets.

@yuanzhi-zhu
Created June 12, 2024 20:57
Show Gist options
  • Save yuanzhi-zhu/1fb3524be484bc80eb63558d5068540c to your computer and use it in GitHub Desktop.
Save yuanzhi-zhu/1fb3524be484bc80eb63558d5068540c to your computer and use it in GitHub Desktop.
self attention with RoPE
"""
This code was originally obtained from:
https://github.com/meta-llama/codellama/blob/main/llama/model.py
adapted from https://github.com/naver-ai/rope-vit/blob/main/self-attn/rope_self_attn.py
"""
import torch
import torch.nn as nn
from functools import partial
import einops
if hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
ATTENTION_MODE = 'flash'
else:
try:
import xformers
import xformers.ops
ATTENTION_MODE = 'xformers'
except:
ATTENTION_MODE = 'math'
def init_pos_xy(height: int, width: int):
# get the x and y coordinates of the grid for each patch
num_tokens = torch.arange(height * width, dtype=torch.float32) # total number of pixels
pos_x = (num_tokens // height).float() # x coordinate; row
pos_y = (num_tokens % width).float() # y coordinate; column
return pos_x, pos_y
def compute_axial_cis(dim: int, height: int, width: int, theta: float = 10_000):
"""Compute 2D axial RoPE.
inputs:
- dim: the head dimension of the token (dim//head_size)
- height: the width of the image // patch_size
- width: the height of the image // patch_size
- theta: the base of the exponential function
"""
freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
pos_x, pos_y = init_pos_xy(height, width)
freqs_x = torch.outer(pos_x, freqs) # for the first half dimension
freqs_y = torch.outer(pos_y, freqs) # for the second half dimension
freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor):
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # (B, H, N, D) -> (B, H, N, D//2, 2) -> (B, H, N, D//2)
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # (B, H, N, D) -> (B, H, N, D//2, 2) -> (B, H, N, D//2)
assert freqs_cis.shape == (xq_.shape[-2], xq_.shape[-1]), f'freqs_cis.shape != (xq_.shape[-2], xq_.shape[-1]), {freqs_cis.shape} != {(xq_.shape[-2], xq_.shape[-1])}'
xq_out = torch.view_as_real(torch.einsum('bhnd,nd->bhnd', xq_, freqs_cis)).flatten(3)
xk_out = torch.view_as_real(torch.einsum('bhnd,nd->bhnd', xk_, freqs_cis)).flatten(3)
return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
# modified from https://github.com/baofff/U-ViT/blob/ce551708dc9cde9818d2af7d84dfadfeb7bd9034/libs/uvit.py#L66
def calculate_attention(q, k, v, attn_drop_fn, num_heads=1, scale=1, mask=None, dropout=0., training=True):
if ATTENTION_MODE == 'flash':
q, k, v = q.float(), k.float(), v.float() # float32 more stable?
attn = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=dropout if training else 0)
attn = einops.rearrange(attn, 'B H N D -> B N (H D)')
elif ATTENTION_MODE == 'xformers':
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3) # (B, N, H, D)
attn = xformers.ops.memory_efficient_attention(q, k, v, p=dropout if training else 0)
attn = einops.rearrange(attn, 'B N H D -> B N (H D)', H=num_heads)
elif ATTENTION_MODE == 'math':
B, H, N, D = q.shape
attn = (q @ k.transpose(-2, -1)) * scale
if mask is not None:
attn = attn.masked_fill(mask == 0, float('-inf'))
attn = attn.softmax(dim=-1)
attn = attn_drop_fn(attn)
attn = (attn @ v).transpose(1, 2).reshape(B, N, H*D)
else:
raise NotImplemented
return attn
class Attention(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., **kwargs):
super().__init__()
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.qkv_proj = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.qkv_proj.num_heads = num_heads
self.attn_drop = attn_drop
self.attn_drop_fn = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv_proj(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (K, B, H, N, D)
q, k, v = qkv[0], qkv[1], qkv[2] # (B, H, N, D)
x = calculate_attention(q, k, v, self.attn_drop_fn, num_heads=self.num_heads, scale=self.scale, dropout=self.attn_drop, training=self.training)
x = self.proj(x)
x = self.proj_drop(x)
return x
class RoPEAttention(Attention):
"""Multi-head Attention block with rotary position embeddings."""
def __init__(self, *args, rope_theta=10_000, **kwargs):
super().__init__(*args, **kwargs)
self.compute_cis = partial(compute_axial_cis, dim=self.dim // self.num_heads, theta=rope_theta)
freqs_cis = self.compute_cis(height = kwargs['height'], width = kwargs['width'])
self.extra_tokens = kwargs['extras'] if 'extras' in kwargs else 0
self.freqs_cis = freqs_cis
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv_proj(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (K, B, H, N, D)
q, k, v = qkv[0], qkv[1], qkv[2] # (B, H, N, D)
######### Apply rotary position embedding
# skip the extra tokens ([CLS], [SEP], time_embedding, etc.)
q[:,:,self.extra_tokens:], k[:,:,self.extra_tokens:] = apply_rotary_emb(q[:,:,self.extra_tokens:], k[:,:,self.extra_tokens:], freqs_cis=self.freqs_cis.to(x.device))
#########
x = calculate_attention(q, k, v, self.attn_drop, num_heads=self.num_heads, scale=self.scale)
x = self.proj(x)
x = self.proj_drop(x)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment