Last active
July 31, 2024 06:20
-
-
Save Chillee/42e4635c59760a74cb3b4ba7ea5ad9f8 to your computer and use it in GitHub Desktop.
Strangely, Matrix Multiplications Run Faster When Given "Predictable" Data! https://www.thonking.ai/p/strangely-matrix-multiplications
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
torch.set_default_device('cuda') | |
from triton.testing import do_bench | |
from collections import defaultdict | |
from functools import partial | |
import random | |
random.seed(0) | |
def get_flops(A, B): | |
ms = do_bench(lambda: torch.mm(A, B)) | |
flops = A.shape[0] * A.shape[1] * B.shape[1] * 2 | |
return (1e3/ms) * flops | |
M = 8192 | |
N = 8192 | |
K = 8192 | |
def get_tensors(f): | |
A = f(M, K, dtype=torch.bfloat16) | |
B = f(N, K, dtype=torch.bfloat16).t() | |
return A, B | |
def one_bit_random(*shape, dtype=torch.bfloat16): | |
x = torch.randn(*shape, dtype=dtype) | |
x = (x.view(torch.int16) & 0b1000).to(dtype=dtype) | |
return x | |
def sparse(*shape, dtype=torch.bfloat16): | |
x = torch.randn(*shape, dtype=dtype) | |
x = torch.where(x < 0, 0, x) | |
return x | |
original_setups = [ | |
("randn", torch.randn), | |
("twos", lambda *shape, dtype: torch.full(shape, fill_value=2, dtype=dtype)), | |
("sparse", sparse), | |
("one bit", one_bit_random), | |
("rand", torch.rand), | |
("zeros", torch.zeros), | |
] | |
results = defaultdict(list) | |
setups = list(original_setups) | |
ITERS = 10 | |
for _ in range(ITERS): | |
random.shuffle(setups) | |
for name, f in setups: | |
results[name].append(get_flops(*get_tensors(f))) | |
def median(x): | |
x = sorted(x) | |
if len(x) % 2 == 0: | |
return (x[len(x)//2] + x[(len(x) - 1)//2])/2 | |
else: | |
return x[len(x)//2] | |
for name, _ in original_setups: | |
print(f"{name}: {median(results[name])/1e12}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
torch.set_default_device('cuda') | |
from triton.testing import do_bench | |
from collections import defaultdict | |
from functools import partial | |
import random | |
import subprocess | |
random.seed(0) | |
def set_gpu_limits(ref_sm_clock=1810, power_limit=330): | |
subprocess.check_output([ | |
"sudo", | |
"nvidia-smi", | |
"-i", | |
"0", | |
f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}", | |
]) | |
subprocess.check_output([ | |
"sudo", | |
"nvidia-smi", | |
"-i", | |
"0", | |
f"-pl={power_limit}", | |
]) | |
def get_flops(A, B): | |
ms = do_bench(lambda: torch.mm(A, B)) | |
flops = A.shape[0] * A.shape[1] * B.shape[1] * 2 | |
return (1e3/ms) * flops | |
M = 8192 | |
N = 8192 | |
K = 8192 | |
def get_tensors(f): | |
A = f(M, K, dtype=torch.bfloat16) | |
B = f(N, K, dtype=torch.bfloat16).t() | |
return A, B | |
def one_bit_random(*shape, dtype=torch.bfloat16): | |
x = torch.randn(*shape, dtype=dtype) | |
x = (x.view(torch.int16) & 0b1000).to(dtype=dtype) | |
return x | |
def sparse(*shape, dtype=torch.bfloat16): | |
x = torch.randn(*shape, dtype=dtype) | |
x = torch.where(torch.rand_like(x) > 0.1, 0, x) | |
return x | |
def checkerboard(*shape, dtype=torch.bfloat16): | |
x = torch.randn(*shape, dtype=dtype) | |
x = torch.where((torch.arange(shape[0]).view(1, -1) - torch.arange(shape[1]).view(-1, 1)) % 2 == 0, x, 0) | |
return x | |
def ternary(*shape, dtype=torch.bfloat16): | |
x = torch.randint(low=-1, high=2, size=shape, dtype=torch.bfloat16) | |
return x | |
original_setups = [ | |
# ("zeros", torch.zeros), | |
("randn", torch.randn), | |
# ("checkerboard", checkerboard), | |
# ("sparse", sparse), | |
# ("rand", torch.rand), | |
# ("ternary", ternary), | |
# ("one bit", one_bit_random), | |
# ("all_pi", lambda *shape, dtype: torch.full(shape, fill_value=3.1415926535897932384626, dtype=dtype)), | |
# ("twos", lambda *shape, dtype: torch.full(shape, fill_value=2, dtype=dtype)), | |
] | |
def get_results(clocks, power): | |
set_gpu_limits(clocks, power) | |
results = defaultdict(list) | |
setups = list(original_setups) | |
ITERS = 10 | |
for _ in range(ITERS): | |
random.shuffle(setups) | |
for name, f in setups: | |
results[name].append(get_flops(*get_tensors(f))) | |
def median(x): | |
x = sorted(x) | |
if len(x) % 2 == 0: | |
return (x[len(x)//2] + x[(len(x) - 1)//2])/2 | |
else: | |
return x[len(x)//2] | |
# for name, _ in original_setups: | |
# print(f"{name}: {median(results[name])/1e12}") | |
# print(median(results['zeros']) / median(results["randn"])) | |
return median(results['randn']) | |
start_clocks = 1980 # H100 | |
for power in reversed([150, 200, 250, 300, 350, 400, 450, 500]): | |
max_clocks = 1980 # H100 | |
start_flops = get_results(max_clocks, power) | |
for clocks in range(start_clocks, 200, -100): | |
# print(power, clocks) | |
cur_flops = get_results(clocks, power) | |
if cur_flops < start_flops * 0.9: | |
print("Done: ", power, clocks) | |
start_clocks = clocks | |
break |
Am I obtuse or isn't this just GPU speculative execution in action?
@photomz If you haven't seen it, this is the associated article: https://www.thonking.ai/p/strangely-matrix-multiplications
But speculative execution is an interesting thought - it's definitely a phenomenon that looks "fairly similar" on the surface, and I thought a bit about it.
But:
- I'm not actually sure GPUs do speculative execution? Since GPUs are generally focused on parallel execution, they usually have much shallower "pipes" than CPUs do.
- Matrix multiplications do not typically have any branching on data. Generally, I've seen speculative execution interact with branch prediction to result in differing performance based off of the input data. But in this case, there is no branching! The GPU executes the exact same assembly instructions in the exact same order, regardless of what the input data is.
- And related to the above, I'd have a hard time seeing how
randn
andrand
would lead to differing speculative execution behavior on the GPUs.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Yeah I reran it on an A100: