import numpy as np N_inp = 64 N_out = 64 d = 128 Q = np.random.randn(N_out, d) K = np.random.randn(N_inp, d) V = np.random.randn(N_inp, d) O = np.random.randn(N_out, d) Bc = 16 Br = 16 Tc = (N_inp + Bc - 1) // Bc Tr = (N_out + Br - 1) // Br scale_factor = 1 / np.sqrt(Q.shape[-1]) L = np.zeros((N_out, 1)) M = np.full((N_out, 1), -np.inf) for j in range(Tc): Kj = K[j * Bc: (j + 1) * Bc] Vj = V[j * Bc: (j + 1) * Bc] for i in range(Tr): Oi = O[i * Br: (i + 1) * Br] li = L[i * Br: (i + 1) * Br] mi = M[i * Br: (i + 1) * Br] Qi = Q[i * Br: (i + 1) * Br] Sij = scale_factor * (Qi @ Kj.T) mij = np.max(Sij, axis=1, keepdims=True) Pij = np.exp(Sij - mij) lij = np.sum(Pij, axis=1, keepdims=True) mi_new = np.maximum(mi, mij) li_new = np.exp(mi - mi_new) * li + np.exp(mij - mi_new) * lij Oi = (1.0 / li_new) * (li * np.exp(mi - mi_new) * Oi + np.exp(mij - mi_new) * (Pij @ Vj)) O[i * Br: (i + 1) * Br] = Oi L[i * Br: (i + 1) * Br] = li_new M[i * Br: (i + 1) * Br] = mi_new S_ = scale_factor * Q @ K.T P_ = np.exp(S_ - np.max(S_, axis=1, keepdims=True)) O_ = (P_ / np.sum(P_, axis=1, keepdims=True)) @ V assert(np.allclose(O, O_))