import numpy as np

N_inp = 64
N_out = 64
d = 128
Q = np.random.randn(N_out, d)
K = np.random.randn(N_inp, d)
V = np.random.randn(N_inp, d)
O = np.random.randn(N_out, d)

Bc = 16
Br = 16
Tc = (N_inp + Bc - 1) // Bc
Tr = (N_out + Br - 1) // Br
scale_factor = 1 / np.sqrt(Q.shape[-1])

L = np.zeros((N_out, 1))
M = np.full((N_out, 1), -np.inf)

for j in range(Tc):
    Kj = K[j * Bc: (j + 1) * Bc]
    Vj = V[j * Bc: (j + 1) * Bc]
    for i in range(Tr):
        Oi = O[i * Br: (i + 1) * Br]
        li = L[i * Br: (i + 1) * Br]
        mi = M[i * Br: (i + 1) * Br]
        Qi = Q[i * Br: (i + 1) * Br]
        Sij = scale_factor * (Qi @ Kj.T)
        mij = np.max(Sij, axis=1, keepdims=True)
        Pij = np.exp(Sij - mij)
        lij = np.sum(Pij, axis=1, keepdims=True)
        mi_new = np.maximum(mi, mij)
        li_new = np.exp(mi - mi_new) * li + np.exp(mij - mi_new) * lij
        Oi = (1.0 / li_new) * (li * np.exp(mi - mi_new) * Oi + np.exp(mij - mi_new) * (Pij @ Vj))
        O[i * Br: (i + 1) * Br] = Oi
        L[i * Br: (i + 1) * Br] = li_new
        M[i * Br: (i + 1) * Br] = mi_new

S_ = scale_factor * Q @ K.T
P_ = np.exp(S_ - np.max(S_, axis=1, keepdims=True))
O_ = (P_ / np.sum(P_, axis=1, keepdims=True)) @ V
assert(np.allclose(O, O_))