Last active
December 17, 2018 23:00
-
-
Save jcjohnson/b03a0275e64681bb7587bbc7399a645a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import time | |
import torch | |
import numpy as np | |
def int_list(s): | |
return [int(x) for x in s.split(',')] | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--Ns', type=int_list, default=[16, 64, 128, 256, 512]) | |
parser.add_argument('--Ds', type=int_list, default=[1, 16, 256, 1024, 4096]) | |
parser.add_argument('--Ks', type=int_list, default=[1, 16, 256, 1024, 4096]) | |
parser.add_argument('--tolerance', type=float, default=1e-8) | |
parser.add_argument('--verbose', action='store_true') | |
parser.add_argument('--with_backward', action='store_true') | |
parser.add_argument('--device', type=str, default='cpu') | |
def main(args): | |
verbose = args.verbose | |
tol = args.tolerance | |
with_backward = args.with_backward | |
num_experiments = len(args.Ns) * len(args.Ds) * len(args.Ks) | |
f_better = [] | |
g_better = [] | |
for device in [args.device]: | |
print('Benchmarking with device = %s' % device) | |
sames, f_speedups, b_speedups = [], [], [] | |
rows = [] | |
deltas = [] | |
i = 0 | |
for N in args.Ns: | |
for D in args.Ds: | |
for K in args.Ks: | |
i += 1 | |
if i % 10 == 0: | |
print('Running experiment %d / %d' % (i, num_experiments)) | |
same, f_speedup, f_time_us, g_time_us = benchmark(N, D, K, tol, with_backward, device, verbose) | |
if f_speedup < 1.0: | |
f_better.append((N, D, K, f_time_us, g_time_us)) | |
else: | |
g_better.append((N, D, K, f_time_us, g_time_us)) | |
rows.append((N, D, K, f_time_us, g_time_us)) | |
deltas.append(f_time_us - g_time_us) | |
sames.append(same) | |
f_speedups.append(f_speedup) | |
print() | |
print('Results with device = %s' % device) | |
print('Differences within tolerance (%f)' % tol, all(sames)) | |
print('Forward gather speedup:') | |
print(' Min: ', np.min(f_speedups)) | |
print(' Max: ', np.max(f_speedups)) | |
print(' Mean: ', np.mean(f_speedups)) | |
print(' Median: ', np.median(f_speedups)) | |
total = len(f_better) + len(g_better) | |
print('Test cases with faster indexing: {}/{}'.format(len(f_better), total)) | |
for row in f_better: | |
print('N: {} D: {} K: {} index: {:4.0f} us gather: {:4.0f}'.format(*row)) | |
print('Test cases with faster gather: {}/{}'.format(len(g_better), total)) | |
for row in g_better: | |
print('N: {} D: {} K: {} index: {:4.0f} us gather: {:4.0f}'.format(*row)) | |
idx_fastest = np.argmin(deltas) | |
print('Indexing is faster by at most {} us on N: {} D: {} K: {}'.format( | |
-deltas[idx_fastest], | |
*rows[idx_fastest][:3])) | |
idx_slowest = np.argmax(deltas) | |
print('Indexing is slower by at most {} us on N: {} D: {} K: {}'.format( | |
deltas[idx_slowest], | |
*rows[idx_slowest][:3])) | |
def timeit(f, x, idx, with_backward=False): | |
if x.grad is not None: | |
x.grad.data.zero_() | |
t0 = time.time() | |
y = f(x, idx) | |
if with_backward: | |
x.requires_grad = True | |
dy = torch.ones_like(y) | |
y.backward(gradient=dy) | |
delta = 1000.0 * (time.time() - t0) | |
# spend ~100 ms benchmarking | |
iters = max(1, int(100.0 / delta)) | |
if x.is_cuda: | |
torch.cuda.synchronize() | |
t0 = time.time() | |
for _ in range(iters): | |
y = f(x, idx) | |
if with_backward: | |
y.backward(gradient=dy) | |
if x.is_cuda: | |
torch.cuda.synchronize() | |
t1 = time.time() | |
# in microseconds | |
t_us = 1000000.0 * (t1 - t0) / iters | |
return y, t_us | |
def benchmark(N, D, K, tol, with_backward, device='cuda', verbose=False): | |
index_times, gather_times = [], [] | |
y_diffs = [] | |
for _ in range(1): | |
x = torch.randn(N, D, requires_grad=True, device=device) | |
idx = torch.randint(N, size=(K,), device=device) | |
# Time forward / backward for index | |
y_index, t_index = timeit(index, x, idx, with_backward) | |
y_gather, t_gather = timeit(gather, x, idx, with_backward) | |
index_times.append(t_index) | |
gather_times.append(t_gather) | |
with torch.no_grad(): | |
y_diff = (y_index - y_gather).abs().sum() | |
y_diffs.append(y_diff.item()) | |
y_diff = np.max(y_diffs) | |
t_index = np.mean(index_times) | |
t_gather = np.mean(gather_times) | |
same = y_diff < tol | |
speedup = t_index / t_gather | |
return same, speedup, t_index, t_gather | |
def index(x, idx): | |
return x[idx] | |
def gather(x, idx): | |
idx = idx[:, None].expand(idx.shape[0], x.shape[1]) | |
return x.gather(0, idx) | |
if __name__ == '__main__': | |
main(parser.parse_args()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment