Created
May 13, 2018 08:04
-
-
Save wkcn/715f30defd736ce663a2f7399ffe960e to your computer and use it in GitHub Desktop.
CountOPTime-mx
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import mxnet as mx | |
import time | |
from distutils.util import strtobool | |
BT = time.time() | |
class CountTimeOP(mx.operator.CustomOp): | |
def __init__(self, first, cname): | |
super(CountTimeOP, self).__init__() | |
self.first = first | |
self.cname = cname | |
def forward(self, is_train, req, in_data, out_data, aux): | |
if not self.first: | |
a = in_data[0].asnumpy() | |
dt = (time.time() - BT) - in_data[1].asscalar() | |
print ("%s:%f" % (self.cname, dt)) | |
out_data[0][:] = in_data[0] | |
out_data[1][:] = time.time() - BT | |
def backward(self, req, out_grad, in_data, out_data, in_grad, aux): | |
in_grad[0][:] = out_grad[0] | |
in_grad[1][:] = 0 | |
@mx.operator.register("CountTimeOP") | |
class CountTimeProp(mx.operator.CustomOpProp): | |
def __init__(self, first, cname): | |
super(CountTimeProp, self).__init__(need_top_grad = True) | |
self.first = strtobool(first) | |
self.cname = cname | |
def list_arguments(self): | |
if self.first: | |
return ['data'] | |
return ['data', 't'] | |
def list_outputs(self): | |
return ['out_data', 'out_t'] | |
def infer_shape(self, in_shape): | |
dshape = in_shape[0] | |
return in_shape, [dshape, (1,)] | |
def create_operator(self, ctx, shapes, dtypes): | |
return CountTimeOP(first = self.first, cname = self.cname) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment