Last active
February 24, 2025 06:14
-
-
Save chsasank/407df67ac0c848d6259f0340887648a9 to your computer and use it in GitHub Desktop.
Measure Bandwidth and FLOPs with PyTorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import time | |
import numpy as np | |
from torch import mps, cuda | |
num_trails = 10 | |
def flops_benchmark(device): | |
test_range = 2 ** np.arange(8, 13, 0.25) | |
print('size, elapsed_time, flops') | |
for n in test_range: | |
total = 0 | |
for _ in range(num_trails): | |
n = int(n) | |
a = torch.rand(n, n, device=device) | |
synchronize(device) | |
now = time.time() | |
b = torch.matmul(a, a) | |
synchronize(device) | |
total += time.time() - now | |
total = total / num_trails | |
tflops = 2 * n**3 / total / 1e12 | |
print(n, total, tflops, sep=", ") | |
def synchronize(device): | |
if device.type == "cuda": | |
cuda.synchronize() | |
elif device.type == "mps": | |
mps.synchronize() | |
elif device.type == "cpu": | |
pass | |
def memory_bandwidth_benchmark(device, size=1024 * 1024 * 256): # 256MB | |
test_range = 2 ** (np.arange(20, 28, 0.5)) | |
print('size (GB), elapsed_time, bandwidth') | |
for size in test_range: | |
elapsed_time = 0 | |
for _ in range(num_trails): | |
size = int(size) | |
# Create random tensors | |
a = torch.rand(size, device=device) | |
b = torch.rand(size, device=device) | |
# Warm-up to ensure CUDA kernel is initialized if using GPU | |
synchronize(device) | |
a.copy_(b) | |
synchronize(device) | |
# Record the start time | |
start_time = time.time() | |
# Perform the copy operation | |
a.copy_(b) | |
# Synchronize if using CUDA to make sure operation is finished | |
synchronize(device) | |
# Record the end time | |
end_time = time.time() | |
# Compute elapsed time | |
elapsed_time += end_time - start_time | |
elapsed_time = elapsed_time / num_trails | |
# Calculate Bandwidth in GB/s | |
bytes_copied = a.nelement() * a.element_size() # bytes | |
bandwidth = 2 * bytes_copied / elapsed_time / 1e9 # GB/s | |
print(bytes_copied / 1e9, elapsed_time, bandwidth, sep=', ') | |
return bandwidth | |
if __name__ == "__main__": | |
device = torch.device('cpu') | |
flops_benchmark(device) | |
memory_bandwidth_benchmark(device) |
M2 Max? How did you came with such table? Here is mine (M2 Max):
% python3 gh-bench.py
size, elapsed_time, flops
256, 0.00033810138702392576, 0.09924369815621466
304, 5.3882598876953125e-05, 1.0428028560447433
362, 8.709430694580078e-05, 1.0893462423329427
430, 0.00013272762298583985, 1.1980475233626728
512, 0.00021550655364990234, 1.2456022865832768
608, 0.0002660274505615234, 1.6897181965664958
724, 0.0004967212677001953, 1.5280337230458825
861, 0.0007333993911743164, 1.7405997023749709
1024, 0.0012362480163574218, 1.7370977502779048
1217, 0.002026009559631348, 1.7793453189115058
1448, 0.003333377838134766, 1.8215921143213984
1722, 0.00507347583770752, 2.0129076047033174
2048, 0.00778050422668457, 2.2080663005205623
2435, 0.014724946022033692, 1.9609800746836268
2896, 0.022782731056213378, 2.1321604575037143
3444, 0.039263224601745604, 2.080814950801767
4096, 0.06567699909210205, 2.09264971560687
4870, 0.11124060153961182, 2.076603351679485
5792, 0.1638232946395874, 2.3721382666057873
6888, 0.29366025924682615, 2.2256877380014912
size (GB), elapsed_time, bandwidth
0.004194304, 4.9448013305664064e-05, 169.64499560671166
0.00593164, 5.79833984375e-05, 204.5978731789474
0.008388608, 7.803440093994141e-05, 214.99769073530092
0.01186328, 9.646415710449219e-05, 245.9624456605042
0.016777216, 0.00015966892242431642, 210.15004980637298
0.023726564, 0.0002486705780029297, 190.8272718915743
0.033554432, 0.0004671335220336914, 143.66098949147963
0.047453132, 0.0007523775100708007, 126.14181408887283
0.067108864, 0.0007260322570800781, 184.8646898139078
0.094906264, 0.0009252786636352539, 205.14093262916128
0.134217728, 0.0013917446136474609, 192.87694981372258
0.189812528, 0.001843738555908203, 205.89961346933018
0.268435456, 0.002697610855102539, 199.0171825504435
0.37962506, 0.0037004232406616213, 205.17926480870577
0.536870912, 0.005187273025512695, 206.99543261343456
0.759250124, 0.007373785972595215, 205.9322380176925
I am maintaining a better repo at https://github.com/chsasank/device-benchmarks
Please send in a PR :)
I chose the last row from the list
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Results on some devices