- Attention
This section is aimed at understanding a model writers POV.
Classical attention looks like:
Attention(Q, K, V) = Softmax(Q @ K.T) @ V
These variants apply a function on the result of the first matmul in attention. This looks something like:
Attention(Q, K, V) = Softmax(score_mod (Q @ K.T)) @ V
Some examples of attention variants which use this score modifier:
- Masked Attention:
score_mod := mask(dim0, dim1) ? x : 0
- Causal Mask:
score_mod := dim0 >= dim1 ? x : 0
- AliBi Bias:
score_mod := x + (bias[head_dim] * (dim0 - dim1))
- TanH:
score_mod := tanh(x)
- SDPA:
score_mod := x / scale
- Attention with dynamic dimensions
# Same as masked attention
score_mod := mask(dim0, dim1) ? x : 0
The second class of variants affects attention post softmax or modifies softmax entirely:
- Sigmoid Attention:
Sigmoid(Q @ K.T) @ V
- FP8 Attention
Clip(Softmax(Q @ K.T), 0, FP8_MAX) @ V
This class of variants is usually rare.
This section is aimed at understanding a kernel writers POV.
Classical Attention, from a codegen POV looks like:
Q: MxK1
K: K2xK1
V: K2xN
Attention(Q, K, V) := Softmax(Q @ K.T) @ V
Safe-Softmax(x) := {m(x) = e^(max(x) - x) ; m(x) / \sum m(x) }
We will start with this form, and derive Flash Attention, in a vectorized compute form and not in a SIMT, cuda kernel form.
The whole point of an algorithm like Flash Attention is to have a single reduction loop. In it's standard form, Attention has 4 reduction loops:
// Matmul Q @ K.T reduction loop
// S: MxK2
for K in range(K1):
S += Q[:][k] @ K[:][k].T
// Softmax reduction loop (max)
// max: M
for K in range(K2):
max = max(max, S[:][k])
// Softmax reduction loop (sum)
// sum: M
for K in range(K2):
sum += e^(max - S[:][k])
P = e^(max - S) / sum
// Matmul P
// P: MxN
for K in range(K2):
out = P[:][k] @ V[k][:]
This brilliant paper: https://arxiv.org/abs/1805.02867 introduced a beautiful formation of softmax as online softmax, which enables us to combine the reduction loops of softmax.
Online Softmax as a parallel operation:
[mi si] xor [mj sj] := { m = max(mi, mj); [ m, si * e^{mi - m} + sj * e^{mj - m}] }
[mV sV] = [x1 1] xor [x2 1] xor [x3 1] xor ... [xV 1]
Softmax = mV / sV
With online softmax, we get down to 3 reduction loops from 4:
// Matmul Q @ K.T reduction loop
// S: MxK2
for K in range(K1):
S += Q[:][k] @ K[:][k].T
// Online Softmax reduction loop
// max: M
// sum: M
for K in range(K2):
x = S[:][k]
nmax = max (max, x)
nsum = sum * e^(max - newmax) + e^(x - newmax)
// yield
max = nmax
sum = nsum
P = e^(max - S) / sum
// Matmul P
// P: MxN
for K in range(K2):
out = P[:][k] @ V[k][:]
The Flash Attention paper, extended this online softmax form to online softmax
+
matmul. This paper doesn't introduce a nice parallel form like the online
softmax paper, but the operation can be represented in a similar way.
Online Softmax + Matmul as a parallel operation:
[mV sV aV] := { m = max(mi, mj);
[ m,
si * e^{mi - m} + sj * e^{mj - m},
e^{mi - m} @ vi + e^{mj - m} @ vj ] }
// V: K2xN
// aV: MxN
[mV sV aV] = [x1 1 v1] xor
[x2 1 v2] xor
[x3 1 v3] xor ... [xV 1 vV]
Attention = aV / sV
With this addition, we are down to 2 reduction loops.
// Matmul Q @ K.T reduction loop
// S: MxK2
for K in range(K1):
S += Q[:][k] @ K[:][k].T
// Online Softmax + Matmul
// max: M
// sum: M
// acc: MxN
for K in range(K2):
x = S[:][k]
v = V[k][:]
nmax = max (max, x)
nsum = sum * e^(max - newmax) + e^(x - newmax)
nacc = acc * e^{max - newmax} + e^(x - newmax) @ v
// yield
max = nmax
sum = nsum
acc = nacc
// Divide by Sum
acc /= sum
Model writers tell us that K1 dimension here is generally small and generally < 256. This allows us to completly unroll the K1 loop, leaving one reduction loop.
// Flash Attention
// max: M
// sum: M
// acc: MxN
for K in range(K2):
// Matmul Q @ K.T reduction loop
// S: MxK2
for K in range(K1):
S += Q[:][k] @ K[K][k].T
x = S[:][k]
v = S[k][:]
nmax = max (max, x)
nsum = sum * e^(max - newmax) + e^(x - newmax)
nacc = acc * e^{max - newmax} + e^(x - newmax) @ v
// yield
max = nmax
sum = nsum
acc = nacc
// Divide by Sum
acc /= sum
Note that this K1 being small is the difference between Attention and MLP. It allows us to do a chained matmul, while doing it in MLP is generally a bad idea as the K1 dimension is big.
Attention can actually be written in it's perfectly nested loop form:
// S = Q @ K.T
S[:][:] = 0
for M in range(M): // parallel
for K2 in range(K2): // parallel
for K1 in range(K1): // reduction
S[M][K2] += Q[M][K1] @ K[K2][K1]
// max, sum, acc = OnlineSoftmax(S) @ V
max[:] = -inf
sum[:] = 0
acc[:][:] = 0
for M in range(M): // parallel
cmax = max[M]
csum = sum[M]
cacc = acc[M][:]
for K2 in range(K2): // reduction
x = S[M][K2]
v = V[K2][:]
nmax = max(cmax, x)
norm = e^{cmax - newmax}
p = e^{x - newmax}
nsum = p + csum * norm
nacc = p @ v + cacc * norm
cmax = nmax
csum = nsum
cacc = nacc
// result = acc / sum
for M in range(M):
result[M][:] = acc[M][:] / sum[M]
The cool part about this perfectly nested loop form is that you can think of flash attention as coalescing the K2 loop between the two perfectly nested loops (and coalescing the M loop and distributing it to workgroups).
This loop nest can actually be represented completly as linalg.generics:
func.func @attention(%Q : tensor<4096x64xf32>,
%K : tensor<4096x64xf32>,
%V : tensor<4096x64xf32>,
%bias : tensor<4096x64xf32>)
-> tensor<4096x64xf32> {
%c_sum = arith.constant 0.0f : f32
%c_max = arith.constant -1e38 : f32
%empty = tensor.empty() : tensor<4096x4096xf32>
%rowempty = tensor.empty() : tensor<4096x64xf32>
%S = linalg.generic
{ indexing_maps = [affine_map<(m, k2, k1) -> (m, k1)>,
affine_map<(m, k2, k1) -> (k2, k1)>,
affine_map<(m, k2, k1) -> (m, k2)],
iterator_types = ["parallel", "parallel", "reduction"] }
ins(%Q, %K : tensor<4096x64xf32>, tensor<4096x64xf32>)
outs(%empty : tensor<4096x4096xf32>) {
^bb0(%q : f32, %k : f32, %s : f32):
%mul = arith.mulf %q, %k : f32
%add = arith.addf %mul, %s : f32
linalg.yield %add : f32
} -> tensor<4096x4096xf32>
%maxinit = linalg.fill ins(%c_max : f32) outs(%rowempty : tensor<4096x64xf32>)
%suminit = linalg.fill ins(%c_sum : f32) outs(%rowempty : tensor<4096x64xf32>)
%MAX, %SUM, %PV = linalg.generic
{ indexing_maps = [affine_map<(m, n, k2) -> (m, k2)>,
affine_map<(m, n, k2) -> (k2, n)>,
affine_map<(m, n, k2) -> (m, n)>,
affine_map<(m, n, k2) -> (m, n)>,
affine_map<(m, n, k2) -> (m, n)>],
iterator_types = ["parallel", "parallel", "reduction"] }
ins(%S, %V : tensor<4096x4096xf32>, tensor<4096x64xf32>)
outs(%maxinit, %suminit, %ACC : tensor<4096x64xf32>, tensor<4096x64xf32>, tensor<4096x64xf32>) {
^bb0(%s : f32, %v : f32, %m : f32, %s : f32, %a : f32):
// softmax
%nmax = arith.maximumf %m, %x : f32
%tmp = arith.subf %m, %nmax : f32
%norm = math.exp2 %tmp : f32
%normsum = arith.mulf %norm, %s : f32
%ktmp = arith.subf %x, %nmax : f32
%k = math.exp2 %ktmp : f32
%nsum = arith.addf %normsum, %k : f32
// matmul
%normacc = arith.mulf %a, %norm : f32
%pv = arith.mulf %k, %v : f32
%nacc = arith.addf %normacc : f32
linalg.yield %nmax, %nsum, %nacc : f32, f32, f32
} -> tensor<4096x64xf32>, tensor<4096x64xf32> tensor<4096x64xf32>
%out = linalg.generic
{indexing_maps = [affine_map<(m, n) -> (m, n)>,
affine_map<(m, n) -> (m, 0)>,
affine_map<(m, n) -> (m, n)>],
iterator_types = ["parallel", "parallel"]}
ins(%SUM, %PV : tensor<4096x64xf32>, tensor<4096x64xf32>)
outs(%rowempty : tensor<4096x64xf32>) {
^bb0(%s : f32, %pv : f32, %out : f32):
%out = arith.divf %pv, %s : f32
linalg.yield %out : f32
} -> tensor<4096x64xf32>
func.return %out : tensor<4096x64xf32>
}
Note how online softmax + pv matmul is actually one generic. This is the true form of attention. Today, the way attention decomposes actually handicaps us.
Simply tiling the K2 dimension on the second linalg.generic and fusing, gives us a single reduction loop, which is Flash Attention 2.
-
The softmax+matmul generic is not vectorizable. It is non-trivial to to reason that the generic actually has block dependencies instead of scalar dependencies. (Bigger Issue)
-
max/sum are broadcasted along N dimension. While this can be be fixed after tiling with subset-hoisting, it is still adding redundant computation which is eliminated by an analysis. (Smaller Issue)
Until we figure out how to represent these block dependencies in linalg, we use
a custom op: linalg_ext.online_softmax_matmul
. This op can be thought of as
the softmax+matmul linalg generic, but can be decomposed into it's blockwise
form.
func.func @attention(%Q : tensor<4096x64xf32>,
%K : tensor<4096x64xf32>,
%V : tensor<4096x64xf32>,
%bias : tensor<4096x64xf32>)
-> tensor<4096x64xf32> {
%c_sum = arith.constant 0.0f : f32
%c_max = arith.constant -1e38 : f32
%empty = tensor.empty() : tensor<4096x4096xf32>
%rowempty = tensor.empty() : tensor<4096x64xf32>
%S = linalg.generic
{ indexing_maps = [affine_map<(m, k2, k1) -> (m, k1)>,
affine_map<(m, k2, k1) -> (k2, k1)>,
affine_map<(m, k2, k1) -> (m, k2)],
iterator_types = ["parallel", "parallel", "reduction"] }
ins(%Q, %K : tensor<4096x64xf32>, tensor<4096x64xf32>)
outs(%empty : tensor<4096x4096xf32>) {
^bb0(%q : f32, %k : f32, %s : f32):
%mul = arith.mulf %q, %k : f32
%add = arith.addf %mul, %s : f32
linalg.yield %add : f32
} -> tensor<4096x4096xf32>
%maxinit = linalg.fill ins(%c_max : f32) outs(%rowempty : tensor<4096x64xf32>)
%suminit = linalg.fill ins(%c_sum : f32) outs(%rowempty : tensor<4096x64xf32>)
%MAX, %SUM, %PV = linalg.online_softmax_matmul
{ indexing_maps = [affine_map<(m, n, k2) -> (m, k2)>,
affine_map<(m, n, k2) -> (k2, n)>,
affine_map<(m, n, k2) -> (m, n)>,
affine_map<(m, n, k2) -> (m, n)>,
affine_map<(m, n, k2) -> (m, n)>],
iterator_types = ["parallel", "parallel", "reduction"] }
ins(%S, %V : tensor<4096x4096xf32>, tensor<4096x64xf32>)
outs(%maxinit, %suminit, %ACC : tensor<4096x64xf32>, tensor<4096x64xf32>, tensor<4096x64xf32>)
-> tensor<4096x64xf32>, tensor<4096x64xf32> tensor<4096x64xf32>
%out = linalg.generic
{indexing_maps = [affine_map<(m, n) -> (m, n)>,
affine_map<(m, n) -> (m, 0)>,
affine_map<(m, n) -> (m, n)>],
iterator_types = ["parallel", "parallel"]}
ins(%SUM, %PV : tensor<4096x64xf32>, tensor<4096x64xf32>)
outs(%rowempty : tensor<4096x64xf32>) {
^bb0(%s : f32, %pv : f32, %out : f32):
%out = arith.divf %pv, %s : f32
linalg.yield %out : f32
} -> tensor<4096x64xf32>
func.return %out : tensor<4096x64xf32>
}
Hi, thanks for the great note! I can't find more easy and comprehensible tutorial for flash attention and mlir linalg
I found some errata;
=========
score_mod := mask(dim0, dim1) ? x : 0
-> 0 might be -inf,
=========
->
sV
<-s
would be a more proper name?=========
sum += e^(max - S[:][k])
should besum += e^(S[:][k] - max)
?