Skip to content

Instantly share code, notes, and snippets.

@bwasti
Created September 11, 2022 19:39
Show Gist options
  • Save bwasti/a3a4cdd18a28651b86cb81f365e50d23 to your computer and use it in GitHub Desktop.
Save bwasti/a3a4cdd18a28651b86cb81f365e50d23 to your computer and use it in GitHub Desktop.
import torch
import time
import sys
fn = sys.argv[1]
N = int(sys.argv[2])
iters = 1000
mps = torch.device("mps")
a = torch.randn(N, N)
b = torch.eye(N)
a_mps = torch.randn(N, N).to(mps)
b_mps = torch.eye(N).to(mps)
def add_comp():
gbytes = N**2 * 3 * 4 / 1e9 # 2 read + 1 write, float=4 bytes
for i in range(iters // 10):
c = a + b
t0 = time.time()
for i in range(iters):
c = a + b
t1 = time.time()
print(gbytes * iters / (t1 - t0), "gbyte/s")
for i in range(iters // 10):
c_mps = a_mps + b_mps
t0 = time.time()
for i in range(iters):
c_mps = a_mps + b_mps
t1 = time.time()
print(gbytes * iters / (t1 - t0), "gbyte/s")
def mm_comp():
flops = N**3 * 2
for i in range(iters // 10):
c = a @ b
t0 = time.time()
for i in range(iters):
c = a @ b
t1 = time.time()
print(flops * iters / (t1 - t0) / 1e9, "gflops")
for i in range(iters // 10):
c_mps = a_mps @ b_mps
t0 = time.time()
for i in range(iters):
c_mps = a_mps @ b_mps
t1 = time.time()
print(flops * iters / (t1 - t0) / 1e9, "gflops")
if fn == "add":
add_comp()
if fn == "mm":
mm_comp()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment