Last active
November 5, 2024 15:07
-
-
Save YouJiacheng/0263d36822f19e98caa3c111565f9441 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.utils.benchmark as benchmark | |
def benchmark_in_us(f, *args, **kwargs): | |
t0 = benchmark.Timer( | |
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} | |
) | |
return t0.blocked_autorange().mean * 1e6 | |
# reduce-overhead is slower | |
@torch.compile | |
# @torch.compile(mode="max-autotune-no-cudagraphs") | |
def zeropower_via_newtonschulz5_original(G: torch.Tensor, steps: int): | |
assert len(G.shape) == 2 | |
a, b, c = (3.4445, -4.7750, 2.0315) | |
X = G.bfloat16() | |
X.div_(X.norm() + 1e-7) | |
if G.size(0) > G.size(1): | |
X = X.T | |
for _ in range(steps): | |
A = X @ X.T | |
B = A @ X | |
X = a * X + b * B + c * A @ B | |
if G.size(0) > G.size(1): | |
X = X.T | |
return X | |
# @torch.compile | |
@torch.compile(mode="max-autotune-no-cudagraphs") | |
def zeropower_via_newtonschulz5_optimized(G: torch.Tensor, steps: int): | |
assert len(G.shape) == 2 | |
a, b, c = (3.4445, -4.7750, 2.0315) | |
X = G.bfloat16() | |
X.div_(X.norm() + 1e-7) | |
E = torch.eye(min(G.size(0), G.size(1)), dtype=X.dtype, device=X.device) | |
if G.size(0) > G.size(1): | |
X = X.T | |
for _ in range(steps): | |
A = X @ X.T | |
# B = c * A | |
# torch.diagonal(B).add_(b) | |
# C = A @ B | |
# torch.diagonal(C).add_(a) | |
# X = C @ X | |
S = A @ (b * E + c * A) | |
torch.diagonal(S).add_(a) | |
X = S @ X | |
if G.size(0) > G.size(1): | |
X = X.T | |
return X | |
zeropower_via_newtonschulz5 = zeropower_via_newtonschulz5_optimized | |
n = 18 | |
n_1_2 = n // 2 | |
n_1_3 = n // 3 | |
n_2_3 = n_1_3 + (n - n_1_3) // 2 | |
gs = [torch.rand(768, 3072, dtype=torch.float32, device="cuda") for _ in range(n)] | |
def single(gs: list[torch.Tensor], n: int): | |
g = gs[0] | |
zeropower_via_newtonschulz5(g, n) | |
def multiple(gs: list[torch.Tensor], n: int): | |
for g in gs: | |
zeropower_via_newtonschulz5(g, n) | |
s1 = torch.cuda.Stream() | |
s2 = torch.cuda.Stream() | |
s3 = torch.cuda.Stream() | |
def multiplexed_2_streams(gs: list[torch.Tensor], n: int): | |
with torch.cuda.stream(s1): | |
for g in gs[:n_1_2]: | |
zeropower_via_newtonschulz5(g, n) | |
with torch.cuda.stream(s2): | |
for g in gs[n_1_2:]: | |
zeropower_via_newtonschulz5(g, n) | |
def multiplexed_3_streams(gs: list[torch.Tensor], steps: int): | |
with torch.cuda.stream(s1): | |
for g in gs[:n_1_3]: | |
zeropower_via_newtonschulz5(g, steps) | |
with torch.cuda.stream(s2): | |
for g in gs[n_1_3:n_2_3]: | |
zeropower_via_newtonschulz5(g, steps) | |
with torch.cuda.stream(s3): | |
for g in gs[n_2_3:]: | |
zeropower_via_newtonschulz5(g, steps) | |
# warmup | |
benchmark_in_us(single, gs, 4) | |
benchmark_in_us(single, gs, 5) | |
runtime_single_4_steps = benchmark_in_us(single, gs, 4) | |
runtime_single_5_steps = benchmark_in_us(single, gs, 5) | |
runtime_multiple_4_steps = benchmark_in_us(multiple, gs, 4) / n | |
runtime_multiple_5_steps = benchmark_in_us(multiple, gs, 5) / n | |
runtime_mux_2_streams_4_steps = benchmark_in_us(multiplexed_2_streams, gs, 4) / n | |
runtime_mux_2_streams_5_steps = benchmark_in_us(multiplexed_2_streams, gs, 5) / n | |
runtime_mux_3_streams_4_steps = benchmark_in_us(multiplexed_3_streams, gs, 4) / n | |
runtime_mux_3_streams_5_steps = benchmark_in_us(multiplexed_3_streams, gs, 5) / n | |
print(f"Single, 4 steps: {runtime_single_4_steps} us") | |
print(f"Single, 5 steps: {runtime_single_5_steps} us") | |
print(f"Multiple, 4 steps: {runtime_multiple_4_steps} us") | |
print(f"Multiple, 5 steps: {runtime_multiple_5_steps} us") | |
print(f"Multiplexed 2 streams, 4 steps: {runtime_mux_2_streams_4_steps} us") | |
print(f"Multiplexed 2 streams, 5 steps: {runtime_mux_2_streams_5_steps} us") | |
print(f"Multiplexed 3 streams, 4 steps: {runtime_mux_3_streams_4_steps} us") | |
print(f"Multiplexed 3 streams, 5 steps: {runtime_mux_3_streams_5_steps} us") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment