Last active
May 17, 2021 06:23
-
-
Save oliver-batchelor/06a47d20f55faf5cab76f5210e876f11 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from os.path import split | |
import torch | |
import opt_einsum as oe | |
import torch.utils.benchmark as benchmark | |
from torch import nn | |
import torch.nn.functional as F | |
f1 = 256 | |
f2 = 256 | |
b = 100000 | |
splits = 2 | |
n_instances = splits * splits * splits | |
weights = torch.randn(n_instances, f2, f1).cuda() | |
features = torch.randn(b, f1).cuda() | |
inds = torch.randint(0, n_instances, [b]).cuda() | |
module = nn.Linear(f1, f2) | |
module.cuda() | |
masked_module = nn.Linear(f1 * n_instances, f2) | |
masked_module.cuda() | |
split_modules = nn.ModuleList([nn.Linear(f1, f2) for i in range(0, n_instances)]) | |
split_modules.cuda() | |
def einsum_linear(i, w, f): | |
i = F.one_hot(i, w.shape[0]).to(dtype=torch.float32) | |
return oe.contract("x n d, b x, b d -> b n", w, i, f) | |
def bmm_linear(i, w, f): | |
b_w = w[i] | |
return torch.bmm(b_w, f.unsqueeze(2)) | |
def split_linear(i, modules, f): | |
i, inds = torch.sort(i) | |
_, counts = torch.unique_consecutive(i, return_counts=True) | |
fs = torch.split_with_sizes(f[inds], tuple(counts)) | |
sorted_out = torch.cat([m.forward(x) for m, x in zip(modules, fs)]) | |
outputs = sorted_out.new(sorted_out.shape) | |
outputs[inds] = sorted_out | |
return outputs | |
def test_linear(m, f): | |
return m.forward(f) | |
def masked_linear(i, m, f): | |
sparse_f = f.new_zeros(f.shape[0], n_instances, f.shape[1]) | |
sparse_f[:, i] = f | |
sparse_f = sparse_f.view(f.shape[0], -1) | |
return m.forward(sparse_f) | |
t0 = benchmark.Timer( | |
stmt='einsum_linear(i, w, f)', | |
setup='from __main__ import einsum_linear', | |
globals={'i': inds, 'w':weights, 'f':features}) | |
t1 = benchmark.Timer( | |
stmt='bmm_linear(i, w, f)', | |
setup='from __main__ import bmm_linear', | |
globals={'i': inds, 'w':weights, 'f':features}) | |
t2 = benchmark.Timer( | |
stmt='test_linear(m, f)', | |
setup='from __main__ import test_linear', | |
globals={'m':module, 'f':features}) | |
t3 = benchmark.Timer( | |
stmt='split_linear(i, m, f)', | |
setup='from __main__ import split_linear', | |
globals={'i': inds, 'm':split_modules, 'f':features}) | |
print(f"feature size inputs {f1}, outputs {f2}") | |
print(f"batch size {b}, instances {n_instances}") | |
# print(t0.timeit(10)) | |
print(t3.timeit(10)) | |
# print(t1.timeit(10)) | |
print(t2.timeit(10)) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment