Created
April 3, 2017 04:31
-
-
Save agibsonccc/4fd79ffe5ca9f97a0789434cd60fae6d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import jnius_config | |
import os | |
try: | |
jnius_classpath = os.environ['KERAS_DL4J_CLASSPATH'] | |
except KeyError: | |
jnius_classpath = '/home/agibsonccc/code/nd4jjcc/target/nd4j-jcc-1.0-SNAPSHOT-dist/nd4j-jcc-1.0-SNAPSHOT/lib/*' | |
jnius_config.set_classpath(jnius_classpath) | |
from jnius import autoclass | |
nd4j = autoclass('org.nd4j.linalg.factory.Nd4j') | |
transforms = autoclass('org.nd4j.linalg.ops.transforms.Transforms') | |
indexing = autoclass('org.nd4j.linalg.indexing.NDArrayIndex') | |
system = autoclass('java.lang.System') | |
integer = autoclass('java.lang.Integer') | |
native_ops_holder = autoclass('org.nd4j.nativeblas.NativeOpsHolder') | |
native_ops = native_ops_holder.getInstance().getDeviceNativeOps() | |
DoublePointer = autoclass('org.bytedeco.javacpp.DoublePointer') | |
FloatPointer = autoclass('org.bytedeco.javacpp.FloatPointer') | |
IntPointer = autoclass('org.bytedeco.javacpp.IntPointer') | |
def get_buffer_from_arr(np_arr): | |
pointer_address = get_array_address(np_arr) | |
pointer = native_ops.pointerForAddress(pointer_address) | |
size = np_arr.size() | |
if np_arr.dtype == 'float64': | |
as_double = DoublePointer(pointer) | |
return nd4j.createBuffer(as_double,size) | |
elif np_arr.dtype == 'float32': | |
as_float = FloatPointer(pointer) | |
return nd4j.createBuffer(as_float,size) | |
elif np_arr.dtype == 'int64': | |
as_int = IntPointer(pointer) | |
return nd4j.createBuffer(as_int,size) | |
def get_array_address(np_arr): | |
''' | |
:param np_arr: The numpy array to get the pointer address for | |
:return: the pointer address as a long | |
''' | |
pointer, read_only_flag = np_arr.__array_interface__['data'] | |
return pointer | |
def from_np(np_arr): | |
data_buffer = get_buffer_from_arr(np_arr) | |
return nd4j.create(data_buffer,np_arr.shape,np_arr.strides,0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment