Last active
April 3, 2025 11:26
-
-
Save Chillee/42e4635c59760a74cb3b4ba7ea5ad9f8 to your computer and use it in GitHub Desktop.
Strangely, Matrix Multiplications Run Faster When Given "Predictable" Data! https://www.thonking.ai/p/strangely-matrix-multiplications
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
torch.set_default_device('cuda') | |
from triton.testing import do_bench | |
from collections import defaultdict | |
from functools import partial | |
import random | |
random.seed(0) | |
def get_flops(A, B): | |
ms = do_bench(lambda: torch.mm(A, B)) | |
flops = A.shape[0] * A.shape[1] * B.shape[1] * 2 | |
return (1e3/ms) * flops | |
M = 8192 | |
N = 8192 | |
K = 8192 | |
def get_tensors(f): | |
A = f(M, K, dtype=torch.bfloat16) | |
B = f(N, K, dtype=torch.bfloat16).t() | |
return A, B | |
def one_bit_random(*shape, dtype=torch.bfloat16): | |
x = torch.randn(*shape, dtype=dtype) | |
x = (x.view(torch.int16) & 0b1000).to(dtype=dtype) | |
return x | |
def sparse(*shape, dtype=torch.bfloat16): | |
x = torch.randn(*shape, dtype=dtype) | |
x = torch.where(x < 0, 0, x) | |
return x | |
original_setups = [ | |
("randn", torch.randn), | |
("twos", lambda *shape, dtype: torch.full(shape, fill_value=2, dtype=dtype)), | |
("sparse", sparse), | |
("one bit", one_bit_random), | |
("rand", torch.rand), | |
("zeros", torch.zeros), | |
] | |
results = defaultdict(list) | |
setups = list(original_setups) | |
ITERS = 10 | |
for _ in range(ITERS): | |
random.shuffle(setups) | |
for name, f in setups: | |
results[name].append(get_flops(*get_tensors(f))) | |
def median(x): | |
x = sorted(x) | |
if len(x) % 2 == 0: | |
return (x[len(x)//2] + x[(len(x) - 1)//2])/2 | |
else: | |
return x[len(x)//2] | |
for name, _ in original_setups: | |
print(f"{name}: {median(results[name])/1e12}") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
torch.set_default_device('cuda') | |
from triton.testing import do_bench | |
from collections import defaultdict | |
from functools import partial | |
import random | |
import subprocess | |
random.seed(0) | |
def set_gpu_limits(ref_sm_clock=1810, power_limit=330): | |
subprocess.check_output([ | |
"sudo", | |
"nvidia-smi", | |
"-i", | |
"0", | |
f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}", | |
]) | |
subprocess.check_output([ | |
"sudo", | |
"nvidia-smi", | |
"-i", | |
"0", | |
f"-pl={power_limit}", | |
]) | |
def get_flops(A, B): | |
ms = do_bench(lambda: torch.mm(A, B)) | |
flops = A.shape[0] * A.shape[1] * B.shape[1] * 2 | |
return (1e3/ms) * flops | |
M = 8192 | |
N = 8192 | |
K = 8192 | |
def get_tensors(f): | |
A = f(M, K, dtype=torch.bfloat16) | |
B = f(N, K, dtype=torch.bfloat16).t() | |
return A, B | |
def one_bit_random(*shape, dtype=torch.bfloat16): | |
x = torch.randn(*shape, dtype=dtype) | |
x = (x.view(torch.int16) & 0b1000).to(dtype=dtype) | |
return x | |
def sparse(*shape, dtype=torch.bfloat16): | |
x = torch.randn(*shape, dtype=dtype) | |
x = torch.where(torch.rand_like(x) > 0.1, 0, x) | |
return x | |
def checkerboard(*shape, dtype=torch.bfloat16): | |
x = torch.randn(*shape, dtype=dtype) | |
x = torch.where((torch.arange(shape[0]).view(1, -1) - torch.arange(shape[1]).view(-1, 1)) % 2 == 0, x, 0) | |
return x | |
def ternary(*shape, dtype=torch.bfloat16): | |
x = torch.randint(low=-1, high=2, size=shape, dtype=torch.bfloat16) | |
return x | |
original_setups = [ | |
# ("zeros", torch.zeros), | |
("randn", torch.randn), | |
# ("checkerboard", checkerboard), | |
# ("sparse", sparse), | |
# ("rand", torch.rand), | |
# ("ternary", ternary), | |
# ("one bit", one_bit_random), | |
# ("all_pi", lambda *shape, dtype: torch.full(shape, fill_value=3.1415926535897932384626, dtype=dtype)), | |
# ("twos", lambda *shape, dtype: torch.full(shape, fill_value=2, dtype=dtype)), | |
] | |
def get_results(clocks, power): | |
set_gpu_limits(clocks, power) | |
results = defaultdict(list) | |
setups = list(original_setups) | |
ITERS = 10 | |
for _ in range(ITERS): | |
random.shuffle(setups) | |
for name, f in setups: | |
results[name].append(get_flops(*get_tensors(f))) | |
def median(x): | |
x = sorted(x) | |
if len(x) % 2 == 0: | |
return (x[len(x)//2] + x[(len(x) - 1)//2])/2 | |
else: | |
return x[len(x)//2] | |
# for name, _ in original_setups: | |
# print(f"{name}: {median(results[name])/1e12}") | |
# print(median(results['zeros']) / median(results["randn"])) | |
return median(results['randn']) | |
start_clocks = 1980 # H100 | |
for power in reversed([150, 200, 250, 300, 350, 400, 450, 500]): | |
max_clocks = 1980 # H100 | |
start_flops = get_results(max_clocks, power) | |
for clocks in range(start_clocks, 200, -100): | |
# print(power, clocks) | |
cur_flops = get_results(clocks, power) | |
if cur_flops < start_flops * 0.9: | |
print("Done: ", power, clocks) | |
start_clocks = clocks | |
break |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@photomz If you haven't seen it, this is the associated article: https://www.thonking.ai/p/strangely-matrix-multiplications
But speculative execution is an interesting thought - it's definitely a phenomenon that looks "fairly similar" on the surface, and I thought a bit about it.
But:
randn
andrand
would lead to differing speculative execution behavior on the GPUs.