Skip to content

Instantly share code, notes, and snippets.

View Chillee's full-sized avatar

Horace He Chillee

View GitHub Profile
@Chillee
Chillee / merge_attention.py
Last active February 22, 2025 17:13
Merge Attention
import torch
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
torch.set_default_device('cuda')
q, k, v = [torch.randn(8, 8, 1024, 64, requires_grad=True) for _ in range(3)]
causal_mask = create_block_mask(lambda b, h, q_idx, kv_idx: q_idx >= kv_idx, None, None, 1024, 1024)
uncausal_mask = create_block_mask(lambda b, h, q_idx, kv_idx: q_idx < kv_idx, None, None, 1024, 1024)
ref_out = flex_attention(q, k, v)
causal_out, causal_lse = flex_attention(q, k, v, block_mask=causal_mask, return_lse=True)
@Chillee
Chillee / peak_mm_perf.py
Last active February 6, 2025 10:05
H100 peak matmul FLOPS
import torch
from triton.testing import do_bench
import torch._inductor.config as config
config.max_autotune_gemm_backends = "cutlass"
torch.set_default_device('cuda')
a = torch.randn(4224, 8192, dtype=torch.bfloat16)
b = torch.randn(2048, 8192, dtype=torch.bfloat16).t()
@Chillee
Chillee / create_block_mask.py
Created October 30, 2024 00:17
Compiling `create_block_mask`
import torch
from triton.testing import do_bench
from torch.nn.attention.flex_attention import create_block_mask, flex_attention, noop_mask
torch.manual_seed(0)
import torch
torch.set_default_device('cuda')
def sliding_window(b, h, q_idx, kv_idx):
@Chillee
Chillee / prefixlm.py
Created August 14, 2024 17:38
FlexAttention examples
import torch.nn as nn
import copy
import torch
from torch.nn.attention.flex_attention import flex_attention, create_block_mask, or_masks, create_mask
from triton.testing import do_bench
from functools import partial
torch.set_default_device('cuda')
B = 4
@Chillee
Chillee / flex_attention_tutorial.py
Last active February 21, 2025 09:33
flex_attention_tutorial.py
import torch
from torch.nn.attention._flex_attention import _create_block_mask, _create_mask
from functools import partial
from torch.nn.attention._flex_attention import _flex_attention
from triton.testing import do_bench
import torch.nn.functional as F
from functools import lru_cache
torch.set_default_device('cuda')
# Example usage
@Chillee
Chillee / assoc_scan.py
Last active February 6, 2025 10:14
Higher Order Kernel - associative scan
import torch
import torch.nn as nn
from torch._higher_order_ops.associative_scan import associative_scan
from triton.testing import do_bench
torch.set_default_device('cuda')
def combine_fn(i, j):
ia, ib = i
ja, jb = j
return ia * ja, ib * ja + jb
@Chillee
Chillee / mm_weird.py
Last active July 31, 2024 06:20
Strangely, Matrix Multiplications Run Faster When Given "Predictable" Data! https://www.thonking.ai/p/strangely-matrix-multiplications
import torch
torch.set_default_device('cuda')
from triton.testing import do_bench
from collections import defaultdict
from functools import partial
import random
random.seed(0)
def get_flops(A, B):
ms = do_bench(lambda: torch.mm(A, B))
@Chillee
Chillee / attention_dim_bench.py
Created April 12, 2024 05:13
You Could Have Invented Flash-Attention!
import torch
from torch.utils.flop_counter import FlopCounterMode
from triton.testing import do_bench
torch.set_default_device('cuda')
def get_flops_achieved(f):
flop_counter = FlopCounterMode(display=False)
with flop_counter:
f()
total_flops = flop_counter.get_total_flops()
@Chillee
Chillee / Q1.py
Last active April 8, 2024 04:07
What Shapes Do Matrix Multiplications Like?
import torch
from triton.testing import do_bench
torch.set_default_device('cuda')
for M, K, N in [(2047, 2048, 2048), (2048, 2047, 2048), (2048, 2048, 2047)]:
A = torch.randn(M, K, dtype=torch.bfloat16)
B = torch.randn(K, N, dtype=torch.bfloat16)
print(f"M={M}, K={K}, N={N}")
print(do_bench(lambda: torch.mm(A, B)))
@Chillee
Chillee / lora_example.py
Last active May 14, 2023 09:45
lora_example.py
import torch
import torch.nn as nn
import torch.nn.utils.parametrize as parametrize
from torch.utils._pytree import tree_map
class LoraTensor(object):
def __init__(self, weights, A, B):
self.weights = weights
self.A = A
self.B = B