Skip to content

Instantly share code, notes, and snippets.

@wkcn
Created August 16, 2018 03:00
Show Gist options
  • Save wkcn/63e11870c0d0e48e6fda95716b7171e4 to your computer and use it in GitHub Desktop.
Save wkcn/63e11870c0d0e48e6fda95716b7171e4 to your computer and use it in GitHub Desktop.
MXNet nd.waitall()
import mxnet as mx
from mxnet import nd
class TestOP(mx.operator.CustomOp):
def __init__(self):
super(TestOP, self).__init__()
def forward(self, is_train, req, in_data, out_data, aux):
print ("Run test OP")
def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
pass
@mx.operator.register("TestOP")
class TestOPProp(mx.operator.CustomOpProp):
def __init__(self):
super(TestOPProp, self).__init__()
def list_arguments(self):
return ['data']
def list_outputs(self):
return ['out']
def infer_shape(self, in_shape):
return in_shape, in_shape
def create_operator(self, ctx, shapes, dtypes):
return TestOP()
a = nd.empty((1, 512, 120*120))
nd.waitall()
for _ in range(3):
b = mx.nd.Custom(a, op_type = 'TestOP')
nd.waitall()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment