This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from triton.testing import do_bench | |
from torch.nn.attention.flex_attention import create_block_mask, flex_attention, noop_mask | |
torch.manual_seed(0) | |
import torch | |
torch.set_default_device('cuda') | |
def sliding_window(b, h, q_idx, kv_idx): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.nn as nn | |
import copy | |
import torch | |
from torch.nn.attention.flex_attention import flex_attention, create_block_mask, or_masks, create_mask | |
from triton.testing import do_bench | |
from functools import partial | |
torch.set_default_device('cuda') | |
B = 4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.nn.attention._flex_attention import _create_block_mask, _create_mask | |
from functools import partial | |
from torch.nn.attention._flex_attention import _flex_attention | |
from triton.testing import do_bench | |
import torch.nn.functional as F | |
from functools import lru_cache | |
torch.set_default_device('cuda') | |
# Example usage |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
from torch._higher_order_ops.associative_scan import associative_scan | |
from triton.testing import do_bench | |
torch.set_default_device('cuda') | |
def combine_fn(i, j): | |
ia, ib = i | |
ja, jb = j | |
return ia * ja, ib * ja + jb |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
torch.set_default_device('cuda') | |
from triton.testing import do_bench | |
from collections import defaultdict | |
from functools import partial | |
import random | |
random.seed(0) | |
def get_flops(A, B): | |
ms = do_bench(lambda: torch.mm(A, B)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.utils.flop_counter import FlopCounterMode | |
from triton.testing import do_bench | |
torch.set_default_device('cuda') | |
def get_flops_achieved(f): | |
flop_counter = FlopCounterMode(display=False) | |
with flop_counter: | |
f() | |
total_flops = flop_counter.get_total_flops() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from triton.testing import do_bench | |
torch.set_default_device('cuda') | |
for M, K, N in [(2047, 2048, 2048), (2048, 2047, 2048), (2048, 2048, 2047)]: | |
A = torch.randn(M, K, dtype=torch.bfloat16) | |
B = torch.randn(K, N, dtype=torch.bfloat16) | |
print(f"M={M}, K={K}, N={N}") | |
print(do_bench(lambda: torch.mm(A, B))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.utils.parametrize as parametrize | |
from torch.utils._pytree import tree_map | |
class LoraTensor(object): | |
def __init__(self, weights, A, B): | |
self.weights = weights | |
self.A = A | |
self.B = B |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.utils.flop_counter import FlopCounterMode | |
from triton.testing import do_bench | |
def get_flops_achieved(f): | |
flop_counter = FlopCounterMode(display=False) | |
with flop_counter: | |
f() | |
total_flops = flop_counter.get_total_flops() | |
ms_per_iter = do_bench(f) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch._inductor.config | |
import time | |
torch._inductor.config.triton.cudagraphs = False | |
torch.set_float32_matmul_precision('high') | |
def bench(f, name=None, iters=100, warmup=5, display=True, profile=False): | |
for _ in range(warmup): | |
f() |
NewerOlder