Skip to content

Instantly share code, notes, and snippets.

@Groverkss
Last active January 22, 2025 05:45
Show Gist options
  • Save Groverkss/7c5eccc6547c8d6c817a263c1d9c7bc9 to your computer and use it in GitHub Desktop.
Save Groverkss/7c5eccc6547c8d6c817a263c1d9c7bc9 to your computer and use it in GitHub Desktop.

Attention

Section 1 : What do attention variants really look like?

This section is aimed at understanding a model writers POV.

Classical attention looks like:

Attention(Q, K, V) = Softmax(Q @ K.T) @ V

Variant class 1 : Score modifiers (Common)

These variants apply a function on the result of the first matmul in attention. This looks something like:

Attention(Q, K, V) = Softmax(score_mod (Q @ K.T)) @ V

Some examples of attention variants which use this score modifier:

  1. Masked Attention:
score_mod := mask(dim0, dim1) ? x : 0
  1. Causal Mask:
score_mod := dim0 >= dim1 ? x : 0
  1. AliBi Bias:
score_mod := x + (bias[head_dim] * (dim0 - dim1))
  1. TanH:
score_mod := tanh(x)
  1. SDPA:
score_mod := x / scale
  1. Attention with dynamic dimensions
# Same as masked attention
score_mod := mask(dim0, dim1) ? x : 0

Variant class 2 : Quantization and Not Really Attention? (Rare)

The second class of variants affects attention post softmax or modifies softmax entirely:

  1. Sigmoid Attention:
Sigmoid(Q @ K.T) @ V
  1. FP8 Attention
Clip(Softmax(Q @ K.T), 0, FP8_MAX) @ V

This class of variants is usually rare.

Section 2 : The Math Behind Attention Codegen

This section is aimed at understanding a kernel writers POV.

Classical Attention, from a codegen POV looks like:

Q: MxK1
K: K2xK1
V: K2xN

Attention(Q, K, V) := Softmax(Q @ K.T) @ V

Safe-Softmax(x) := {m(x) = e^(max(x) - x) ; m(x) / \sum m(x) }

We will start with this form, and derive Flash Attention, in a vectorized compute form and not in a SIMT, cuda kernel form.

Reduction loops in attention

The whole point of an algorithm like Flash Attention is to have a single reduction loop. In it's standard form, Attention has 4 reduction loops:

// Matmul Q @ K.T reduction loop
// S: MxK2
for K in range(K1):
    S += Q[:][k] @ K[:][k].T

// Softmax reduction loop (max)
// max: M
for K in range(K2):
    max = max(max, S[:][k])

// Softmax reduction loop (sum)
// sum: M
for K in range(K2):
    sum += e^(max - S[:][k])

P = e^(max - S) / sum

// Matmul P
// P: MxN
for K in range(K2):
    out = P[:][k] @ V[k][:]

Online Softmax?

This brilliant paper: https://arxiv.org/abs/1805.02867 introduced a beautiful formation of softmax as online softmax, which enables us to combine the reduction loops of softmax.

Online Softmax as a parallel operation:

[mi si] xor [mj sj] := { m = max(mi, mj); [ m, si * e^{mi - m} + sj * e^{mj - m}] }

[mV sV] = [x1 1] xor [x2 1] xor [x3 1] xor ... [xV 1]

Softmax = mV / sV

Attention with online softmax:

With online softmax, we get down to 3 reduction loops from 4:

// Matmul Q @ K.T reduction loop
// S: MxK2
for K in range(K1):
    S += Q[:][k] @ K[:][k].T

// Online Softmax reduction loop
// max: M
// sum: M
for K in range(K2):
    x = S[:][k]
    nmax = max (max, x) 
    nsum = sum * e^(max - newmax) + e^(x - newmax)
    // yield
    max = nmax
    sum = nsum

P = e^(max - S) / sum

// Matmul P
// P: MxN
for K in range(K2):
    out = P[:][k] @ V[k][:]

Online Softmax plus Matmul?

The Flash Attention paper, extended this online softmax form to online softmax + matmul. This paper doesn't introduce a nice parallel form like the online softmax paper, but the operation can be represented in a similar way.

Online Softmax + Matmul as a parallel operation:

[mV sV aV] := { m = max(mi, mj); 
               [ m,
                 si * e^{mi - m} + sj * e^{mj - m},
                 e^{mi - m} @ vi + e^{mj - m} @ vj ] }

// V: K2xN
// aV: MxN
[mV sV aV] = [x1 1 v1] xor 
             [x2 1 v2] xor 
             [x3 1 v3] xor ... [xV 1 vV]

Attention = aV / sV

Attention with Online Softmax+Matmul

With this addition, we are down to 2 reduction loops.

// Matmul Q @ K.T reduction loop
// S: MxK2
for K in range(K1):
    S += Q[:][k] @ K[:][k].T

// Online Softmax + Matmul
// max: M
// sum: M
// acc: MxN
for K in range(K2):
    x = S[:][k]
    v = V[k][:]
    nmax = max (max, x) 
    nsum = sum * e^(max - newmax) + e^(x - newmax)
    nacc = acc * e^{max - newmax} + e^(x - newmax) @ v
    // yield
    max = nmax
    sum = nsum
    acc = nacc

// Divide by Sum
acc /= sum

Towards a single reduction loop

Model writers tell us that K1 dimension here is generally small and generally < 256. This allows us to completly unroll the K1 loop, leaving one reduction loop.

// Flash Attention
// max: M
// sum: M
// acc: MxN
for K in range(K2):
    // Matmul Q @ K.T reduction loop
    // S: MxK2
    for K in range(K1):
        S += Q[:][k] @ K[K][k].T

    x = S[:][k]
    v = S[k][:]
    nmax = max (max, x) 
    nsum = sum * e^(max - newmax) + e^(x - newmax)
    nacc = acc * e^{max - newmax} + e^(x - newmax) @ v
    // yield
    max = nmax
    sum = nsum
    acc = nacc

// Divide by Sum
acc /= sum

Note that this K1 being small is the difference between Attention and MLP. It allows us to do a chained matmul, while doing it in MLP is generally a bad idea as the K1 dimension is big.

Perfectly nested loop form

Attention can actually be written in it's perfectly nested loop form:

// S = Q @ K.T
S[:][:] = 0
for M in range(M): // parallel
    for K2 in range(K2): // parallel
        for K1 in range(K1): // reduction
            S[M][K2] += Q[M][K1] @ K[K2][K1]

// max, sum, acc = OnlineSoftmax(S) @ V
max[:] = -inf
sum[:] = 0
acc[:][:] = 0
for M in range(M): // parallel
    cmax = max[M]
    csum = sum[M]
    cacc = acc[M][:]
    for K2 in range(K2): // reduction
        x = S[M][K2]
        v = V[K2][:]

        nmax = max(cmax, x)
        norm = e^{cmax - newmax}
        p    = e^{x - newmax}

        nsum = p + csum * norm
        nacc = p @ v + cacc * norm

        cmax = nmax
        csum = nsum
        cacc = nacc

// result = acc / sum
for M in range(M):
    result[M][:] = acc[M][:] / sum[M]

The cool part about this perfectly nested loop form is that you can think of flash attention as coalescing the K2 loop between the two perfectly nested loops (and coalescing the M loop and distributing it to workgroups).

linalg.generic form

This loop nest can actually be represented completly as linalg.generics:

func.func @attention(%Q : tensor<4096x64xf32>, 
                     %K : tensor<4096x64xf32>, 
                     %V : tensor<4096x64xf32>, 
                     %bias : tensor<4096x64xf32>) 
                     -> tensor<4096x64xf32> {

  %c_sum = arith.constant 0.0f : f32
  %c_max = arith.constant -1e38 : f32
  %empty = tensor.empty() : tensor<4096x4096xf32>
  %rowempty = tensor.empty() : tensor<4096x64xf32>

  %S = linalg.generic
  { indexing_maps = [affine_map<(m, k2, k1) -> (m, k1)>,
                     affine_map<(m, k2, k1) -> (k2, k1)>,
                     affine_map<(m, k2, k1) -> (m, k2)],
  iterator_types = ["parallel", "parallel", "reduction"] }
  ins(%Q, %K : tensor<4096x64xf32>, tensor<4096x64xf32>)
  outs(%empty : tensor<4096x4096xf32>) {
  ^bb0(%q : f32, %k : f32, %s : f32):
    %mul = arith.mulf %q, %k : f32
    %add = arith.addf %mul, %s : f32
    linalg.yield %add : f32
  } -> tensor<4096x4096xf32>

  %maxinit = linalg.fill ins(%c_max : f32) outs(%rowempty : tensor<4096x64xf32>)
  %suminit = linalg.fill ins(%c_sum : f32) outs(%rowempty : tensor<4096x64xf32>)

  %MAX, %SUM, %PV = linalg.generic
  { indexing_maps = [affine_map<(m, n, k2) -> (m, k2)>,
                     affine_map<(m, n, k2) -> (k2, n)>,
                     affine_map<(m, n, k2) -> (m, n)>,
                     affine_map<(m, n, k2) -> (m, n)>,
                     affine_map<(m, n, k2) -> (m, n)>],
  iterator_types = ["parallel", "parallel", "reduction"] }
  ins(%S, %V : tensor<4096x4096xf32>, tensor<4096x64xf32>)
  outs(%maxinit, %suminit, %ACC : tensor<4096x64xf32>, tensor<4096x64xf32>, tensor<4096x64xf32>) {
  ^bb0(%s : f32, %v : f32, %m : f32, %s : f32, %a : f32):
    // softmax
    %nmax           = arith.maximumf %m, %x   : f32
    %tmp            = arith.subf %m, %nmax    : f32
    %norm           = math.exp2 %tmp          : f32
    %normsum        = arith.mulf %norm, %s    : f32
    %ktmp           = arith.subf %x, %nmax    : f32
    %k              = math.exp2 %ktmp         : f32
    %nsum           = arith.addf %normsum, %k : f32
    // matmul
    %normacc        = arith.mulf %a, %norm    : f32
    %pv             = arith.mulf %k, %v       : f32
    %nacc           = arith.addf %normacc     : f32
    linalg.yield      %nmax, %nsum, %nacc     : f32, f32, f32
  } -> tensor<4096x64xf32>, tensor<4096x64xf32> tensor<4096x64xf32>

  %out = linalg.generic
  {indexing_maps = [affine_map<(m, n) -> (m, n)>,
                    affine_map<(m, n) -> (m, 0)>,
                    affine_map<(m, n) -> (m, n)>],
  iterator_types = ["parallel", "parallel"]}
  ins(%SUM, %PV : tensor<4096x64xf32>, tensor<4096x64xf32>)
  outs(%rowempty : tensor<4096x64xf32>) {
  ^bb0(%s : f32, %pv : f32, %out : f32):
    %out = arith.divf %pv, %s : f32
    linalg.yield %out : f32
  } -> tensor<4096x64xf32>

  func.return %out : tensor<4096x64xf32>
}

Note how online softmax + pv matmul is actually one generic. This is the true form of attention. Today, the way attention decomposes actually handicaps us.

Simply tiling the K2 dimension on the second linalg.generic and fusing, gives us a single reduction loop, which is Flash Attention 2.

Problems with this generic form

  1. The softmax+matmul generic is not vectorizable. It is non-trivial to to reason that the generic actually has block dependencies instead of scalar dependencies. (Bigger Issue)

  2. max/sum are broadcasted along N dimension. While this can be be fixed after tiling with subset-hoisting, it is still adding redundant computation which is eliminated by an analysis. (Smaller Issue)

Solution

Until we figure out how to represent these block dependencies in linalg, we use a custom op: linalg_ext.online_softmax_matmul. This op can be thought of as the softmax+matmul linalg generic, but can be decomposed into it's blockwise form.

linalg IR with custom op

func.func @attention(%Q : tensor<4096x64xf32>, 
                     %K : tensor<4096x64xf32>, 
                     %V : tensor<4096x64xf32>, 
                     %bias : tensor<4096x64xf32>) 
                     -> tensor<4096x64xf32> {

  %c_sum = arith.constant 0.0f : f32
  %c_max = arith.constant -1e38 : f32
  %empty = tensor.empty() : tensor<4096x4096xf32>
  %rowempty = tensor.empty() : tensor<4096x64xf32>

  %S = linalg.generic
  { indexing_maps = [affine_map<(m, k2, k1) -> (m, k1)>,
                     affine_map<(m, k2, k1) -> (k2, k1)>,
                     affine_map<(m, k2, k1) -> (m, k2)],
  iterator_types = ["parallel", "parallel", "reduction"] }
  ins(%Q, %K : tensor<4096x64xf32>, tensor<4096x64xf32>)
  outs(%empty : tensor<4096x4096xf32>) {
  ^bb0(%q : f32, %k : f32, %s : f32):
    %mul = arith.mulf %q, %k : f32
    %add = arith.addf %mul, %s : f32
    linalg.yield %add : f32
  } -> tensor<4096x4096xf32>

  %maxinit = linalg.fill ins(%c_max : f32) outs(%rowempty : tensor<4096x64xf32>)
  %suminit = linalg.fill ins(%c_sum : f32) outs(%rowempty : tensor<4096x64xf32>)

  %MAX, %SUM, %PV = linalg.online_softmax_matmul
  { indexing_maps = [affine_map<(m, n, k2) -> (m, k2)>,
                     affine_map<(m, n, k2) -> (k2, n)>,
                     affine_map<(m, n, k2) -> (m, n)>,
                     affine_map<(m, n, k2) -> (m, n)>,
                     affine_map<(m, n, k2) -> (m, n)>],
  iterator_types = ["parallel", "parallel", "reduction"] }
  ins(%S, %V : tensor<4096x4096xf32>, tensor<4096x64xf32>)
  outs(%maxinit, %suminit, %ACC : tensor<4096x64xf32>, tensor<4096x64xf32>, tensor<4096x64xf32>)
  -> tensor<4096x64xf32>, tensor<4096x64xf32> tensor<4096x64xf32>

  %out = linalg.generic
  {indexing_maps = [affine_map<(m, n) -> (m, n)>,
                    affine_map<(m, n) -> (m, 0)>,
                    affine_map<(m, n) -> (m, n)>],
  iterator_types = ["parallel", "parallel"]}
  ins(%SUM, %PV : tensor<4096x64xf32>, tensor<4096x64xf32>)
  outs(%rowempty : tensor<4096x64xf32>) {
  ^bb0(%s : f32, %pv : f32, %out : f32):
    %out = arith.divf %pv, %s : f32
    linalg.yield %out : f32
  } -> tensor<4096x64xf32>

  func.return %out : tensor<4096x64xf32>
}
@ita9naiwa
Copy link

ita9naiwa commented Jan 18, 2025

Hi, thanks for the great note! I can't find more easy and comprehensible tutorial for flash attention and mlir linalg

I found some errata;

=========

score_mod := mask(dim0, dim1) ? x : 0
-> 0 might be -inf,

=========

[mV sV aV] := { m = max(mi, mj); 
               [ m,
                 si * e^{mi - m} + sj * e^{mj - m},
                 e^{mi - m} @ vi + e^{mj - m} @ vj ] }

// V: K2xN
// aV: MxN
[mV sV aV] = [x1 1 v1] xor 
             [x2 1 v2] xor 
             [x3 1 v3] xor ... [xV 1 vV]

-> sV <- s would be a more proper name?

=========

// Softmax reduction loop (sum)
// sum: M
for K in range(K2):
    sum += e^(max - S[:][k])

sum += e^(max - S[:][k]) should be sum += e^(S[:][k] - max)?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment