Skip to content

Instantly share code, notes, and snippets.

@YouJiacheng
Last active November 5, 2024 15:07
Show Gist options
  • Save YouJiacheng/0263d36822f19e98caa3c111565f9441 to your computer and use it in GitHub Desktop.
Save YouJiacheng/0263d36822f19e98caa3c111565f9441 to your computer and use it in GitHub Desktop.
import torch
import torch.utils.benchmark as benchmark
def benchmark_in_us(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)
return t0.blocked_autorange().mean * 1e6
# reduce-overhead is slower
@torch.compile
# @torch.compile(mode="max-autotune-no-cudagraphs")
def zeropower_via_newtonschulz5_original(G: torch.Tensor, steps: int):
assert len(G.shape) == 2
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.bfloat16()
X.div_(X.norm() + 1e-7)
if G.size(0) > G.size(1):
X = X.T
for _ in range(steps):
A = X @ X.T
B = A @ X
X = a * X + b * B + c * A @ B
if G.size(0) > G.size(1):
X = X.T
return X
# @torch.compile
@torch.compile(mode="max-autotune-no-cudagraphs")
def zeropower_via_newtonschulz5_optimized(G: torch.Tensor, steps: int):
assert len(G.shape) == 2
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.bfloat16()
X.div_(X.norm() + 1e-7)
E = torch.eye(min(G.size(0), G.size(1)), dtype=X.dtype, device=X.device)
if G.size(0) > G.size(1):
X = X.T
for _ in range(steps):
A = X @ X.T
# B = c * A
# torch.diagonal(B).add_(b)
# C = A @ B
# torch.diagonal(C).add_(a)
# X = C @ X
S = A @ (b * E + c * A)
torch.diagonal(S).add_(a)
X = S @ X
if G.size(0) > G.size(1):
X = X.T
return X
zeropower_via_newtonschulz5 = zeropower_via_newtonschulz5_optimized
n = 18
n_1_2 = n // 2
n_1_3 = n // 3
n_2_3 = n_1_3 + (n - n_1_3) // 2
gs = [torch.rand(768, 3072, dtype=torch.float32, device="cuda") for _ in range(n)]
def single(gs: list[torch.Tensor], n: int):
g = gs[0]
zeropower_via_newtonschulz5(g, n)
def multiple(gs: list[torch.Tensor], n: int):
for g in gs:
zeropower_via_newtonschulz5(g, n)
s1 = torch.cuda.Stream()
s2 = torch.cuda.Stream()
s3 = torch.cuda.Stream()
def multiplexed_2_streams(gs: list[torch.Tensor], n: int):
with torch.cuda.stream(s1):
for g in gs[:n_1_2]:
zeropower_via_newtonschulz5(g, n)
with torch.cuda.stream(s2):
for g in gs[n_1_2:]:
zeropower_via_newtonschulz5(g, n)
def multiplexed_3_streams(gs: list[torch.Tensor], steps: int):
with torch.cuda.stream(s1):
for g in gs[:n_1_3]:
zeropower_via_newtonschulz5(g, steps)
with torch.cuda.stream(s2):
for g in gs[n_1_3:n_2_3]:
zeropower_via_newtonschulz5(g, steps)
with torch.cuda.stream(s3):
for g in gs[n_2_3:]:
zeropower_via_newtonschulz5(g, steps)
# warmup
benchmark_in_us(single, gs, 4)
benchmark_in_us(single, gs, 5)
runtime_single_4_steps = benchmark_in_us(single, gs, 4)
runtime_single_5_steps = benchmark_in_us(single, gs, 5)
runtime_multiple_4_steps = benchmark_in_us(multiple, gs, 4) / n
runtime_multiple_5_steps = benchmark_in_us(multiple, gs, 5) / n
runtime_mux_2_streams_4_steps = benchmark_in_us(multiplexed_2_streams, gs, 4) / n
runtime_mux_2_streams_5_steps = benchmark_in_us(multiplexed_2_streams, gs, 5) / n
runtime_mux_3_streams_4_steps = benchmark_in_us(multiplexed_3_streams, gs, 4) / n
runtime_mux_3_streams_5_steps = benchmark_in_us(multiplexed_3_streams, gs, 5) / n
print(f"Single, 4 steps: {runtime_single_4_steps} us")
print(f"Single, 5 steps: {runtime_single_5_steps} us")
print(f"Multiple, 4 steps: {runtime_multiple_4_steps} us")
print(f"Multiple, 5 steps: {runtime_multiple_5_steps} us")
print(f"Multiplexed 2 streams, 4 steps: {runtime_mux_2_streams_4_steps} us")
print(f"Multiplexed 2 streams, 5 steps: {runtime_mux_2_streams_5_steps} us")
print(f"Multiplexed 3 streams, 4 steps: {runtime_mux_3_streams_4_steps} us")
print(f"Multiplexed 3 streams, 5 steps: {runtime_mux_3_streams_5_steps} us")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment