Skip to content

Instantly share code, notes, and snippets.

View Chillee's full-sized avatar

Horace He Chillee

View GitHub Profile
@Chillee
Chillee / create_block_mask.py
Created October 30, 2024 00:17
Compiling `create_block_mask`
import torch
from triton.testing import do_bench
from torch.nn.attention.flex_attention import create_block_mask, flex_attention, noop_mask
torch.manual_seed(0)
import torch
torch.set_default_device('cuda')
def sliding_window(b, h, q_idx, kv_idx):
@Chillee
Chillee / prefixlm.py
Created August 14, 2024 17:38
FlexAttention examples
import torch.nn as nn
import copy
import torch
from torch.nn.attention.flex_attention import flex_attention, create_block_mask, or_masks, create_mask
from triton.testing import do_bench
from functools import partial
torch.set_default_device('cuda')
B = 4
@Chillee
Chillee / flex_attention_tutorial.py
Last active November 2, 2024 15:41
flex_attention_tutorial.py
import torch
from torch.nn.attention._flex_attention import _create_block_mask, _create_mask
from functools import partial
from torch.nn.attention._flex_attention import _flex_attention
from triton.testing import do_bench
import torch.nn.functional as F
from functools import lru_cache
torch.set_default_device('cuda')
# Example usage
@Chillee
Chillee / assoc_scan.py
Last active May 31, 2024 21:52
Higher Order Kernel - associative scan
import torch
import torch.nn as nn
from torch._higher_order_ops.associative_scan import associative_scan
from triton.testing import do_bench
torch.set_default_device('cuda')
def combine_fn(i, j):
ia, ib = i
ja, jb = j
return ia * ja, ib * ja + jb
@Chillee
Chillee / mm_weird.py
Last active July 31, 2024 06:20
Strangely, Matrix Multiplications Run Faster When Given "Predictable" Data! https://www.thonking.ai/p/strangely-matrix-multiplications
import torch
torch.set_default_device('cuda')
from triton.testing import do_bench
from collections import defaultdict
from functools import partial
import random
random.seed(0)
def get_flops(A, B):
ms = do_bench(lambda: torch.mm(A, B))
@Chillee
Chillee / attention_dim_bench.py
Created April 12, 2024 05:13
You Could Have Invented Flash-Attention!
import torch
from torch.utils.flop_counter import FlopCounterMode
from triton.testing import do_bench
torch.set_default_device('cuda')
def get_flops_achieved(f):
flop_counter = FlopCounterMode(display=False)
with flop_counter:
f()
total_flops = flop_counter.get_total_flops()
@Chillee
Chillee / Q1.py
Last active April 8, 2024 04:07
What Shapes Do Matrix Multiplications Like?
import torch
from triton.testing import do_bench
torch.set_default_device('cuda')
for M, K, N in [(2047, 2048, 2048), (2048, 2047, 2048), (2048, 2048, 2047)]:
A = torch.randn(M, K, dtype=torch.bfloat16)
B = torch.randn(K, N, dtype=torch.bfloat16)
print(f"M={M}, K={K}, N={N}")
print(do_bench(lambda: torch.mm(A, B)))
@Chillee
Chillee / lora_example.py
Last active May 14, 2023 09:45
lora_example.py
import torch
import torch.nn as nn
import torch.nn.utils.parametrize as parametrize
from torch.utils._pytree import tree_map
class LoraTensor(object):
def __init__(self, weights, A, B):
self.weights = weights
self.A = A
self.B = B
@Chillee
Chillee / mfu_compute.py
Last active November 15, 2024 12:47
Compute Flop Utilization in PyTorch
import torch
from torch.utils.flop_counter import FlopCounterMode
from triton.testing import do_bench
def get_flops_achieved(f):
flop_counter = FlopCounterMode(display=False)
with flop_counter:
f()
total_flops = flop_counter.get_total_flops()
ms_per_iter = do_bench(f)
@Chillee
Chillee / 1-pw_op_fusion.py
Last active August 1, 2024 18:11
PT 2.0 Benchmarks
import torch
import torch._inductor.config
import time
torch._inductor.config.triton.cudagraphs = False
torch.set_float32_matmul_precision('high')
def bench(f, name=None, iters=100, warmup=5, display=True, profile=False):
for _ in range(warmup):
f()