Skip to content

Instantly share code, notes, and snippets.

@honno
Last active January 18, 2022 11:00
Show Gist options
  • Save honno/c95eeabc71f5ca65d40e5a5105eda672 to your computer and use it in GitHub Desktop.
Save honno/c95eeabc71f5ca65d40e5a5105eda672 to your computer and use it in GitHub Desktop.
# Wrapper of mxnet for use with github.com/data-apis/array-api-tests
# Tested with dask version 1.9.0
# How to use:
# 1. Place this file in `array_api_tests/_mxnet.py`
# 2. In `array_api_tests/_array_module.py` replace `array_module = None` with
# `from ._mxnet import array_module`
import mxnet as mx
import numpy as np
from numpy import array_api as nxp
array_module = mx.np
# mxnet mixes np.dtype() and namespaced dtypes
for name in ("bool", "uint8", "int32", "int64", "float32", "float64"):
dtype = np.dtype(name)
setattr(array_module, name, dtype)
# mxnet does not implement all dtypes for all functions
for name in ("uint16", "uint32", "uint64", "int8", "int16"):
try:
delattr(array_module, name)
except AttributeError:
pass
# mxnet has no asarray function
def asarray(obj, dtype=None, copy=None):
if copy is not None:
raise NotImplementedError()
if dtype is None:
# mxnet really loves defaulting to float32
# TODO: infer from nested sequences
if isinstance(obj, int):
dtype = array_module.int32
elif isinstance(obj, float):
dtype = array_module.float32
elif isinstance(obj, bool):
dtype = array_module.bool
return mx.np.array(obj, dtype=dtype)
array_module.asarray = asarray
# mxnet has no info functions
array_module.iinfo = nxp.iinfo
array_module.finfo = nxp.finfo
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment