Skip to content

Instantly share code, notes, and snippets.

@wkcn
Last active July 5, 2018 09:26
Show Gist options
  • Save wkcn/c73fd8a3692ec0605ece3713f11252de to your computer and use it in GitHub Desktop.
Save wkcn/c73fd8a3692ec0605ece3713f11252de to your computer and use it in GitHub Desktop.
ConstantOP for MobulaOP
import mobula_op
@mobula_op.operator.register(need_top_grad = False)
class ConstantOP:
def __init__(self, constant):
self.constant = mx.nd.array(constant)
def forward(self, dummy):
return self.constant
def backward(self, dy):
return [0]
def infer_shape(self, in_shape):
return in_shape, [self.constant.shape]
if __name__ == '__main__':
import mxnet as mx
import numpy as np
# NDArray
a = mx.nd.array([1,2,3])
b = mx.nd.array([4,5,6])
c = a + ConstantOP(a, b)
print (c) # [5,7,9]
# Symbol
a_sym = mx.sym.Variable('a')
output_sym = a_sym + ConstantOP(a_sym, [4,5,6])
exe = output_sym.simple_bind(ctx = mx.context.current_context(), a = a.shape)
exe.forward(a = np.array([1,2,3]))
print (exe.outputs[0].asnumpy())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment