Skip to content

Instantly share code, notes, and snippets.

@Chillee
Last active July 31, 2024 06:20
Show Gist options
  • Save Chillee/42e4635c59760a74cb3b4ba7ea5ad9f8 to your computer and use it in GitHub Desktop.
Save Chillee/42e4635c59760a74cb3b4ba7ea5ad9f8 to your computer and use it in GitHub Desktop.
Strangely, Matrix Multiplications Run Faster When Given "Predictable" Data! https://www.thonking.ai/p/strangely-matrix-multiplications
import torch
torch.set_default_device('cuda')
from triton.testing import do_bench
from collections import defaultdict
from functools import partial
import random
random.seed(0)
def get_flops(A, B):
ms = do_bench(lambda: torch.mm(A, B))
flops = A.shape[0] * A.shape[1] * B.shape[1] * 2
return (1e3/ms) * flops
M = 8192
N = 8192
K = 8192
def get_tensors(f):
A = f(M, K, dtype=torch.bfloat16)
B = f(N, K, dtype=torch.bfloat16).t()
return A, B
def one_bit_random(*shape, dtype=torch.bfloat16):
x = torch.randn(*shape, dtype=dtype)
x = (x.view(torch.int16) & 0b1000).to(dtype=dtype)
return x
def sparse(*shape, dtype=torch.bfloat16):
x = torch.randn(*shape, dtype=dtype)
x = torch.where(x < 0, 0, x)
return x
original_setups = [
("randn", torch.randn),
("twos", lambda *shape, dtype: torch.full(shape, fill_value=2, dtype=dtype)),
("sparse", sparse),
("one bit", one_bit_random),
("rand", torch.rand),
("zeros", torch.zeros),
]
results = defaultdict(list)
setups = list(original_setups)
ITERS = 10
for _ in range(ITERS):
random.shuffle(setups)
for name, f in setups:
results[name].append(get_flops(*get_tensors(f)))
def median(x):
x = sorted(x)
if len(x) % 2 == 0:
return (x[len(x)//2] + x[(len(x) - 1)//2])/2
else:
return x[len(x)//2]
for name, _ in original_setups:
print(f"{name}: {median(results[name])/1e12}")
import torch
torch.set_default_device('cuda')
from triton.testing import do_bench
from collections import defaultdict
from functools import partial
import random
import subprocess
random.seed(0)
def set_gpu_limits(ref_sm_clock=1810, power_limit=330):
subprocess.check_output([
"sudo",
"nvidia-smi",
"-i",
"0",
f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}",
])
subprocess.check_output([
"sudo",
"nvidia-smi",
"-i",
"0",
f"-pl={power_limit}",
])
def get_flops(A, B):
ms = do_bench(lambda: torch.mm(A, B))
flops = A.shape[0] * A.shape[1] * B.shape[1] * 2
return (1e3/ms) * flops
M = 8192
N = 8192
K = 8192
def get_tensors(f):
A = f(M, K, dtype=torch.bfloat16)
B = f(N, K, dtype=torch.bfloat16).t()
return A, B
def one_bit_random(*shape, dtype=torch.bfloat16):
x = torch.randn(*shape, dtype=dtype)
x = (x.view(torch.int16) & 0b1000).to(dtype=dtype)
return x
def sparse(*shape, dtype=torch.bfloat16):
x = torch.randn(*shape, dtype=dtype)
x = torch.where(torch.rand_like(x) > 0.1, 0, x)
return x
def checkerboard(*shape, dtype=torch.bfloat16):
x = torch.randn(*shape, dtype=dtype)
x = torch.where((torch.arange(shape[0]).view(1, -1) - torch.arange(shape[1]).view(-1, 1)) % 2 == 0, x, 0)
return x
def ternary(*shape, dtype=torch.bfloat16):
x = torch.randint(low=-1, high=2, size=shape, dtype=torch.bfloat16)
return x
original_setups = [
# ("zeros", torch.zeros),
("randn", torch.randn),
# ("checkerboard", checkerboard),
# ("sparse", sparse),
# ("rand", torch.rand),
# ("ternary", ternary),
# ("one bit", one_bit_random),
# ("all_pi", lambda *shape, dtype: torch.full(shape, fill_value=3.1415926535897932384626, dtype=dtype)),
# ("twos", lambda *shape, dtype: torch.full(shape, fill_value=2, dtype=dtype)),
]
def get_results(clocks, power):
set_gpu_limits(clocks, power)
results = defaultdict(list)
setups = list(original_setups)
ITERS = 10
for _ in range(ITERS):
random.shuffle(setups)
for name, f in setups:
results[name].append(get_flops(*get_tensors(f)))
def median(x):
x = sorted(x)
if len(x) % 2 == 0:
return (x[len(x)//2] + x[(len(x) - 1)//2])/2
else:
return x[len(x)//2]
# for name, _ in original_setups:
# print(f"{name}: {median(results[name])/1e12}")
# print(median(results['zeros']) / median(results["randn"]))
return median(results['randn'])
start_clocks = 1980 # H100
for power in reversed([150, 200, 250, 300, 350, 400, 450, 500]):
max_clocks = 1980 # H100
start_flops = get_results(max_clocks, power)
for clocks in range(start_clocks, 200, -100):
# print(power, clocks)
cur_flops = get_results(clocks, power)
if cur_flops < start_flops * 0.9:
print("Done: ", power, clocks)
start_clocks = clocks
break
@TJ-Solergibert
Copy link

And I got the same behavior with 80GB A100…

@Chillee
Copy link
Author

Chillee commented Jun 21, 2024

@TJ-Solergibert Can you show your nvidia-smi?

@TJ-Solergibert
Copy link

Running on pretty exotic H100s (My A100 Cluster is down)
image

$ python3 mm_weird.py 
randn: 1055.2074280203815
twos: 805.9562233451499
sparse: 1085.084700223281
one bit: 897.8750065084391
rand: 906.2403213001224
zeros: 805.7438630174415

Have you tried re-running the script?

@Chillee
Copy link
Author

Chillee commented Jun 21, 2024

Yeah I reran it on an A100:
image

@photomz
Copy link

photomz commented Jul 31, 2024

Am I obtuse or isn't this just GPU speculative execution in action?

@Chillee
Copy link
Author

Chillee commented Jul 31, 2024

@photomz If you haven't seen it, this is the associated article: https://www.thonking.ai/p/strangely-matrix-multiplications

But speculative execution is an interesting thought - it's definitely a phenomenon that looks "fairly similar" on the surface, and I thought a bit about it.

But:

  1. I'm not actually sure GPUs do speculative execution? Since GPUs are generally focused on parallel execution, they usually have much shallower "pipes" than CPUs do.
  2. Matrix multiplications do not typically have any branching on data. Generally, I've seen speculative execution interact with branch prediction to result in differing performance based off of the input data. But in this case, there is no branching! The GPU executes the exact same assembly instructions in the exact same order, regardless of what the input data is.
  3. And related to the above, I'd have a hard time seeing how randn and rand would lead to differing speculative execution behavior on the GPUs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment