Created
July 9, 2024 14:04
-
-
Save saurabh-kataria/77cee5ffa8fc6b3008bed57d52da1b05 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
####################### | |
# CODE BASED ON https://github.com/hyunwoongko/transformer/blob/master/README.md | |
####################### | |
import torch | |
import torch.nn as nn | |
import math | |
from torch.cuda.amp import autocast | |
import torch.nn.functional as F | |
from rotary_embedding_torch import RotaryEmbedding | |
#from torchtune.modules import RotaryPositionalEmbeddings as RotaryEmbedding | |
import math | |
from functools import partial | |
import torch | |
import torch.nn as nn | |
from einops import rearrange, repeat | |
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func | |
from flash_attn.utils.distributed import get_dim_for_local_rank | |
try: | |
from flash_attn import ( | |
flash_attn_kvpacked_func, | |
flash_attn_qkvpacked_func, | |
flash_attn_varlen_kvpacked_func, | |
flash_attn_varlen_qkvpacked_func, | |
flash_attn_with_kvcache, | |
) | |
except ImportError: | |
flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None | |
flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None | |
flash_attn_with_kvcache = None | |
try: | |
from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear | |
except ImportError: | |
FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None | |
#try: | |
# from flash_attn.layers.rotary import RotaryEmbedding | |
#except ImportError: | |
# RotaryEmbedding = None | |
# From https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742 | |
def get_alibi_slopes(nheads): | |
def get_slopes_power_of_2(nheads): | |
start = 2 ** (-(2 ** -(math.log2(nheads) - 3))) | |
ratio = start | |
return [start * ratio**i for i in range(nheads)] | |
if math.log2(nheads).is_integer(): | |
return get_slopes_power_of_2(nheads) | |
else: | |
closest_power_of_2 = 2 ** math.floor(math.log2(nheads)) | |
return ( | |
get_slopes_power_of_2(closest_power_of_2) | |
+ get_alibi_slopes(2 * closest_power_of_2)[0::2][: nheads - closest_power_of_2] | |
) | |
class MultiHeadAttention(nn.Module): | |
def __init__(self, d_model, n_head): | |
super().__init__() | |
self.n_head = n_head | |
self.attention = Attention(d_model // n_head) | |
self.w_q = nn.Linear(d_model, d_model) | |
self.w_k = nn.Linear(d_model, d_model) | |
self.w_v = nn.Linear(d_model, d_model) | |
self.w_concat = nn.Linear(d_model, d_model) | |
def forward(self, q, k, v, mask=None): | |
# 1. dot product with weight matrices | |
q, k, v = self.w_q(q), self.w_k(k), self.w_v(v) | |
# 2. split tensor by number of heads | |
q, k, v = self.split(q), self.split(k), self.split(v) | |
# 3. do scale dot product to compute similarity | |
out, attention = self.attention(q, k, v, mask) | |
# 4. concat and pass to linear layer | |
out = self.concat(out) | |
out = self.w_concat(out) | |
return out | |
def split(self, tensor): | |
""" | |
split tensor by number of head | |
:param tensor: [batch_size, length, d_model] | |
:return: [batch_size, head, length, d_tensor] | |
""" | |
batch_size, length, d_model = tensor.size() | |
d_tensor = d_model // self.n_head | |
tensor = tensor.view(batch_size, length, self.n_head, d_tensor).transpose(1, 2) | |
return tensor | |
@staticmethod | |
def concat(tensor): | |
""" | |
inverse function of self.split(tensor : torch.Tensor) | |
:param tensor: [batch_size, head, length, d_tensor] | |
:return: [batch_size, length, d_model] | |
""" | |
batch_size, head, length, d_tensor = tensor.size() | |
d_model = head * d_tensor | |
tensor = tensor.transpose(1, 2).contiguous().view(batch_size, length, d_model) | |
return tensor | |
## 147hrs per epoch 1 B model 4 GPUs - definitely faster on all types than vanilla O(n^3) matmul att calc | |
#class Attention(nn.Module): | |
# def __init__(self, d_head, dropout=0.1): | |
# super().__init__() | |
# self.dropout = torch.nn.Dropout(dropout) | |
# self.softmax = nn.Softmax(dim=-1) | |
# self.rotary_embed = RotaryEmbedding(d_head//2) | |
# self.first = True | |
# self.dropout_value = dropout | |
# | |
# def forward(self, q, k, v, mask=None): | |
# # apply RoPE | |
# q = self.rotary_embed.rotate_queries_or_keys(q) | |
# k = self.rotary_embed.rotate_queries_or_keys(k) | |
# | |
# d_k = k.size(-1) | |
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True): | |
# scores = F.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout_value if self.training else 0, is_causal=True if mask is not None else False) | |
# | |
# return scores, None | |
# cant conclude this yet - for 1B, it is faster | |
class Attention(nn.Module): | |
"""Implement the scaled dot product attention with softmax. | |
Arguments | |
--------- | |
softmax_scale: The temperature to use for the softmax attention. | |
(default: 1/sqrt(d_keys) where d_keys is computed at | |
runtime) | |
attention_dropout: The dropout rate to apply to the attention | |
(default: 0.0) | |
""" | |
def __init__( | |
self, | |
d_head, | |
dropout=0.1, | |
causal=False, | |
softmax_scale=None, | |
attention_dropout=0.1, | |
window_size=(-1, -1), | |
alibi_slopes=None, | |
deterministic=False, | |
): | |
super().__init__() | |
assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed" | |
assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed" | |
self.causal = causal | |
self.softmax_scale = softmax_scale | |
self.drop = nn.Dropout(attention_dropout) | |
self.register_buffer("alibi_slopes", alibi_slopes, persistent=False) | |
self.window_size = window_size | |
self.deterministic = deterministic | |
self.rotary_embed = RotaryEmbedding(d_head//2) | |
def forward(self, q, k, v, mask=None, causal=None, cu_seqlens=None, max_seqlen=None): | |
"""Implements the multihead softmax attention. | |
Arguments | |
--------- | |
qkv: The tensor containing the query, key, and value. | |
If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D). | |
If cu_seqlens is not None and max_seqlen is not None, then qkv has shape | |
(total, 3, H, D), where total is the sum of the sequence lengths in the batch. | |
causal: if passed, will override self.causal | |
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths | |
of the sequences in the batch, used to index into qkv. | |
max_seqlen: int. Maximum sequence length in the batch. | |
Returns: | |
-------- | |
out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None, | |
else (B, S, H, D). | |
""" | |
# q.shape = B, H, S, D | |
causal = True if mask is not None else False | |
## with torch.cuda.amp.autocast(): | |
q = self.rotary_embed.rotate_queries_or_keys(q) | |
k = self.rotary_embed.rotate_queries_or_keys(k) | |
## print(q.shape, k.shape) # torch.Size([1024, 16, 31, 64]) | |
## q = q.to(v.dtype) | |
## k = k.to(v.dtype) | |
## q = self.rotary_embed(q) | |
## k = self.rotary_embed(k) | |
qkv = torch.concatenate([q.transpose(1,2).unsqueeze(2), k.transpose(1,2).unsqueeze(2), v.transpose(1,2).unsqueeze(2)], axis=2) | |
assert qkv.dtype in [torch.float16, torch.bfloat16], f'{type(qkv)=}' | |
assert qkv.is_cuda | |
causal = self.causal if causal is None else causal | |
unpadded = cu_seqlens is not None | |
if self.alibi_slopes is not None: | |
self.alibi_slopes = self.alibi_slopes.to(torch.float32) | |
if unpadded: | |
assert cu_seqlens.dtype == torch.int32 | |
assert max_seqlen is not None | |
assert isinstance(max_seqlen, int) | |
out = flash_attn_varlen_qkvpacked_func( | |
qkv, | |
cu_seqlens, | |
max_seqlen, | |
self.drop.p if self.training else 0.0, | |
softmax_scale=self.softmax_scale, | |
causal=causal, | |
alibi_slopes=self.alibi_slopes, | |
window_size=self.window_size, | |
self.drop.p if self.training else 0.0, | |
softmax_scale=self.softmax_scale, deterministic=self.deterministic, | |
) | |
else: | |
out = flash_attn_qkvpacked_func( | |
qkv, | |
self.drop.p if self.training else 0.0, | |
softmax_scale=self.softmax_scale, | |
causal=causal, | |
alibi_slopes=self.alibi_slopes, | |
window_size=self.window_size, | |
deterministic=self.deterministic, | |
) | |
return out.transpose(1,2), None | |
#class Attention(nn.Module): | |
# def __init__(self, d_head, dropout=0.1): | |
# super().__init__() | |
# self.dropout = torch.nn.Dropout(dropout) | |
# self.softmax = nn.Softmax(dim=-1) | |
# self.rotary_embed = RotaryEmbedding(d_head//2) | |
# self.first = True | |
# | |
# def forward(self, q, k, v, mask=None): | |
# # apply RoPE | |
# q = self.rotary_embed.rotate_queries_or_keys(q) | |
# k = self.rotary_embed.rotate_queries_or_keys(k) | |
# | |
# d_k = k.size(-1) | |
## if self.first: | |
## print(f'{q.shape=} {k.shape=} {v.shape=}') | |
## self.first = False | |
# | |
## scores = flash_attn_func(q.transpose(1,2), k.transpose(1,2), v.transpose(1,2)) | |
## F.scaled_dot_product_attention(q.transpose(1,2), k.transpose(1,2), v.transpose(1,2)) | |
# scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) | |
# | |
## raise Exception(f'{q.shape=} {k.shape=} {v.shape=} {scores.shape=} {mask.shape=} {mask=}') | |
# if mask is not None: | |
# #mask_value = 1e-9 if scores.dtype == torch.float32 else -1e4 | |
# scores = scores.masked_fill(mask == 0, -torch.inf) | |
# | |
# p_attn = self.softmax(scores) | |
# p_attn = self.dropout(p_attn) | |
# scores = torch.matmul(p_attn, v) | |
# #raise Exception(f'{q.shape=} {k.shape=} {v.shape=} {scores.shape=} {mask.shape=} {mask=}') | |
# # q.shape=torch.Size([128, 20, 31, 80]) k.shape=torch.Size([128, 20, 31, 80]) v.shape=torch.Size([128, 20, 31, 80]) scores.shape=torch.Size([128, 20, 31, 80]) mask.shape=torch.Size([31, 31]) mask=tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., | |
# return scores, p_attn | |
## DISREGARD BELOW VERSIONS | |
#class Attention(nn.Module): | |
# def __init__(self, d_head, dropout=0.1): | |
# super().__init__() | |
# self.dropout = torch.nn.Dropout(dropout) | |
# self.softmax = nn.Softmax(dim=-1) | |
# self.rotary_embed = RotaryEmbedding(d_head//2) | |
# self.first = True | |
# self.dropout_value = dropout | |
# | |
# def forward(self, q, k, v, mask=None): | |
# # apply RoPE | |
# q = self.rotary_embed.rotate_queries_or_keys(q) | |
# k = self.rotary_embed.rotate_queries_or_keys(k) | |
# | |
# d_k = k.size(-1) | |
# raise Exception(f'{type(q)=} {type(k)=} {type(v)=}') | |
# scores = flash_attn_func(q.transpose(1,2), k.transpose(1,2), v.transpose(1,2), dropout_p=self.dropout_value, causal=True if mask is not None else False).transpose(1,2) | |
# | |
# return scores, None | |
#class Attention2(nn.Module): | |
# def __init__(self, d_head, dropout=0.1): | |
# super().__init__() | |
# self.dropout = torch.nn.Dropout(dropout) | |
# self.softmax = nn.Softmax(dim=-1) | |
# self.rotary_embed = RotaryEmbedding(d_head//2) # Assuming this is a defined class elsewhere | |
# self.first = True | |
# | |
# def forward(self, q, k, v, mask=None): | |
# with autocast(enabled=True, dtype=torch.float16): | |
# # Apply RoPE (Rotary Positional Embedding) | |
# q = self.rotary_embed.rotate_queries_or_keys(q) | |
# k = self.rotary_embed.rotate_queries_or_keys(k) | |
# | |
# # Transpose the tensors to match FlashAttention's expected input shape | |
# q_transposed = q.transpose(1, 2).to(dtype=torch.float16) | |
# k_transposed = k.transpose(1, 2).to(dtype=torch.float16) | |
# v_transposed = v.transpose(1, 2).to(dtype=torch.float16) | |
# | |
# if self.first: | |
# print(f'{q_transposed.shape=} {k_transposed.shape=} {v_transposed.shape=} {mask.shape=}') | |
# | |
# # Pass the transposed and casted tensors to FlashAttention function | |
# scores = flash_attn_func(q_transposed, k_transposed, v_transposed, dropout_p=self.dropout if self.training else 0, causal=True if mask is None else False) | |
# | |
# if self.first: | |
# print(f'{q_transposed.shape=} {k_transposed.shape=} {v_transposed.shape=} {scores.shape=} {mask.shape=}') | |
# self.first = False | |
# | |
## # You may want to keep the softmax operation in fp32 for numerical stability | |
## if mask is not None: | |
## scores = scores.float() # Convert back to fp32 if necessary | |
## scores = scores.masked_fill(mask == 0, float('-inf')) | |
# | |
## if mask is not None: | |
## # Ensure mask is broadcastable to the size of scores. | |
## # This might involve unsqueezing dimensions or ensuring it has the right shape. | |
## # For example, if mask should cover the sequence length which is the last dimension: | |
## mask = mask.unsqueeze(1).unsqueeze(2) # Adding dimensions to match scores shape | |
## # The mask needs to be the same dtype as scores, and usually, the mask is not in fp16. | |
## # Convert mask to the same dtype as scores, if necessary | |
## mask = mask.to(dtype=scores.dtype) | |
## # Use broadcasting to apply the mask | |
## scores = scores.masked_fill(mask == 0, float('-inf')) | |
## | |
## p_attn = self.softmax(scores) | |
## p_attn = self.dropout(p_attn) | |
## | |
## # Ensure that 'v' is in the correct dtype before matmul if needed | |
## v = v.to(dtype=scores.dtype) | |
## | |
## return torch.matmul(p_attn, v), p_attn | |
# return scores, None | |
class PositionwiseFeedForward(nn.Module): | |
def __init__(self, d_model, hidden, drop_prob=0.1): | |
super(PositionwiseFeedForward, self).__init__() | |
self.linear1 = nn.Linear(d_model, hidden) | |
self.linear2 = nn.Linear(hidden, d_model) | |
self.relu = nn.ReLU() #inplace=True) | |
self.dropout = nn.Dropout(p=drop_prob) | |
def forward(self, x): | |
x = self.linear1(x) | |
x = self.relu(x) | |
x = self.dropout(x) | |
x = self.linear2(x) | |
return x |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment