Skip to content

Instantly share code, notes, and snippets.

@Lyken17
Created September 30, 2022 10:08
Show Gist options
  • Save Lyken17/f31c8e95bbf09eadab2bd88f6aaea91a to your computer and use it in GitHub Desktop.
Save Lyken17/f31c8e95bbf09eadab2bd88f6aaea91a to your computer and use it in GitHub Desktop.
flash attention
batch_size, seq_length, embed_dim = x.size()
# B, T, D
qkv = self.qkv_proj(x) # B, T, 3xE
# head_dim = embed_dim // num_heads
# Separate Q, K, V from linear output
qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3 * self.head_dim) # B, T, H, 3xHD
qkv = qkv.permute(0, 2, 1, 3) # B, H, T, 3xHD
q, k, v = qkv.chunk(3, dim=-1) # B, H, T, HD
# Determine value outputs
# O(T^2 HD)
# (B, H, T, HD) x (B, H, HD, T) => (B, H, T, T)
attn_logits = torch.matmul(q, k.transpose(-2, -1))
attn_logits = attn_logits / math.sqrt(q.shape[-1])
attention = F.softmax(attn_logits, dim=-1)
# (B, H, T, T) x (B, H, T, HD) => (B, H, T, HD)
values = torch.matmul(attention, v)
# values, attention = scaled_dot_product(q, k, v, mask=mask)
values = values.permute(0, 2, 1, 3) # [Batch, SeqLen, Head, Dims]
values = values.reshape(batch_size, seq_length, embed_dim)
o = self.o_proj(values)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment