Last active
January 18, 2022 11:00
-
-
Save honno/c95eeabc71f5ca65d40e5a5105eda672 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Wrapper of mxnet for use with github.com/data-apis/array-api-tests | |
# Tested with dask version 1.9.0 | |
# How to use: | |
# 1. Place this file in `array_api_tests/_mxnet.py` | |
# 2. In `array_api_tests/_array_module.py` replace `array_module = None` with | |
# `from ._mxnet import array_module` | |
import mxnet as mx | |
import numpy as np | |
from numpy import array_api as nxp | |
array_module = mx.np | |
# mxnet mixes np.dtype() and namespaced dtypes | |
for name in ("bool", "uint8", "int32", "int64", "float32", "float64"): | |
dtype = np.dtype(name) | |
setattr(array_module, name, dtype) | |
# mxnet does not implement all dtypes for all functions | |
for name in ("uint16", "uint32", "uint64", "int8", "int16"): | |
try: | |
delattr(array_module, name) | |
except AttributeError: | |
pass | |
# mxnet has no asarray function | |
def asarray(obj, dtype=None, copy=None): | |
if copy is not None: | |
raise NotImplementedError() | |
if dtype is None: | |
# mxnet really loves defaulting to float32 | |
# TODO: infer from nested sequences | |
if isinstance(obj, int): | |
dtype = array_module.int32 | |
elif isinstance(obj, float): | |
dtype = array_module.float32 | |
elif isinstance(obj, bool): | |
dtype = array_module.bool | |
return mx.np.array(obj, dtype=dtype) | |
array_module.asarray = asarray | |
# mxnet has no info functions | |
array_module.iinfo = nxp.iinfo | |
array_module.finfo = nxp.finfo |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment