Forked from xiabingquan/flash_attention_in_numpy.py
Created
October 14, 2024 12:20
-
-
Save mshr-h/bacc51ed5d068776a7506c3117daf2ad to your computer and use it in GitHub Desktop.
An toy example of flash attention implemented in Numpy.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# A minimal exmaple of flash attention implemented in Numpy | |
# Contact: bingquanxia AT qq.com | |
import unittest | |
from typing import List | |
import numpy as np | |
import torch | |
class SoftMax(object): | |
""" | |
Softmax in Numpy. A naive implementation. | |
""" | |
def forward(self, x: List[float]): | |
# loop 1: get the maximum value | |
max_x = -np.inf | |
for t in x: | |
max_x = t if t > max_x else max_x | |
# loop 2: get the accumulative sum of exp(x_i - x_max) | |
accum_exp = 0. | |
for t in x: | |
accum_exp += np.exp(t - max_x) | |
# loop 3: get the softmax output by dividing the exponential of `x-max(x)` with `accum_exp` | |
output = [0. for _ in range(len(x))] | |
for i, t in enumerate(x): | |
output[i] = np.exp(t - max_x) / accum_exp | |
return output | |
def __call__(self, *args, **kwargs): | |
return self.forward(*args, **kwargs) | |
class SoftMaxWithTiling(object): | |
""" | |
Softmax with tiling in Numpy. A naive implementation. | |
""" | |
def forward(self, x: List[float]): | |
# loop 1: get the maximum value of x and the accumulated exponential values | |
max_x = -np.inf | |
accum_exp = 0. | |
for t in x: | |
max_x_new = t if t > max_x else max_x | |
accum_exp = np.exp(max_x - max_x_new) * accum_exp + np.exp(t - max_x_new) | |
max_x = max_x_new | |
# loop 2: get the softmax output by dividing the exponential of `x-max(x)` with `accum_exp` | |
out = [0. for _ in range(len(x))] | |
for i, t in enumerate(x): | |
out[i] = np.exp(t - max_x) / accum_exp | |
return out | |
def __call__(self, *args, **kwargs): | |
return self.forward(*args, **kwargs) | |
class SoftMaxTest(unittest.TestCase): | |
""" | |
Unit test for SoftMax and SoftMaxWithTiling. | |
""" | |
def test_softmax(self): | |
n_test = 10 | |
for _ in range(n_test): | |
n_elem = np.random.randint(1, 11) | |
x = np.random.randn(n_elem).tolist() | |
expected = torch.nn.functional.softmax(torch.tensor(x), dim=-1).tolist() | |
out = SoftMax()(x) | |
self.assertTrue(np.allclose(expected, out, atol=1e-4)) | |
out_with_tiling = SoftMaxWithTiling()(x) | |
self.assertTrue(np.allclose(expected, out_with_tiling, atol=1e-4)) | |
class StandardAttention(object): | |
def __init__(self) -> None: | |
""" | |
Attention module implemented in Numpy. | |
Formula: | |
P = QK^T | |
S = softmax(P / sqrt(d_k)) | |
O = SV | |
Reference: | |
<<Attention Is All You Need>> | |
URL: | |
https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf | |
""" | |
pass | |
def _validity_check(self, q: np.ndarray, k: np.ndarray, v: np.ndarray) -> None: | |
assert q.ndim == 3, "q should be a 3D tensor" # [batch_size, seq_len, hidden_size] | |
assert k.ndim == 3, "k should be a 3D tensor" | |
assert v.ndim == 3, "v should be a 3D tensor" | |
assert q.shape[0] == k.shape[0], "batch_size of q and k should be the same" | |
assert q.shape[2] == k.shape[2], "hidden_size of q and k should be the same" | |
assert q.shape[2] == v.shape[2], "hidden_size of q and v should be the same" | |
def forward(self, q: np.ndarray, k: np.ndarray, v: np.ndarray) -> np.ndarray: | |
self._validity_check(q, k, v) | |
batch_size, q_len, hidden_size = q.shape | |
denom = np.sqrt(hidden_size) | |
attn = np.matmul(q, k.transpose(0, 2, 1)) # [batch_size, q_len, k_len] | |
attn = np.exp((attn - attn.max(axis=-1, keepdims=True)) / denom) | |
attn = attn / attn.sum(axis=-1, keepdims=True) | |
out = np.matmul(attn, v) # [batch_size, q_len, hidden_size] | |
return out | |
def __call__(self, *args, **kwargs): | |
return self.forward(*args, **kwargs) | |
def self_attention(x): | |
return StandardAttention()(x, x, x) | |
class FlashAttention(object): | |
def __init__(self, row_block_size: int, col_block_size: int) -> None: | |
""" | |
Flash Attention in Numpy. | |
Reference: | |
<<FLASHATTENTION: Fast and Memory-Efficient Exact Attention with IO-Awareness>> | |
https://proceedings.neurips.cc/paper_files/paper/2022/file/67d57c32e20fd0a7a302cb81d36e40d5-Paper-Conference.pdf | |
row_block_size: block size of query | |
col_block_size: block size of key | |
""" | |
# (Line 1): set the block size. | |
# We manually set the block size for query and key, respectively, since we do not know the on-chip SRAM size of GPU. | |
self.row_block_size = row_block_size | |
self.col_block_size = col_block_size | |
def _validity_check(self, q: np.ndarray, k: np.ndarray, v: np.ndarray) -> None: | |
assert q.ndim == 3, "q should be a 3D tensor" # [batch_size, seq_len, hidden_size] | |
assert k.ndim == 3, "k should be a 3D tensor" | |
assert v.ndim == 3, "v should be a 3D tensor" | |
assert q.shape[0] == k.shape[0], "batch_size of q and k should be the same" | |
assert q.shape[2] == k.shape[2] == v.shape[2], "hidden_size of q, k and v should be the same" | |
assert q.shape[1] % self.row_block_size == 0 and k.shape[1] % self.col_block_size == 0, \ | |
"seq_len should be divisible by block_size" | |
@staticmethod | |
def load(arr, st, ed, step): | |
"""Simulate the process that moves data from HBM to SRAM""" | |
return arr[:, st * step: ed * step] | |
@staticmethod | |
def write(arr, val, st, ed, step): | |
"""Simulate the process that moves data from SRAM to HBM""" | |
arr[:, st * step: ed * step] = val | |
def forward(self, q, k, v): | |
""" | |
The following implementation strictly follows the Algorithm 1 in the paper of FLASH-ATTENTION. | |
Except that we put it in a batched way, i.e. the batch_size is the first dimension of q, k, v. | |
Algorithm 1 is on the 5th page of the orginal paper of FLASH-ATTENTION. | |
""" | |
self._validity_check(q, k, v) | |
batch_size, q_len, hidden_size = q.shape | |
k_len = k.shape[1] | |
# (Line 2): initialize O, l and m | |
# O: output, will be updated in a row-block-wise manner | |
out = np.zeros((batch_size, q_len, hidden_size)) | |
# l: exp-sum of each row block, will be the denominator in softmax. | |
# l will be updated in a exponential moving average way. | |
l = np.zeros((batch_size, q_len)) | |
# m: max of each row block, will be part of the numerator in softmax. | |
# m will also be updated in a exponential moving average way. | |
m = np.zeros((batch_size, q_len)) | |
m.fill(-np.inf) | |
# (Line 3): divide q into row blocks and k, v into column blocks | |
Tr = q_len // self.row_block_size # Tr: number of row blocks | |
Tc = k_len // self.col_block_size # Tc: number of column blocks | |
# (Line 4): pass. We do not need to explicitly split the output into row blocks, | |
# but we will update the output in a row-block-wise manner to simulate the process of FLASH-ATTENTION. | |
# (Line 5): iterate over column blocks | |
for j in range(Tc): | |
# (Line 6), load the key and value block | |
# kj: key block, [batch_size, col_block_size, hidden_size] | |
# vj: value block, [batch_size, col_block_size, hidden_size] | |
kj = self.load(k, j, j + 1, self.col_block_size) | |
vj = self.load(v, j, j + 1, self.col_block_size) | |
# (Line 7): iterate over row blocks | |
for i in range(Tr): | |
# (Line 8): load the query block. [batch_size, row_block_size, hidden_size] | |
qi = self.load(q, i, i + 1, self.row_block_size) | |
oi = self.load(out, i, i + 1, self.row_block_size) | |
mi = self.load(m, i, i + 1, self.row_block_size) | |
li = self.load(l, i, i + 1, self.row_block_size) | |
# (Line 9): compute the dot-product attention score | |
sij = np.matmul(qi, kj.transpose(0, 2, 1)) / np.sqrt(hidden_size) | |
# (Line 10): compute max, softmax, and exp-sum | |
mij = np.max(sij, axis=-1) # [batch_size, row_block_size] | |
pij = np.exp((sij - mij[..., np.newaxis])) # [batch_size, row_block_size, col_block_size] | |
lij = pij.sum(axis=-1) # [batch_size, row_block_size] | |
# (Line 11): update m and l | |
# 11.a. update m, the max of each row block | |
m_new = np.maximum.reduce([mi, mij]) | |
# 11.b. update l, the accumulated exp-sum of each row block | |
l_new = np.exp(mi - m_new) * li + np.exp(mij - m_new) * lij | |
# (Line 12): update output | |
temp = li[..., np.newaxis] * np.exp(mi - m_new)[..., np.newaxis] * oi + np.exp(mij - m_new)[..., np.newaxis] * np.matmul(pij, vj) | |
temp /= l_new[..., np.newaxis] | |
self.write(out, temp, i, i + 1, self.row_block_size) | |
# (Line 13): store the m and l of current row block to the global m and l | |
self.write(m, m_new, i, i + 1, self.row_block_size) | |
self.write(l, l_new, i, i + 1, self.row_block_size) | |
return out | |
def __call__(self, *args, **kwargs): | |
return self.forward(*args, **kwargs) | |
class FlashAttentionTest(unittest.TestCase): | |
def run_test(self, batch_size, q_len, k_len, hidden_size, row_block_size, col_block_size): | |
# generate random inputs | |
q = np.random.randn(batch_size, q_len, hidden_size) | |
k = np.random.randn(batch_size, k_len, hidden_size) | |
v = np.random.randn(batch_size, k_len, hidden_size) | |
# standard attention | |
standard_out = StandardAttention()(q, k, v) | |
eps = 1e-8 | |
# scaled_dot_product_attention of PyTorch | |
torch_out = torch.nn.functional.scaled_dot_product_attention(*map(torch.from_numpy, [q, k, v])) | |
self.assertTrue(np.allclose(standard_out, torch_out.numpy(), atol=eps)) | |
# flash attention | |
attn = FlashAttention(row_block_size=row_block_size, col_block_size=col_block_size) | |
flash_out = attn(q, k, v) | |
self.assertTrue(np.allclose(standard_out, flash_out, atol=eps)) | |
def test(self): | |
n_test = 2 | |
batch_size = 2 | |
for row_block_size in (2, 4): | |
for col_block_size in (4, 8): | |
for factor in (10, 20): | |
q_len = row_block_size * factor | |
k_len = col_block_size * factor | |
for _ in range(n_test): | |
for hidden_size in (8, 16, 32): | |
self.run_test(batch_size, q_len, k_len, hidden_size, row_block_size, col_block_size) | |
if __name__ == "__main__": | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment