Skip to content

Instantly share code, notes, and snippets.

@muupan
Created March 6, 2017 08:05
Show Gist options
  • Save muupan/deea148120bc50f8de989e6fb2e03597 to your computer and use it in GitHub Desktop.
Save muupan/deea148120bc50f8de989e6fb2e03597 to your computer and use it in GitHub Desktop.
from timeit import default_timer as timer
import chainer
from chainer import cuda
from chainer import function
import chainer.functions as F
from chainer import utils
from chainer.utils import type_check
import cupy
class SumVariables(function.Function):
"""Element-wise sum of input variables."""
def check_type_forward(self, in_types):
type_check.expect(
in_types[0].dtype.kind == 'f',
in_types[0].dtype == in_types[1].dtype,
in_types[0].shape == in_types[1].shape
)
def forward_cpu(self, inputs):
y = sum(inputs)
return utils.force_array(y),
def backward(self, inputs, grads):
return (grads[0],) * len(inputs)
def forward_gpu(self, inputs):
n = len(inputs)
y = cuda.elementwise(
', '.join('T x{}'.format(i) for i in range(n)),
'T y',
'y = ' + '+'.join('x{}'.format(i) for i in range(n)),
'sum_variable_{}'.format(n))(*inputs)
return y,
def sum_variables(xs):
"""Element-wise sum of input variables.
Args:
xs (tuple of ~chainer.Variable): Input variables to be summed.
Returns:
~chainer.Variable: Output variable.
"""
return SumVariables()(*xs)
def normal_sum(xs):
s = sum(xs)
return float(s.data)
def normal_sum_backward(xs):
s = sum(xs)
s.backward()
return float(s.data)
def vstack_sum(xs):
s = F.sum(F.vstack(x[None] for x in xs), axis=0)
return float(s.data)
def vstack_sum_backward(xs):
s = F.sum(F.vstack([x[None] for x in xs]), axis=0)
s.backward()
return float(s.data)
def custom_sum(xs):
s = sum_variables(xs)
return float(s.data)
def custom_sum_backward(xs):
s = sum_variables(xs)
s.backward()
return float(s.data)
n_repeat = 1000
n_variables = 100
xs = [chainer.Variable(cupy.random.rand(1, 1).astype(cupy.float32))
for _ in range(n_variables)]
def measure(f):
f(xs)
start = timer()
total = sum(f(xs) for _ in range(n_repeat))
assert total > 0
stop = timer()
print('time', stop - start)
print('n_variables:', n_variables)
print('n_repeat:', n_repeat)
print('sum(xs)')
measure(normal_sum)
print('F.sum and F.vstack')
measure(vstack_sum)
print('custom kernel')
measure(custom_sum)
print('sum(xs) with backward')
measure(normal_sum_backward)
print('F.sum and F.vstack with backward')
measure(vstack_sum_backward)
print('custom kernel with backward')
measure(custom_sum_backward)
@muupan
Copy link
Author

muupan commented Mar 6, 2017

n=100

n_variables: 100
n_repeat: 1000
sum(xs)
time 4.8308634031564
F.sum and F.vstack
time 3.626658985391259
custom kernel
time 0.3236686922609806
sum(xs) with backward
time 8.414123510941863
F.sum and F.vstack with backward
time 11.781718255020678
custom kernel with backward
time 2.262535797432065

n=10

n_variables: 10
n_repeat: 1000
sum(xs)
time 0.5071171717718244
F.sum and F.vstack
time 0.5300100678578019
custom kernel
time 0.10647348966449499
sum(xs) with backward
time 0.9468782944604754
F.sum and F.vstack with backward
time 1.5264564296230674
custom kernel with backward
time 0.37463031709194183

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment