Last active
September 10, 2018 09:04
-
-
Save wkcn/166a139e43eef6f7f402840a03defe79 to your computer and use it in GitHub Desktop.
Count Time for MobulaOP
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import mobula_op | |
import mxnet as mx | |
import numpy as np | |
from mobula_op.test_utils import assert_almost_equal | |
import time | |
def check_conv(data, weight, bias, kernel, stride, dilate, pad, num_filter, no_bias, conv): | |
data_mx = data.copy() | |
weight_mx = weight.copy() | |
bias_mx = bias.copy() | |
our_data = [data, weight, bias] | |
mx_data = [data_mx, weight_mx, bias_mx] | |
for d in our_data + mx_data: | |
d.attach_grad() | |
if no_bias: | |
bias = bias_mx = None | |
with mx.autograd.record(): | |
if no_bias: | |
out = conv(data = data, weight = weight, num_filter = num_filter, | |
no_bias = no_bias, kernel = kernel, stride = stride, dilate = dilate, pad = pad) | |
else: | |
out = conv(data = data, weight = weight, bias = bias, num_filter = num_filter, | |
no_bias = no_bias, kernel = kernel, stride = stride, dilate = dilate, pad = pad) | |
out.backward() | |
mx.nd.waitall() | |
def check_fc(data, weight, bias, num_hidden, no_bias, flatten, fc): | |
data_mx = data.copy() | |
weight_mx = weight.copy() | |
bias_mx = bias.copy() | |
our_data = [data, weight, bias] | |
mx_data = [data_mx, weight_mx, bias_mx] | |
for d in our_data + mx_data: | |
d.attach_grad() | |
if no_bias: | |
bias = bias_mx = None | |
with mx.autograd.record(): | |
if no_bias: | |
out = fc(data = data, weight = weight, num_hidden = num_hidden, no_bias = no_bias, flatten = flatten) | |
else: | |
out = fc(data = data, weight = weight, bias = bias, num_hidden = num_hidden, no_bias = no_bias, flatten = flatten) | |
out.backward() | |
mx.nd.waitall() | |
def test_dot(): | |
A, B, C = 1000, 2000, 1000 | |
a = np.random.random((A, B)).astype(np.float32) | |
b = np.random.random((B, C)).astype(np.float32) | |
tic = time.time() | |
c = mobula_op.math.dot(a, b) | |
print (time.time() - tic) | |
def test_fc(): | |
for fc in [mobula_op.op.FullyConnected, mx.nd.FullyConnected]: | |
t = 0 | |
for _ in range(1): | |
batch_size = 1000 | |
num_hidden = 1000 | |
K = 1000 | |
data = mx.nd.random.uniform(-1, 1, shape = (batch_size, K)) | |
data2 = mx.nd.random.uniform(-1, 1, shape = (batch_size, 4, 2, 3, K)) | |
bias = mx.nd.random.uniform(-1, 1, shape = (num_hidden, )) | |
for no_bias in [False, True]: | |
for flatten in [False, True]: | |
for u in range(3): | |
if u == 2: | |
tic = time.time() | |
input_dim = K | |
weight = mx.nd.random.uniform(-1, 1, shape = (num_hidden, input_dim)) | |
check_fc(data, weight, bias, num_hidden, no_bias, flatten, fc) | |
input_dim = data2.size // data2.shape[0] if flatten else data2.shape[-1] | |
weight = mx.nd.random.uniform(-1, 1, shape = (num_hidden, input_dim)) | |
check_fc(data2, weight, bias, num_hidden, no_bias, flatten, fc) | |
if u == 2: | |
t += time.time() - tic | |
print (fc, t) | |
def test_conv(): | |
for conv in [mobula_op.op.Convolution, mx.nd.Convolution]: | |
t = 0 | |
for _ in range(1): | |
N, C, H, W = 8, 256, 128, 128 | |
K = 5 | |
num_filter = 256 | |
data = mx.nd.random.uniform(-1, 1, shape = (N, C, H, W)) | |
weight = mx.nd.random.uniform(-1, 1, shape = (num_filter, C, K, K)) | |
bias = mx.nd.random.uniform(-1, 1, shape = (num_filter, )) | |
finished = False | |
for no_bias in [False, True]: | |
for stride in [(1, 1), (2, 2)]: | |
for dilate in [(1, 1), (2, 2)]: | |
for pad in [(0, 0), (1, 1)]: | |
for u in range(3): | |
if finished: | |
break | |
if u == 2: | |
print ("=====") | |
tic = time.time() | |
check_conv(data, weight, bias, (K, K), stride, dilate, pad, num_filter, no_bias, conv) | |
if u == 2: | |
t += time.time() - tic | |
finished = True | |
print (conv, t) | |
print ("CPU") | |
with mx.cpu(0): | |
test_fc() | |
test_conv() | |
test_dot() | |
print ("GPU") | |
with mx.gpu(0): | |
test_fc() | |
test_conv() | |
test_dot() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
(conv, t)